Add data_visualization/src/data_visualization/cli.py
This commit is contained in:
parent
2fb8c533c7
commit
b2333d4978
1 changed files with 70 additions and 0 deletions
70
data_visualization/src/data_visualization/cli.py
Normal file
70
data_visualization/src/data_visualization/cli.py
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from data_visualization import core
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_frequency_row(row: dict) -> bool:
|
||||||
|
"""Validates that a frequency data row contains valid float values."""
|
||||||
|
try:
|
||||||
|
float(row['frequency_hz'])
|
||||||
|
float(row['amplitude_db'])
|
||||||
|
return True
|
||||||
|
except (KeyError, ValueError, TypeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _read_frequency_data(csv_path: Path) -> List[float]:
|
||||||
|
"""Reads and validates frequency data from a CSV file."""
|
||||||
|
frequency_data = []
|
||||||
|
with csv_path.open('r', newline='', encoding='utf-8') as csvfile:
|
||||||
|
reader = csv.DictReader(csvfile)
|
||||||
|
required_fields = {'frequency_hz', 'amplitude_db'}
|
||||||
|
if not required_fields.issubset(reader.fieldnames or []):
|
||||||
|
raise ValueError(f"CSV file must contain fields: {required_fields}")
|
||||||
|
|
||||||
|
for row in reader:
|
||||||
|
if _validate_frequency_row(row):
|
||||||
|
frequency_data.append((float(row['frequency_hz']), float(row['amplitude_db'])))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid row in CSV: {row}")
|
||||||
|
|
||||||
|
assert len(frequency_data) > 0, "No valid frequency data found."
|
||||||
|
return frequency_data
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""CLI entry point for generating spectrum visualization from frequency CSV data."""
|
||||||
|
parser = argparse.ArgumentParser(description="Visualize frequency spectrum from urban acoustic reflection data.")
|
||||||
|
parser.add_argument("--input", required=True, help="Path to the CSV file containing frequency data.")
|
||||||
|
parser.add_argument("--output", required=False, default="output/spectrum_plot.png", help="Path to save the spectrum PNG plot.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
input_path = Path(args.input).resolve()
|
||||||
|
output_path = Path(args.output).resolve()
|
||||||
|
|
||||||
|
if not input_path.exists() or not input_path.is_file():
|
||||||
|
raise FileNotFoundError(f"Input file not found: {input_path}")
|
||||||
|
|
||||||
|
frequency_data = _read_frequency_data(input_path)
|
||||||
|
|
||||||
|
# Extract just amplitude list for plotting function if it expects list[float]
|
||||||
|
amplitudes = [amp for _, amp in frequency_data]
|
||||||
|
fig = core.plot_spectrum(amplitudes)
|
||||||
|
|
||||||
|
output_dir = output_path.parent
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
fig.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
print(f"Spectrum plot saved to: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in a new issue