Add python_sampling_tool/src/python_sampling_tool/main.py

This commit is contained in:
Mika 2025-12-09 14:56:49 +00:00
commit b1f6ef7e55

View file

@ -0,0 +1,75 @@
import json
import random
import argparse
import os
from pathlib import Path
from typing import List, Dict, Any
def perform_stratified_sampling(sample_size: int, num_jobs: int) -> List[Dict[str, Any]]:
"""Generate stratified samples for CI jobs.
Each job receives an approximately equal share of the sample size.
Returns list of Sample-like dicts with job_id and data_points.
"""
if num_jobs <= 0:
raise ValueError("num_jobs must be positive")
if sample_size <= 0:
raise ValueError("sample_size must be positive")
base = sample_size // num_jobs
remainder = sample_size % num_jobs
samples = []
for i in range(num_jobs):
data_points = base + (1 if i < remainder else 0)
sample = {
"job_id": f"job_{i+1}",
"data_points": data_points
}
samples.append(sample)
random.shuffle(samples)
return samples
def _load_config(path: Path) -> Dict[str, Any]:
with path.open("r", encoding="utf-8") as f:
cfg = json.load(f)
if not isinstance(cfg, dict):
raise ValueError("Config JSON must be an object")
return cfg
def _save_output(samples: List[Dict[str, Any]], path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(samples, f, indent=2)
def main():
parser = argparse.ArgumentParser(description="Stratified sampling tool for Mini CI experiments.")
parser.add_argument("--config", required=True, help="Path to sampling config JSON")
parser.add_argument("--output", required=False, help="Output JSON file path")
args = parser.parse_args()
config_path = Path(args.config)
cfg = _load_config(config_path)
try:
sample_size = int(cfg.get("sample_size", 0))
num_jobs = int(cfg.get("num_jobs", 0))
except (TypeError, ValueError):
raise ValueError("Invalid sample_size or num_jobs in config")
samples = perform_stratified_sampling(sample_size, num_jobs)
if args.output:
output_path = Path(args.output)
else:
output_path = Path("output/stratified_samples.json")
_save_output(samples, output_path)
print(f"Generated {len(samples)} stratified samples -> {output_path}")
if __name__ == "__main__":
main()