Add rollout_data_analysis/src/rollout_data_analysis/core.py
This commit is contained in:
parent
c191a78745
commit
a446e943f5
1 changed files with 66 additions and 0 deletions
66
rollout_data_analysis/src/rollout_data_analysis/core.py
Normal file
66
rollout_data_analysis/src/rollout_data_analysis/core.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class AnalysisResults:
|
||||
"""Repräsentiert die Ergebnisstatistik einer Rollout-Datenanalyse."""
|
||||
min: float
|
||||
median: float
|
||||
p95: float
|
||||
max: float
|
||||
|
||||
def _validate_rollout_data(rollout_data: List[Dict[str, Any]]) -> None:
|
||||
if not isinstance(rollout_data, list):
|
||||
raise TypeError("rollout_data muss eine Liste von Dictionaries sein.")
|
||||
if not rollout_data:
|
||||
raise ValueError("rollout_data darf nicht leer sein.")
|
||||
required_keys = {"unknown_rate", "warn_rate"}
|
||||
for idx, row in enumerate(rollout_data):
|
||||
if not isinstance(row, dict):
|
||||
raise TypeError(f"Eintrag {idx} ist kein Dictionary.")
|
||||
missing = required_keys - set(row)
|
||||
if missing:
|
||||
raise ValueError(f"Fehlende Keys in Eintrag {idx}: {missing}")
|
||||
for key in required_keys:
|
||||
val = row[key]
|
||||
if not isinstance(val, (int, float)):
|
||||
raise TypeError(f"Wert für '{key}' in Eintrag {idx} muss numerisch sein.")
|
||||
if val < 0:
|
||||
raise ValueError(f"Wert für '{key}' in Eintrag {idx} darf nicht negativ sein.")
|
||||
|
||||
def analyze_data(rollout_data: List[Dict[str, Any]]) -> AnalysisResults:
|
||||
"""Analysiert Rollout-Daten und berechnet statistische Kennzahlen (min, median, p95, max).
|
||||
|
||||
Args:
|
||||
rollout_data: Liste von Dictionaries, die Messwerte wie 'unknown_rate' und 'warn_rate' enthalten.
|
||||
|
||||
Returns:
|
||||
AnalysisResults: Objekt mit den berechneten Statistikwerten.
|
||||
"""
|
||||
_validate_rollout_data(rollout_data)
|
||||
|
||||
df = pd.DataFrame(rollout_data)
|
||||
|
||||
# Auswahl der Metrik – Priorität: unknown_rate, fallback warn_rate
|
||||
metric_col = 'unknown_rate' if 'unknown_rate' in df.columns else 'warn_rate'
|
||||
|
||||
logger.debug("Berechne Statistik für Metrik: %s", metric_col)
|
||||
|
||||
series = df[metric_col].astype(float)
|
||||
min_val = float(series.min())
|
||||
median_val = float(series.median())
|
||||
p95_val = float(series.quantile(0.95))
|
||||
max_val = float(series.max())
|
||||
|
||||
results = AnalysisResults(min=min_val, median=median_val, p95=p95_val, max=max_val)
|
||||
logger.debug("Analyseergebnisse: %s", results)
|
||||
|
||||
assert results.min <= results.median <= results.max, "Median außerhalb Wertebereich"
|
||||
assert results.p95 <= results.max, "p95 überschreitet Maximalwert"
|
||||
|
||||
return results
|
||||
Loading…
Reference in a new issue