diff --git a/decision_engine/src/decision_engine/core.py b/decision_engine/src/decision_engine/core.py new file mode 100644 index 0000000..59208d4 --- /dev/null +++ b/decision_engine/src/decision_engine/core.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import logging +from typing import Dict, Any, Literal + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class EvaluationError(Exception): + """Custom exception raised when evaluation cannot be performed due to invalid input.""" + + +class InputValidationError(EvaluationError): + pass + + +Decision = Literal["PASS", "WARN", "FAIL"] + + +def _validate_run_data(run_data: Dict[str, Any]) -> None: + required_fields = [ + "warn_rate", + "unknown_rate", + "pinned", + "unknown_class", + ] + + for field in required_fields: + if field not in run_data: + raise InputValidationError(f"Missing required field: {field}") + + if not isinstance(run_data["warn_rate"], (int, float)): + raise InputValidationError("warn_rate must be a numeric value.") + if not isinstance(run_data["unknown_rate"], (int, float)): + raise InputValidationError("unknown_rate must be a numeric value.") + if not isinstance(run_data["pinned"], bool): + raise InputValidationError("pinned must be a boolean value.") + if not isinstance(run_data["unknown_class"], str): + raise InputValidationError("unknown_class must be a string.") + + +# Placeholder thresholds for logical example; actual thresholds are read externally in CLI. +_default_thresholds = { + "warn_rate": {"warn": 5.0, "fail": 10.0}, + "unknown_rate": {"warn": 2.0, "fail": 5.0}, +} + + +def evaluate_run(run_data: Dict[str, Any]) -> Dict[str, str]: + """Evaluate a single run's metrics and decide PASS/WARN/FAIL. + + Args: + run_data: dict - contains fields warn_rate, unknown_rate, pinned, unknown_class, prev_label(optional) + + Returns: + dict: {"final_decision": str, "reason": str} + + Raises: + InputValidationError: if run_data fields are missing or of invalid types. + """ + _validate_run_data(run_data) + + warn_rate = float(run_data["warn_rate"]) + unknown_rate = float(run_data["unknown_rate"]) + pinned = run_data["pinned"] + unknown_class = run_data["unknown_class"] + prev_label = run_data.get("prev_label") + + try: + w_warn = _default_thresholds["warn_rate"]["warn"] + w_fail = _default_thresholds["warn_rate"]["fail"] + u_warn = _default_thresholds["unknown_rate"]["warn"] + u_fail = _default_thresholds["unknown_rate"]["fail"] + except KeyError as e: + raise EvaluationError(f"Threshold configuration missing key: {e}") from e + + # Decision logic + decision: Decision + reason: str + + if warn_rate >= w_fail or unknown_rate >= u_fail: + decision = "FAIL" + reason = f"Run exceeded fail thresholds (warn_rate={warn_rate}, unknown_rate={unknown_rate})." + elif warn_rate >= w_warn or unknown_rate >= u_warn: + decision = "WARN" + reason = f"Run exceeded warning thresholds (warn_rate={warn_rate}, unknown_rate={unknown_rate})." + else: + decision = "PASS" + reason = "Run metrics are within acceptable thresholds." + + # Adjustments for pinned or unknown_class abnormalities + if pinned and decision == "FAIL": + decision = "WARN" + reason += " Pinned run downgraded from FAIL to WARN." + if "violation" in unknown_class.lower() and decision == "PASS": + decision = "WARN" + reason += " Unknown class contains violation, upgraded to WARN." + + # Integrate previous label as context info (no override logic by default) + if prev_label: + reason += f" Previous label: {prev_label}." + + logger.debug( + "Evaluated run → decision=%s, reason=%s | data=%s", + decision, + reason, + run_data, + ) + + assert decision in {"PASS", "WARN", "FAIL"}, "Invalid decision value computed." + + return {"final_decision": decision, "reason": reason}