Add data_visualization/src/data_visualization/cli.py

This commit is contained in:
Mika 2026-04-19 02:07:47 +00:00
parent 2fb8c533c7
commit b2333d4978

View file

@ -0,0 +1,70 @@
import argparse
import csv
import os
from pathlib import Path
from typing import List
import numpy as np
import matplotlib.pyplot as plt
from data_visualization import core
def _validate_frequency_row(row: dict) -> bool:
"""Validates that a frequency data row contains valid float values."""
try:
float(row['frequency_hz'])
float(row['amplitude_db'])
return True
except (KeyError, ValueError, TypeError):
return False
def _read_frequency_data(csv_path: Path) -> List[float]:
"""Reads and validates frequency data from a CSV file."""
frequency_data = []
with csv_path.open('r', newline='', encoding='utf-8') as csvfile:
reader = csv.DictReader(csvfile)
required_fields = {'frequency_hz', 'amplitude_db'}
if not required_fields.issubset(reader.fieldnames or []):
raise ValueError(f"CSV file must contain fields: {required_fields}")
for row in reader:
if _validate_frequency_row(row):
frequency_data.append((float(row['frequency_hz']), float(row['amplitude_db'])))
else:
raise ValueError(f"Invalid row in CSV: {row}")
assert len(frequency_data) > 0, "No valid frequency data found."
return frequency_data
def main() -> None:
"""CLI entry point for generating spectrum visualization from frequency CSV data."""
parser = argparse.ArgumentParser(description="Visualize frequency spectrum from urban acoustic reflection data.")
parser.add_argument("--input", required=True, help="Path to the CSV file containing frequency data.")
parser.add_argument("--output", required=False, default="output/spectrum_plot.png", help="Path to save the spectrum PNG plot.")
args = parser.parse_args()
input_path = Path(args.input).resolve()
output_path = Path(args.output).resolve()
if not input_path.exists() or not input_path.is_file():
raise FileNotFoundError(f"Input file not found: {input_path}")
frequency_data = _read_frequency_data(input_path)
# Extract just amplitude list for plotting function if it expects list[float]
amplitudes = [amp for _, amp in frequency_data]
fig = core.plot_spectrum(amplitudes)
output_dir = output_path.parent
os.makedirs(output_dir, exist_ok=True)
fig.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close(fig)
print(f"Spectrum plot saved to: {output_path}")
if __name__ == "__main__":
main()