mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 23:30:21 +02:00
Expanded cli tests
This commit is contained in:
parent
f0af5dd79e
commit
0163a572cb
@ -164,6 +164,7 @@ def inference_file_list_command(
|
|||||||
num_workers: int,
|
num_workers: int,
|
||||||
format_name: str | None,
|
format_name: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
file_list = Path(file_list)
|
||||||
audio_files = [
|
audio_files = [
|
||||||
Path(line.strip())
|
Path(line.strip())
|
||||||
for line in file_list.read_text().splitlines()
|
for line in file_list.read_text().splitlines()
|
||||||
@ -207,6 +208,7 @@ def inference_dataset_command(
|
|||||||
num_workers: int,
|
num_workers: int,
|
||||||
format_name: str | None,
|
format_name: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
dataset_path = Path(dataset_path)
|
||||||
dataset = io.load(dataset_path, type="annotation_set")
|
dataset = io.load(dataset_path, type="annotation_set")
|
||||||
audio_files = sorted(
|
audio_files = sorted(
|
||||||
{
|
{
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import lightning as L
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
@ -12,6 +13,7 @@ from soundevent import data, terms
|
|||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.audio.clips import build_clipper
|
from batdetect2.audio.clips import build_clipper
|
||||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.data import DatasetConfig, load_dataset
|
from batdetect2.data import DatasetConfig, load_dataset
|
||||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
@ -24,6 +26,7 @@ from batdetect2.targets import (
|
|||||||
from batdetect2.targets.classes import TargetClassConfig
|
from batdetect2.targets.classes import TargetClassConfig
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
|
from batdetect2.train.lightning import build_training_module
|
||||||
from batdetect2.train.types import ClipLabeller
|
from batdetect2.train.types import ClipLabeller
|
||||||
|
|
||||||
|
|
||||||
@ -452,3 +455,23 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
|||||||
return temp_file
|
return temp_file
|
||||||
|
|
||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tiny_checkpoint_path(tmp_path: Path) -> Path:
|
||||||
|
module = build_training_module(model_config=BatDetect2Config().model)
|
||||||
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
|
checkpoint_path = tmp_path / "model.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(checkpoint_path)
|
||||||
|
return checkpoint_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def single_audio_dir(tmp_path: Path, example_audio_files: List[Path]) -> Path:
|
||||||
|
audio_dir = tmp_path / "audio"
|
||||||
|
audio_dir.mkdir()
|
||||||
|
source = example_audio_files[0]
|
||||||
|
target = audio_dir / source.name
|
||||||
|
target.write_bytes(source.read_bytes())
|
||||||
|
return audio_dir
|
||||||
|
|||||||
@ -1,185 +0,0 @@
|
|||||||
"""Test the command line interface."""
|
|
||||||
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import lightning as L
|
|
||||||
import pandas as pd
|
|
||||||
from click.testing import CliRunner
|
|
||||||
|
|
||||||
from batdetect2.cli import cli
|
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
from batdetect2.train.lightning import build_training_module
|
|
||||||
|
|
||||||
runner = CliRunner()
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_base_command():
|
|
||||||
"""Test the base command."""
|
|
||||||
result = runner.invoke(cli, ["--help"])
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert (
|
|
||||||
"BatDetect2 - Bat Call Detection and Classification" in result.output
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_detect_command_help():
|
|
||||||
"""Test the detect command help."""
|
|
||||||
result = runner.invoke(cli, ["detect", "--help"])
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_predict_command_help():
|
|
||||||
"""Test the predict command help."""
|
|
||||||
result = runner.invoke(cli, ["predict", "--help"])
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "directory" in result.output
|
|
||||||
assert "file_list" in result.output
|
|
||||||
assert "dataset" in result.output
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_predict_directory_runs_on_real_audio(tmp_path: Path):
|
|
||||||
"""User story: run prediction from CLI on a small directory."""
|
|
||||||
|
|
||||||
source_audio = Path("example_data/audio")
|
|
||||||
source_file = next(source_audio.glob("*.wav"))
|
|
||||||
audio_dir = tmp_path / "audio"
|
|
||||||
audio_dir.mkdir()
|
|
||||||
target_file = audio_dir / source_file.name
|
|
||||||
shutil.copy(source_file, target_file)
|
|
||||||
|
|
||||||
module = build_training_module(model_config=BatDetect2Config().model)
|
|
||||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
|
||||||
model_path = tmp_path / "model.ckpt"
|
|
||||||
trainer.strategy.connect(module)
|
|
||||||
trainer.save_checkpoint(model_path)
|
|
||||||
output_path = tmp_path / "predictions"
|
|
||||||
|
|
||||||
result = runner.invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"predict",
|
|
||||||
"directory",
|
|
||||||
str(model_path),
|
|
||||||
str(audio_dir),
|
|
||||||
str(output_path),
|
|
||||||
"--batch-size",
|
|
||||||
"1",
|
|
||||||
"--workers",
|
|
||||||
"0",
|
|
||||||
"--format",
|
|
||||||
"batdetect2",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert output_path.exists()
|
|
||||||
output_files = list(output_path.glob("*.json"))
|
|
||||||
assert len(output_files) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_detect_command_on_test_audio(tmp_path):
|
|
||||||
"""Test the detect command on test audio."""
|
|
||||||
results_dir = tmp_path / "results"
|
|
||||||
|
|
||||||
# Remove results dir if it exists
|
|
||||||
if results_dir.exists():
|
|
||||||
results_dir.rmdir()
|
|
||||||
|
|
||||||
result = runner.invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"detect",
|
|
||||||
"example_data/audio",
|
|
||||||
str(results_dir),
|
|
||||||
"0.3",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert results_dir.exists()
|
|
||||||
assert len(list(results_dir.glob("*.csv"))) == 3
|
|
||||||
assert len(list(results_dir.glob("*.json"))) == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
|
||||||
"""Test the detect command with a non-trivial time expansion factor."""
|
|
||||||
results_dir = tmp_path / "results"
|
|
||||||
|
|
||||||
# Remove results dir if it exists
|
|
||||||
if results_dir.exists():
|
|
||||||
results_dir.rmdir()
|
|
||||||
|
|
||||||
result = runner.invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"detect",
|
|
||||||
"example_data/audio",
|
|
||||||
str(results_dir),
|
|
||||||
"0.3",
|
|
||||||
"--time_expansion_factor",
|
|
||||||
"10",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert "Time Expansion Factor: 10" in result.stdout
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
|
||||||
"""Test the detect command with the spec feature flag."""
|
|
||||||
results_dir = tmp_path / "results"
|
|
||||||
|
|
||||||
# Remove results dir if it exists
|
|
||||||
if results_dir.exists():
|
|
||||||
results_dir.rmdir()
|
|
||||||
|
|
||||||
result = runner.invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"detect",
|
|
||||||
"example_data/audio",
|
|
||||||
str(results_dir),
|
|
||||||
"0.3",
|
|
||||||
"--spec_features",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert results_dir.exists()
|
|
||||||
|
|
||||||
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
|
||||||
|
|
||||||
expected_files = [
|
|
||||||
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
|
||||||
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
|
||||||
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
|
|
||||||
]
|
|
||||||
|
|
||||||
for expected_file in expected_files:
|
|
||||||
assert expected_file in csv_files
|
|
||||||
|
|
||||||
df = pd.read_csv(results_dir / expected_file)
|
|
||||||
assert not (df.duration == -1).any()
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
|
|
||||||
results_dir = tmp_path / "results"
|
|
||||||
target = tmp_path / "audio"
|
|
||||||
target.mkdir()
|
|
||||||
|
|
||||||
# Create an empty file with the .wav extension
|
|
||||||
empty_file = target / "empty.wav"
|
|
||||||
empty_file.touch()
|
|
||||||
|
|
||||||
result = runner.invoke(
|
|
||||||
cli,
|
|
||||||
args=[
|
|
||||||
"detect",
|
|
||||||
str(target),
|
|
||||||
str(results_dir),
|
|
||||||
"0.3",
|
|
||||||
"--spec_features",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert result.exit_code == 0
|
|
||||||
assert f"Error processing file {empty_file}" in result.output
|
|
||||||
18
tests/test_cli/test_base.py
Normal file
18
tests/test_cli/test_base.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
"""Behavior-focused tests for top-level CLI command discovery."""
|
||||||
|
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_base_help_lists_main_commands() -> None:
|
||||||
|
"""User story: discover available workflows from top-level help."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(cli, ["--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "predict" in result.output
|
||||||
|
assert "train" in result.output
|
||||||
|
assert "evaluate" in result.output
|
||||||
|
assert "data" in result.output
|
||||||
|
assert "detect" in result.output
|
||||||
60
tests/test_cli/test_data.py
Normal file
60
tests/test_cli/test_data.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""Behavior tests for data CLI command group."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_data_help() -> None:
|
||||||
|
"""User story: discover data subcommands."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(cli, ["data", "--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "summary" in result.output
|
||||||
|
assert "convert" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_data_convert_creates_annotation_set(tmp_path: Path) -> None:
|
||||||
|
"""User story: convert dataset config into a soundevent annotation set."""
|
||||||
|
|
||||||
|
output = tmp_path / "annotations.json"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"data",
|
||||||
|
"convert",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--base-dir",
|
||||||
|
".",
|
||||||
|
"--output",
|
||||||
|
str(output),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert output.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_data_convert_fails_with_invalid_field(tmp_path: Path) -> None:
|
||||||
|
"""User story: invalid nested field in dataset config fails clearly."""
|
||||||
|
|
||||||
|
output = tmp_path / "annotations.json"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"data",
|
||||||
|
"convert",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--field",
|
||||||
|
"does.not.exist",
|
||||||
|
"--output",
|
||||||
|
str(output),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
119
tests/test_cli/test_detect.py
Normal file
119
tests/test_cli/test_detect.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
"""Behavior tests for legacy detect command."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_help() -> None:
|
||||||
|
"""User story: get usage help for legacy detect command."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(cli, ["detect", "--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
|
||||||
|
"""User story: run legacy detect on example audio directory."""
|
||||||
|
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"detect",
|
||||||
|
"example_data/audio",
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert results_dir.exists()
|
||||||
|
assert len(list(results_dir.glob("*.csv"))) == 3
|
||||||
|
assert len(list(results_dir.glob("*.json"))) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_command_with_non_trivial_time_expansion(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: set time expansion in legacy detect command."""
|
||||||
|
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"detect",
|
||||||
|
"example_data/audio",
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
"--time_expansion_factor",
|
||||||
|
"10",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Time Expansion Factor: 10" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_command_with_spec_feature_flag(tmp_path: Path) -> None:
|
||||||
|
"""User story: request extra spectral features in output CSV."""
|
||||||
|
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"detect",
|
||||||
|
"example_data/audio",
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
"--spec_features",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert results_dir.exists()
|
||||||
|
|
||||||
|
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
||||||
|
|
||||||
|
expected_files = [
|
||||||
|
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
||||||
|
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
||||||
|
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
|
||||||
|
]
|
||||||
|
|
||||||
|
for expected_file in expected_files:
|
||||||
|
assert expected_file in csv_files
|
||||||
|
df = pd.read_csv(results_dir / expected_file)
|
||||||
|
assert not (df.duration == -1).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path) -> None:
|
||||||
|
"""User story: bad/empty input file reports error but command survives."""
|
||||||
|
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
target = tmp_path / "audio"
|
||||||
|
target.mkdir()
|
||||||
|
|
||||||
|
empty_file = target / "empty.wav"
|
||||||
|
empty_file.touch()
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
args=[
|
||||||
|
"detect",
|
||||||
|
str(target),
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
"--spec_features",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert f"Error processing file {empty_file}" in result.output
|
||||||
47
tests/test_cli/test_evaluate.py
Normal file
47
tests/test_cli/test_evaluate.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""CLI tests for evaluate command."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
BASE_DIR = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_evaluate_help() -> None:
|
||||||
|
"""User story: inspect evaluate command interface and options."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(cli, ["evaluate", "--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "MODEL_PATH" in result.output
|
||||||
|
assert "TEST_DATASET" in result.output
|
||||||
|
assert "--evaluation-config" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_evaluate_writes_metrics_for_small_dataset(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: evaluate a checkpoint and get metrics artifacts."""
|
||||||
|
|
||||||
|
output_dir = tmp_path / "eval_out"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"evaluate",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
str(BASE_DIR / "example_data" / "dataset.yaml"),
|
||||||
|
"--base-dir",
|
||||||
|
str(BASE_DIR),
|
||||||
|
"--workers",
|
||||||
|
"0",
|
||||||
|
"--output-dir",
|
||||||
|
str(output_dir),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert len(list(output_dir.rglob("metrics.csv"))) >= 1
|
||||||
257
tests/test_cli/test_predict.py
Normal file
257
tests/test_cli/test_predict.py
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
"""Behavior tests for predict CLI workflows."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from click.testing import CliRunner
|
||||||
|
from soundevent import data, io
|
||||||
|
|
||||||
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_predict_help() -> None:
|
||||||
|
"""User story: discover available predict modes."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(cli, ["predict", "--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "directory" in result.output
|
||||||
|
assert "file_list" in result.output
|
||||||
|
assert "dataset" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_predict_directory_runs_on_real_audio(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
single_audio_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: run prediction for all files in a directory."""
|
||||||
|
|
||||||
|
output_path = tmp_path / "predictions"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"predict",
|
||||||
|
"directory",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
str(single_audio_dir),
|
||||||
|
str(output_path),
|
||||||
|
"--batch-size",
|
||||||
|
"1",
|
||||||
|
"--workers",
|
||||||
|
"0",
|
||||||
|
"--format",
|
||||||
|
"batdetect2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert output_path.exists()
|
||||||
|
assert len(list(output_path.glob("*.json"))) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_predict_file_list_runs_on_real_audio(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
single_audio_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: run prediction from an explicit list of files."""
|
||||||
|
|
||||||
|
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||||
|
file_list = tmp_path / "files.txt"
|
||||||
|
file_list.write_text(f"{audio_file}\n")
|
||||||
|
|
||||||
|
output_path = tmp_path / "predictions"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"predict",
|
||||||
|
"file_list",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
str(file_list),
|
||||||
|
str(output_path),
|
||||||
|
"--batch-size",
|
||||||
|
"1",
|
||||||
|
"--workers",
|
||||||
|
"0",
|
||||||
|
"--format",
|
||||||
|
"batdetect2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert output_path.exists()
|
||||||
|
assert len(list(output_path.glob("*.json"))) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_predict_dataset_runs_on_aoef_metadata(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
single_audio_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: predict from AOEF dataset metadata file."""
|
||||||
|
|
||||||
|
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||||
|
recording = data.Recording.from_file(audio_file)
|
||||||
|
clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
)
|
||||||
|
annotation_set = data.AnnotationSet(
|
||||||
|
name="test",
|
||||||
|
description="predict dataset test",
|
||||||
|
clip_annotations=[data.ClipAnnotation(clip=clip, sound_events=[])],
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_path = tmp_path / "dataset.json"
|
||||||
|
io.save(annotation_set, dataset_path)
|
||||||
|
|
||||||
|
output_path = tmp_path / "predictions"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"predict",
|
||||||
|
"dataset",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
str(dataset_path),
|
||||||
|
str(output_path),
|
||||||
|
"--batch-size",
|
||||||
|
"1",
|
||||||
|
"--workers",
|
||||||
|
"0",
|
||||||
|
"--format",
|
||||||
|
"batdetect2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert output_path.exists()
|
||||||
|
assert len(list(output_path.glob("*.json"))) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("format_name", "expected_pattern", "writes_single_file"),
|
||||||
|
[
|
||||||
|
("batdetect2", "*.json", False),
|
||||||
|
("raw", "*.nc", False),
|
||||||
|
("soundevent", "*.json", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_cli_predict_directory_supports_output_format_override(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
single_audio_dir: Path,
|
||||||
|
format_name: str,
|
||||||
|
expected_pattern: str,
|
||||||
|
writes_single_file: bool,
|
||||||
|
) -> None:
|
||||||
|
"""User story: change output format via --format only."""
|
||||||
|
|
||||||
|
output_path = tmp_path / f"predictions_{format_name}"
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"predict",
|
||||||
|
"directory",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
str(single_audio_dir),
|
||||||
|
str(output_path),
|
||||||
|
"--batch-size",
|
||||||
|
"1",
|
||||||
|
"--workers",
|
||||||
|
"0",
|
||||||
|
"--format",
|
||||||
|
format_name,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
if writes_single_file:
|
||||||
|
assert output_path.with_suffix(".json").exists()
|
||||||
|
else:
|
||||||
|
assert output_path.exists()
|
||||||
|
assert len(list(output_path.glob(expected_pattern))) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_predict_dataset_deduplicates_recordings(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
single_audio_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: duplicated recording entries are predicted once."""
|
||||||
|
|
||||||
|
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||||
|
recording = data.Recording.from_file(audio_file)
|
||||||
|
first_clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
)
|
||||||
|
second_clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
)
|
||||||
|
annotation_set = data.AnnotationSet(
|
||||||
|
name="dupe-recording-dataset",
|
||||||
|
description="contains same recording twice",
|
||||||
|
clip_annotations=[
|
||||||
|
data.ClipAnnotation(clip=first_clip, sound_events=[]),
|
||||||
|
data.ClipAnnotation(clip=second_clip, sound_events=[]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_path = tmp_path / "dupes.json"
|
||||||
|
io.save(annotation_set, dataset_path)
|
||||||
|
|
||||||
|
output_path = tmp_path / "predictions"
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"predict",
|
||||||
|
"dataset",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
str(dataset_path),
|
||||||
|
str(output_path),
|
||||||
|
"--batch-size",
|
||||||
|
"1",
|
||||||
|
"--workers",
|
||||||
|
"0",
|
||||||
|
"--format",
|
||||||
|
"raw",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert output_path.exists()
|
||||||
|
assert len(list(output_path.glob("*.nc"))) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_predict_rejects_unknown_output_format(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
single_audio_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: invalid output format fails with error."""
|
||||||
|
|
||||||
|
output_path = tmp_path / "predictions"
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"predict",
|
||||||
|
"directory",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
str(single_audio_dir),
|
||||||
|
str(output_path),
|
||||||
|
"--format",
|
||||||
|
"not_a_real_format",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
81
tests/test_cli/test_train.py
Normal file
81
tests/test_cli/test_train.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""CLI tests for train command."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from batdetect2.cli import cli
|
||||||
|
from batdetect2.models import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_train_help() -> None:
|
||||||
|
"""User story: inspect train command interface and options."""
|
||||||
|
|
||||||
|
result = CliRunner().invoke(cli, ["train", "--help"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "TRAIN_DATASET" in result.output
|
||||||
|
assert "--training-config" in result.output
|
||||||
|
assert "--model" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_train_from_checkpoint_runs_on_small_dataset(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: continue training from checkpoint via CLI."""
|
||||||
|
|
||||||
|
ckpt_dir = tmp_path / "checkpoints"
|
||||||
|
log_dir = tmp_path / "logs"
|
||||||
|
ckpt_dir.mkdir()
|
||||||
|
log_dir.mkdir()
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--val-dataset",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--model",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
"--num-epochs",
|
||||||
|
"1",
|
||||||
|
"--train-workers",
|
||||||
|
"0",
|
||||||
|
"--val-workers",
|
||||||
|
"0",
|
||||||
|
"--ckpt-dir",
|
||||||
|
str(ckpt_dir),
|
||||||
|
"--log-dir",
|
||||||
|
str(log_dir),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert len(list(ckpt_dir.rglob("*.ckpt"))) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_train_rejects_model_and_model_config_together(
|
||||||
|
tmp_path: Path,
|
||||||
|
tiny_checkpoint_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: invalid train flags fail with clear usage error."""
|
||||||
|
|
||||||
|
model_config_path = tmp_path / "model.yaml"
|
||||||
|
model_config_path.write_text(ModelConfig().to_yaml_string())
|
||||||
|
|
||||||
|
result = CliRunner().invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
"example_data/dataset.yaml",
|
||||||
|
"--model",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
|
"--model-config",
|
||||||
|
str(model_config_path),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "--model-config cannot be used with --model" in result.output
|
||||||
Loading…
Reference in New Issue
Block a user