Add audit_drift_script/src/audit_drift_script/core.py

This commit is contained in:
Mika 2026-02-03 17:11:44 +00:00
commit 4d283df368

View file

@ -0,0 +1,183 @@
import os
import json
import argparse
import logging
from datetime import datetime
from pathlib import Path
from collections import defaultdict
from typing import List, Dict, Any
import pandas as pd
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s: %(message)s')
class SchemaValidationError(Exception):
"""Custom exception for invalid report schema."""
pass
class ReportData:
"""Represents a single drift_report.json file."""
def __init__(self, timestamp: datetime, run_id: str, stratum: str, decision: str, warn_rate: float, unknown_rate: float):
self.timestamp = timestamp
self.run_id = run_id
self.stratum = stratum
self.decision = decision
self.warn_rate = warn_rate
self.unknown_rate = unknown_rate
@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> 'ReportData':
if not validate_report_schema(json_dict):
raise SchemaValidationError("Invalid report schema encountered.")
timestamp = datetime.fromisoformat(json_dict['timestamp'])
return cls(
timestamp=timestamp,
run_id=json_dict['run_id'],
stratum=json_dict['stratum'],
decision=json_dict['decision'],
warn_rate=float(json_dict['warn_rate']),
unknown_rate=float(json_dict['unknown_rate'])
)
def to_dict(self) -> Dict[str, Any]:
return {
'timestamp': self.timestamp.isoformat(),
'run_id': self.run_id,
'stratum': self.stratum,
'decision': self.decision,
'warn_rate': self.warn_rate,
'unknown_rate': self.unknown_rate
}
def validate_report_schema(report_json: Dict[str, Any]) -> bool:
"""Validate required fields and value ranges."""
required_fields = {'timestamp': str, 'run_id': str, 'stratum': str, 'decision': str, 'warn_rate': (int, float), 'unknown_rate': (int, float)}
valid_strata = {'pinned', 'unpinned'}
valid_decisions = {'PASS', 'WARN', 'FAIL'}
for field, expected_type in required_fields.items():
if field not in report_json:
logging.error(f"Missing required field: {field}")
return False
if not isinstance(report_json[field], expected_type):
logging.error(f"Invalid type for field {field}")
return False
if report_json['stratum'] not in valid_strata:
logging.error(f"Invalid stratum value: {report_json['stratum']}")
return False
if report_json['decision'] not in valid_decisions:
logging.error(f"Invalid decision value: {report_json['decision']}")
return False
try:
datetime.fromisoformat(report_json['timestamp'])
except Exception:
logging.error("Invalid timestamp format.")
return False
if not (0.0 <= float(report_json['warn_rate']) <= 1.0):
logging.error("warn_rate out of range.")
return False
if not (0.0 <= float(report_json['unknown_rate']) <= 1.0):
logging.error("unknown_rate out of range.")
return False
return True
def aggregate_statistics(reports: List[ReportData]) -> Dict[str, Any]:
"""Aggregate results per stratum and overall metrics."""
if not reports:
return {}
df = pd.DataFrame([r.to_dict() for r in reports])
agg = {}
for stratum, group in df.groupby('stratum'):
agg[stratum] = {
'count': len(group),
'mean_warn_rate': group['warn_rate'].mean(),
'mean_unknown_rate': group['unknown_rate'].mean(),
'decision_counts': group['decision'].value_counts().to_dict()
}
agg['total'] = {
'count': len(df),
'mean_warn_rate': df['warn_rate'].mean(),
'mean_unknown_rate': df['unknown_rate'].mean(),
'decision_counts': df['decision'].value_counts().to_dict()
}
return agg
def analyze_reports(directory_path: str) -> Dict[str, Any]:
"""Scan directory recursively, validate and aggregate drift_report.json files."""
directory = Path(directory_path)
if not directory.exists() or not directory.is_dir():
raise FileNotFoundError(f"Directory not found: {directory_path}")
reports: List[ReportData] = []
for file_path in directory.rglob('drift_report.json'):
try:
with open(file_path, 'r', encoding='utf-8') as f:
report_json = json.load(f)
if validate_report_schema(report_json):
reports.append(ReportData.from_json(report_json))
else:
logging.warning(f"Skipping invalid report: {file_path}")
except Exception as e:
logging.error(f"Failed to process {file_path}: {e}")
agg_result = aggregate_statistics(reports)
# Generate outputs
if reports:
output_dir = Path('output')
output_dir.mkdir(parents=True, exist_ok=True)
csv_path = output_dir / 'audit.csv'
md_path = output_dir / 'drift_report_agg.md'
# CSV output
rows = []
for stratum, stats in agg_result.items():
row = {'stratum': stratum, **{k: v for k, v in stats.items() if k != 'decision_counts'}}
for decision, count in stats['decision_counts'].items():
row[f'decision_{decision}'] = count
rows.append(row)
pd.DataFrame(rows).to_csv(csv_path, index=False)
# Markdown output
with open(md_path, 'w', encoding='utf-8') as md:
md.write('# Drift Report Aggregation Summary\n\n')
for stratum, stats in agg_result.items():
md.write(f"## {stratum.capitalize()}\n")
md.write(f"Total Reports: {stats['count']}\n\n")
md.write(f"Mean Warn Rate: {stats['mean_warn_rate']:.3f}\n\n")
md.write(f"Mean Unknown Rate: {stats['mean_unknown_rate']:.3f}\n\n")
md.write('### Decision Counts\n')
for d, c in stats['decision_counts'].items():
md.write(f"- {d}: {c}\n")
md.write('\n')
return agg_result
def _main():
parser = argparse.ArgumentParser(description='Analyze drift_report.json files.')
parser.add_argument('--input', required=True, help='Path to input directory containing drift_report.json files.')
parser.add_argument('--out', required=False, default='output', help='Output directory for generated files.')
args = parser.parse_args()
logging.info(f"Analyzing drift reports in {args.input}")
results = analyze_reports(args.input)
logging.info(f"Analysis completed. Results:")
for k, v in results.items():
logging.info(f"{k}: {v}")
if __name__ == '__main__':
_main()