Add rollout_data_analysis/tests/test_core.py

This commit is contained in:
Mika 2026-02-21 15:27:05 +00:00
parent 27da182255
commit aaca3d29df

View file

@ -0,0 +1,71 @@
import pytest
import math
from rollout_data_analysis import core
class DummyAnalysisResults:
def __init__(self, min_val, median_val, p95_val, max_val):
self.min = min_val
self.median = median_val
self.p95 = p95_val
self.max = max_val
@pytest.fixture
def sample_rollout_data():
return [
{"run_id": "r1", "unknown_rate": 0.1, "warn_rate": 0.05, "policy_hash": "abc", "pinned": True},
{"run_id": "r2", "unknown_rate": 0.2, "warn_rate": 0.1, "policy_hash": "abc", "pinned": False},
{"run_id": "r3", "unknown_rate": 0.4, "warn_rate": 0.2, "policy_hash": "def", "pinned": False},
{"run_id": "r4", "unknown_rate": 0.3, "warn_rate": 0.15, "policy_hash": "ghi", "pinned": True}
]
def approx_equal(a, b, tol=1e-6):
return math.isclose(a, b, rel_tol=tol, abs_tol=tol)
def test_analyze_data_basic(sample_rollout_data):
results = core.analyze_data(sample_rollout_data)
assert hasattr(results, 'min') and hasattr(results, 'median')
assert hasattr(results, 'p95') and hasattr(results, 'max')
values = [v["unknown_rate"] for v in sample_rollout_data]
assert approx_equal(results.min, min(values))
assert approx_equal(results.max, max(values))
# Median: Zwischen 0.2 und 0.3 -> 0.25
assert approx_equal(results.median, 0.25)
# P95 sollte nahe dem maximum liegen, aber nicht exakt
assert results.p95 <= results.max
def test_analyze_data_empty():
with pytest.raises(ValueError):
core.analyze_data([])
def test_analyze_data_invalid_input():
# Fehlt Feld unknown_rate
invalid_data = [{"run_id": "r1", "warn_rate": 0.1, "policy_hash": "x", "pinned": False}]
with pytest.raises((KeyError, ValueError, TypeError)):
core.analyze_data(invalid_data)
def test_analyze_data_float_accuracy(sample_rollout_data):
results = core.analyze_data(sample_rollout_data)
assert isinstance(results.min, float)
assert isinstance(results.median, float)
assert isinstance(results.p95, float)
assert isinstance(results.max, float)
# Check monotonicity: min <= median <= p95 <= max
assert results.min <= results.median <= results.p95 <= results.max
def test_analyze_data_with_duplicate_values():
rollout_data = [
{"run_id": f"r{i}", "unknown_rate": 0.25, "warn_rate": 0.05, "policy_hash": "h{i}", "pinned": False}
for i in range(10)
]
results = core.analyze_data(rollout_data)
expected_val = 0.25
assert all(
approx_equal(getattr(results, attr), expected_val)
for attr in ('min', 'median', 'p95', 'max')
)