Add policy_grid_evaluation/tests/test_core.py
This commit is contained in:
parent
554513373b
commit
27af6e6c74
1 changed files with 83 additions and 0 deletions
83
policy_grid_evaluation/tests/test_core.py
Normal file
83
policy_grid_evaluation/tests/test_core.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import os
|
||||
import tempfile
|
||||
import csv
|
||||
import pytest
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from src.policy_grid_evaluation import core
|
||||
|
||||
|
||||
def _validate_gridresult(obj):
|
||||
assert isinstance(obj.policy, str)
|
||||
assert isinstance(obj.unknown_conversion, float)
|
||||
assert isinstance(obj.real_missing_cases, int)
|
||||
assert isinstance(obj.additional_wait_time, float)
|
||||
|
||||
|
||||
def test_evaluate_grid_returns_list_of_gridresults():
|
||||
grace_values = [1, 2]
|
||||
delay_values = [10, 20]
|
||||
policies = ["baseline", "combined"]
|
||||
|
||||
results = core.evaluate_grid(grace_values, delay_values, policies)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert results, "Results list should not be empty"
|
||||
for res in results:
|
||||
_validate_gridresult(res)
|
||||
|
||||
all_policies = set(r.policy for r in results)
|
||||
expected_policies = set(policies)
|
||||
assert all_policies == expected_policies or all_policies.issubset(expected_policies)
|
||||
|
||||
|
||||
def test_save_results_to_csv_creates_valid_csv(tmp_path):
|
||||
# Prepare sample results
|
||||
sample_results: List[core.GridResult] = [
|
||||
core.GridResult(policy="combined", unknown_conversion=0.95, real_missing_cases=4, additional_wait_time=0.6),
|
||||
core.GridResult(policy="baseline", unknown_conversion=0.88, real_missing_cases=6, additional_wait_time=1.1),
|
||||
]
|
||||
|
||||
csv_file = tmp_path / "grid_output.csv"
|
||||
core.save_results_to_csv(sample_results, str(csv_file))
|
||||
|
||||
assert csv_file.exists(), "CSV file should be created"
|
||||
df = pd.read_csv(csv_file)
|
||||
expected_columns = [
|
||||
"policy",
|
||||
"unknown_conversion",
|
||||
"real_missing_cases",
|
||||
"additional_wait_time",
|
||||
]
|
||||
for col in expected_columns:
|
||||
assert col in df.columns
|
||||
|
||||
assert len(df) == 2
|
||||
assert set(df["policy"]) == {"combined", "baseline"}
|
||||
|
||||
|
||||
def test_save_and_evaluate_integration(tmp_path):
|
||||
# Simulated integration to ensure both functions cooperate
|
||||
results = core.evaluate_grid([1], [5], ["baseline"])
|
||||
out_file = tmp_path / "out.csv"
|
||||
core.save_results_to_csv(results, str(out_file))
|
||||
|
||||
assert out_file.exists()
|
||||
with open(out_file, newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
rows = list(reader)
|
||||
|
||||
assert len(rows) == len(results)
|
||||
for row in rows:
|
||||
assert set(row.keys()) == {"policy", "unknown_conversion", "real_missing_cases", "additional_wait_time"}
|
||||
|
||||
|
||||
def test_evaluate_grid_input_validation():
|
||||
with pytest.raises((TypeError, ValueError)):
|
||||
core.evaluate_grid("123", [1], ["baseline"]) # invalid grace_values
|
||||
with pytest.raises((TypeError, ValueError)):
|
||||
core.evaluate_grid([1], None, ["baseline"]) # invalid delay_values
|
||||
with pytest.raises((TypeError, ValueError)):
|
||||
core.evaluate_grid([1], [1], None) # invalid policies
|
||||
Loading…
Reference in a new issue