diff --git a/rollout_data_analysis/tests/test_core.py b/rollout_data_analysis/tests/test_core.py new file mode 100644 index 0000000..18e7c29 --- /dev/null +++ b/rollout_data_analysis/tests/test_core.py @@ -0,0 +1,71 @@ +import pytest +import math +from rollout_data_analysis import core + +class DummyAnalysisResults: + def __init__(self, min_val, median_val, p95_val, max_val): + self.min = min_val + self.median = median_val + self.p95 = p95_val + self.max = max_val + +@pytest.fixture +def sample_rollout_data(): + return [ + {"run_id": "r1", "unknown_rate": 0.1, "warn_rate": 0.05, "policy_hash": "abc", "pinned": True}, + {"run_id": "r2", "unknown_rate": 0.2, "warn_rate": 0.1, "policy_hash": "abc", "pinned": False}, + {"run_id": "r3", "unknown_rate": 0.4, "warn_rate": 0.2, "policy_hash": "def", "pinned": False}, + {"run_id": "r4", "unknown_rate": 0.3, "warn_rate": 0.15, "policy_hash": "ghi", "pinned": True} + ] + + +def approx_equal(a, b, tol=1e-6): + return math.isclose(a, b, rel_tol=tol, abs_tol=tol) + + +def test_analyze_data_basic(sample_rollout_data): + results = core.analyze_data(sample_rollout_data) + assert hasattr(results, 'min') and hasattr(results, 'median') + assert hasattr(results, 'p95') and hasattr(results, 'max') + values = [v["unknown_rate"] for v in sample_rollout_data] + assert approx_equal(results.min, min(values)) + assert approx_equal(results.max, max(values)) + # Median: Zwischen 0.2 und 0.3 -> 0.25 + assert approx_equal(results.median, 0.25) + # P95 sollte nahe dem maximum liegen, aber nicht exakt + assert results.p95 <= results.max + + +def test_analyze_data_empty(): + with pytest.raises(ValueError): + core.analyze_data([]) + + +def test_analyze_data_invalid_input(): + # Fehlt Feld unknown_rate + invalid_data = [{"run_id": "r1", "warn_rate": 0.1, "policy_hash": "x", "pinned": False}] + with pytest.raises((KeyError, ValueError, TypeError)): + core.analyze_data(invalid_data) + + +def test_analyze_data_float_accuracy(sample_rollout_data): + results = core.analyze_data(sample_rollout_data) + assert isinstance(results.min, float) + assert isinstance(results.median, float) + assert isinstance(results.p95, float) + assert isinstance(results.max, float) + # Check monotonicity: min <= median <= p95 <= max + assert results.min <= results.median <= results.p95 <= results.max + + +def test_analyze_data_with_duplicate_values(): + rollout_data = [ + {"run_id": f"r{i}", "unknown_rate": 0.25, "warn_rate": 0.05, "policy_hash": "h{i}", "pinned": False} + for i in range(10) + ] + results = core.analyze_data(rollout_data) + expected_val = 0.25 + assert all( + approx_equal(getattr(results, attr), expected_val) + for attr in ('min', 'median', 'p95', 'max') + ) \ No newline at end of file