Add trace_agg/src/trace_agg/core.py
This commit is contained in:
parent
4dc6bfa3b4
commit
e9449a30aa
1 changed files with 118 additions and 0 deletions
118
trace_agg/src/trace_agg/core.py
Normal file
118
trace_agg/src/trace_agg/core.py
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
from math import fabs
|
||||||
|
from statistics import median
|
||||||
|
from scipy.stats import spearmanr
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'RunData',
|
||||||
|
'ClassificationResults',
|
||||||
|
'classify_runs',
|
||||||
|
'compute_spearman_correlation',
|
||||||
|
'calculate_edit_distance'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunData:
|
||||||
|
run_id: str
|
||||||
|
retry_free_reads: int
|
||||||
|
window_duration: float
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
assert isinstance(self.run_id, str), 'run_id must be a string'
|
||||||
|
assert isinstance(self.retry_free_reads, int), 'retry_free_reads must be an int'
|
||||||
|
assert isinstance(self.window_duration, (float, int)), 'window_duration must be a number'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassificationResults:
|
||||||
|
run_id: str
|
||||||
|
step_stability_score: float
|
||||||
|
correlation_coefficient: float
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
assert isinstance(self.run_id, str), 'run_id must be a string'
|
||||||
|
assert isinstance(self.step_stability_score, (float, int)), 'step_stability_score must be a float'
|
||||||
|
assert isinstance(self.correlation_coefficient, (float, int)), 'correlation_coefficient must be a float'
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_spearman_correlation(retry_count: List[int], window_duration: List[float]) -> float:
|
||||||
|
"""Compute Spearman correlation between retry_count and window_duration."""
|
||||||
|
if not retry_count or not window_duration:
|
||||||
|
raise ValueError('Input lists must not be empty')
|
||||||
|
if len(retry_count) != len(window_duration):
|
||||||
|
raise ValueError('Input lists must have equal length')
|
||||||
|
|
||||||
|
try:
|
||||||
|
coef, _ = spearmanr(retry_count, window_duration)
|
||||||
|
return float(coef) if coef is not None else 0.0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('Failed to compute Spearman correlation: %s', e)
|
||||||
|
raise ClassificationError('Spearman correlation computation failed') from e
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_edit_distance(step_sequence: List[str], pinned_median_sequence: List[str]) -> float:
|
||||||
|
"""Compute normalized edit distance between two step sequences."""
|
||||||
|
if not step_sequence and not pinned_median_sequence:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
n, m = len(step_sequence), len(pinned_median_sequence)
|
||||||
|
dp = [[0] * (m + 1) for _ in range(n + 1)]
|
||||||
|
|
||||||
|
for i in range(n + 1):
|
||||||
|
dp[i][0] = i
|
||||||
|
for j in range(m + 1):
|
||||||
|
dp[0][j] = j
|
||||||
|
|
||||||
|
for i in range(1, n + 1):
|
||||||
|
for j in range(1, m + 1):
|
||||||
|
cost = 0 if step_sequence[i - 1] == pinned_median_sequence[j - 1] else 1
|
||||||
|
dp[i][j] = min(
|
||||||
|
dp[i - 1][j] + 1, # deletion
|
||||||
|
dp[i][j - 1] + 1, # insertion
|
||||||
|
dp[i - 1][j - 1] + cost # substitution
|
||||||
|
)
|
||||||
|
|
||||||
|
max_len = max(n, m, 1)
|
||||||
|
return dp[n][m] / max_len
|
||||||
|
|
||||||
|
|
||||||
|
def classify_runs(run_data: List[RunData]) -> List[ClassificationResults]:
|
||||||
|
"""Classify runs using correlation and stability metrics."""
|
||||||
|
if not run_data:
|
||||||
|
raise ValueError('run_data must not be empty')
|
||||||
|
for item in run_data:
|
||||||
|
if not isinstance(item, RunData):
|
||||||
|
raise TypeError('All elements in run_data must be of type RunData')
|
||||||
|
|
||||||
|
retry_counts = [r.retry_free_reads for r in run_data]
|
||||||
|
window_durations = [r.window_duration for r in run_data]
|
||||||
|
correlation = compute_spearman_correlation(retry_counts, window_durations)
|
||||||
|
|
||||||
|
# create synthetic pinned median reference for stability placeholder
|
||||||
|
median_window = median(window_durations)
|
||||||
|
pinned_seq = [f'step_{int(median_window)}']
|
||||||
|
|
||||||
|
results: List[ClassificationResults] = []
|
||||||
|
for r in run_data:
|
||||||
|
seq = [f'step_{r.retry_free_reads}', f'dur_{int(r.window_duration)}']
|
||||||
|
stability = calculate_edit_distance(seq, pinned_seq)
|
||||||
|
result = ClassificationResults(
|
||||||
|
run_id=r.run_id,
|
||||||
|
step_stability_score=stability,
|
||||||
|
correlation_coefficient=correlation,
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
Loading…
Reference in a new issue