diff --git a/rerun_analysis_tool/tests/test_core.py b/rerun_analysis_tool/tests/test_core.py new file mode 100644 index 0000000..d3156a0 --- /dev/null +++ b/rerun_analysis_tool/tests/test_core.py @@ -0,0 +1,53 @@ +import pytest +import pandas as pd +from rerun_analysis_tool import core + +@pytest.fixture +def sample_runs_data(): + return [ + { + 'run_id': 'r1', 'status': 'WARN', 'unknown_rate': 0.2, + 'rerun_helps': 1, 'rerun_shifts': 0, 'rerun_hurts': 0 + }, + { + 'run_id': 'r2', 'status': 'PASS', 'unknown_rate': 0.1, + 'rerun_helps': 0, 'rerun_shifts': 1, 'rerun_hurts': 0 + }, + { + 'run_id': 'r3', 'status': 'FAIL', 'unknown_rate': 0.5, + 'rerun_helps': 0, 'rerun_shifts': 0, 'rerun_hurts': 1 + } + ] + +def test_analyze_runs_nominal_case(sample_runs_data): + result = core.analyze_runs(sample_runs_data, threshold=0.3, rerun_budget=1) + assert isinstance(result, dict) + assert set(result.keys()) >= {'rerun_helps', 'rerun_shifts', 'rerun_hurts'} + assert result['rerun_helps'] == 1 + assert result['rerun_shifts'] == 1 + assert result['rerun_hurts'] == 1 + +def test_analyze_runs_empty_data(): + result = core.analyze_runs([], threshold=0.3, rerun_budget=1) + assert result['rerun_helps'] == 0 + assert result['rerun_shifts'] == 0 + assert result['rerun_hurts'] == 0 + +def test_analyze_runs_threshold_effect(sample_runs_data): + res_low = core.analyze_runs(sample_runs_data, threshold=0.1, rerun_budget=1) + res_high = core.analyze_runs(sample_runs_data, threshold=0.9, rerun_budget=1) + assert isinstance(res_low, dict) + assert isinstance(res_high, dict) + # Threshold should not break output schema + for key in ('rerun_helps', 'rerun_shifts', 'rerun_hurts'): + assert key in res_low and key in res_high + +def test_invalid_input_type(): + with pytest.raises((TypeError, ValueError)): + core.analyze_runs('notalist', threshold=0.3, rerun_budget=1) + +@pytest.mark.parametrize('budget', [0, 1, 2]) +def test_valid_budget_values(sample_runs_data, budget): + result = core.analyze_runs(sample_runs_data, threshold=0.3, rerun_budget=budget) + assert isinstance(result, dict) + assert all(isinstance(v, int) for v in result.values()) \ No newline at end of file