Add rollout_data_analysis/tests/test_core.py
This commit is contained in:
parent
27da182255
commit
aaca3d29df
1 changed files with 71 additions and 0 deletions
71
rollout_data_analysis/tests/test_core.py
Normal file
71
rollout_data_analysis/tests/test_core.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
import pytest
|
||||
import math
|
||||
from rollout_data_analysis import core
|
||||
|
||||
class DummyAnalysisResults:
|
||||
def __init__(self, min_val, median_val, p95_val, max_val):
|
||||
self.min = min_val
|
||||
self.median = median_val
|
||||
self.p95 = p95_val
|
||||
self.max = max_val
|
||||
|
||||
@pytest.fixture
|
||||
def sample_rollout_data():
|
||||
return [
|
||||
{"run_id": "r1", "unknown_rate": 0.1, "warn_rate": 0.05, "policy_hash": "abc", "pinned": True},
|
||||
{"run_id": "r2", "unknown_rate": 0.2, "warn_rate": 0.1, "policy_hash": "abc", "pinned": False},
|
||||
{"run_id": "r3", "unknown_rate": 0.4, "warn_rate": 0.2, "policy_hash": "def", "pinned": False},
|
||||
{"run_id": "r4", "unknown_rate": 0.3, "warn_rate": 0.15, "policy_hash": "ghi", "pinned": True}
|
||||
]
|
||||
|
||||
|
||||
def approx_equal(a, b, tol=1e-6):
|
||||
return math.isclose(a, b, rel_tol=tol, abs_tol=tol)
|
||||
|
||||
|
||||
def test_analyze_data_basic(sample_rollout_data):
|
||||
results = core.analyze_data(sample_rollout_data)
|
||||
assert hasattr(results, 'min') and hasattr(results, 'median')
|
||||
assert hasattr(results, 'p95') and hasattr(results, 'max')
|
||||
values = [v["unknown_rate"] for v in sample_rollout_data]
|
||||
assert approx_equal(results.min, min(values))
|
||||
assert approx_equal(results.max, max(values))
|
||||
# Median: Zwischen 0.2 und 0.3 -> 0.25
|
||||
assert approx_equal(results.median, 0.25)
|
||||
# P95 sollte nahe dem maximum liegen, aber nicht exakt
|
||||
assert results.p95 <= results.max
|
||||
|
||||
|
||||
def test_analyze_data_empty():
|
||||
with pytest.raises(ValueError):
|
||||
core.analyze_data([])
|
||||
|
||||
|
||||
def test_analyze_data_invalid_input():
|
||||
# Fehlt Feld unknown_rate
|
||||
invalid_data = [{"run_id": "r1", "warn_rate": 0.1, "policy_hash": "x", "pinned": False}]
|
||||
with pytest.raises((KeyError, ValueError, TypeError)):
|
||||
core.analyze_data(invalid_data)
|
||||
|
||||
|
||||
def test_analyze_data_float_accuracy(sample_rollout_data):
|
||||
results = core.analyze_data(sample_rollout_data)
|
||||
assert isinstance(results.min, float)
|
||||
assert isinstance(results.median, float)
|
||||
assert isinstance(results.p95, float)
|
||||
assert isinstance(results.max, float)
|
||||
# Check monotonicity: min <= median <= p95 <= max
|
||||
assert results.min <= results.median <= results.p95 <= results.max
|
||||
|
||||
|
||||
def test_analyze_data_with_duplicate_values():
|
||||
rollout_data = [
|
||||
{"run_id": f"r{i}", "unknown_rate": 0.25, "warn_rate": 0.05, "policy_hash": "h{i}", "pinned": False}
|
||||
for i in range(10)
|
||||
]
|
||||
results = core.analyze_data(rollout_data)
|
||||
expected_val = 0.25
|
||||
assert all(
|
||||
approx_equal(getattr(results, attr), expected_val)
|
||||
for attr in ('min', 'median', 'p95', 'max')
|
||||
)
|
||||
Loading…
Reference in a new issue