Add report_generator/tests/test_core.py
This commit is contained in:
parent
62fa3e2966
commit
9624413464
1 changed files with 85 additions and 0 deletions
85
report_generator/tests/test_core.py
Normal file
85
report_generator/tests/test_core.py
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Da der Modulname im api_contract 'core' ist und das Paket report_generator heißt
|
||||||
|
from report_generator import core
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_output(monkeypatch, tmp_path):
|
||||||
|
output_path = tmp_path / 'drift_report.json'
|
||||||
|
# monkeypatch output path if used inside core
|
||||||
|
monkeypatch.setattr(core.Path, 'cwd', lambda: tmp_path)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_report(path: str) -> dict:
|
||||||
|
with open(path, 'r', encoding='utf-8') as fh:
|
||||||
|
return json.load(fh)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"warn_count,total_runs,threshold,expected_alert",
|
||||||
|
[
|
||||||
|
(1, 10, 0.3, False), # warn_rate < threshold
|
||||||
|
(3, 10, 0.3, False), # warn_rate == threshold (0.3)
|
||||||
|
(4, 10, 0.3, True), # warn_rate > threshold
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_generate_report_threshold_logic(tmp_path, warn_count, total_runs, threshold, expected_alert):
|
||||||
|
# invoke function
|
||||||
|
report_path = core.generate_report(warn_count, total_runs, threshold)
|
||||||
|
|
||||||
|
# output verification
|
||||||
|
assert isinstance(report_path, str), "generate_report should return path as string"
|
||||||
|
assert os.path.exists(report_path), f"Report file {report_path} does not exist"
|
||||||
|
|
||||||
|
# check JSON structure
|
||||||
|
report = load_report(report_path)
|
||||||
|
|
||||||
|
# Required fields in DriftReport according to data_models
|
||||||
|
for field in ["warn_rate", "threshold", "alert", "timestamp"]:
|
||||||
|
assert field in report, f"Missing field: {field}"
|
||||||
|
|
||||||
|
# semantic checks
|
||||||
|
expected_rate = pytest.approx(warn_count / total_runs)
|
||||||
|
assert report["warn_rate"] == expected_rate
|
||||||
|
assert float(report["threshold"]) == threshold
|
||||||
|
assert bool(report["alert"]) == expected_alert
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_report_invalid_input_types():
|
||||||
|
with pytest.raises((AssertionError, ValueError, TypeError)):
|
||||||
|
core.generate_report('warns', 10, 0.3)
|
||||||
|
|
||||||
|
with pytest.raises((AssertionError, ValueError, TypeError)):
|
||||||
|
core.generate_report(5, 'total', 0.3)
|
||||||
|
|
||||||
|
with pytest.raises((AssertionError, ValueError, TypeError)):
|
||||||
|
core.generate_report(2, 5, 'threshold')
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_report_zero_total():
|
||||||
|
with pytest.raises((ZeroDivisionError, ValueError, AssertionError)):
|
||||||
|
core.generate_report(1, 0, 0.3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_report_file_content_consistency(tmp_path):
|
||||||
|
warn_count, total_runs, threshold = 2, 10, 0.2
|
||||||
|
path = core.generate_report(warn_count, total_runs, threshold)
|
||||||
|
data = load_report(path)
|
||||||
|
|
||||||
|
# check timestamp format ISO-like
|
||||||
|
assert isinstance(data["timestamp"], str)
|
||||||
|
assert 'T' in data["timestamp"], "timestamp should be in ISO format"
|
||||||
|
|
||||||
|
# check global consistency
|
||||||
|
assert pytest.approx(data["warn_rate"], rel=1e-6) == warn_count / total_runs
|
||||||
|
assert data["threshold"] == threshold
|
||||||
|
assert isinstance(data["alert"], bool)
|
||||||
|
|
||||||
|
# Cleanup for CI safety
|
||||||
|
os.remove(path)
|
||||||
|
assert not os.path.exists(path)
|
||||||
Loading…
Reference in a new issue