Add rollup_rollout/src/rollup_rollout/core.py
This commit is contained in:
commit
39a8082bb3
1 changed files with 82 additions and 0 deletions
82
rollup_rollout/src/rollup_rollout/core.py
Normal file
82
rollup_rollout/src/rollup_rollout/core.py
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import json
|
||||||
|
import csv
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RolloutData:
|
||||||
|
"""Repräsentiert die strukturierten CI-Laufdaten eines Gate-V1-Runs."""
|
||||||
|
policy_hash: str
|
||||||
|
outcome: str
|
||||||
|
unknown_rate: float
|
||||||
|
top_reasons: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "RolloutData":
|
||||||
|
"""Validiert und erstellt eine Instanz aus einem Dictionary."""
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError("Input must be a dictionary.")
|
||||||
|
required_fields = ["policy_hash", "outcome", "unknown_rate", "top_reasons"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in data:
|
||||||
|
raise ValueError(f"Missing required field: {field}")
|
||||||
|
if not isinstance(data["policy_hash"], str):
|
||||||
|
raise TypeError("policy_hash must be a string")
|
||||||
|
if not isinstance(data["outcome"], str):
|
||||||
|
raise TypeError("outcome must be a string")
|
||||||
|
if not isinstance(data["unknown_rate"], (float, int)):
|
||||||
|
raise TypeError("unknown_rate must be a float or int")
|
||||||
|
if not isinstance(data["top_reasons"], str):
|
||||||
|
raise TypeError("top_reasons must be a string")
|
||||||
|
return cls(
|
||||||
|
policy_hash=data["policy_hash"],
|
||||||
|
outcome=data["outcome"],
|
||||||
|
unknown_rate=float(data["unknown_rate"]),
|
||||||
|
top_reasons=data["top_reasons"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_rollout_series(input_file: str, output_file: str) -> None:
|
||||||
|
"""Aggregiert CI-Ergebnisse aus mehreren 'gate_result.json'-Dateien zu einer konsistenten CSV-Datei.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_file: Pfad zur Eingabedatei im JSON-Format mit CI-Ergebnissen.
|
||||||
|
output_file: Pfad zur Ausgabedatei (rollout_series.csv).
|
||||||
|
"""
|
||||||
|
input_path = Path(input_file)
|
||||||
|
output_path = Path(output_file)
|
||||||
|
|
||||||
|
if not input_path.exists():
|
||||||
|
raise FileNotFoundError(f"Input file not found: {input_path}")
|
||||||
|
|
||||||
|
with input_path.open("r", encoding="utf-8") as f:
|
||||||
|
try:
|
||||||
|
data = json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Invalid JSON format: {e}")
|
||||||
|
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise ValueError("JSON root must be a list of CI result objects.")
|
||||||
|
|
||||||
|
rollout_entries: List[RolloutData] = []
|
||||||
|
for entry in data:
|
||||||
|
rollout_entries.append(RolloutData.from_dict(entry))
|
||||||
|
|
||||||
|
rollout_entries.sort(key=lambda x: (x.policy_hash, x.outcome))
|
||||||
|
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with output_path.open("w", encoding="utf-8", newline="") as csvfile:
|
||||||
|
writer = csv.writer(csvfile)
|
||||||
|
writer.writerow(["policy_hash", "outcome", "unknown_rate", "top_reasons"])
|
||||||
|
for entry in rollout_entries:
|
||||||
|
writer.writerow([
|
||||||
|
entry.policy_hash,
|
||||||
|
entry.outcome,
|
||||||
|
f"{entry.unknown_rate:.4f}",
|
||||||
|
entry.top_reasons,
|
||||||
|
])
|
||||||
|
|
||||||
|
assert output_path.exists() and output_path.stat().st_size > 0, "Output CSV was not created or is empty."
|
||||||
Loading…
Reference in a new issue