Add decision_engine/src/decision_engine/core.py
This commit is contained in:
parent
8cefeee1b4
commit
e2ce445ffa
1 changed files with 113 additions and 0 deletions
113
decision_engine/src/decision_engine/core.py
Normal file
113
decision_engine/src/decision_engine/core.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, Literal
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationError(Exception):
|
||||||
|
"""Custom exception raised when evaluation cannot be performed due to invalid input."""
|
||||||
|
|
||||||
|
|
||||||
|
class InputValidationError(EvaluationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
Decision = Literal["PASS", "WARN", "FAIL"]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_run_data(run_data: Dict[str, Any]) -> None:
|
||||||
|
required_fields = [
|
||||||
|
"warn_rate",
|
||||||
|
"unknown_rate",
|
||||||
|
"pinned",
|
||||||
|
"unknown_class",
|
||||||
|
]
|
||||||
|
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in run_data:
|
||||||
|
raise InputValidationError(f"Missing required field: {field}")
|
||||||
|
|
||||||
|
if not isinstance(run_data["warn_rate"], (int, float)):
|
||||||
|
raise InputValidationError("warn_rate must be a numeric value.")
|
||||||
|
if not isinstance(run_data["unknown_rate"], (int, float)):
|
||||||
|
raise InputValidationError("unknown_rate must be a numeric value.")
|
||||||
|
if not isinstance(run_data["pinned"], bool):
|
||||||
|
raise InputValidationError("pinned must be a boolean value.")
|
||||||
|
if not isinstance(run_data["unknown_class"], str):
|
||||||
|
raise InputValidationError("unknown_class must be a string.")
|
||||||
|
|
||||||
|
|
||||||
|
# Placeholder thresholds for logical example; actual thresholds are read externally in CLI.
|
||||||
|
_default_thresholds = {
|
||||||
|
"warn_rate": {"warn": 5.0, "fail": 10.0},
|
||||||
|
"unknown_rate": {"warn": 2.0, "fail": 5.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_run(run_data: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
"""Evaluate a single run's metrics and decide PASS/WARN/FAIL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_data: dict - contains fields warn_rate, unknown_rate, pinned, unknown_class, prev_label(optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {"final_decision": str, "reason": str}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InputValidationError: if run_data fields are missing or of invalid types.
|
||||||
|
"""
|
||||||
|
_validate_run_data(run_data)
|
||||||
|
|
||||||
|
warn_rate = float(run_data["warn_rate"])
|
||||||
|
unknown_rate = float(run_data["unknown_rate"])
|
||||||
|
pinned = run_data["pinned"]
|
||||||
|
unknown_class = run_data["unknown_class"]
|
||||||
|
prev_label = run_data.get("prev_label")
|
||||||
|
|
||||||
|
try:
|
||||||
|
w_warn = _default_thresholds["warn_rate"]["warn"]
|
||||||
|
w_fail = _default_thresholds["warn_rate"]["fail"]
|
||||||
|
u_warn = _default_thresholds["unknown_rate"]["warn"]
|
||||||
|
u_fail = _default_thresholds["unknown_rate"]["fail"]
|
||||||
|
except KeyError as e:
|
||||||
|
raise EvaluationError(f"Threshold configuration missing key: {e}") from e
|
||||||
|
|
||||||
|
# Decision logic
|
||||||
|
decision: Decision
|
||||||
|
reason: str
|
||||||
|
|
||||||
|
if warn_rate >= w_fail or unknown_rate >= u_fail:
|
||||||
|
decision = "FAIL"
|
||||||
|
reason = f"Run exceeded fail thresholds (warn_rate={warn_rate}, unknown_rate={unknown_rate})."
|
||||||
|
elif warn_rate >= w_warn or unknown_rate >= u_warn:
|
||||||
|
decision = "WARN"
|
||||||
|
reason = f"Run exceeded warning thresholds (warn_rate={warn_rate}, unknown_rate={unknown_rate})."
|
||||||
|
else:
|
||||||
|
decision = "PASS"
|
||||||
|
reason = "Run metrics are within acceptable thresholds."
|
||||||
|
|
||||||
|
# Adjustments for pinned or unknown_class abnormalities
|
||||||
|
if pinned and decision == "FAIL":
|
||||||
|
decision = "WARN"
|
||||||
|
reason += " Pinned run downgraded from FAIL to WARN."
|
||||||
|
if "violation" in unknown_class.lower() and decision == "PASS":
|
||||||
|
decision = "WARN"
|
||||||
|
reason += " Unknown class contains violation, upgraded to WARN."
|
||||||
|
|
||||||
|
# Integrate previous label as context info (no override logic by default)
|
||||||
|
if prev_label:
|
||||||
|
reason += f" Previous label: {prev_label}."
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Evaluated run → decision=%s, reason=%s | data=%s",
|
||||||
|
decision,
|
||||||
|
reason,
|
||||||
|
run_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert decision in {"PASS", "WARN", "FAIL"}, "Invalid decision value computed."
|
||||||
|
|
||||||
|
return {"final_decision": decision, "reason": reason}
|
||||||
Loading…
Reference in a new issue