Add rollout_data_analysis/src/rollout_data_analysis/core.py

This commit is contained in:
Mika 2026-02-21 15:27:05 +00:00
parent c191a78745
commit a446e943f5

View file

@ -0,0 +1,66 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List, Dict, Any
import pandas as pd
logger = logging.getLogger(__name__)
@dataclass
class AnalysisResults:
"""Repräsentiert die Ergebnisstatistik einer Rollout-Datenanalyse."""
min: float
median: float
p95: float
max: float
def _validate_rollout_data(rollout_data: List[Dict[str, Any]]) -> None:
if not isinstance(rollout_data, list):
raise TypeError("rollout_data muss eine Liste von Dictionaries sein.")
if not rollout_data:
raise ValueError("rollout_data darf nicht leer sein.")
required_keys = {"unknown_rate", "warn_rate"}
for idx, row in enumerate(rollout_data):
if not isinstance(row, dict):
raise TypeError(f"Eintrag {idx} ist kein Dictionary.")
missing = required_keys - set(row)
if missing:
raise ValueError(f"Fehlende Keys in Eintrag {idx}: {missing}")
for key in required_keys:
val = row[key]
if not isinstance(val, (int, float)):
raise TypeError(f"Wert für '{key}' in Eintrag {idx} muss numerisch sein.")
if val < 0:
raise ValueError(f"Wert für '{key}' in Eintrag {idx} darf nicht negativ sein.")
def analyze_data(rollout_data: List[Dict[str, Any]]) -> AnalysisResults:
"""Analysiert Rollout-Daten und berechnet statistische Kennzahlen (min, median, p95, max).
Args:
rollout_data: Liste von Dictionaries, die Messwerte wie 'unknown_rate' und 'warn_rate' enthalten.
Returns:
AnalysisResults: Objekt mit den berechneten Statistikwerten.
"""
_validate_rollout_data(rollout_data)
df = pd.DataFrame(rollout_data)
# Auswahl der Metrik Priorität: unknown_rate, fallback warn_rate
metric_col = 'unknown_rate' if 'unknown_rate' in df.columns else 'warn_rate'
logger.debug("Berechne Statistik für Metrik: %s", metric_col)
series = df[metric_col].astype(float)
min_val = float(series.min())
median_val = float(series.median())
p95_val = float(series.quantile(0.95))
max_val = float(series.max())
results = AnalysisResults(min=min_val, median=median_val, p95=p95_val, max=max_val)
logger.debug("Analyseergebnisse: %s", results)
assert results.min <= results.median <= results.max, "Median außerhalb Wertebereich"
assert results.p95 <= results.max, "p95 überschreitet Maximalwert"
return results