diff --git a/trace_agg/src/trace_agg/cli.py b/trace_agg/src/trace_agg/cli.py new file mode 100644 index 0000000..7172628 --- /dev/null +++ b/trace_agg/src/trace_agg/cli.py @@ -0,0 +1,103 @@ +import argparse +import json +import sys +from pathlib import Path +import logging +from typing import List + +from trace_agg import core + + +def setup_logging() -> logging.Logger: + logger = logging.getLogger("trace_agg.cli") + handler = logging.StreamHandler() + formatter = logging.Formatter( + fmt="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler.setFormatter(formatter) + if not logger.handlers: + logger.addHandler(handler) + logger.setLevel(logging.INFO) + return logger + + +def _validate_run_data_entries(data: list) -> List[dict]: + required_keys = {"run_id", "retry_free_reads", "window_duration"} + validated = [] + for entry in data: + if not isinstance(entry, dict): + raise TypeError(f"Run data entry must be dict, got {type(entry)}") + if not required_keys.issubset(entry): + raise ValueError(f"Missing required keys in {entry}") + if not isinstance(entry["run_id"], str): + raise TypeError("run_id must be str") + if not isinstance(entry["retry_free_reads"], int): + raise TypeError("retry_free_reads must be int") + if not isinstance(entry["window_duration"], (int, float)): + raise TypeError("window_duration must be number") + validated.append(entry) + assert validated, "Input data list cannot be empty" + return validated + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Command-line interface for run set analysis and retry classification." + ) + parser.add_argument("--input", required=True, help="Path to input JSON with run data.") + parser.add_argument( + "--output", required=True, help="Path to output JSON file for classification results." + ) + args = parser.parse_args() + + logger = setup_logging() + + try: + input_path = Path(args.input) + output_path = Path(args.output) + + if not input_path.exists(): + raise FileNotFoundError(f"Input file not found: {input_path}") + + logger.info(f"Loading run data from {input_path}") + with input_path.open("r", encoding="utf-8") as f: + data = json.load(f) + + validated_data = _validate_run_data_entries(data) + + # Prepare run data objects for core classification. + run_data_objs = [] + for d in validated_data: + run_data_objs.append( + type("RunData", (), d) + ) + + logger.info("Classifying runs and computing correlations...") + results = core.classify_runs(run_data_objs) + + # Results serialization. + output_serializable = [] + for r in results: + output_serializable.append( + { + "run_id": getattr(r, "run_id", None), + "step_stability_score": getattr(r, "step_stability_score", None), + "correlation_coefficient": getattr(r, "correlation_coefficient", None), + } + ) + + logger.info(f"Writing classification results to {output_path}") + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + json.dump(output_serializable, f, indent=2) + + logger.info("Run classification completed successfully.") + + except Exception as exc: + logger.exception(f"Execution failed: {exc}") + sys.exit(1) + + +if __name__ == "__main__": + main()