diff --git a/rollout_data_analysis/src/rollout_data_analysis/cli.py b/rollout_data_analysis/src/rollout_data_analysis/cli.py new file mode 100644 index 0000000..9ed3d09 --- /dev/null +++ b/rollout_data_analysis/src/rollout_data_analysis/cli.py @@ -0,0 +1,80 @@ +import argparse +import json +from pathlib import Path +from typing import Any +import pandas as pd + +from rollout_data_analysis import core + + +class CLIError(Exception): + """Custom exception for CLI errors.""" + pass + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Analysewerkzeug zur statistischen Auswertung von Rollout-Daten." + ) + parser.add_argument( + "--input", required=True, type=str, help="Pfad zur CSV-Datei mit Rollout-Daten" + ) + parser.add_argument( + "--whitelist", required=False, type=str, help="Optionaler Pfad zur Whitelist-Datei im JSON-Format" + ) + parser.add_argument( + "--output", required=True, type=str, help="Pfad zur Ausgabe der Analyseergebnisse im JSON-Format" + ) + return parser.parse_args() + + +def load_rollout_data(csv_path: Path) -> list[dict[str, Any]]: + if not csv_path.exists(): + raise CLIError(f"Eingabedatei nicht gefunden: {csv_path}") + + df = pd.read_csv(csv_path) + required_columns = {"run_id", "unknown_rate", "warn_rate", "policy_hash", "pinned"} + missing = required_columns - set(df.columns) + if missing: + raise CLIError(f"Fehlende Spalten in CSV: {', '.join(missing)}") + + data = df.to_dict(orient="records") + assert isinstance(data, list) and all(isinstance(row, dict) for row in data) + return data + + +def save_results(output_path: Path, results: Any) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + json.dump(results.__dict__, f, indent=2, ensure_ascii=False) + + +def main() -> None: + args = parse_args() + + csv_path = Path(args.input) + output_path = Path(args.output) + + rollout_data = load_rollout_data(csv_path) + + # analyze_data expected input: list[dict] + analysis_results = core.analyze_data(rollout_data) + + save_results(output_path, analysis_results) + print(f"Analyse abgeschlossen. Ergebnisse gespeichert unter: {output_path}") + + +if __name__ == "__main__": + try: + main() + except CLIError as err: + import sys + + print(f"[Fehler] {err}", file=sys.stderr) + sys.exit(1) + except Exception as exc: + import sys + import traceback + + traceback.print_exc(file=sys.stderr) + sys.exit(2)