mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Expanded cli tests
This commit is contained in:
parent
f0af5dd79e
commit
0163a572cb
@ -164,6 +164,7 @@ def inference_file_list_command(
|
||||
num_workers: int,
|
||||
format_name: str | None,
|
||||
) -> None:
|
||||
file_list = Path(file_list)
|
||||
audio_files = [
|
||||
Path(line.strip())
|
||||
for line in file_list.read_text().splitlines()
|
||||
@ -207,6 +208,7 @@ def inference_dataset_command(
|
||||
num_workers: int,
|
||||
format_name: str | None,
|
||||
) -> None:
|
||||
dataset_path = Path(dataset_path)
|
||||
dataset = io.load(dataset_path, type="annotation_set")
|
||||
audio_files = sorted(
|
||||
{
|
||||
|
||||
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
from typing import Callable, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import lightning as L
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
@ -12,6 +13,7 @@ from soundevent import data, terms
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio.clips import build_clipper
|
||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.data import DatasetConfig, load_dataset
|
||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
@ -24,6 +26,7 @@ from batdetect2.targets import (
|
||||
from batdetect2.targets.classes import TargetClassConfig
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
|
||||
@ -452,3 +455,23 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
||||
return temp_file
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tiny_checkpoint_path(tmp_path: Path) -> Path:
|
||||
module = build_training_module(model_config=BatDetect2Config().model)
|
||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||
checkpoint_path = tmp_path / "model.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
return checkpoint_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def single_audio_dir(tmp_path: Path, example_audio_files: List[Path]) -> Path:
|
||||
audio_dir = tmp_path / "audio"
|
||||
audio_dir.mkdir()
|
||||
source = example_audio_files[0]
|
||||
target = audio_dir / source.name
|
||||
target.write_bytes(source.read_bytes())
|
||||
return audio_dir
|
||||
|
||||
@ -1,185 +0,0 @@
|
||||
"""Test the command line interface."""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import lightning as L
|
||||
import pandas as pd
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_cli_base_command():
|
||||
"""Test the base command."""
|
||||
result = runner.invoke(cli, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert (
|
||||
"BatDetect2 - Bat Call Detection and Classification" in result.output
|
||||
)
|
||||
|
||||
|
||||
def test_cli_detect_command_help():
|
||||
"""Test the detect command help."""
|
||||
result = runner.invoke(cli, ["detect", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||
|
||||
|
||||
def test_cli_predict_command_help():
|
||||
"""Test the predict command help."""
|
||||
result = runner.invoke(cli, ["predict", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "directory" in result.output
|
||||
assert "file_list" in result.output
|
||||
assert "dataset" in result.output
|
||||
|
||||
|
||||
def test_cli_predict_directory_runs_on_real_audio(tmp_path: Path):
|
||||
"""User story: run prediction from CLI on a small directory."""
|
||||
|
||||
source_audio = Path("example_data/audio")
|
||||
source_file = next(source_audio.glob("*.wav"))
|
||||
audio_dir = tmp_path / "audio"
|
||||
audio_dir.mkdir()
|
||||
target_file = audio_dir / source_file.name
|
||||
shutil.copy(source_file, target_file)
|
||||
|
||||
module = build_training_module(model_config=BatDetect2Config().model)
|
||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||
model_path = tmp_path / "model.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(model_path)
|
||||
output_path = tmp_path / "predictions"
|
||||
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"directory",
|
||||
str(model_path),
|
||||
str(audio_dir),
|
||||
str(output_path),
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--workers",
|
||||
"0",
|
||||
"--format",
|
||||
"batdetect2",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert output_path.exists()
|
||||
output_files = list(output_path.glob("*.json"))
|
||||
assert len(output_files) == 1
|
||||
|
||||
|
||||
def test_cli_detect_command_on_test_audio(tmp_path):
|
||||
"""Test the detect command on test audio."""
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
# Remove results dir if it exists
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
assert len(list(results_dir.glob("*.csv"))) == 3
|
||||
assert len(list(results_dir.glob("*.json"))) == 3
|
||||
|
||||
|
||||
def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
||||
"""Test the detect command with a non-trivial time expansion factor."""
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
# Remove results dir if it exists
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--time_expansion_factor",
|
||||
"10",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Time Expansion Factor: 10" in result.stdout
|
||||
|
||||
|
||||
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||
"""Test the detect command with the spec feature flag."""
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
# Remove results dir if it exists
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--spec_features",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
|
||||
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
||||
|
||||
expected_files = [
|
||||
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
||||
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
||||
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
|
||||
]
|
||||
|
||||
for expected_file in expected_files:
|
||||
assert expected_file in csv_files
|
||||
|
||||
df = pd.read_csv(results_dir / expected_file)
|
||||
assert not (df.duration == -1).any()
|
||||
|
||||
|
||||
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
|
||||
results_dir = tmp_path / "results"
|
||||
target = tmp_path / "audio"
|
||||
target.mkdir()
|
||||
|
||||
# Create an empty file with the .wav extension
|
||||
empty_file = target / "empty.wav"
|
||||
empty_file.touch()
|
||||
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
args=[
|
||||
"detect",
|
||||
str(target),
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--spec_features",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert f"Error processing file {empty_file}" in result.output
|
||||
18
tests/test_cli/test_base.py
Normal file
18
tests/test_cli/test_base.py
Normal file
@ -0,0 +1,18 @@
|
||||
"""Behavior-focused tests for top-level CLI command discovery."""
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
|
||||
def test_cli_base_help_lists_main_commands() -> None:
|
||||
"""User story: discover available workflows from top-level help."""
|
||||
|
||||
result = CliRunner().invoke(cli, ["--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "predict" in result.output
|
||||
assert "train" in result.output
|
||||
assert "evaluate" in result.output
|
||||
assert "data" in result.output
|
||||
assert "detect" in result.output
|
||||
60
tests/test_cli/test_data.py
Normal file
60
tests/test_cli/test_data.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""Behavior tests for data CLI command group."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
|
||||
def test_cli_data_help() -> None:
|
||||
"""User story: discover data subcommands."""
|
||||
|
||||
result = CliRunner().invoke(cli, ["data", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "summary" in result.output
|
||||
assert "convert" in result.output
|
||||
|
||||
|
||||
def test_cli_data_convert_creates_annotation_set(tmp_path: Path) -> None:
|
||||
"""User story: convert dataset config into a soundevent annotation set."""
|
||||
|
||||
output = tmp_path / "annotations.json"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"data",
|
||||
"convert",
|
||||
"example_data/dataset.yaml",
|
||||
"--base-dir",
|
||||
".",
|
||||
"--output",
|
||||
str(output),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert output.exists()
|
||||
|
||||
|
||||
def test_cli_data_convert_fails_with_invalid_field(tmp_path: Path) -> None:
|
||||
"""User story: invalid nested field in dataset config fails clearly."""
|
||||
|
||||
output = tmp_path / "annotations.json"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"data",
|
||||
"convert",
|
||||
"example_data/dataset.yaml",
|
||||
"--field",
|
||||
"does.not.exist",
|
||||
"--output",
|
||||
str(output),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
119
tests/test_cli/test_detect.py
Normal file
119
tests/test_cli/test_detect.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""Behavior tests for legacy detect command."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
|
||||
def test_cli_detect_help() -> None:
|
||||
"""User story: get usage help for legacy detect command."""
|
||||
|
||||
result = CliRunner().invoke(cli, ["detect", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||
|
||||
|
||||
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
|
||||
"""User story: run legacy detect on example audio directory."""
|
||||
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
assert len(list(results_dir.glob("*.csv"))) == 3
|
||||
assert len(list(results_dir.glob("*.json"))) == 3
|
||||
|
||||
|
||||
def test_cli_detect_command_with_non_trivial_time_expansion(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""User story: set time expansion in legacy detect command."""
|
||||
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--time_expansion_factor",
|
||||
"10",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Time Expansion Factor: 10" in result.stdout
|
||||
|
||||
|
||||
def test_cli_detect_command_with_spec_feature_flag(tmp_path: Path) -> None:
|
||||
"""User story: request extra spectral features in output CSV."""
|
||||
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--spec_features",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
|
||||
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
||||
|
||||
expected_files = [
|
||||
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
||||
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
||||
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
|
||||
]
|
||||
|
||||
for expected_file in expected_files:
|
||||
assert expected_file in csv_files
|
||||
df = pd.read_csv(results_dir / expected_file)
|
||||
assert not (df.duration == -1).any()
|
||||
|
||||
|
||||
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path) -> None:
|
||||
"""User story: bad/empty input file reports error but command survives."""
|
||||
|
||||
results_dir = tmp_path / "results"
|
||||
target = tmp_path / "audio"
|
||||
target.mkdir()
|
||||
|
||||
empty_file = target / "empty.wav"
|
||||
empty_file.touch()
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
args=[
|
||||
"detect",
|
||||
str(target),
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--spec_features",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert f"Error processing file {empty_file}" in result.output
|
||||
47
tests/test_cli/test_evaluate.py
Normal file
47
tests/test_cli/test_evaluate.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""CLI tests for evaluate command."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
BASE_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
def test_cli_evaluate_help() -> None:
|
||||
"""User story: inspect evaluate command interface and options."""
|
||||
|
||||
result = CliRunner().invoke(cli, ["evaluate", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "MODEL_PATH" in result.output
|
||||
assert "TEST_DATASET" in result.output
|
||||
assert "--evaluation-config" in result.output
|
||||
|
||||
|
||||
def test_cli_evaluate_writes_metrics_for_small_dataset(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
) -> None:
|
||||
"""User story: evaluate a checkpoint and get metrics artifacts."""
|
||||
|
||||
output_dir = tmp_path / "eval_out"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(tiny_checkpoint_path),
|
||||
str(BASE_DIR / "example_data" / "dataset.yaml"),
|
||||
"--base-dir",
|
||||
str(BASE_DIR),
|
||||
"--workers",
|
||||
"0",
|
||||
"--output-dir",
|
||||
str(output_dir),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert len(list(output_dir.rglob("metrics.csv"))) >= 1
|
||||
257
tests/test_cli/test_predict.py
Normal file
257
tests/test_cli/test_predict.py
Normal file
@ -0,0 +1,257 @@
|
||||
"""Behavior tests for predict CLI workflows."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
|
||||
def test_cli_predict_help() -> None:
|
||||
"""User story: discover available predict modes."""
|
||||
|
||||
result = CliRunner().invoke(cli, ["predict", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "directory" in result.output
|
||||
assert "file_list" in result.output
|
||||
assert "dataset" in result.output
|
||||
|
||||
|
||||
def test_cli_predict_directory_runs_on_real_audio(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
) -> None:
|
||||
"""User story: run prediction for all files in a directory."""
|
||||
|
||||
output_path = tmp_path / "predictions"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"directory",
|
||||
str(tiny_checkpoint_path),
|
||||
str(single_audio_dir),
|
||||
str(output_path),
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--workers",
|
||||
"0",
|
||||
"--format",
|
||||
"batdetect2",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert output_path.exists()
|
||||
assert len(list(output_path.glob("*.json"))) == 1
|
||||
|
||||
|
||||
def test_cli_predict_file_list_runs_on_real_audio(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
) -> None:
|
||||
"""User story: run prediction from an explicit list of files."""
|
||||
|
||||
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||
file_list = tmp_path / "files.txt"
|
||||
file_list.write_text(f"{audio_file}\n")
|
||||
|
||||
output_path = tmp_path / "predictions"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"file_list",
|
||||
str(tiny_checkpoint_path),
|
||||
str(file_list),
|
||||
str(output_path),
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--workers",
|
||||
"0",
|
||||
"--format",
|
||||
"batdetect2",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert output_path.exists()
|
||||
assert len(list(output_path.glob("*.json"))) == 1
|
||||
|
||||
|
||||
def test_cli_predict_dataset_runs_on_aoef_metadata(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
) -> None:
|
||||
"""User story: predict from AOEF dataset metadata file."""
|
||||
|
||||
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||
recording = data.Recording.from_file(audio_file)
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
annotation_set = data.AnnotationSet(
|
||||
name="test",
|
||||
description="predict dataset test",
|
||||
clip_annotations=[data.ClipAnnotation(clip=clip, sound_events=[])],
|
||||
)
|
||||
|
||||
dataset_path = tmp_path / "dataset.json"
|
||||
io.save(annotation_set, dataset_path)
|
||||
|
||||
output_path = tmp_path / "predictions"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"dataset",
|
||||
str(tiny_checkpoint_path),
|
||||
str(dataset_path),
|
||||
str(output_path),
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--workers",
|
||||
"0",
|
||||
"--format",
|
||||
"batdetect2",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert output_path.exists()
|
||||
assert len(list(output_path.glob("*.json"))) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("format_name", "expected_pattern", "writes_single_file"),
|
||||
[
|
||||
("batdetect2", "*.json", False),
|
||||
("raw", "*.nc", False),
|
||||
("soundevent", "*.json", True),
|
||||
],
|
||||
)
|
||||
def test_cli_predict_directory_supports_output_format_override(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
format_name: str,
|
||||
expected_pattern: str,
|
||||
writes_single_file: bool,
|
||||
) -> None:
|
||||
"""User story: change output format via --format only."""
|
||||
|
||||
output_path = tmp_path / f"predictions_{format_name}"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"directory",
|
||||
str(tiny_checkpoint_path),
|
||||
str(single_audio_dir),
|
||||
str(output_path),
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--workers",
|
||||
"0",
|
||||
"--format",
|
||||
format_name,
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
|
||||
if writes_single_file:
|
||||
assert output_path.with_suffix(".json").exists()
|
||||
else:
|
||||
assert output_path.exists()
|
||||
assert len(list(output_path.glob(expected_pattern))) >= 1
|
||||
|
||||
|
||||
def test_cli_predict_dataset_deduplicates_recordings(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
) -> None:
|
||||
"""User story: duplicated recording entries are predicted once."""
|
||||
|
||||
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||
recording = data.Recording.from_file(audio_file)
|
||||
first_clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
second_clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
annotation_set = data.AnnotationSet(
|
||||
name="dupe-recording-dataset",
|
||||
description="contains same recording twice",
|
||||
clip_annotations=[
|
||||
data.ClipAnnotation(clip=first_clip, sound_events=[]),
|
||||
data.ClipAnnotation(clip=second_clip, sound_events=[]),
|
||||
],
|
||||
)
|
||||
|
||||
dataset_path = tmp_path / "dupes.json"
|
||||
io.save(annotation_set, dataset_path)
|
||||
|
||||
output_path = tmp_path / "predictions"
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"dataset",
|
||||
str(tiny_checkpoint_path),
|
||||
str(dataset_path),
|
||||
str(output_path),
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--workers",
|
||||
"0",
|
||||
"--format",
|
||||
"raw",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert output_path.exists()
|
||||
assert len(list(output_path.glob("*.nc"))) == 1
|
||||
|
||||
|
||||
def test_cli_predict_rejects_unknown_output_format(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
) -> None:
|
||||
"""User story: invalid output format fails with error."""
|
||||
|
||||
output_path = tmp_path / "predictions"
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"directory",
|
||||
str(tiny_checkpoint_path),
|
||||
str(single_audio_dir),
|
||||
str(output_path),
|
||||
"--format",
|
||||
"not_a_real_format",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
81
tests/test_cli/test_train.py
Normal file
81
tests/test_cli/test_train.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""CLI tests for train command."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
from batdetect2.models import ModelConfig
|
||||
|
||||
|
||||
def test_cli_train_help() -> None:
|
||||
"""User story: inspect train command interface and options."""
|
||||
|
||||
result = CliRunner().invoke(cli, ["train", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "TRAIN_DATASET" in result.output
|
||||
assert "--training-config" in result.output
|
||||
assert "--model" in result.output
|
||||
|
||||
|
||||
def test_cli_train_from_checkpoint_runs_on_small_dataset(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
) -> None:
|
||||
"""User story: continue training from checkpoint via CLI."""
|
||||
|
||||
ckpt_dir = tmp_path / "checkpoints"
|
||||
log_dir = tmp_path / "logs"
|
||||
ckpt_dir.mkdir()
|
||||
log_dir.mkdir()
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
"example_data/dataset.yaml",
|
||||
"--val-dataset",
|
||||
"example_data/dataset.yaml",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
"--num-epochs",
|
||||
"1",
|
||||
"--train-workers",
|
||||
"0",
|
||||
"--val-workers",
|
||||
"0",
|
||||
"--ckpt-dir",
|
||||
str(ckpt_dir),
|
||||
"--log-dir",
|
||||
str(log_dir),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1
|
||||
|
||||
|
||||
def test_cli_train_rejects_model_and_model_config_together(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
) -> None:
|
||||
"""User story: invalid train flags fail with clear usage error."""
|
||||
|
||||
model_config_path = tmp_path / "model.yaml"
|
||||
model_config_path.write_text(ModelConfig().to_yaml_string())
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
"example_data/dataset.yaml",
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
"--model-config",
|
||||
str(model_config_path),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "--model-config cannot be used with --model" in result.output
|
||||
Loading…
Reference in New Issue
Block a user