Expanded cli tests

This commit is contained in:
mbsantiago 2026-03-18 20:35:08 +00:00
parent f0af5dd79e
commit 0163a572cb
9 changed files with 607 additions and 185 deletions

View File

@ -164,6 +164,7 @@ def inference_file_list_command(
num_workers: int,
format_name: str | None,
) -> None:
file_list = Path(file_list)
audio_files = [
Path(line.strip())
for line in file_list.read_text().splitlines()
@ -207,6 +208,7 @@ def inference_dataset_command(
num_workers: int,
format_name: str | None,
) -> None:
dataset_path = Path(dataset_path)
dataset = io.load(dataset_path, type="annotation_set")
audio_files = sorted(
{

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Callable, List, Optional
from uuid import uuid4
import lightning as L
import numpy as np
import pytest
import soundfile as sf
@ -12,6 +13,7 @@ from soundevent import data, terms
from batdetect2.audio import build_audio_loader
from batdetect2.audio.clips import build_clipper
from batdetect2.audio.types import AudioLoader, ClipperProtocol
from batdetect2.config import BatDetect2Config
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor
@ -24,6 +26,7 @@ from batdetect2.targets import (
from batdetect2.targets.classes import TargetClassConfig
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module
from batdetect2.train.types import ClipLabeller
@ -452,3 +455,23 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
return temp_file
return factory
@pytest.fixture
def tiny_checkpoint_path(tmp_path: Path) -> Path:
module = build_training_module(model_config=BatDetect2Config().model)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "model.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
return checkpoint_path
@pytest.fixture
def single_audio_dir(tmp_path: Path, example_audio_files: List[Path]) -> Path:
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
source = example_audio_files[0]
target = audio_dir / source.name
target.write_bytes(source.read_bytes())
return audio_dir

View File

@ -1,185 +0,0 @@
"""Test the command line interface."""
import shutil
from pathlib import Path
import lightning as L
import pandas as pd
from click.testing import CliRunner
from batdetect2.cli import cli
from batdetect2.config import BatDetect2Config
from batdetect2.train.lightning import build_training_module
runner = CliRunner()
def test_cli_base_command():
"""Test the base command."""
result = runner.invoke(cli, ["--help"])
assert result.exit_code == 0
assert (
"BatDetect2 - Bat Call Detection and Classification" in result.output
)
def test_cli_detect_command_help():
"""Test the detect command help."""
result = runner.invoke(cli, ["detect", "--help"])
assert result.exit_code == 0
assert "Detect bat calls in files in AUDIO_DIR" in result.output
def test_cli_predict_command_help():
"""Test the predict command help."""
result = runner.invoke(cli, ["predict", "--help"])
assert result.exit_code == 0
assert "directory" in result.output
assert "file_list" in result.output
assert "dataset" in result.output
def test_cli_predict_directory_runs_on_real_audio(tmp_path: Path):
"""User story: run prediction from CLI on a small directory."""
source_audio = Path("example_data/audio")
source_file = next(source_audio.glob("*.wav"))
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
target_file = audio_dir / source_file.name
shutil.copy(source_file, target_file)
module = build_training_module(model_config=BatDetect2Config().model)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
model_path = tmp_path / "model.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(model_path)
output_path = tmp_path / "predictions"
result = runner.invoke(
cli,
[
"predict",
"directory",
str(model_path),
str(audio_dir),
str(output_path),
"--batch-size",
"1",
"--workers",
"0",
"--format",
"batdetect2",
],
)
assert result.exit_code == 0
assert output_path.exists()
output_files = list(output_path.glob("*.json"))
assert len(output_files) == 1
def test_cli_detect_command_on_test_audio(tmp_path):
"""Test the detect command on test audio."""
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
],
)
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 3
assert len(list(results_dir.glob("*.json"))) == 3
def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
"""Test the detect command with a non-trivial time expansion factor."""
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--time_expansion_factor",
"10",
],
)
assert result.exit_code == 0
assert "Time Expansion Factor: 10" in result.stdout
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
"""Test the detect command with the spec feature flag."""
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert results_dir.exists()
csv_files = [path.name for path in results_dir.glob("*.csv")]
expected_files = [
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
]
for expected_file in expected_files:
assert expected_file in csv_files
df = pd.read_csv(results_dir / expected_file)
assert not (df.duration == -1).any()
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
results_dir = tmp_path / "results"
target = tmp_path / "audio"
target.mkdir()
# Create an empty file with the .wav extension
empty_file = target / "empty.wav"
empty_file.touch()
result = runner.invoke(
cli,
args=[
"detect",
str(target),
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert f"Error processing file {empty_file}" in result.output

View File

@ -0,0 +1,18 @@
"""Behavior-focused tests for top-level CLI command discovery."""
from click.testing import CliRunner
from batdetect2.cli import cli
def test_cli_base_help_lists_main_commands() -> None:
"""User story: discover available workflows from top-level help."""
result = CliRunner().invoke(cli, ["--help"])
assert result.exit_code == 0
assert "predict" in result.output
assert "train" in result.output
assert "evaluate" in result.output
assert "data" in result.output
assert "detect" in result.output

View File

@ -0,0 +1,60 @@
"""Behavior tests for data CLI command group."""
from pathlib import Path
from click.testing import CliRunner
from batdetect2.cli import cli
def test_cli_data_help() -> None:
"""User story: discover data subcommands."""
result = CliRunner().invoke(cli, ["data", "--help"])
assert result.exit_code == 0
assert "summary" in result.output
assert "convert" in result.output
def test_cli_data_convert_creates_annotation_set(tmp_path: Path) -> None:
"""User story: convert dataset config into a soundevent annotation set."""
output = tmp_path / "annotations.json"
result = CliRunner().invoke(
cli,
[
"data",
"convert",
"example_data/dataset.yaml",
"--base-dir",
".",
"--output",
str(output),
],
)
assert result.exit_code == 0
assert output.exists()
def test_cli_data_convert_fails_with_invalid_field(tmp_path: Path) -> None:
"""User story: invalid nested field in dataset config fails clearly."""
output = tmp_path / "annotations.json"
result = CliRunner().invoke(
cli,
[
"data",
"convert",
"example_data/dataset.yaml",
"--field",
"does.not.exist",
"--output",
str(output),
],
)
assert result.exit_code != 0

View File

@ -0,0 +1,119 @@
"""Behavior tests for legacy detect command."""
from pathlib import Path
import pandas as pd
from click.testing import CliRunner
from batdetect2.cli import cli
def test_cli_detect_help() -> None:
"""User story: get usage help for legacy detect command."""
result = CliRunner().invoke(cli, ["detect", "--help"])
assert result.exit_code == 0
assert "Detect bat calls in files in AUDIO_DIR" in result.output
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
"""User story: run legacy detect on example audio directory."""
results_dir = tmp_path / "results"
result = CliRunner().invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
],
)
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 3
assert len(list(results_dir.glob("*.json"))) == 3
def test_cli_detect_command_with_non_trivial_time_expansion(
tmp_path: Path,
) -> None:
"""User story: set time expansion in legacy detect command."""
results_dir = tmp_path / "results"
result = CliRunner().invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--time_expansion_factor",
"10",
],
)
assert result.exit_code == 0
assert "Time Expansion Factor: 10" in result.stdout
def test_cli_detect_command_with_spec_feature_flag(tmp_path: Path) -> None:
"""User story: request extra spectral features in output CSV."""
results_dir = tmp_path / "results"
result = CliRunner().invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert results_dir.exists()
csv_files = [path.name for path in results_dir.glob("*.csv")]
expected_files = [
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
]
for expected_file in expected_files:
assert expected_file in csv_files
df = pd.read_csv(results_dir / expected_file)
assert not (df.duration == -1).any()
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path) -> None:
"""User story: bad/empty input file reports error but command survives."""
results_dir = tmp_path / "results"
target = tmp_path / "audio"
target.mkdir()
empty_file = target / "empty.wav"
empty_file.touch()
result = CliRunner().invoke(
cli,
args=[
"detect",
str(target),
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert f"Error processing file {empty_file}" in result.output

View File

@ -0,0 +1,47 @@
"""CLI tests for evaluate command."""
from pathlib import Path
from click.testing import CliRunner
from batdetect2.cli import cli
BASE_DIR = Path(__file__).parent.parent.parent
def test_cli_evaluate_help() -> None:
"""User story: inspect evaluate command interface and options."""
result = CliRunner().invoke(cli, ["evaluate", "--help"])
assert result.exit_code == 0
assert "MODEL_PATH" in result.output
assert "TEST_DATASET" in result.output
assert "--evaluation-config" in result.output
def test_cli_evaluate_writes_metrics_for_small_dataset(
tmp_path: Path,
tiny_checkpoint_path: Path,
) -> None:
"""User story: evaluate a checkpoint and get metrics artifacts."""
output_dir = tmp_path / "eval_out"
result = CliRunner().invoke(
cli,
[
"evaluate",
str(tiny_checkpoint_path),
str(BASE_DIR / "example_data" / "dataset.yaml"),
"--base-dir",
str(BASE_DIR),
"--workers",
"0",
"--output-dir",
str(output_dir),
],
)
assert result.exit_code == 0
assert len(list(output_dir.rglob("metrics.csv"))) >= 1

View File

@ -0,0 +1,257 @@
"""Behavior tests for predict CLI workflows."""
from pathlib import Path
import pytest
from click.testing import CliRunner
from soundevent import data, io
from batdetect2.cli import cli
def test_cli_predict_help() -> None:
"""User story: discover available predict modes."""
result = CliRunner().invoke(cli, ["predict", "--help"])
assert result.exit_code == 0
assert "directory" in result.output
assert "file_list" in result.output
assert "dataset" in result.output
def test_cli_predict_directory_runs_on_real_audio(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
) -> None:
"""User story: run prediction for all files in a directory."""
output_path = tmp_path / "predictions"
result = CliRunner().invoke(
cli,
[
"predict",
"directory",
str(tiny_checkpoint_path),
str(single_audio_dir),
str(output_path),
"--batch-size",
"1",
"--workers",
"0",
"--format",
"batdetect2",
],
)
assert result.exit_code == 0
assert output_path.exists()
assert len(list(output_path.glob("*.json"))) == 1
def test_cli_predict_file_list_runs_on_real_audio(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
) -> None:
"""User story: run prediction from an explicit list of files."""
audio_file = next(single_audio_dir.glob("*.wav"))
file_list = tmp_path / "files.txt"
file_list.write_text(f"{audio_file}\n")
output_path = tmp_path / "predictions"
result = CliRunner().invoke(
cli,
[
"predict",
"file_list",
str(tiny_checkpoint_path),
str(file_list),
str(output_path),
"--batch-size",
"1",
"--workers",
"0",
"--format",
"batdetect2",
],
)
assert result.exit_code == 0
assert output_path.exists()
assert len(list(output_path.glob("*.json"))) == 1
def test_cli_predict_dataset_runs_on_aoef_metadata(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
) -> None:
"""User story: predict from AOEF dataset metadata file."""
audio_file = next(single_audio_dir.glob("*.wav"))
recording = data.Recording.from_file(audio_file)
clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
annotation_set = data.AnnotationSet(
name="test",
description="predict dataset test",
clip_annotations=[data.ClipAnnotation(clip=clip, sound_events=[])],
)
dataset_path = tmp_path / "dataset.json"
io.save(annotation_set, dataset_path)
output_path = tmp_path / "predictions"
result = CliRunner().invoke(
cli,
[
"predict",
"dataset",
str(tiny_checkpoint_path),
str(dataset_path),
str(output_path),
"--batch-size",
"1",
"--workers",
"0",
"--format",
"batdetect2",
],
)
assert result.exit_code == 0
assert output_path.exists()
assert len(list(output_path.glob("*.json"))) == 1
@pytest.mark.parametrize(
("format_name", "expected_pattern", "writes_single_file"),
[
("batdetect2", "*.json", False),
("raw", "*.nc", False),
("soundevent", "*.json", True),
],
)
def test_cli_predict_directory_supports_output_format_override(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
format_name: str,
expected_pattern: str,
writes_single_file: bool,
) -> None:
"""User story: change output format via --format only."""
output_path = tmp_path / f"predictions_{format_name}"
result = CliRunner().invoke(
cli,
[
"predict",
"directory",
str(tiny_checkpoint_path),
str(single_audio_dir),
str(output_path),
"--batch-size",
"1",
"--workers",
"0",
"--format",
format_name,
],
)
assert result.exit_code == 0
if writes_single_file:
assert output_path.with_suffix(".json").exists()
else:
assert output_path.exists()
assert len(list(output_path.glob(expected_pattern))) >= 1
def test_cli_predict_dataset_deduplicates_recordings(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
) -> None:
"""User story: duplicated recording entries are predicted once."""
audio_file = next(single_audio_dir.glob("*.wav"))
recording = data.Recording.from_file(audio_file)
first_clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
second_clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
annotation_set = data.AnnotationSet(
name="dupe-recording-dataset",
description="contains same recording twice",
clip_annotations=[
data.ClipAnnotation(clip=first_clip, sound_events=[]),
data.ClipAnnotation(clip=second_clip, sound_events=[]),
],
)
dataset_path = tmp_path / "dupes.json"
io.save(annotation_set, dataset_path)
output_path = tmp_path / "predictions"
result = CliRunner().invoke(
cli,
[
"predict",
"dataset",
str(tiny_checkpoint_path),
str(dataset_path),
str(output_path),
"--batch-size",
"1",
"--workers",
"0",
"--format",
"raw",
],
)
assert result.exit_code == 0
assert output_path.exists()
assert len(list(output_path.glob("*.nc"))) == 1
def test_cli_predict_rejects_unknown_output_format(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
) -> None:
"""User story: invalid output format fails with error."""
output_path = tmp_path / "predictions"
result = CliRunner().invoke(
cli,
[
"predict",
"directory",
str(tiny_checkpoint_path),
str(single_audio_dir),
str(output_path),
"--format",
"not_a_real_format",
],
)
assert result.exit_code != 0

View File

@ -0,0 +1,81 @@
"""CLI tests for train command."""
from pathlib import Path
from click.testing import CliRunner
from batdetect2.cli import cli
from batdetect2.models import ModelConfig
def test_cli_train_help() -> None:
"""User story: inspect train command interface and options."""
result = CliRunner().invoke(cli, ["train", "--help"])
assert result.exit_code == 0
assert "TRAIN_DATASET" in result.output
assert "--training-config" in result.output
assert "--model" in result.output
def test_cli_train_from_checkpoint_runs_on_small_dataset(
tmp_path: Path,
tiny_checkpoint_path: Path,
) -> None:
"""User story: continue training from checkpoint via CLI."""
ckpt_dir = tmp_path / "checkpoints"
log_dir = tmp_path / "logs"
ckpt_dir.mkdir()
log_dir.mkdir()
result = CliRunner().invoke(
cli,
[
"train",
"example_data/dataset.yaml",
"--val-dataset",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
"--num-epochs",
"1",
"--train-workers",
"0",
"--val-workers",
"0",
"--ckpt-dir",
str(ckpt_dir),
"--log-dir",
str(log_dir),
],
)
assert result.exit_code == 0
assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1
def test_cli_train_rejects_model_and_model_config_together(
tmp_path: Path,
tiny_checkpoint_path: Path,
) -> None:
"""User story: invalid train flags fail with clear usage error."""
model_config_path = tmp_path / "model.yaml"
model_config_path.write_text(ModelConfig().to_yaml_string())
result = CliRunner().invoke(
cli,
[
"train",
"example_data/dataset.yaml",
"--model",
str(tiny_checkpoint_path),
"--model-config",
str(model_config_path),
],
)
assert result.exit_code != 0
assert "--model-config cannot be used with --model" in result.output