mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
258 lines
6.4 KiB
Python
258 lines
6.4 KiB
Python
"""Behavior tests for predict CLI workflows."""
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from click.testing import CliRunner
|
|
from soundevent import data, io
|
|
|
|
from batdetect2.cli import cli
|
|
|
|
|
|
def test_cli_predict_help() -> None:
|
|
"""User story: discover available predict modes."""
|
|
|
|
result = CliRunner().invoke(cli, ["predict", "--help"])
|
|
|
|
assert result.exit_code == 0
|
|
assert "directory" in result.output
|
|
assert "file_list" in result.output
|
|
assert "dataset" in result.output
|
|
|
|
|
|
def test_cli_predict_directory_runs_on_real_audio(
|
|
tmp_path: Path,
|
|
tiny_checkpoint_path: Path,
|
|
single_audio_dir: Path,
|
|
) -> None:
|
|
"""User story: run prediction for all files in a directory."""
|
|
|
|
output_path = tmp_path / "predictions"
|
|
|
|
result = CliRunner().invoke(
|
|
cli,
|
|
[
|
|
"predict",
|
|
"directory",
|
|
str(tiny_checkpoint_path),
|
|
str(single_audio_dir),
|
|
str(output_path),
|
|
"--batch-size",
|
|
"1",
|
|
"--workers",
|
|
"0",
|
|
"--format",
|
|
"batdetect2",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert output_path.exists()
|
|
assert len(list(output_path.glob("*.json"))) == 1
|
|
|
|
|
|
def test_cli_predict_file_list_runs_on_real_audio(
|
|
tmp_path: Path,
|
|
tiny_checkpoint_path: Path,
|
|
single_audio_dir: Path,
|
|
) -> None:
|
|
"""User story: run prediction from an explicit list of files."""
|
|
|
|
audio_file = next(single_audio_dir.glob("*.wav"))
|
|
file_list = tmp_path / "files.txt"
|
|
file_list.write_text(f"{audio_file}\n")
|
|
|
|
output_path = tmp_path / "predictions"
|
|
|
|
result = CliRunner().invoke(
|
|
cli,
|
|
[
|
|
"predict",
|
|
"file_list",
|
|
str(tiny_checkpoint_path),
|
|
str(file_list),
|
|
str(output_path),
|
|
"--batch-size",
|
|
"1",
|
|
"--workers",
|
|
"0",
|
|
"--format",
|
|
"batdetect2",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert output_path.exists()
|
|
assert len(list(output_path.glob("*.json"))) == 1
|
|
|
|
|
|
def test_cli_predict_dataset_runs_on_aoef_metadata(
|
|
tmp_path: Path,
|
|
tiny_checkpoint_path: Path,
|
|
single_audio_dir: Path,
|
|
) -> None:
|
|
"""User story: predict from AOEF dataset metadata file."""
|
|
|
|
audio_file = next(single_audio_dir.glob("*.wav"))
|
|
recording = data.Recording.from_file(audio_file)
|
|
clip = data.Clip(
|
|
recording=recording,
|
|
start_time=0,
|
|
end_time=recording.duration,
|
|
)
|
|
annotation_set = data.AnnotationSet(
|
|
name="test",
|
|
description="predict dataset test",
|
|
clip_annotations=[data.ClipAnnotation(clip=clip, sound_events=[])],
|
|
)
|
|
|
|
dataset_path = tmp_path / "dataset.json"
|
|
io.save(annotation_set, dataset_path)
|
|
|
|
output_path = tmp_path / "predictions"
|
|
|
|
result = CliRunner().invoke(
|
|
cli,
|
|
[
|
|
"predict",
|
|
"dataset",
|
|
str(tiny_checkpoint_path),
|
|
str(dataset_path),
|
|
str(output_path),
|
|
"--batch-size",
|
|
"1",
|
|
"--workers",
|
|
"0",
|
|
"--format",
|
|
"batdetect2",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert output_path.exists()
|
|
assert len(list(output_path.glob("*.json"))) == 1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("format_name", "expected_pattern", "writes_single_file"),
|
|
[
|
|
("batdetect2", "*.json", False),
|
|
("raw", "*.nc", False),
|
|
("soundevent", "*.json", True),
|
|
],
|
|
)
|
|
def test_cli_predict_directory_supports_output_format_override(
|
|
tmp_path: Path,
|
|
tiny_checkpoint_path: Path,
|
|
single_audio_dir: Path,
|
|
format_name: str,
|
|
expected_pattern: str,
|
|
writes_single_file: bool,
|
|
) -> None:
|
|
"""User story: change output format via --format only."""
|
|
|
|
output_path = tmp_path / f"predictions_{format_name}"
|
|
|
|
result = CliRunner().invoke(
|
|
cli,
|
|
[
|
|
"predict",
|
|
"directory",
|
|
str(tiny_checkpoint_path),
|
|
str(single_audio_dir),
|
|
str(output_path),
|
|
"--batch-size",
|
|
"1",
|
|
"--workers",
|
|
"0",
|
|
"--format",
|
|
format_name,
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
|
|
if writes_single_file:
|
|
assert output_path.with_suffix(".json").exists()
|
|
else:
|
|
assert output_path.exists()
|
|
assert len(list(output_path.glob(expected_pattern))) >= 1
|
|
|
|
|
|
def test_cli_predict_dataset_deduplicates_recordings(
|
|
tmp_path: Path,
|
|
tiny_checkpoint_path: Path,
|
|
single_audio_dir: Path,
|
|
) -> None:
|
|
"""User story: duplicated recording entries are predicted once."""
|
|
|
|
audio_file = next(single_audio_dir.glob("*.wav"))
|
|
recording = data.Recording.from_file(audio_file)
|
|
first_clip = data.Clip(
|
|
recording=recording,
|
|
start_time=0,
|
|
end_time=recording.duration,
|
|
)
|
|
second_clip = data.Clip(
|
|
recording=recording,
|
|
start_time=0,
|
|
end_time=recording.duration,
|
|
)
|
|
annotation_set = data.AnnotationSet(
|
|
name="dupe-recording-dataset",
|
|
description="contains same recording twice",
|
|
clip_annotations=[
|
|
data.ClipAnnotation(clip=first_clip, sound_events=[]),
|
|
data.ClipAnnotation(clip=second_clip, sound_events=[]),
|
|
],
|
|
)
|
|
|
|
dataset_path = tmp_path / "dupes.json"
|
|
io.save(annotation_set, dataset_path)
|
|
|
|
output_path = tmp_path / "predictions"
|
|
result = CliRunner().invoke(
|
|
cli,
|
|
[
|
|
"predict",
|
|
"dataset",
|
|
str(tiny_checkpoint_path),
|
|
str(dataset_path),
|
|
str(output_path),
|
|
"--batch-size",
|
|
"1",
|
|
"--workers",
|
|
"0",
|
|
"--format",
|
|
"raw",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code == 0
|
|
assert output_path.exists()
|
|
assert len(list(output_path.glob("*.nc"))) == 1
|
|
|
|
|
|
def test_cli_predict_rejects_unknown_output_format(
|
|
tmp_path: Path,
|
|
tiny_checkpoint_path: Path,
|
|
single_audio_dir: Path,
|
|
) -> None:
|
|
"""User story: invalid output format fails with error."""
|
|
|
|
output_path = tmp_path / "predictions"
|
|
result = CliRunner().invoke(
|
|
cli,
|
|
[
|
|
"predict",
|
|
"directory",
|
|
str(tiny_checkpoint_path),
|
|
str(single_audio_dir),
|
|
str(output_path),
|
|
"--format",
|
|
"not_a_real_format",
|
|
],
|
|
)
|
|
|
|
assert result.exit_code != 0
|