Add policy_grid_evaluation/tests/test_core.py

This commit is contained in:
Mika 2026-02-17 16:16:29 +00:00
parent 554513373b
commit 27af6e6c74

View file

@ -0,0 +1,83 @@
import os
import tempfile
import csv
import pytest
from typing import List
import pandas as pd
from src.policy_grid_evaluation import core
def _validate_gridresult(obj):
assert isinstance(obj.policy, str)
assert isinstance(obj.unknown_conversion, float)
assert isinstance(obj.real_missing_cases, int)
assert isinstance(obj.additional_wait_time, float)
def test_evaluate_grid_returns_list_of_gridresults():
grace_values = [1, 2]
delay_values = [10, 20]
policies = ["baseline", "combined"]
results = core.evaluate_grid(grace_values, delay_values, policies)
assert isinstance(results, list)
assert results, "Results list should not be empty"
for res in results:
_validate_gridresult(res)
all_policies = set(r.policy for r in results)
expected_policies = set(policies)
assert all_policies == expected_policies or all_policies.issubset(expected_policies)
def test_save_results_to_csv_creates_valid_csv(tmp_path):
# Prepare sample results
sample_results: List[core.GridResult] = [
core.GridResult(policy="combined", unknown_conversion=0.95, real_missing_cases=4, additional_wait_time=0.6),
core.GridResult(policy="baseline", unknown_conversion=0.88, real_missing_cases=6, additional_wait_time=1.1),
]
csv_file = tmp_path / "grid_output.csv"
core.save_results_to_csv(sample_results, str(csv_file))
assert csv_file.exists(), "CSV file should be created"
df = pd.read_csv(csv_file)
expected_columns = [
"policy",
"unknown_conversion",
"real_missing_cases",
"additional_wait_time",
]
for col in expected_columns:
assert col in df.columns
assert len(df) == 2
assert set(df["policy"]) == {"combined", "baseline"}
def test_save_and_evaluate_integration(tmp_path):
# Simulated integration to ensure both functions cooperate
results = core.evaluate_grid([1], [5], ["baseline"])
out_file = tmp_path / "out.csv"
core.save_results_to_csv(results, str(out_file))
assert out_file.exists()
with open(out_file, newline="") as f:
reader = csv.DictReader(f)
rows = list(reader)
assert len(rows) == len(results)
for row in rows:
assert set(row.keys()) == {"policy", "unknown_conversion", "real_missing_cases", "additional_wait_time"}
def test_evaluate_grid_input_validation():
with pytest.raises((TypeError, ValueError)):
core.evaluate_grid("123", [1], ["baseline"]) # invalid grace_values
with pytest.raises((TypeError, ValueError)):
core.evaluate_grid([1], None, ["baseline"]) # invalid delay_values
with pytest.raises((TypeError, ValueError)):
core.evaluate_grid([1], [1], None) # invalid policies