From b2333d49782b604ef4ffda8bd11454e0455b05ff Mon Sep 17 00:00:00 2001 From: Mika Date: Sun, 19 Apr 2026 02:07:47 +0000 Subject: [PATCH] Add data_visualization/src/data_visualization/cli.py --- .../src/data_visualization/cli.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 data_visualization/src/data_visualization/cli.py diff --git a/data_visualization/src/data_visualization/cli.py b/data_visualization/src/data_visualization/cli.py new file mode 100644 index 0000000..40b8d72 --- /dev/null +++ b/data_visualization/src/data_visualization/cli.py @@ -0,0 +1,70 @@ +import argparse +import csv +import os +from pathlib import Path +from typing import List +import numpy as np +import matplotlib.pyplot as plt + +from data_visualization import core + + +def _validate_frequency_row(row: dict) -> bool: + """Validates that a frequency data row contains valid float values.""" + try: + float(row['frequency_hz']) + float(row['amplitude_db']) + return True + except (KeyError, ValueError, TypeError): + return False + + +def _read_frequency_data(csv_path: Path) -> List[float]: + """Reads and validates frequency data from a CSV file.""" + frequency_data = [] + with csv_path.open('r', newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + required_fields = {'frequency_hz', 'amplitude_db'} + if not required_fields.issubset(reader.fieldnames or []): + raise ValueError(f"CSV file must contain fields: {required_fields}") + + for row in reader: + if _validate_frequency_row(row): + frequency_data.append((float(row['frequency_hz']), float(row['amplitude_db']))) + else: + raise ValueError(f"Invalid row in CSV: {row}") + + assert len(frequency_data) > 0, "No valid frequency data found." + return frequency_data + + +def main() -> None: + """CLI entry point for generating spectrum visualization from frequency CSV data.""" + parser = argparse.ArgumentParser(description="Visualize frequency spectrum from urban acoustic reflection data.") + parser.add_argument("--input", required=True, help="Path to the CSV file containing frequency data.") + parser.add_argument("--output", required=False, default="output/spectrum_plot.png", help="Path to save the spectrum PNG plot.") + + args = parser.parse_args() + + input_path = Path(args.input).resolve() + output_path = Path(args.output).resolve() + + if not input_path.exists() or not input_path.is_file(): + raise FileNotFoundError(f"Input file not found: {input_path}") + + frequency_data = _read_frequency_data(input_path) + + # Extract just amplitude list for plotting function if it expects list[float] + amplitudes = [amp for _, amp in frequency_data] + fig = core.plot_spectrum(amplitudes) + + output_dir = output_path.parent + os.makedirs(output_dir, exist_ok=True) + fig.savefig(output_path, dpi=300, bbox_inches='tight') + plt.close(fig) + + print(f"Spectrum plot saved to: {output_path}") + + +if __name__ == "__main__": + main() \ No newline at end of file