Add trace_agg/src/trace_agg/cli.py

This commit is contained in:
Mika 2026-01-19 12:48:36 +00:00
parent e9449a30aa
commit 622d6c5037

View file

@ -0,0 +1,103 @@
import argparse
import json
import sys
from pathlib import Path
import logging
from typing import List
from trace_agg import core
def setup_logging() -> logging.Logger:
logger = logging.getLogger("trace_agg.cli")
handler = logging.StreamHandler()
formatter = logging.Formatter(
fmt="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
if not logger.handlers:
logger.addHandler(handler)
logger.setLevel(logging.INFO)
return logger
def _validate_run_data_entries(data: list) -> List[dict]:
required_keys = {"run_id", "retry_free_reads", "window_duration"}
validated = []
for entry in data:
if not isinstance(entry, dict):
raise TypeError(f"Run data entry must be dict, got {type(entry)}")
if not required_keys.issubset(entry):
raise ValueError(f"Missing required keys in {entry}")
if not isinstance(entry["run_id"], str):
raise TypeError("run_id must be str")
if not isinstance(entry["retry_free_reads"], int):
raise TypeError("retry_free_reads must be int")
if not isinstance(entry["window_duration"], (int, float)):
raise TypeError("window_duration must be number")
validated.append(entry)
assert validated, "Input data list cannot be empty"
return validated
def main() -> None:
parser = argparse.ArgumentParser(
description="Command-line interface for run set analysis and retry classification."
)
parser.add_argument("--input", required=True, help="Path to input JSON with run data.")
parser.add_argument(
"--output", required=True, help="Path to output JSON file for classification results."
)
args = parser.parse_args()
logger = setup_logging()
try:
input_path = Path(args.input)
output_path = Path(args.output)
if not input_path.exists():
raise FileNotFoundError(f"Input file not found: {input_path}")
logger.info(f"Loading run data from {input_path}")
with input_path.open("r", encoding="utf-8") as f:
data = json.load(f)
validated_data = _validate_run_data_entries(data)
# Prepare run data objects for core classification.
run_data_objs = []
for d in validated_data:
run_data_objs.append(
type("RunData", (), d)
)
logger.info("Classifying runs and computing correlations...")
results = core.classify_runs(run_data_objs)
# Results serialization.
output_serializable = []
for r in results:
output_serializable.append(
{
"run_id": getattr(r, "run_id", None),
"step_stability_score": getattr(r, "step_stability_score", None),
"correlation_coefficient": getattr(r, "correlation_coefficient", None),
}
)
logger.info(f"Writing classification results to {output_path}")
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as f:
json.dump(output_serializable, f, indent=2)
logger.info("Run classification completed successfully.")
except Exception as exc:
logger.exception(f"Execution failed: {exc}")
sys.exit(1)
if __name__ == "__main__":
main()