Add python_sampling_tool/src/python_sampling_tool/main.py
This commit is contained in:
commit
b1f6ef7e55
1 changed files with 75 additions and 0 deletions
75
python_sampling_tool/src/python_sampling_tool/main.py
Normal file
75
python_sampling_tool/src/python_sampling_tool/main.py
Normal file
|
|
@ -0,0 +1,75 @@
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
def perform_stratified_sampling(sample_size: int, num_jobs: int) -> List[Dict[str, Any]]:
|
||||||
|
"""Generate stratified samples for CI jobs.
|
||||||
|
|
||||||
|
Each job receives an approximately equal share of the sample size.
|
||||||
|
Returns list of Sample-like dicts with job_id and data_points.
|
||||||
|
"""
|
||||||
|
if num_jobs <= 0:
|
||||||
|
raise ValueError("num_jobs must be positive")
|
||||||
|
if sample_size <= 0:
|
||||||
|
raise ValueError("sample_size must be positive")
|
||||||
|
|
||||||
|
base = sample_size // num_jobs
|
||||||
|
remainder = sample_size % num_jobs
|
||||||
|
samples = []
|
||||||
|
for i in range(num_jobs):
|
||||||
|
data_points = base + (1 if i < remainder else 0)
|
||||||
|
sample = {
|
||||||
|
"job_id": f"job_{i+1}",
|
||||||
|
"data_points": data_points
|
||||||
|
}
|
||||||
|
samples.append(sample)
|
||||||
|
random.shuffle(samples)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def _load_config(path: Path) -> Dict[str, Any]:
|
||||||
|
with path.open("r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
if not isinstance(cfg, dict):
|
||||||
|
raise ValueError("Config JSON must be an object")
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def _save_output(samples: List[Dict[str, Any]], path: Path) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with path.open("w", encoding="utf-8") as f:
|
||||||
|
json.dump(samples, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Stratified sampling tool for Mini CI experiments.")
|
||||||
|
parser.add_argument("--config", required=True, help="Path to sampling config JSON")
|
||||||
|
parser.add_argument("--output", required=False, help="Output JSON file path")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config_path = Path(args.config)
|
||||||
|
cfg = _load_config(config_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
sample_size = int(cfg.get("sample_size", 0))
|
||||||
|
num_jobs = int(cfg.get("num_jobs", 0))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise ValueError("Invalid sample_size or num_jobs in config")
|
||||||
|
|
||||||
|
samples = perform_stratified_sampling(sample_size, num_jobs)
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
output_path = Path(args.output)
|
||||||
|
else:
|
||||||
|
output_path = Path("output/stratified_samples.json")
|
||||||
|
_save_output(samples, output_path)
|
||||||
|
|
||||||
|
print(f"Generated {len(samples)} stratified samples -> {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in a new issue