Add rollup_rollout/tests/test_core.py
This commit is contained in:
parent
bcc6acc3f6
commit
e5b7376c74
1 changed files with 73 additions and 0 deletions
73
rollup_rollout/tests/test_core.py
Normal file
73
rollup_rollout/tests/test_core.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
import json
|
||||||
|
import csv
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import src.rollup_rollout.core as core
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_json_file(tmp_path):
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"policy_hash": "abc123",
|
||||||
|
"outcome": "PASS",
|
||||||
|
"unknown_rate": 0.0,
|
||||||
|
"top_reasons": "[]"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"policy_hash": "def456",
|
||||||
|
"outcome": "FAIL",
|
||||||
|
"unknown_rate": 0.25,
|
||||||
|
"top_reasons": "[\"timeout\"]"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
path = tmp_path / "gate_result.json"
|
||||||
|
with open(path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_rollout_series_creates_csv(tmp_json_file, tmp_path):
|
||||||
|
output_path = tmp_path / "rollout_series.csv"
|
||||||
|
core.generate_rollout_series(str(tmp_json_file), str(output_path))
|
||||||
|
|
||||||
|
assert output_path.exists(), "Output CSV file should be created."
|
||||||
|
content = output_path.read_text(encoding='utf-8').strip().splitlines()
|
||||||
|
assert content[0].split(',') == ["policy_hash", "outcome", "unknown_rate", "top_reasons"]
|
||||||
|
assert any("abc123" in line for line in content)
|
||||||
|
assert any("def456" in line for line in content)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_rollout_series_invalid_json(tmp_path):
|
||||||
|
bad_json_file = tmp_path / "bad.json"
|
||||||
|
bad_json_file.write_text('{invalid json}', encoding='utf-8')
|
||||||
|
output_path = tmp_path / "rollout.csv"
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
core.generate_rollout_series(str(bad_json_file), str(output_path))
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_rollout_series_empty_input(tmp_path):
|
||||||
|
empty_json = tmp_path / "empty.json"
|
||||||
|
with open(empty_json, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump([], f)
|
||||||
|
output_csv = tmp_path / "out.csv"
|
||||||
|
|
||||||
|
core.generate_rollout_series(str(empty_json), str(output_csv))
|
||||||
|
lines = output_csv.read_text(encoding='utf-8').splitlines()
|
||||||
|
assert len(lines) == 1, "Only header should exist for empty input."
|
||||||
|
|
||||||
|
|
||||||
|
def test_rolloutdata_fields():
|
||||||
|
data = core.RolloutData(
|
||||||
|
policy_hash="xyz000",
|
||||||
|
outcome="WARN",
|
||||||
|
unknown_rate=0.1,
|
||||||
|
top_reasons="[\"unstable\"]"
|
||||||
|
)
|
||||||
|
assert data.policy_hash == "xyz000"
|
||||||
|
assert isinstance(data.outcome, str)
|
||||||
|
assert isinstance(data.unknown_rate, float)
|
||||||
|
assert "unstable" in data.top_reasons
|
||||||
Loading…
Reference in a new issue