Add decision_engine/src/decision_engine/core.py
This commit is contained in:
parent
8cefeee1b4
commit
e2ce445ffa
1 changed files with 113 additions and 0 deletions
113
decision_engine/src/decision_engine/core.py
Normal file
113
decision_engine/src/decision_engine/core.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Literal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class EvaluationError(Exception):
|
||||
"""Custom exception raised when evaluation cannot be performed due to invalid input."""
|
||||
|
||||
|
||||
class InputValidationError(EvaluationError):
|
||||
pass
|
||||
|
||||
|
||||
Decision = Literal["PASS", "WARN", "FAIL"]
|
||||
|
||||
|
||||
def _validate_run_data(run_data: Dict[str, Any]) -> None:
|
||||
required_fields = [
|
||||
"warn_rate",
|
||||
"unknown_rate",
|
||||
"pinned",
|
||||
"unknown_class",
|
||||
]
|
||||
|
||||
for field in required_fields:
|
||||
if field not in run_data:
|
||||
raise InputValidationError(f"Missing required field: {field}")
|
||||
|
||||
if not isinstance(run_data["warn_rate"], (int, float)):
|
||||
raise InputValidationError("warn_rate must be a numeric value.")
|
||||
if not isinstance(run_data["unknown_rate"], (int, float)):
|
||||
raise InputValidationError("unknown_rate must be a numeric value.")
|
||||
if not isinstance(run_data["pinned"], bool):
|
||||
raise InputValidationError("pinned must be a boolean value.")
|
||||
if not isinstance(run_data["unknown_class"], str):
|
||||
raise InputValidationError("unknown_class must be a string.")
|
||||
|
||||
|
||||
# Placeholder thresholds for logical example; actual thresholds are read externally in CLI.
|
||||
_default_thresholds = {
|
||||
"warn_rate": {"warn": 5.0, "fail": 10.0},
|
||||
"unknown_rate": {"warn": 2.0, "fail": 5.0},
|
||||
}
|
||||
|
||||
|
||||
def evaluate_run(run_data: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Evaluate a single run's metrics and decide PASS/WARN/FAIL.
|
||||
|
||||
Args:
|
||||
run_data: dict - contains fields warn_rate, unknown_rate, pinned, unknown_class, prev_label(optional)
|
||||
|
||||
Returns:
|
||||
dict: {"final_decision": str, "reason": str}
|
||||
|
||||
Raises:
|
||||
InputValidationError: if run_data fields are missing or of invalid types.
|
||||
"""
|
||||
_validate_run_data(run_data)
|
||||
|
||||
warn_rate = float(run_data["warn_rate"])
|
||||
unknown_rate = float(run_data["unknown_rate"])
|
||||
pinned = run_data["pinned"]
|
||||
unknown_class = run_data["unknown_class"]
|
||||
prev_label = run_data.get("prev_label")
|
||||
|
||||
try:
|
||||
w_warn = _default_thresholds["warn_rate"]["warn"]
|
||||
w_fail = _default_thresholds["warn_rate"]["fail"]
|
||||
u_warn = _default_thresholds["unknown_rate"]["warn"]
|
||||
u_fail = _default_thresholds["unknown_rate"]["fail"]
|
||||
except KeyError as e:
|
||||
raise EvaluationError(f"Threshold configuration missing key: {e}") from e
|
||||
|
||||
# Decision logic
|
||||
decision: Decision
|
||||
reason: str
|
||||
|
||||
if warn_rate >= w_fail or unknown_rate >= u_fail:
|
||||
decision = "FAIL"
|
||||
reason = f"Run exceeded fail thresholds (warn_rate={warn_rate}, unknown_rate={unknown_rate})."
|
||||
elif warn_rate >= w_warn or unknown_rate >= u_warn:
|
||||
decision = "WARN"
|
||||
reason = f"Run exceeded warning thresholds (warn_rate={warn_rate}, unknown_rate={unknown_rate})."
|
||||
else:
|
||||
decision = "PASS"
|
||||
reason = "Run metrics are within acceptable thresholds."
|
||||
|
||||
# Adjustments for pinned or unknown_class abnormalities
|
||||
if pinned and decision == "FAIL":
|
||||
decision = "WARN"
|
||||
reason += " Pinned run downgraded from FAIL to WARN."
|
||||
if "violation" in unknown_class.lower() and decision == "PASS":
|
||||
decision = "WARN"
|
||||
reason += " Unknown class contains violation, upgraded to WARN."
|
||||
|
||||
# Integrate previous label as context info (no override logic by default)
|
||||
if prev_label:
|
||||
reason += f" Previous label: {prev_label}."
|
||||
|
||||
logger.debug(
|
||||
"Evaluated run → decision=%s, reason=%s | data=%s",
|
||||
decision,
|
||||
reason,
|
||||
run_data,
|
||||
)
|
||||
|
||||
assert decision in {"PASS", "WARN", "FAIL"}, "Invalid decision value computed."
|
||||
|
||||
return {"final_decision": decision, "reason": reason}
|
||||
Loading…
Reference in a new issue