From 0163a572cb0a68ffc665473c6e8b650b8c36d06b Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 18 Mar 2026 20:35:08 +0000 Subject: [PATCH] Expanded cli tests --- src/batdetect2/cli/inference.py | 2 + tests/conftest.py | 23 +++ tests/test_cli.py | 185 ----------------------- tests/test_cli/test_base.py | 18 +++ tests/test_cli/test_data.py | 60 ++++++++ tests/test_cli/test_detect.py | 119 +++++++++++++++ tests/test_cli/test_evaluate.py | 47 ++++++ tests/test_cli/test_predict.py | 257 ++++++++++++++++++++++++++++++++ tests/test_cli/test_train.py | 81 ++++++++++ 9 files changed, 607 insertions(+), 185 deletions(-) delete mode 100644 tests/test_cli.py create mode 100644 tests/test_cli/test_base.py create mode 100644 tests/test_cli/test_data.py create mode 100644 tests/test_cli/test_detect.py create mode 100644 tests/test_cli/test_evaluate.py create mode 100644 tests/test_cli/test_predict.py create mode 100644 tests/test_cli/test_train.py diff --git a/src/batdetect2/cli/inference.py b/src/batdetect2/cli/inference.py index 70abe42..edc268f 100644 --- a/src/batdetect2/cli/inference.py +++ b/src/batdetect2/cli/inference.py @@ -164,6 +164,7 @@ def inference_file_list_command( num_workers: int, format_name: str | None, ) -> None: + file_list = Path(file_list) audio_files = [ Path(line.strip()) for line in file_list.read_text().splitlines() @@ -207,6 +208,7 @@ def inference_dataset_command( num_workers: int, format_name: str | None, ) -> None: + dataset_path = Path(dataset_path) dataset = io.load(dataset_path, type="annotation_set") audio_files = sorted( { diff --git a/tests/conftest.py b/tests/conftest.py index 367992a..f3d864d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Callable, List, Optional from uuid import uuid4 +import lightning as L import numpy as np import pytest import soundfile as sf @@ -12,6 +13,7 @@ from soundevent import data, terms from batdetect2.audio import build_audio_loader from batdetect2.audio.clips import build_clipper from batdetect2.audio.types import AudioLoader, ClipperProtocol +from batdetect2.config import BatDetect2Config from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.preprocess import build_preprocessor @@ -24,6 +26,7 @@ from batdetect2.targets import ( from batdetect2.targets.classes import TargetClassConfig from batdetect2.targets.types import TargetProtocol from batdetect2.train.labels import build_clip_labeler +from batdetect2.train.lightning import build_training_module from batdetect2.train.types import ClipLabeller @@ -452,3 +455,23 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]: return temp_file return factory + + +@pytest.fixture +def tiny_checkpoint_path(tmp_path: Path) -> Path: + module = build_training_module(model_config=BatDetect2Config().model) + trainer = L.Trainer(enable_checkpointing=False, logger=False) + checkpoint_path = tmp_path / "model.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(checkpoint_path) + return checkpoint_path + + +@pytest.fixture +def single_audio_dir(tmp_path: Path, example_audio_files: List[Path]) -> Path: + audio_dir = tmp_path / "audio" + audio_dir.mkdir() + source = example_audio_files[0] + target = audio_dir / source.name + target.write_bytes(source.read_bytes()) + return audio_dir diff --git a/tests/test_cli.py b/tests/test_cli.py deleted file mode 100644 index e4e4d99..0000000 --- a/tests/test_cli.py +++ /dev/null @@ -1,185 +0,0 @@ -"""Test the command line interface.""" - -import shutil -from pathlib import Path - -import lightning as L -import pandas as pd -from click.testing import CliRunner - -from batdetect2.cli import cli -from batdetect2.config import BatDetect2Config -from batdetect2.train.lightning import build_training_module - -runner = CliRunner() - - -def test_cli_base_command(): - """Test the base command.""" - result = runner.invoke(cli, ["--help"]) - assert result.exit_code == 0 - assert ( - "BatDetect2 - Bat Call Detection and Classification" in result.output - ) - - -def test_cli_detect_command_help(): - """Test the detect command help.""" - result = runner.invoke(cli, ["detect", "--help"]) - assert result.exit_code == 0 - assert "Detect bat calls in files in AUDIO_DIR" in result.output - - -def test_cli_predict_command_help(): - """Test the predict command help.""" - result = runner.invoke(cli, ["predict", "--help"]) - assert result.exit_code == 0 - assert "directory" in result.output - assert "file_list" in result.output - assert "dataset" in result.output - - -def test_cli_predict_directory_runs_on_real_audio(tmp_path: Path): - """User story: run prediction from CLI on a small directory.""" - - source_audio = Path("example_data/audio") - source_file = next(source_audio.glob("*.wav")) - audio_dir = tmp_path / "audio" - audio_dir.mkdir() - target_file = audio_dir / source_file.name - shutil.copy(source_file, target_file) - - module = build_training_module(model_config=BatDetect2Config().model) - trainer = L.Trainer(enable_checkpointing=False, logger=False) - model_path = tmp_path / "model.ckpt" - trainer.strategy.connect(module) - trainer.save_checkpoint(model_path) - output_path = tmp_path / "predictions" - - result = runner.invoke( - cli, - [ - "predict", - "directory", - str(model_path), - str(audio_dir), - str(output_path), - "--batch-size", - "1", - "--workers", - "0", - "--format", - "batdetect2", - ], - ) - - assert result.exit_code == 0 - assert output_path.exists() - output_files = list(output_path.glob("*.json")) - assert len(output_files) == 1 - - -def test_cli_detect_command_on_test_audio(tmp_path): - """Test the detect command on test audio.""" - results_dir = tmp_path / "results" - - # Remove results dir if it exists - if results_dir.exists(): - results_dir.rmdir() - - result = runner.invoke( - cli, - [ - "detect", - "example_data/audio", - str(results_dir), - "0.3", - ], - ) - assert result.exit_code == 0 - assert results_dir.exists() - assert len(list(results_dir.glob("*.csv"))) == 3 - assert len(list(results_dir.glob("*.json"))) == 3 - - -def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path): - """Test the detect command with a non-trivial time expansion factor.""" - results_dir = tmp_path / "results" - - # Remove results dir if it exists - if results_dir.exists(): - results_dir.rmdir() - - result = runner.invoke( - cli, - [ - "detect", - "example_data/audio", - str(results_dir), - "0.3", - "--time_expansion_factor", - "10", - ], - ) - - assert result.exit_code == 0 - assert "Time Expansion Factor: 10" in result.stdout - - -def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path): - """Test the detect command with the spec feature flag.""" - results_dir = tmp_path / "results" - - # Remove results dir if it exists - if results_dir.exists(): - results_dir.rmdir() - - result = runner.invoke( - cli, - [ - "detect", - "example_data/audio", - str(results_dir), - "0.3", - "--spec_features", - ], - ) - assert result.exit_code == 0 - assert results_dir.exists() - - csv_files = [path.name for path in results_dir.glob("*.csv")] - - expected_files = [ - "20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv", - "20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv", - "20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv", - ] - - for expected_file in expected_files: - assert expected_file in csv_files - - df = pd.read_csv(results_dir / expected_file) - assert not (df.duration == -1).any() - - -def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path): - results_dir = tmp_path / "results" - target = tmp_path / "audio" - target.mkdir() - - # Create an empty file with the .wav extension - empty_file = target / "empty.wav" - empty_file.touch() - - result = runner.invoke( - cli, - args=[ - "detect", - str(target), - str(results_dir), - "0.3", - "--spec_features", - ], - ) - assert result.exit_code == 0 - assert f"Error processing file {empty_file}" in result.output diff --git a/tests/test_cli/test_base.py b/tests/test_cli/test_base.py new file mode 100644 index 0000000..95782d4 --- /dev/null +++ b/tests/test_cli/test_base.py @@ -0,0 +1,18 @@ +"""Behavior-focused tests for top-level CLI command discovery.""" + +from click.testing import CliRunner + +from batdetect2.cli import cli + + +def test_cli_base_help_lists_main_commands() -> None: + """User story: discover available workflows from top-level help.""" + + result = CliRunner().invoke(cli, ["--help"]) + + assert result.exit_code == 0 + assert "predict" in result.output + assert "train" in result.output + assert "evaluate" in result.output + assert "data" in result.output + assert "detect" in result.output diff --git a/tests/test_cli/test_data.py b/tests/test_cli/test_data.py new file mode 100644 index 0000000..8d4b9df --- /dev/null +++ b/tests/test_cli/test_data.py @@ -0,0 +1,60 @@ +"""Behavior tests for data CLI command group.""" + +from pathlib import Path + +from click.testing import CliRunner + +from batdetect2.cli import cli + + +def test_cli_data_help() -> None: + """User story: discover data subcommands.""" + + result = CliRunner().invoke(cli, ["data", "--help"]) + + assert result.exit_code == 0 + assert "summary" in result.output + assert "convert" in result.output + + +def test_cli_data_convert_creates_annotation_set(tmp_path: Path) -> None: + """User story: convert dataset config into a soundevent annotation set.""" + + output = tmp_path / "annotations.json" + + result = CliRunner().invoke( + cli, + [ + "data", + "convert", + "example_data/dataset.yaml", + "--base-dir", + ".", + "--output", + str(output), + ], + ) + + assert result.exit_code == 0 + assert output.exists() + + +def test_cli_data_convert_fails_with_invalid_field(tmp_path: Path) -> None: + """User story: invalid nested field in dataset config fails clearly.""" + + output = tmp_path / "annotations.json" + + result = CliRunner().invoke( + cli, + [ + "data", + "convert", + "example_data/dataset.yaml", + "--field", + "does.not.exist", + "--output", + str(output), + ], + ) + + assert result.exit_code != 0 diff --git a/tests/test_cli/test_detect.py b/tests/test_cli/test_detect.py new file mode 100644 index 0000000..ed26b95 --- /dev/null +++ b/tests/test_cli/test_detect.py @@ -0,0 +1,119 @@ +"""Behavior tests for legacy detect command.""" + +from pathlib import Path + +import pandas as pd +from click.testing import CliRunner + +from batdetect2.cli import cli + + +def test_cli_detect_help() -> None: + """User story: get usage help for legacy detect command.""" + + result = CliRunner().invoke(cli, ["detect", "--help"]) + + assert result.exit_code == 0 + assert "Detect bat calls in files in AUDIO_DIR" in result.output + + +def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None: + """User story: run legacy detect on example audio directory.""" + + results_dir = tmp_path / "results" + + result = CliRunner().invoke( + cli, + [ + "detect", + "example_data/audio", + str(results_dir), + "0.3", + ], + ) + + assert result.exit_code == 0 + assert results_dir.exists() + assert len(list(results_dir.glob("*.csv"))) == 3 + assert len(list(results_dir.glob("*.json"))) == 3 + + +def test_cli_detect_command_with_non_trivial_time_expansion( + tmp_path: Path, +) -> None: + """User story: set time expansion in legacy detect command.""" + + results_dir = tmp_path / "results" + + result = CliRunner().invoke( + cli, + [ + "detect", + "example_data/audio", + str(results_dir), + "0.3", + "--time_expansion_factor", + "10", + ], + ) + + assert result.exit_code == 0 + assert "Time Expansion Factor: 10" in result.stdout + + +def test_cli_detect_command_with_spec_feature_flag(tmp_path: Path) -> None: + """User story: request extra spectral features in output CSV.""" + + results_dir = tmp_path / "results" + + result = CliRunner().invoke( + cli, + [ + "detect", + "example_data/audio", + str(results_dir), + "0.3", + "--spec_features", + ], + ) + + assert result.exit_code == 0 + assert results_dir.exists() + + csv_files = [path.name for path in results_dir.glob("*.csv")] + + expected_files = [ + "20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv", + "20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv", + "20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv", + ] + + for expected_file in expected_files: + assert expected_file in csv_files + df = pd.read_csv(results_dir / expected_file) + assert not (df.duration == -1).any() + + +def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path) -> None: + """User story: bad/empty input file reports error but command survives.""" + + results_dir = tmp_path / "results" + target = tmp_path / "audio" + target.mkdir() + + empty_file = target / "empty.wav" + empty_file.touch() + + result = CliRunner().invoke( + cli, + args=[ + "detect", + str(target), + str(results_dir), + "0.3", + "--spec_features", + ], + ) + + assert result.exit_code == 0 + assert f"Error processing file {empty_file}" in result.output diff --git a/tests/test_cli/test_evaluate.py b/tests/test_cli/test_evaluate.py new file mode 100644 index 0000000..8eb282d --- /dev/null +++ b/tests/test_cli/test_evaluate.py @@ -0,0 +1,47 @@ +"""CLI tests for evaluate command.""" + +from pathlib import Path + +from click.testing import CliRunner + +from batdetect2.cli import cli + +BASE_DIR = Path(__file__).parent.parent.parent + + +def test_cli_evaluate_help() -> None: + """User story: inspect evaluate command interface and options.""" + + result = CliRunner().invoke(cli, ["evaluate", "--help"]) + + assert result.exit_code == 0 + assert "MODEL_PATH" in result.output + assert "TEST_DATASET" in result.output + assert "--evaluation-config" in result.output + + +def test_cli_evaluate_writes_metrics_for_small_dataset( + tmp_path: Path, + tiny_checkpoint_path: Path, +) -> None: + """User story: evaluate a checkpoint and get metrics artifacts.""" + + output_dir = tmp_path / "eval_out" + + result = CliRunner().invoke( + cli, + [ + "evaluate", + str(tiny_checkpoint_path), + str(BASE_DIR / "example_data" / "dataset.yaml"), + "--base-dir", + str(BASE_DIR), + "--workers", + "0", + "--output-dir", + str(output_dir), + ], + ) + + assert result.exit_code == 0 + assert len(list(output_dir.rglob("metrics.csv"))) >= 1 diff --git a/tests/test_cli/test_predict.py b/tests/test_cli/test_predict.py new file mode 100644 index 0000000..e820e12 --- /dev/null +++ b/tests/test_cli/test_predict.py @@ -0,0 +1,257 @@ +"""Behavior tests for predict CLI workflows.""" + +from pathlib import Path + +import pytest +from click.testing import CliRunner +from soundevent import data, io + +from batdetect2.cli import cli + + +def test_cli_predict_help() -> None: + """User story: discover available predict modes.""" + + result = CliRunner().invoke(cli, ["predict", "--help"]) + + assert result.exit_code == 0 + assert "directory" in result.output + assert "file_list" in result.output + assert "dataset" in result.output + + +def test_cli_predict_directory_runs_on_real_audio( + tmp_path: Path, + tiny_checkpoint_path: Path, + single_audio_dir: Path, +) -> None: + """User story: run prediction for all files in a directory.""" + + output_path = tmp_path / "predictions" + + result = CliRunner().invoke( + cli, + [ + "predict", + "directory", + str(tiny_checkpoint_path), + str(single_audio_dir), + str(output_path), + "--batch-size", + "1", + "--workers", + "0", + "--format", + "batdetect2", + ], + ) + + assert result.exit_code == 0 + assert output_path.exists() + assert len(list(output_path.glob("*.json"))) == 1 + + +def test_cli_predict_file_list_runs_on_real_audio( + tmp_path: Path, + tiny_checkpoint_path: Path, + single_audio_dir: Path, +) -> None: + """User story: run prediction from an explicit list of files.""" + + audio_file = next(single_audio_dir.glob("*.wav")) + file_list = tmp_path / "files.txt" + file_list.write_text(f"{audio_file}\n") + + output_path = tmp_path / "predictions" + + result = CliRunner().invoke( + cli, + [ + "predict", + "file_list", + str(tiny_checkpoint_path), + str(file_list), + str(output_path), + "--batch-size", + "1", + "--workers", + "0", + "--format", + "batdetect2", + ], + ) + + assert result.exit_code == 0 + assert output_path.exists() + assert len(list(output_path.glob("*.json"))) == 1 + + +def test_cli_predict_dataset_runs_on_aoef_metadata( + tmp_path: Path, + tiny_checkpoint_path: Path, + single_audio_dir: Path, +) -> None: + """User story: predict from AOEF dataset metadata file.""" + + audio_file = next(single_audio_dir.glob("*.wav")) + recording = data.Recording.from_file(audio_file) + clip = data.Clip( + recording=recording, + start_time=0, + end_time=recording.duration, + ) + annotation_set = data.AnnotationSet( + name="test", + description="predict dataset test", + clip_annotations=[data.ClipAnnotation(clip=clip, sound_events=[])], + ) + + dataset_path = tmp_path / "dataset.json" + io.save(annotation_set, dataset_path) + + output_path = tmp_path / "predictions" + + result = CliRunner().invoke( + cli, + [ + "predict", + "dataset", + str(tiny_checkpoint_path), + str(dataset_path), + str(output_path), + "--batch-size", + "1", + "--workers", + "0", + "--format", + "batdetect2", + ], + ) + + assert result.exit_code == 0 + assert output_path.exists() + assert len(list(output_path.glob("*.json"))) == 1 + + +@pytest.mark.parametrize( + ("format_name", "expected_pattern", "writes_single_file"), + [ + ("batdetect2", "*.json", False), + ("raw", "*.nc", False), + ("soundevent", "*.json", True), + ], +) +def test_cli_predict_directory_supports_output_format_override( + tmp_path: Path, + tiny_checkpoint_path: Path, + single_audio_dir: Path, + format_name: str, + expected_pattern: str, + writes_single_file: bool, +) -> None: + """User story: change output format via --format only.""" + + output_path = tmp_path / f"predictions_{format_name}" + + result = CliRunner().invoke( + cli, + [ + "predict", + "directory", + str(tiny_checkpoint_path), + str(single_audio_dir), + str(output_path), + "--batch-size", + "1", + "--workers", + "0", + "--format", + format_name, + ], + ) + + assert result.exit_code == 0 + + if writes_single_file: + assert output_path.with_suffix(".json").exists() + else: + assert output_path.exists() + assert len(list(output_path.glob(expected_pattern))) >= 1 + + +def test_cli_predict_dataset_deduplicates_recordings( + tmp_path: Path, + tiny_checkpoint_path: Path, + single_audio_dir: Path, +) -> None: + """User story: duplicated recording entries are predicted once.""" + + audio_file = next(single_audio_dir.glob("*.wav")) + recording = data.Recording.from_file(audio_file) + first_clip = data.Clip( + recording=recording, + start_time=0, + end_time=recording.duration, + ) + second_clip = data.Clip( + recording=recording, + start_time=0, + end_time=recording.duration, + ) + annotation_set = data.AnnotationSet( + name="dupe-recording-dataset", + description="contains same recording twice", + clip_annotations=[ + data.ClipAnnotation(clip=first_clip, sound_events=[]), + data.ClipAnnotation(clip=second_clip, sound_events=[]), + ], + ) + + dataset_path = tmp_path / "dupes.json" + io.save(annotation_set, dataset_path) + + output_path = tmp_path / "predictions" + result = CliRunner().invoke( + cli, + [ + "predict", + "dataset", + str(tiny_checkpoint_path), + str(dataset_path), + str(output_path), + "--batch-size", + "1", + "--workers", + "0", + "--format", + "raw", + ], + ) + + assert result.exit_code == 0 + assert output_path.exists() + assert len(list(output_path.glob("*.nc"))) == 1 + + +def test_cli_predict_rejects_unknown_output_format( + tmp_path: Path, + tiny_checkpoint_path: Path, + single_audio_dir: Path, +) -> None: + """User story: invalid output format fails with error.""" + + output_path = tmp_path / "predictions" + result = CliRunner().invoke( + cli, + [ + "predict", + "directory", + str(tiny_checkpoint_path), + str(single_audio_dir), + str(output_path), + "--format", + "not_a_real_format", + ], + ) + + assert result.exit_code != 0 diff --git a/tests/test_cli/test_train.py b/tests/test_cli/test_train.py new file mode 100644 index 0000000..8f2876f --- /dev/null +++ b/tests/test_cli/test_train.py @@ -0,0 +1,81 @@ +"""CLI tests for train command.""" + +from pathlib import Path + +from click.testing import CliRunner + +from batdetect2.cli import cli +from batdetect2.models import ModelConfig + + +def test_cli_train_help() -> None: + """User story: inspect train command interface and options.""" + + result = CliRunner().invoke(cli, ["train", "--help"]) + + assert result.exit_code == 0 + assert "TRAIN_DATASET" in result.output + assert "--training-config" in result.output + assert "--model" in result.output + + +def test_cli_train_from_checkpoint_runs_on_small_dataset( + tmp_path: Path, + tiny_checkpoint_path: Path, +) -> None: + """User story: continue training from checkpoint via CLI.""" + + ckpt_dir = tmp_path / "checkpoints" + log_dir = tmp_path / "logs" + ckpt_dir.mkdir() + log_dir.mkdir() + + result = CliRunner().invoke( + cli, + [ + "train", + "example_data/dataset.yaml", + "--val-dataset", + "example_data/dataset.yaml", + "--model", + str(tiny_checkpoint_path), + "--num-epochs", + "1", + "--train-workers", + "0", + "--val-workers", + "0", + "--ckpt-dir", + str(ckpt_dir), + "--log-dir", + str(log_dir), + ], + ) + + assert result.exit_code == 0 + assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1 + + +def test_cli_train_rejects_model_and_model_config_together( + tmp_path: Path, + tiny_checkpoint_path: Path, +) -> None: + """User story: invalid train flags fail with clear usage error.""" + + model_config_path = tmp_path / "model.yaml" + model_config_path.write_text(ModelConfig().to_yaml_string()) + + result = CliRunner().invoke( + cli, + [ + "train", + "example_data/dataset.yaml", + "--model", + str(tiny_checkpoint_path), + "--model-config", + str(model_config_path), + ], + ) + + assert result.exit_code != 0 + assert "--model-config cannot be used with --model" in result.output