Add rollout_data_analysis/src/rollout_data_analysis/core.py
This commit is contained in:
parent
c191a78745
commit
a446e943f5
1 changed files with 66 additions and 0 deletions
66
rollout_data_analysis/src/rollout_data_analysis/core.py
Normal file
66
rollout_data_analysis/src/rollout_data_analysis/core.py
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AnalysisResults:
|
||||||
|
"""Repräsentiert die Ergebnisstatistik einer Rollout-Datenanalyse."""
|
||||||
|
min: float
|
||||||
|
median: float
|
||||||
|
p95: float
|
||||||
|
max: float
|
||||||
|
|
||||||
|
def _validate_rollout_data(rollout_data: List[Dict[str, Any]]) -> None:
|
||||||
|
if not isinstance(rollout_data, list):
|
||||||
|
raise TypeError("rollout_data muss eine Liste von Dictionaries sein.")
|
||||||
|
if not rollout_data:
|
||||||
|
raise ValueError("rollout_data darf nicht leer sein.")
|
||||||
|
required_keys = {"unknown_rate", "warn_rate"}
|
||||||
|
for idx, row in enumerate(rollout_data):
|
||||||
|
if not isinstance(row, dict):
|
||||||
|
raise TypeError(f"Eintrag {idx} ist kein Dictionary.")
|
||||||
|
missing = required_keys - set(row)
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Fehlende Keys in Eintrag {idx}: {missing}")
|
||||||
|
for key in required_keys:
|
||||||
|
val = row[key]
|
||||||
|
if not isinstance(val, (int, float)):
|
||||||
|
raise TypeError(f"Wert für '{key}' in Eintrag {idx} muss numerisch sein.")
|
||||||
|
if val < 0:
|
||||||
|
raise ValueError(f"Wert für '{key}' in Eintrag {idx} darf nicht negativ sein.")
|
||||||
|
|
||||||
|
def analyze_data(rollout_data: List[Dict[str, Any]]) -> AnalysisResults:
|
||||||
|
"""Analysiert Rollout-Daten und berechnet statistische Kennzahlen (min, median, p95, max).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rollout_data: Liste von Dictionaries, die Messwerte wie 'unknown_rate' und 'warn_rate' enthalten.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AnalysisResults: Objekt mit den berechneten Statistikwerten.
|
||||||
|
"""
|
||||||
|
_validate_rollout_data(rollout_data)
|
||||||
|
|
||||||
|
df = pd.DataFrame(rollout_data)
|
||||||
|
|
||||||
|
# Auswahl der Metrik – Priorität: unknown_rate, fallback warn_rate
|
||||||
|
metric_col = 'unknown_rate' if 'unknown_rate' in df.columns else 'warn_rate'
|
||||||
|
|
||||||
|
logger.debug("Berechne Statistik für Metrik: %s", metric_col)
|
||||||
|
|
||||||
|
series = df[metric_col].astype(float)
|
||||||
|
min_val = float(series.min())
|
||||||
|
median_val = float(series.median())
|
||||||
|
p95_val = float(series.quantile(0.95))
|
||||||
|
max_val = float(series.max())
|
||||||
|
|
||||||
|
results = AnalysisResults(min=min_val, median=median_val, p95=p95_val, max=max_val)
|
||||||
|
logger.debug("Analyseergebnisse: %s", results)
|
||||||
|
|
||||||
|
assert results.min <= results.median <= results.max, "Median außerhalb Wertebereich"
|
||||||
|
assert results.p95 <= results.max, "p95 überschreitet Maximalwert"
|
||||||
|
|
||||||
|
return results
|
||||||
Loading…
Reference in a new issue