batdetect2/tests/test_cli.py
2026-03-18 20:07:53 +00:00

186 lines
4.9 KiB
Python

"""Test the command line interface."""
import shutil
from pathlib import Path
import lightning as L
import pandas as pd
from click.testing import CliRunner
from batdetect2.cli import cli
from batdetect2.config import BatDetect2Config
from batdetect2.train.lightning import build_training_module
runner = CliRunner()
def test_cli_base_command():
"""Test the base command."""
result = runner.invoke(cli, ["--help"])
assert result.exit_code == 0
assert (
"BatDetect2 - Bat Call Detection and Classification" in result.output
)
def test_cli_detect_command_help():
"""Test the detect command help."""
result = runner.invoke(cli, ["detect", "--help"])
assert result.exit_code == 0
assert "Detect bat calls in files in AUDIO_DIR" in result.output
def test_cli_predict_command_help():
"""Test the predict command help."""
result = runner.invoke(cli, ["predict", "--help"])
assert result.exit_code == 0
assert "directory" in result.output
assert "file_list" in result.output
assert "dataset" in result.output
def test_cli_predict_directory_runs_on_real_audio(tmp_path: Path):
"""User story: run prediction from CLI on a small directory."""
source_audio = Path("example_data/audio")
source_file = next(source_audio.glob("*.wav"))
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
target_file = audio_dir / source_file.name
shutil.copy(source_file, target_file)
module = build_training_module(model_config=BatDetect2Config().model)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
model_path = tmp_path / "model.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(model_path)
output_path = tmp_path / "predictions"
result = runner.invoke(
cli,
[
"predict",
"directory",
str(model_path),
str(audio_dir),
str(output_path),
"--batch-size",
"1",
"--workers",
"0",
"--format",
"batdetect2",
],
)
assert result.exit_code == 0
assert output_path.exists()
output_files = list(output_path.glob("*.json"))
assert len(output_files) == 1
def test_cli_detect_command_on_test_audio(tmp_path):
"""Test the detect command on test audio."""
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
],
)
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 3
assert len(list(results_dir.glob("*.json"))) == 3
def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
"""Test the detect command with a non-trivial time expansion factor."""
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--time_expansion_factor",
"10",
],
)
assert result.exit_code == 0
assert "Time Expansion Factor: 10" in result.stdout
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
"""Test the detect command with the spec feature flag."""
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert results_dir.exists()
csv_files = [path.name for path in results_dir.glob("*.csv")]
expected_files = [
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
]
for expected_file in expected_files:
assert expected_file in csv_files
df = pd.read_csv(results_dir / expected_file)
assert not (df.duration == -1).any()
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
results_dir = tmp_path / "results"
target = tmp_path / "audio"
target.mkdir()
# Create an empty file with the .wav extension
empty_file = target / "empty.wav"
empty_file.touch()
result = runner.invoke(
cli,
args=[
"detect",
str(target),
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert f"Error processing file {empty_file}" in result.output