Add report_generator/tests/test_core.py
This commit is contained in:
parent
62fa3e2966
commit
9624413464
1 changed files with 85 additions and 0 deletions
85
report_generator/tests/test_core.py
Normal file
85
report_generator/tests/test_core.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
# Da der Modulname im api_contract 'core' ist und das Paket report_generator heißt
|
||||
from report_generator import core
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_output(monkeypatch, tmp_path):
|
||||
output_path = tmp_path / 'drift_report.json'
|
||||
# monkeypatch output path if used inside core
|
||||
monkeypatch.setattr(core.Path, 'cwd', lambda: tmp_path)
|
||||
return output_path
|
||||
|
||||
|
||||
def load_report(path: str) -> dict:
|
||||
with open(path, 'r', encoding='utf-8') as fh:
|
||||
return json.load(fh)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"warn_count,total_runs,threshold,expected_alert",
|
||||
[
|
||||
(1, 10, 0.3, False), # warn_rate < threshold
|
||||
(3, 10, 0.3, False), # warn_rate == threshold (0.3)
|
||||
(4, 10, 0.3, True), # warn_rate > threshold
|
||||
]
|
||||
)
|
||||
def test_generate_report_threshold_logic(tmp_path, warn_count, total_runs, threshold, expected_alert):
|
||||
# invoke function
|
||||
report_path = core.generate_report(warn_count, total_runs, threshold)
|
||||
|
||||
# output verification
|
||||
assert isinstance(report_path, str), "generate_report should return path as string"
|
||||
assert os.path.exists(report_path), f"Report file {report_path} does not exist"
|
||||
|
||||
# check JSON structure
|
||||
report = load_report(report_path)
|
||||
|
||||
# Required fields in DriftReport according to data_models
|
||||
for field in ["warn_rate", "threshold", "alert", "timestamp"]:
|
||||
assert field in report, f"Missing field: {field}"
|
||||
|
||||
# semantic checks
|
||||
expected_rate = pytest.approx(warn_count / total_runs)
|
||||
assert report["warn_rate"] == expected_rate
|
||||
assert float(report["threshold"]) == threshold
|
||||
assert bool(report["alert"]) == expected_alert
|
||||
|
||||
|
||||
def test_generate_report_invalid_input_types():
|
||||
with pytest.raises((AssertionError, ValueError, TypeError)):
|
||||
core.generate_report('warns', 10, 0.3)
|
||||
|
||||
with pytest.raises((AssertionError, ValueError, TypeError)):
|
||||
core.generate_report(5, 'total', 0.3)
|
||||
|
||||
with pytest.raises((AssertionError, ValueError, TypeError)):
|
||||
core.generate_report(2, 5, 'threshold')
|
||||
|
||||
|
||||
def test_generate_report_zero_total():
|
||||
with pytest.raises((ZeroDivisionError, ValueError, AssertionError)):
|
||||
core.generate_report(1, 0, 0.3)
|
||||
|
||||
|
||||
def test_generate_report_file_content_consistency(tmp_path):
|
||||
warn_count, total_runs, threshold = 2, 10, 0.2
|
||||
path = core.generate_report(warn_count, total_runs, threshold)
|
||||
data = load_report(path)
|
||||
|
||||
# check timestamp format ISO-like
|
||||
assert isinstance(data["timestamp"], str)
|
||||
assert 'T' in data["timestamp"], "timestamp should be in ISO format"
|
||||
|
||||
# check global consistency
|
||||
assert pytest.approx(data["warn_rate"], rel=1e-6) == warn_count / total_runs
|
||||
assert data["threshold"] == threshold
|
||||
assert isinstance(data["alert"], bool)
|
||||
|
||||
# Cleanup for CI safety
|
||||
os.remove(path)
|
||||
assert not os.path.exists(path)
|
||||
Loading…
Reference in a new issue