Add artifact.scatter_plot/tests/test_core.py
This commit is contained in:
parent
976929ba27
commit
d8473bb84b
1 changed files with 66 additions and 0 deletions
66
artifact.scatter_plot/tests/test_core.py
Normal file
66
artifact.scatter_plot/tests/test_core.py
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from artifact_scatter_plot import core
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_data():
|
||||||
|
return [
|
||||||
|
core.ScatterData(band_width=1.5, near_expiry_unpinned=0.25),
|
||||||
|
core.ScatterData(band_width=2.0, near_expiry_unpinned=0.35),
|
||||||
|
core.ScatterData(band_width=3.3, near_expiry_unpinned=0.55)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatterdata_creation_and_fields(sample_data):
|
||||||
|
for item in sample_data:
|
||||||
|
assert isinstance(item.band_width, float)
|
||||||
|
assert isinstance(item.near_expiry_unpinned, float)
|
||||||
|
assert 0.0 <= item.near_expiry_unpinned <= 1.0 or item.near_expiry_unpinned > 1.0 # allow any float > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_scatterdata_from_json(tmp_path: Path):
|
||||||
|
json_data = [
|
||||||
|
{"band_width": 4.2, "near_expiry_unpinned": 0.6},
|
||||||
|
{"band_width": 2.8, "near_expiry_unpinned": 0.45}
|
||||||
|
]
|
||||||
|
json_path = tmp_path / "input.json"
|
||||||
|
with open(json_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(json_data, f)
|
||||||
|
|
||||||
|
with open(json_path, 'r', encoding='utf-8') as f:
|
||||||
|
raw = json.load(f)
|
||||||
|
|
||||||
|
data_objects = [core.ScatterData(**entry) for entry in raw]
|
||||||
|
assert all(isinstance(d, core.ScatterData) for d in data_objects)
|
||||||
|
assert pytest.approx(data_objects[0].band_width, rel=1e-6) == 4.2
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_scatter_plot_runs_without_error(sample_data, tmp_path):
|
||||||
|
output_file = tmp_path / "scatter_plot.png"
|
||||||
|
# The function returns None but should create a plot file
|
||||||
|
core.create_scatter_plot(sample_data)
|
||||||
|
core.create_scatter_plot(sample_data) # call twice to ensure idempotence
|
||||||
|
# Save to a buffer-like object to simulate file saving
|
||||||
|
buf = io.BytesIO()
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
plt.scatter([d.band_width for d in sample_data], [d.near_expiry_unpinned for d in sample_data])
|
||||||
|
plt.savefig(buf, format='png')
|
||||||
|
buf.seek(0)
|
||||||
|
assert buf.read(8).startswith(b'\x89PNG') # PNG header check
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_scatter_plot_invalid_data():
|
||||||
|
bad_data = [
|
||||||
|
{'band_width': 'abc', 'near_expiry_unpinned': 0.5},
|
||||||
|
{'band_width': 2.5, 'near_expiry_unpinned': 'xyz'}
|
||||||
|
]
|
||||||
|
# Manual validation: expect TypeError if misused
|
||||||
|
with pytest.raises((TypeError, AttributeError, ValueError)):
|
||||||
|
objects = [core.ScatterData(**entry) for entry in bad_data]
|
||||||
|
core.create_scatter_plot(objects)
|
||||||
Loading…
Reference in a new issue