diff --git a/python_sampling_tool/tests/test_main.py b/python_sampling_tool/tests/test_main.py new file mode 100644 index 0000000..5af7c83 --- /dev/null +++ b/python_sampling_tool/tests/test_main.py @@ -0,0 +1,44 @@ +import pytest +from python_sampling_tool import main + + +def test_stratified_sampling_total_size(): + sample_size = 20 + num_jobs = 5 + samples = main.perform_stratified_sampling(sample_size, num_jobs) + assert isinstance(samples, list) + # Gesamte Anzahl der Datenpunkte sollte der Samplegröße entsprechen + total_points = sum(s['data_points'] for s in samples) + assert total_points == sample_size + # Jeder Eintrag sollte einen Job-Identifier enthalten + for s in samples: + assert 'job_id' in s + assert 'data_points' in s + assert isinstance(s['data_points'], int) + assert isinstance(s['job_id'], str) + + +def test_stratified_sampling_balanced_distribution(): + sample_size = 15 + num_jobs = 3 + samples = main.perform_stratified_sampling(sample_size, num_jobs) + # Prüfe, dass alle Jobs berücksichtigt sind + job_ids = [s['job_id'] for s in samples] + assert len(job_ids) == num_jobs + # Verteilung der Datenpunkte sollte einigermaßen gleichmäßig sein + counts = [s['data_points'] for s in samples] + avg = sum(counts) / num_jobs + for c in counts: + assert abs(c - avg) <= 1 + + +def test_stratified_sampling_small_sample(): + sample_size = 3 + num_jobs = 5 + samples = main.perform_stratified_sampling(sample_size, num_jobs) + # Summe sollte immer noch der Samplegröße entsprechen + total_points = sum(s['data_points'] for s in samples) + assert total_points == sample_size + # Nicht mehr Jobs als Samples erhalten Datenpunkte + nonzero = [s for s in samples if s['data_points'] > 0] + assert len(nonzero) <= sample_size \ No newline at end of file