Add rollout_report_generator/tests/test_core.py

This commit is contained in:
Mika 2026-02-21 15:27:04 +00:00
parent 20205707ee
commit 623a8b465f

View file

@ -0,0 +1,68 @@
import io
import tempfile
from pathlib import Path
import pandas as pd
import pytest
from src.rollout_report_generator import core
def _create_sample_csv(tmp_path: Path) -> Path:
data = pd.DataFrame(
[
{"unknown_rate": 0.02, "warn_rate": 0.03, "policy_hash": "abc123", "runs": 10},
{"unknown_rate": 0.05, "warn_rate": 0.08, "policy_hash": "def456", "runs": 12},
{"unknown_rate": 0.01, "warn_rate": 0.02, "policy_hash": "ghi789", "runs": 9},
{"unknown_rate": 0.10, "warn_rate": 0.15, "policy_hash": "xyz000", "runs": 7},
]
)
csv_path = tmp_path / "rollout_series.csv"
data.to_csv(csv_path, index=False)
return csv_path
def test_generate_report_creates_markdown_file(tmp_path: Path):
csv_file = _create_sample_csv(tmp_path)
md_file = tmp_path / "report.md"
core.generate_report(str(csv_file), str(md_file))
assert md_file.exists(), "Markdown file should be created."
content = md_file.read_text().strip()
assert len(content) > 0, "Markdown content should not be empty."
assert "unknown_rate" in content
assert "warn_rate" in content
assert "Median" in content or "median" in content
def test_generate_report_with_invalid_csv_raises(tmp_path: Path):
invalid_csv = tmp_path / "invalid.csv"
invalid_csv.write_text("not,a,valid,csv\n1,2,3")
md_file = tmp_path / "report.md"
with pytest.raises(Exception):
core.generate_report(str(invalid_csv), str(md_file))
def test_generate_report_empty_csv(tmp_path: Path):
empty_csv = tmp_path / "empty.csv"
pd.DataFrame(columns=["unknown_rate", "warn_rate", "policy_hash", "runs"]).to_csv(empty_csv, index=False)
out_md = tmp_path / "report.md"
core.generate_report(str(empty_csv), str(out_md))
content = out_md.read_text()
assert "No data" in content or len(content) > 0
def test_generate_report_statistics_are_correct(tmp_path: Path):
csv_file = _create_sample_csv(tmp_path)
out_md = tmp_path / "report.md"
core.generate_report(str(csv_file), str(out_md))
content = out_md.read_text()
# Check that numeric summaries exist in Markdown report
assert any(term in content for term in ["min", "p95", "max"])
assert "policy_hash" in content
assert "runs" in content