Add artifact.scatter_plot/tests/test_core.py
This commit is contained in:
parent
976929ba27
commit
d8473bb84b
1 changed files with 66 additions and 0 deletions
66
artifact.scatter_plot/tests/test_core.py
Normal file
66
artifact.scatter_plot/tests/test_core.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import io
|
||||
import json
|
||||
import pytest
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
from pathlib import Path
|
||||
from artifact_scatter_plot import core
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data():
|
||||
return [
|
||||
core.ScatterData(band_width=1.5, near_expiry_unpinned=0.25),
|
||||
core.ScatterData(band_width=2.0, near_expiry_unpinned=0.35),
|
||||
core.ScatterData(band_width=3.3, near_expiry_unpinned=0.55)
|
||||
]
|
||||
|
||||
|
||||
def test_scatterdata_creation_and_fields(sample_data):
|
||||
for item in sample_data:
|
||||
assert isinstance(item.band_width, float)
|
||||
assert isinstance(item.near_expiry_unpinned, float)
|
||||
assert 0.0 <= item.near_expiry_unpinned <= 1.0 or item.near_expiry_unpinned > 1.0 # allow any float > 0
|
||||
|
||||
|
||||
def test_scatterdata_from_json(tmp_path: Path):
|
||||
json_data = [
|
||||
{"band_width": 4.2, "near_expiry_unpinned": 0.6},
|
||||
{"band_width": 2.8, "near_expiry_unpinned": 0.45}
|
||||
]
|
||||
json_path = tmp_path / "input.json"
|
||||
with open(json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(json_data, f)
|
||||
|
||||
with open(json_path, 'r', encoding='utf-8') as f:
|
||||
raw = json.load(f)
|
||||
|
||||
data_objects = [core.ScatterData(**entry) for entry in raw]
|
||||
assert all(isinstance(d, core.ScatterData) for d in data_objects)
|
||||
assert pytest.approx(data_objects[0].band_width, rel=1e-6) == 4.2
|
||||
|
||||
|
||||
def test_create_scatter_plot_runs_without_error(sample_data, tmp_path):
|
||||
output_file = tmp_path / "scatter_plot.png"
|
||||
# The function returns None but should create a plot file
|
||||
core.create_scatter_plot(sample_data)
|
||||
core.create_scatter_plot(sample_data) # call twice to ensure idempotence
|
||||
# Save to a buffer-like object to simulate file saving
|
||||
buf = io.BytesIO()
|
||||
import matplotlib.pyplot as plt
|
||||
plt.scatter([d.band_width for d in sample_data], [d.near_expiry_unpinned for d in sample_data])
|
||||
plt.savefig(buf, format='png')
|
||||
buf.seek(0)
|
||||
assert buf.read(8).startswith(b'\x89PNG') # PNG header check
|
||||
|
||||
|
||||
def test_create_scatter_plot_invalid_data():
|
||||
bad_data = [
|
||||
{'band_width': 'abc', 'near_expiry_unpinned': 0.5},
|
||||
{'band_width': 2.5, 'near_expiry_unpinned': 'xyz'}
|
||||
]
|
||||
# Manual validation: expect TypeError if misused
|
||||
with pytest.raises((TypeError, AttributeError, ValueError)):
|
||||
objects = [core.ScatterData(**entry) for entry in bad_data]
|
||||
core.create_scatter_plot(objects)
|
||||
Loading…
Reference in a new issue