Add audit_drift_script/src/audit_drift_script/core.py
This commit is contained in:
commit
4d283df368
1 changed files with 183 additions and 0 deletions
183
audit_drift_script/src/audit_drift_script/core.py
Normal file
183
audit_drift_script/src/audit_drift_script/core.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
import os
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from typing import List, Dict, Any
|
||||
import pandas as pd
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s: %(message)s')
|
||||
|
||||
|
||||
class SchemaValidationError(Exception):
|
||||
"""Custom exception for invalid report schema."""
|
||||
pass
|
||||
|
||||
|
||||
class ReportData:
|
||||
"""Represents a single drift_report.json file."""
|
||||
|
||||
def __init__(self, timestamp: datetime, run_id: str, stratum: str, decision: str, warn_rate: float, unknown_rate: float):
|
||||
self.timestamp = timestamp
|
||||
self.run_id = run_id
|
||||
self.stratum = stratum
|
||||
self.decision = decision
|
||||
self.warn_rate = warn_rate
|
||||
self.unknown_rate = unknown_rate
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_dict: Dict[str, Any]) -> 'ReportData':
|
||||
if not validate_report_schema(json_dict):
|
||||
raise SchemaValidationError("Invalid report schema encountered.")
|
||||
timestamp = datetime.fromisoformat(json_dict['timestamp'])
|
||||
return cls(
|
||||
timestamp=timestamp,
|
||||
run_id=json_dict['run_id'],
|
||||
stratum=json_dict['stratum'],
|
||||
decision=json_dict['decision'],
|
||||
warn_rate=float(json_dict['warn_rate']),
|
||||
unknown_rate=float(json_dict['unknown_rate'])
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'timestamp': self.timestamp.isoformat(),
|
||||
'run_id': self.run_id,
|
||||
'stratum': self.stratum,
|
||||
'decision': self.decision,
|
||||
'warn_rate': self.warn_rate,
|
||||
'unknown_rate': self.unknown_rate
|
||||
}
|
||||
|
||||
|
||||
def validate_report_schema(report_json: Dict[str, Any]) -> bool:
|
||||
"""Validate required fields and value ranges."""
|
||||
required_fields = {'timestamp': str, 'run_id': str, 'stratum': str, 'decision': str, 'warn_rate': (int, float), 'unknown_rate': (int, float)}
|
||||
valid_strata = {'pinned', 'unpinned'}
|
||||
valid_decisions = {'PASS', 'WARN', 'FAIL'}
|
||||
|
||||
for field, expected_type in required_fields.items():
|
||||
if field not in report_json:
|
||||
logging.error(f"Missing required field: {field}")
|
||||
return False
|
||||
if not isinstance(report_json[field], expected_type):
|
||||
logging.error(f"Invalid type for field {field}")
|
||||
return False
|
||||
|
||||
if report_json['stratum'] not in valid_strata:
|
||||
logging.error(f"Invalid stratum value: {report_json['stratum']}")
|
||||
return False
|
||||
if report_json['decision'] not in valid_decisions:
|
||||
logging.error(f"Invalid decision value: {report_json['decision']}")
|
||||
return False
|
||||
try:
|
||||
datetime.fromisoformat(report_json['timestamp'])
|
||||
except Exception:
|
||||
logging.error("Invalid timestamp format.")
|
||||
return False
|
||||
if not (0.0 <= float(report_json['warn_rate']) <= 1.0):
|
||||
logging.error("warn_rate out of range.")
|
||||
return False
|
||||
if not (0.0 <= float(report_json['unknown_rate']) <= 1.0):
|
||||
logging.error("unknown_rate out of range.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def aggregate_statistics(reports: List[ReportData]) -> Dict[str, Any]:
|
||||
"""Aggregate results per stratum and overall metrics."""
|
||||
if not reports:
|
||||
return {}
|
||||
|
||||
df = pd.DataFrame([r.to_dict() for r in reports])
|
||||
agg = {}
|
||||
|
||||
for stratum, group in df.groupby('stratum'):
|
||||
agg[stratum] = {
|
||||
'count': len(group),
|
||||
'mean_warn_rate': group['warn_rate'].mean(),
|
||||
'mean_unknown_rate': group['unknown_rate'].mean(),
|
||||
'decision_counts': group['decision'].value_counts().to_dict()
|
||||
}
|
||||
|
||||
agg['total'] = {
|
||||
'count': len(df),
|
||||
'mean_warn_rate': df['warn_rate'].mean(),
|
||||
'mean_unknown_rate': df['unknown_rate'].mean(),
|
||||
'decision_counts': df['decision'].value_counts().to_dict()
|
||||
}
|
||||
|
||||
return agg
|
||||
|
||||
|
||||
def analyze_reports(directory_path: str) -> Dict[str, Any]:
|
||||
"""Scan directory recursively, validate and aggregate drift_report.json files."""
|
||||
directory = Path(directory_path)
|
||||
if not directory.exists() or not directory.is_dir():
|
||||
raise FileNotFoundError(f"Directory not found: {directory_path}")
|
||||
|
||||
reports: List[ReportData] = []
|
||||
|
||||
for file_path in directory.rglob('drift_report.json'):
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
report_json = json.load(f)
|
||||
if validate_report_schema(report_json):
|
||||
reports.append(ReportData.from_json(report_json))
|
||||
else:
|
||||
logging.warning(f"Skipping invalid report: {file_path}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process {file_path}: {e}")
|
||||
|
||||
agg_result = aggregate_statistics(reports)
|
||||
|
||||
# Generate outputs
|
||||
if reports:
|
||||
output_dir = Path('output')
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
csv_path = output_dir / 'audit.csv'
|
||||
md_path = output_dir / 'drift_report_agg.md'
|
||||
|
||||
# CSV output
|
||||
rows = []
|
||||
for stratum, stats in agg_result.items():
|
||||
row = {'stratum': stratum, **{k: v for k, v in stats.items() if k != 'decision_counts'}}
|
||||
for decision, count in stats['decision_counts'].items():
|
||||
row[f'decision_{decision}'] = count
|
||||
rows.append(row)
|
||||
pd.DataFrame(rows).to_csv(csv_path, index=False)
|
||||
|
||||
# Markdown output
|
||||
with open(md_path, 'w', encoding='utf-8') as md:
|
||||
md.write('# Drift Report Aggregation Summary\n\n')
|
||||
for stratum, stats in agg_result.items():
|
||||
md.write(f"## {stratum.capitalize()}\n")
|
||||
md.write(f"Total Reports: {stats['count']}\n\n")
|
||||
md.write(f"Mean Warn Rate: {stats['mean_warn_rate']:.3f}\n\n")
|
||||
md.write(f"Mean Unknown Rate: {stats['mean_unknown_rate']:.3f}\n\n")
|
||||
md.write('### Decision Counts\n')
|
||||
for d, c in stats['decision_counts'].items():
|
||||
md.write(f"- {d}: {c}\n")
|
||||
md.write('\n')
|
||||
|
||||
return agg_result
|
||||
|
||||
|
||||
def _main():
|
||||
parser = argparse.ArgumentParser(description='Analyze drift_report.json files.')
|
||||
parser.add_argument('--input', required=True, help='Path to input directory containing drift_report.json files.')
|
||||
parser.add_argument('--out', required=False, default='output', help='Output directory for generated files.')
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.info(f"Analyzing drift reports in {args.input}")
|
||||
results = analyze_reports(args.input)
|
||||
logging.info(f"Analysis completed. Results:")
|
||||
for k, v in results.items():
|
||||
logging.info(f"{k}: {v}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_main()
|
||||
Loading…
Reference in a new issue