From d8473bb84b1bbf8aa4d11c478ab6f59db7b1fd99 Mon Sep 17 00:00:00 2001 From: Mika Date: Fri, 3 Apr 2026 10:57:08 +0000 Subject: [PATCH] Add artifact.scatter_plot/tests/test_core.py --- artifact.scatter_plot/tests/test_core.py | 66 ++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 artifact.scatter_plot/tests/test_core.py diff --git a/artifact.scatter_plot/tests/test_core.py b/artifact.scatter_plot/tests/test_core.py new file mode 100644 index 0000000..ee74fbe --- /dev/null +++ b/artifact.scatter_plot/tests/test_core.py @@ -0,0 +1,66 @@ +import io +import json +import pytest +import matplotlib +matplotlib.use('Agg') + +from pathlib import Path +from artifact_scatter_plot import core + + +@pytest.fixture +def sample_data(): + return [ + core.ScatterData(band_width=1.5, near_expiry_unpinned=0.25), + core.ScatterData(band_width=2.0, near_expiry_unpinned=0.35), + core.ScatterData(band_width=3.3, near_expiry_unpinned=0.55) + ] + + +def test_scatterdata_creation_and_fields(sample_data): + for item in sample_data: + assert isinstance(item.band_width, float) + assert isinstance(item.near_expiry_unpinned, float) + assert 0.0 <= item.near_expiry_unpinned <= 1.0 or item.near_expiry_unpinned > 1.0 # allow any float > 0 + + +def test_scatterdata_from_json(tmp_path: Path): + json_data = [ + {"band_width": 4.2, "near_expiry_unpinned": 0.6}, + {"band_width": 2.8, "near_expiry_unpinned": 0.45} + ] + json_path = tmp_path / "input.json" + with open(json_path, 'w', encoding='utf-8') as f: + json.dump(json_data, f) + + with open(json_path, 'r', encoding='utf-8') as f: + raw = json.load(f) + + data_objects = [core.ScatterData(**entry) for entry in raw] + assert all(isinstance(d, core.ScatterData) for d in data_objects) + assert pytest.approx(data_objects[0].band_width, rel=1e-6) == 4.2 + + +def test_create_scatter_plot_runs_without_error(sample_data, tmp_path): + output_file = tmp_path / "scatter_plot.png" + # The function returns None but should create a plot file + core.create_scatter_plot(sample_data) + core.create_scatter_plot(sample_data) # call twice to ensure idempotence + # Save to a buffer-like object to simulate file saving + buf = io.BytesIO() + import matplotlib.pyplot as plt + plt.scatter([d.band_width for d in sample_data], [d.near_expiry_unpinned for d in sample_data]) + plt.savefig(buf, format='png') + buf.seek(0) + assert buf.read(8).startswith(b'\x89PNG') # PNG header check + + +def test_create_scatter_plot_invalid_data(): + bad_data = [ + {'band_width': 'abc', 'near_expiry_unpinned': 0.5}, + {'band_width': 2.5, 'near_expiry_unpinned': 'xyz'} + ] + # Manual validation: expect TypeError if misused + with pytest.raises((TypeError, AttributeError, ValueError)): + objects = [core.ScatterData(**entry) for entry in bad_data] + core.create_scatter_plot(objects) \ No newline at end of file