From 4d283df3686c43aa07a4dc66df7f41c67327876e Mon Sep 17 00:00:00 2001 From: Mika Date: Tue, 3 Feb 2026 17:11:44 +0000 Subject: [PATCH] Add audit_drift_script/src/audit_drift_script/core.py --- .../src/audit_drift_script/core.py | 183 ++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 audit_drift_script/src/audit_drift_script/core.py diff --git a/audit_drift_script/src/audit_drift_script/core.py b/audit_drift_script/src/audit_drift_script/core.py new file mode 100644 index 0000000..b44d919 --- /dev/null +++ b/audit_drift_script/src/audit_drift_script/core.py @@ -0,0 +1,183 @@ +import os +import json +import argparse +import logging +from datetime import datetime +from pathlib import Path +from collections import defaultdict +from typing import List, Dict, Any +import pandas as pd + +logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s: %(message)s') + + +class SchemaValidationError(Exception): + """Custom exception for invalid report schema.""" + pass + + +class ReportData: + """Represents a single drift_report.json file.""" + + def __init__(self, timestamp: datetime, run_id: str, stratum: str, decision: str, warn_rate: float, unknown_rate: float): + self.timestamp = timestamp + self.run_id = run_id + self.stratum = stratum + self.decision = decision + self.warn_rate = warn_rate + self.unknown_rate = unknown_rate + + @classmethod + def from_json(cls, json_dict: Dict[str, Any]) -> 'ReportData': + if not validate_report_schema(json_dict): + raise SchemaValidationError("Invalid report schema encountered.") + timestamp = datetime.fromisoformat(json_dict['timestamp']) + return cls( + timestamp=timestamp, + run_id=json_dict['run_id'], + stratum=json_dict['stratum'], + decision=json_dict['decision'], + warn_rate=float(json_dict['warn_rate']), + unknown_rate=float(json_dict['unknown_rate']) + ) + + def to_dict(self) -> Dict[str, Any]: + return { + 'timestamp': self.timestamp.isoformat(), + 'run_id': self.run_id, + 'stratum': self.stratum, + 'decision': self.decision, + 'warn_rate': self.warn_rate, + 'unknown_rate': self.unknown_rate + } + + +def validate_report_schema(report_json: Dict[str, Any]) -> bool: + """Validate required fields and value ranges.""" + required_fields = {'timestamp': str, 'run_id': str, 'stratum': str, 'decision': str, 'warn_rate': (int, float), 'unknown_rate': (int, float)} + valid_strata = {'pinned', 'unpinned'} + valid_decisions = {'PASS', 'WARN', 'FAIL'} + + for field, expected_type in required_fields.items(): + if field not in report_json: + logging.error(f"Missing required field: {field}") + return False + if not isinstance(report_json[field], expected_type): + logging.error(f"Invalid type for field {field}") + return False + + if report_json['stratum'] not in valid_strata: + logging.error(f"Invalid stratum value: {report_json['stratum']}") + return False + if report_json['decision'] not in valid_decisions: + logging.error(f"Invalid decision value: {report_json['decision']}") + return False + try: + datetime.fromisoformat(report_json['timestamp']) + except Exception: + logging.error("Invalid timestamp format.") + return False + if not (0.0 <= float(report_json['warn_rate']) <= 1.0): + logging.error("warn_rate out of range.") + return False + if not (0.0 <= float(report_json['unknown_rate']) <= 1.0): + logging.error("unknown_rate out of range.") + return False + return True + + +def aggregate_statistics(reports: List[ReportData]) -> Dict[str, Any]: + """Aggregate results per stratum and overall metrics.""" + if not reports: + return {} + + df = pd.DataFrame([r.to_dict() for r in reports]) + agg = {} + + for stratum, group in df.groupby('stratum'): + agg[stratum] = { + 'count': len(group), + 'mean_warn_rate': group['warn_rate'].mean(), + 'mean_unknown_rate': group['unknown_rate'].mean(), + 'decision_counts': group['decision'].value_counts().to_dict() + } + + agg['total'] = { + 'count': len(df), + 'mean_warn_rate': df['warn_rate'].mean(), + 'mean_unknown_rate': df['unknown_rate'].mean(), + 'decision_counts': df['decision'].value_counts().to_dict() + } + + return agg + + +def analyze_reports(directory_path: str) -> Dict[str, Any]: + """Scan directory recursively, validate and aggregate drift_report.json files.""" + directory = Path(directory_path) + if not directory.exists() or not directory.is_dir(): + raise FileNotFoundError(f"Directory not found: {directory_path}") + + reports: List[ReportData] = [] + + for file_path in directory.rglob('drift_report.json'): + try: + with open(file_path, 'r', encoding='utf-8') as f: + report_json = json.load(f) + if validate_report_schema(report_json): + reports.append(ReportData.from_json(report_json)) + else: + logging.warning(f"Skipping invalid report: {file_path}") + except Exception as e: + logging.error(f"Failed to process {file_path}: {e}") + + agg_result = aggregate_statistics(reports) + + # Generate outputs + if reports: + output_dir = Path('output') + output_dir.mkdir(parents=True, exist_ok=True) + + csv_path = output_dir / 'audit.csv' + md_path = output_dir / 'drift_report_agg.md' + + # CSV output + rows = [] + for stratum, stats in agg_result.items(): + row = {'stratum': stratum, **{k: v for k, v in stats.items() if k != 'decision_counts'}} + for decision, count in stats['decision_counts'].items(): + row[f'decision_{decision}'] = count + rows.append(row) + pd.DataFrame(rows).to_csv(csv_path, index=False) + + # Markdown output + with open(md_path, 'w', encoding='utf-8') as md: + md.write('# Drift Report Aggregation Summary\n\n') + for stratum, stats in agg_result.items(): + md.write(f"## {stratum.capitalize()}\n") + md.write(f"Total Reports: {stats['count']}\n\n") + md.write(f"Mean Warn Rate: {stats['mean_warn_rate']:.3f}\n\n") + md.write(f"Mean Unknown Rate: {stats['mean_unknown_rate']:.3f}\n\n") + md.write('### Decision Counts\n') + for d, c in stats['decision_counts'].items(): + md.write(f"- {d}: {c}\n") + md.write('\n') + + return agg_result + + +def _main(): + parser = argparse.ArgumentParser(description='Analyze drift_report.json files.') + parser.add_argument('--input', required=True, help='Path to input directory containing drift_report.json files.') + parser.add_argument('--out', required=False, default='output', help='Output directory for generated files.') + args = parser.parse_args() + + logging.info(f"Analyzing drift reports in {args.input}") + results = analyze_reports(args.input) + logging.info(f"Analysis completed. Results:") + for k, v in results.items(): + logging.info(f"{k}: {v}") + + +if __name__ == '__main__': + _main() \ No newline at end of file