Add trace_agg/src/trace_agg/cli.py
This commit is contained in:
parent
e9449a30aa
commit
622d6c5037
1 changed files with 103 additions and 0 deletions
103
trace_agg/src/trace_agg/cli.py
Normal file
103
trace_agg/src/trace_agg/cli.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from trace_agg import core
|
||||
|
||||
|
||||
def setup_logging() -> logging.Logger:
|
||||
logger = logging.getLogger("trace_agg.cli")
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
if not logger.handlers:
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
return logger
|
||||
|
||||
|
||||
def _validate_run_data_entries(data: list) -> List[dict]:
|
||||
required_keys = {"run_id", "retry_free_reads", "window_duration"}
|
||||
validated = []
|
||||
for entry in data:
|
||||
if not isinstance(entry, dict):
|
||||
raise TypeError(f"Run data entry must be dict, got {type(entry)}")
|
||||
if not required_keys.issubset(entry):
|
||||
raise ValueError(f"Missing required keys in {entry}")
|
||||
if not isinstance(entry["run_id"], str):
|
||||
raise TypeError("run_id must be str")
|
||||
if not isinstance(entry["retry_free_reads"], int):
|
||||
raise TypeError("retry_free_reads must be int")
|
||||
if not isinstance(entry["window_duration"], (int, float)):
|
||||
raise TypeError("window_duration must be number")
|
||||
validated.append(entry)
|
||||
assert validated, "Input data list cannot be empty"
|
||||
return validated
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Command-line interface for run set analysis and retry classification."
|
||||
)
|
||||
parser.add_argument("--input", required=True, help="Path to input JSON with run data.")
|
||||
parser.add_argument(
|
||||
"--output", required=True, help="Path to output JSON file for classification results."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
try:
|
||||
input_path = Path(args.input)
|
||||
output_path = Path(args.output)
|
||||
|
||||
if not input_path.exists():
|
||||
raise FileNotFoundError(f"Input file not found: {input_path}")
|
||||
|
||||
logger.info(f"Loading run data from {input_path}")
|
||||
with input_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
validated_data = _validate_run_data_entries(data)
|
||||
|
||||
# Prepare run data objects for core classification.
|
||||
run_data_objs = []
|
||||
for d in validated_data:
|
||||
run_data_objs.append(
|
||||
type("RunData", (), d)
|
||||
)
|
||||
|
||||
logger.info("Classifying runs and computing correlations...")
|
||||
results = core.classify_runs(run_data_objs)
|
||||
|
||||
# Results serialization.
|
||||
output_serializable = []
|
||||
for r in results:
|
||||
output_serializable.append(
|
||||
{
|
||||
"run_id": getattr(r, "run_id", None),
|
||||
"step_stability_score": getattr(r, "step_stability_score", None),
|
||||
"correlation_coefficient": getattr(r, "correlation_coefficient", None),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Writing classification results to {output_path}")
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(output_serializable, f, indent=2)
|
||||
|
||||
logger.info("Run classification completed successfully.")
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f"Execution failed: {exc}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in a new issue