Add decision_table_generator/src/decision_table_generator/core.py
This commit is contained in:
parent
0e9a9e7266
commit
0f775944be
1 changed files with 107 additions and 0 deletions
107
decision_table_generator/src/decision_table_generator/core.py
Normal file
107
decision_table_generator/src/decision_table_generator/core.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List
|
||||
import pandas as pd
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class DecisionTableError(Exception):
|
||||
"""Custom exception for decision table generation errors."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecisionConfig:
|
||||
N_values: List[int]
|
||||
warn_threshold: float
|
||||
rerun_options: List[str]
|
||||
unknown_handling: str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> DecisionConfig:
|
||||
required_keys = {"N_values", "warn_threshold", "rerun_options", "unknown_handling"}
|
||||
missing = required_keys - data.keys()
|
||||
if missing:
|
||||
raise ValueError(f"Missing fields in DecisionConfig: {missing}")
|
||||
|
||||
N_values = data.get("N_values")
|
||||
warn_threshold = data.get("warn_threshold")
|
||||
rerun_options = data.get("rerun_options")
|
||||
unknown_handling = data.get("unknown_handling")
|
||||
|
||||
# Type validation
|
||||
if not isinstance(N_values, list) or not all(isinstance(n, int) for n in N_values):
|
||||
raise TypeError("N_values must be a list of integers.")
|
||||
if not isinstance(warn_threshold, (float, int)):
|
||||
raise TypeError("warn_threshold must be a float.")
|
||||
if not isinstance(rerun_options, list) or not all(isinstance(r, str) for r in rerun_options):
|
||||
raise TypeError("rerun_options must be a list of strings.")
|
||||
if not isinstance(unknown_handling, str):
|
||||
raise TypeError("unknown_handling must be a string.")
|
||||
|
||||
return cls(N_values=list(N_values),
|
||||
warn_threshold=float(warn_threshold),
|
||||
rerun_options=list(rerun_options),
|
||||
unknown_handling=str(unknown_handling))
|
||||
|
||||
|
||||
def generate_decision_table(config: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate a structured decision table based on provided configuration.
|
||||
|
||||
Args:
|
||||
config: A dictionary with keys matching DecisionConfig fields.
|
||||
|
||||
Returns:
|
||||
A list of decision entries summarizing metrics per configuration.
|
||||
"""
|
||||
logger.info("Starting generation of decision table.")
|
||||
try:
|
||||
cfg = DecisionConfig.from_dict(config)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid configuration: {e}")
|
||||
raise DecisionTableError(str(e)) from e
|
||||
|
||||
# Construct decision rows (simulate with deterministic metrics for reproducibility)
|
||||
rows: List[Dict[str, Any]] = []
|
||||
|
||||
for N in cfg.N_values:
|
||||
for rerun_option in cfg.rerun_options:
|
||||
warn_count = int(max(0, (cfg.warn_threshold / 100.0) * N))
|
||||
rerun_helps = (N // 10) if rerun_option == 'on' else 0
|
||||
rerun_shifts = (N // 20) if cfg.unknown_handling == 'separate_gate' else (N // 30)
|
||||
rerun_hurts = max(0, (N // 40) - rerun_helps // 2)
|
||||
|
||||
row = {
|
||||
'N': N,
|
||||
'warn_threshold': cfg.warn_threshold,
|
||||
'rerun_option': rerun_option,
|
||||
'unknown_handling': cfg.unknown_handling,
|
||||
'warn_count': warn_count,
|
||||
'rerun_helps': rerun_helps,
|
||||
'rerun_shifts': rerun_shifts,
|
||||
'rerun_hurts': rerun_hurts,
|
||||
}
|
||||
rows.append(row)
|
||||
logger.debug(f"Generated row: {row}")
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
if df.empty:
|
||||
logger.warning("Generated decision table is empty.")
|
||||
|
||||
# Ensure consistent column order
|
||||
columns = [
|
||||
'N', 'warn_threshold', 'rerun_option', 'unknown_handling',
|
||||
'warn_count', 'rerun_helps', 'rerun_shifts', 'rerun_hurts'
|
||||
]
|
||||
df = df[columns]
|
||||
|
||||
result = df.to_dict(orient='records')
|
||||
# Simple CI-ready assertion: ensure fields exist
|
||||
assert all(set(columns) == set(rec.keys()) for rec in result), "Mismatch in expected fields."
|
||||
|
||||
logger.info(f"Decision table generated with {len(result)} entries.")
|
||||
return result
|
||||
Loading…
Reference in a new issue