Add rollout_report_generator/tests/test_core.py
This commit is contained in:
parent
20205707ee
commit
623a8b465f
1 changed files with 68 additions and 0 deletions
68
rollout_report_generator/tests/test_core.py
Normal file
68
rollout_report_generator/tests/test_core.py
Normal file
|
|
@ -0,0 +1,68 @@
|
||||||
|
import io
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.rollout_report_generator import core
|
||||||
|
|
||||||
|
def _create_sample_csv(tmp_path: Path) -> Path:
|
||||||
|
data = pd.DataFrame(
|
||||||
|
[
|
||||||
|
{"unknown_rate": 0.02, "warn_rate": 0.03, "policy_hash": "abc123", "runs": 10},
|
||||||
|
{"unknown_rate": 0.05, "warn_rate": 0.08, "policy_hash": "def456", "runs": 12},
|
||||||
|
{"unknown_rate": 0.01, "warn_rate": 0.02, "policy_hash": "ghi789", "runs": 9},
|
||||||
|
{"unknown_rate": 0.10, "warn_rate": 0.15, "policy_hash": "xyz000", "runs": 7},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
csv_path = tmp_path / "rollout_series.csv"
|
||||||
|
data.to_csv(csv_path, index=False)
|
||||||
|
return csv_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_report_creates_markdown_file(tmp_path: Path):
|
||||||
|
csv_file = _create_sample_csv(tmp_path)
|
||||||
|
md_file = tmp_path / "report.md"
|
||||||
|
|
||||||
|
core.generate_report(str(csv_file), str(md_file))
|
||||||
|
|
||||||
|
assert md_file.exists(), "Markdown file should be created."
|
||||||
|
content = md_file.read_text().strip()
|
||||||
|
|
||||||
|
assert len(content) > 0, "Markdown content should not be empty."
|
||||||
|
assert "unknown_rate" in content
|
||||||
|
assert "warn_rate" in content
|
||||||
|
assert "Median" in content or "median" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_report_with_invalid_csv_raises(tmp_path: Path):
|
||||||
|
invalid_csv = tmp_path / "invalid.csv"
|
||||||
|
invalid_csv.write_text("not,a,valid,csv\n1,2,3")
|
||||||
|
md_file = tmp_path / "report.md"
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
core.generate_report(str(invalid_csv), str(md_file))
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_report_empty_csv(tmp_path: Path):
|
||||||
|
empty_csv = tmp_path / "empty.csv"
|
||||||
|
pd.DataFrame(columns=["unknown_rate", "warn_rate", "policy_hash", "runs"]).to_csv(empty_csv, index=False)
|
||||||
|
out_md = tmp_path / "report.md"
|
||||||
|
|
||||||
|
core.generate_report(str(empty_csv), str(out_md))
|
||||||
|
|
||||||
|
content = out_md.read_text()
|
||||||
|
assert "No data" in content or len(content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_report_statistics_are_correct(tmp_path: Path):
|
||||||
|
csv_file = _create_sample_csv(tmp_path)
|
||||||
|
out_md = tmp_path / "report.md"
|
||||||
|
|
||||||
|
core.generate_report(str(csv_file), str(out_md))
|
||||||
|
content = out_md.read_text()
|
||||||
|
|
||||||
|
# Check that numeric summaries exist in Markdown report
|
||||||
|
assert any(term in content for term in ["min", "p95", "max"])
|
||||||
|
assert "policy_hash" in content
|
||||||
|
assert "runs" in content
|
||||||
Loading…
Reference in a new issue