Add decision_engine/src/decision_engine/core.py

This commit is contained in:
Mika 2026-02-05 13:42:03 +00:00
parent 8cefeee1b4
commit e2ce445ffa

View file

@ -0,0 +1,113 @@
from __future__ import annotations
import logging
from typing import Dict, Any, Literal
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class EvaluationError(Exception):
"""Custom exception raised when evaluation cannot be performed due to invalid input."""
class InputValidationError(EvaluationError):
pass
Decision = Literal["PASS", "WARN", "FAIL"]
def _validate_run_data(run_data: Dict[str, Any]) -> None:
required_fields = [
"warn_rate",
"unknown_rate",
"pinned",
"unknown_class",
]
for field in required_fields:
if field not in run_data:
raise InputValidationError(f"Missing required field: {field}")
if not isinstance(run_data["warn_rate"], (int, float)):
raise InputValidationError("warn_rate must be a numeric value.")
if not isinstance(run_data["unknown_rate"], (int, float)):
raise InputValidationError("unknown_rate must be a numeric value.")
if not isinstance(run_data["pinned"], bool):
raise InputValidationError("pinned must be a boolean value.")
if not isinstance(run_data["unknown_class"], str):
raise InputValidationError("unknown_class must be a string.")
# Placeholder thresholds for logical example; actual thresholds are read externally in CLI.
_default_thresholds = {
"warn_rate": {"warn": 5.0, "fail": 10.0},
"unknown_rate": {"warn": 2.0, "fail": 5.0},
}
def evaluate_run(run_data: Dict[str, Any]) -> Dict[str, str]:
"""Evaluate a single run's metrics and decide PASS/WARN/FAIL.
Args:
run_data: dict - contains fields warn_rate, unknown_rate, pinned, unknown_class, prev_label(optional)
Returns:
dict: {"final_decision": str, "reason": str}
Raises:
InputValidationError: if run_data fields are missing or of invalid types.
"""
_validate_run_data(run_data)
warn_rate = float(run_data["warn_rate"])
unknown_rate = float(run_data["unknown_rate"])
pinned = run_data["pinned"]
unknown_class = run_data["unknown_class"]
prev_label = run_data.get("prev_label")
try:
w_warn = _default_thresholds["warn_rate"]["warn"]
w_fail = _default_thresholds["warn_rate"]["fail"]
u_warn = _default_thresholds["unknown_rate"]["warn"]
u_fail = _default_thresholds["unknown_rate"]["fail"]
except KeyError as e:
raise EvaluationError(f"Threshold configuration missing key: {e}") from e
# Decision logic
decision: Decision
reason: str
if warn_rate >= w_fail or unknown_rate >= u_fail:
decision = "FAIL"
reason = f"Run exceeded fail thresholds (warn_rate={warn_rate}, unknown_rate={unknown_rate})."
elif warn_rate >= w_warn or unknown_rate >= u_warn:
decision = "WARN"
reason = f"Run exceeded warning thresholds (warn_rate={warn_rate}, unknown_rate={unknown_rate})."
else:
decision = "PASS"
reason = "Run metrics are within acceptable thresholds."
# Adjustments for pinned or unknown_class abnormalities
if pinned and decision == "FAIL":
decision = "WARN"
reason += " Pinned run downgraded from FAIL to WARN."
if "violation" in unknown_class.lower() and decision == "PASS":
decision = "WARN"
reason += " Unknown class contains violation, upgraded to WARN."
# Integrate previous label as context info (no override logic by default)
if prev_label:
reason += f" Previous label: {prev_label}."
logger.debug(
"Evaluated run → decision=%s, reason=%s | data=%s",
decision,
reason,
run_data,
)
assert decision in {"PASS", "WARN", "FAIL"}, "Invalid decision value computed."
return {"final_decision": decision, "reason": reason}