Add trace_agg/src/trace_agg/core.py

This commit is contained in:
Mika 2026-01-19 12:48:36 +00:00
parent 4dc6bfa3b4
commit e9449a30aa

View file

@ -0,0 +1,118 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List
from math import fabs
from statistics import median
from scipy.stats import spearmanr
__all__ = [
'RunData',
'ClassificationResults',
'classify_runs',
'compute_spearman_correlation',
'calculate_edit_distance'
]
@dataclass
class RunData:
run_id: str
retry_free_reads: int
window_duration: float
def __post_init__(self) -> None:
assert isinstance(self.run_id, str), 'run_id must be a string'
assert isinstance(self.retry_free_reads, int), 'retry_free_reads must be an int'
assert isinstance(self.window_duration, (float, int)), 'window_duration must be a number'
@dataclass
class ClassificationResults:
run_id: str
step_stability_score: float
correlation_coefficient: float
def __post_init__(self) -> None:
assert isinstance(self.run_id, str), 'run_id must be a string'
assert isinstance(self.step_stability_score, (float, int)), 'step_stability_score must be a float'
assert isinstance(self.correlation_coefficient, (float, int)), 'correlation_coefficient must be a float'
class ClassificationError(Exception):
pass
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def compute_spearman_correlation(retry_count: List[int], window_duration: List[float]) -> float:
"""Compute Spearman correlation between retry_count and window_duration."""
if not retry_count or not window_duration:
raise ValueError('Input lists must not be empty')
if len(retry_count) != len(window_duration):
raise ValueError('Input lists must have equal length')
try:
coef, _ = spearmanr(retry_count, window_duration)
return float(coef) if coef is not None else 0.0
except Exception as e:
logger.error('Failed to compute Spearman correlation: %s', e)
raise ClassificationError('Spearman correlation computation failed') from e
def calculate_edit_distance(step_sequence: List[str], pinned_median_sequence: List[str]) -> float:
"""Compute normalized edit distance between two step sequences."""
if not step_sequence and not pinned_median_sequence:
return 0.0
n, m = len(step_sequence), len(pinned_median_sequence)
dp = [[0] * (m + 1) for _ in range(n + 1)]
for i in range(n + 1):
dp[i][0] = i
for j in range(m + 1):
dp[0][j] = j
for i in range(1, n + 1):
for j in range(1, m + 1):
cost = 0 if step_sequence[i - 1] == pinned_median_sequence[j - 1] else 1
dp[i][j] = min(
dp[i - 1][j] + 1, # deletion
dp[i][j - 1] + 1, # insertion
dp[i - 1][j - 1] + cost # substitution
)
max_len = max(n, m, 1)
return dp[n][m] / max_len
def classify_runs(run_data: List[RunData]) -> List[ClassificationResults]:
"""Classify runs using correlation and stability metrics."""
if not run_data:
raise ValueError('run_data must not be empty')
for item in run_data:
if not isinstance(item, RunData):
raise TypeError('All elements in run_data must be of type RunData')
retry_counts = [r.retry_free_reads for r in run_data]
window_durations = [r.window_duration for r in run_data]
correlation = compute_spearman_correlation(retry_counts, window_durations)
# create synthetic pinned median reference for stability placeholder
median_window = median(window_durations)
pinned_seq = [f'step_{int(median_window)}']
results: List[ClassificationResults] = []
for r in run_data:
seq = [f'step_{r.retry_free_reads}', f'dur_{int(r.window_duration)}']
stability = calculate_edit_distance(seq, pinned_seq)
result = ClassificationResults(
run_id=r.run_id,
step_stability_score=stability,
correlation_coefficient=correlation,
)
results.append(result)
return results