Add artifact.scatter_plot/tests/test_core.py

This commit is contained in:
Mika 2026-04-03 10:57:08 +00:00
parent 976929ba27
commit d8473bb84b

View file

@ -0,0 +1,66 @@
import io
import json
import pytest
import matplotlib
matplotlib.use('Agg')
from pathlib import Path
from artifact_scatter_plot import core
@pytest.fixture
def sample_data():
return [
core.ScatterData(band_width=1.5, near_expiry_unpinned=0.25),
core.ScatterData(band_width=2.0, near_expiry_unpinned=0.35),
core.ScatterData(band_width=3.3, near_expiry_unpinned=0.55)
]
def test_scatterdata_creation_and_fields(sample_data):
for item in sample_data:
assert isinstance(item.band_width, float)
assert isinstance(item.near_expiry_unpinned, float)
assert 0.0 <= item.near_expiry_unpinned <= 1.0 or item.near_expiry_unpinned > 1.0 # allow any float > 0
def test_scatterdata_from_json(tmp_path: Path):
json_data = [
{"band_width": 4.2, "near_expiry_unpinned": 0.6},
{"band_width": 2.8, "near_expiry_unpinned": 0.45}
]
json_path = tmp_path / "input.json"
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f)
with open(json_path, 'r', encoding='utf-8') as f:
raw = json.load(f)
data_objects = [core.ScatterData(**entry) for entry in raw]
assert all(isinstance(d, core.ScatterData) for d in data_objects)
assert pytest.approx(data_objects[0].band_width, rel=1e-6) == 4.2
def test_create_scatter_plot_runs_without_error(sample_data, tmp_path):
output_file = tmp_path / "scatter_plot.png"
# The function returns None but should create a plot file
core.create_scatter_plot(sample_data)
core.create_scatter_plot(sample_data) # call twice to ensure idempotence
# Save to a buffer-like object to simulate file saving
buf = io.BytesIO()
import matplotlib.pyplot as plt
plt.scatter([d.band_width for d in sample_data], [d.near_expiry_unpinned for d in sample_data])
plt.savefig(buf, format='png')
buf.seek(0)
assert buf.read(8).startswith(b'\x89PNG') # PNG header check
def test_create_scatter_plot_invalid_data():
bad_data = [
{'band_width': 'abc', 'near_expiry_unpinned': 0.5},
{'band_width': 2.5, 'near_expiry_unpinned': 'xyz'}
]
# Manual validation: expect TypeError if misused
with pytest.raises((TypeError, AttributeError, ValueError)):
objects = [core.ScatterData(**entry) for entry in bad_data]
core.create_scatter_plot(objects)