Add policy_eval/src/policy_eval/core.py

This commit is contained in:
Mika 2026-02-07 11:55:55 +00:00
commit ade86428c6

View file

@ -0,0 +1,69 @@
from __future__ import annotations
import json
from collections.abc import Mapping
from dataclasses import dataclass, asdict
from typing import Any, Dict
@dataclass
class EvaluationMetrics:
total_warn: int
total_fail: int
unknowns: int
manual_overrides: int
def to_dict(self) -> Dict[str, int]:
"""Return a dict with keys sorted alphabetically for diff-friendly output."""
data = asdict(self)
return {k: data[k] for k in sorted(data.keys())}
def evaluate_policy(drift_report: Dict[str, Any]) -> Dict[str, int]:
"""Aggregiert Metriken aus einem Drift-Report.
Erwartet ein Dictionary, das Metriken wie Warnungen, Fehler,
unbekannte Elemente oder manuelle Overrides enthält.
Fehlende Felder werden als 'unknown' gezählt.
Args:
drift_report: Strukturierter JSON-Report über Policy-Drifts.
Returns:
Dict mit aggregierten Metriken (total_warn, total_fail, unknowns, manual_overrides).
"""
if not isinstance(drift_report, Mapping):
raise TypeError("drift_report must be a dict-like object")
# Extract known keys safely
warn = drift_report.get("warn") or drift_report.get("warnings") or 0
fail = drift_report.get("fail") or drift_report.get("errors") or 0
overrides = drift_report.get("manual_overrides") or drift_report.get("overrides") or 0
# Identify unknowns: fields not in known keys and missing expected ones
known_keys = {"warn", "warnings", "fail", "errors", "manual_overrides", "overrides"}
unknown_fields = [k for k in drift_report.keys() if k not in known_keys]
# Missing expected keys are also counted as unknown
missing_expected = [key for key in ("warn", "fail", "manual_overrides") if key not in drift_report]
unknown_total = len(unknown_fields) + len(missing_expected)
# Ensure numeric consistency
try:
metrics = EvaluationMetrics(
total_warn=int(warn),
total_fail=int(fail),
unknowns=int(unknown_total),
manual_overrides=int(overrides)
)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid field type in drift_report: {e}") from e
result = metrics.to_dict()
# Basic assertions for CI consistency
assert all(isinstance(v, int) for v in result.values()), "All metric values must be int"
assert all(k in {"manual_overrides", "total_fail", "total_warn", "unknowns"} for k in result), (
"Unexpected keys in result dict"
)
return result