From a332c5c3bd1467562a60a13e18d7508c5da387f0 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 18 Mar 2026 16:17:50 +0000 Subject: [PATCH] Allow loading predictions in different formats --- src/batdetect2/api_v2.py | 16 +++- tests/test_api_v2/test_api_v2.py | 115 +++++++++++++++++++++++++ tests/test_api_v2/test_outputs_io.py | 123 +++++++++++++++++++++++++++ tests/test_inference/test_clips.py | 25 ++++++ 4 files changed, 277 insertions(+), 2 deletions(-) create mode 100644 tests/test_api_v2/test_api_v2.py create mode 100644 tests/test_api_v2/test_outputs_io.py create mode 100644 tests/test_inference/test_clips.py diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 5ff689e..bed0924 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -304,8 +304,20 @@ class BatDetect2API: def load_predictions( self, path: data.PathLike, - ) -> list[ClipDetections]: - return self.formatter.load(path) + format: str | None = None, + config: OutputFormatConfig | None = None, + ) -> list[object]: + formatter = self.formatter + + if format is not None or config is not None: + format = format or config.name # type: ignore + formatter = get_output_formatter( + name=format, + targets=self.targets, + config=config, + ) + + return formatter.load(path) @classmethod def from_config( diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py new file mode 100644 index 0000000..a3910f7 --- /dev/null +++ b/tests/test_api_v2/test_api_v2.py @@ -0,0 +1,115 @@ +from pathlib import Path + +import numpy as np +import pytest +import torch +from soundevent.geometry import compute_bounds + +from batdetect2.api_v2 import BatDetect2API +from batdetect2.config import BatDetect2Config + + +@pytest.fixture +def api_v2() -> BatDetect2API: + """User story: users can create a ready-to-use API from config.""" + + config = BatDetect2Config() + config.inference.loader.batch_size = 2 + return BatDetect2API.from_config(config) + + +def test_process_file_returns_recording_level_predictions( + api_v2: BatDetect2API, + example_audio_files: list[Path], +) -> None: + """User story: process a file and get detections in recording time.""" + + prediction = api_v2.process_file(example_audio_files[0]) + + assert prediction.clip.recording.path == example_audio_files[0] + assert prediction.clip.start_time == 0 + assert prediction.clip.end_time == prediction.clip.recording.duration + + for detection in prediction.detections: + start, low, end, high = compute_bounds(detection.geometry) + assert 0 <= start <= end <= prediction.clip.recording.duration + assert prediction.clip.recording.samplerate > 2 * low + assert prediction.clip.recording.samplerate > 2 * high + assert detection.class_scores.shape[0] == len( + api_v2.targets.class_names + ) + + +def test_process_files_is_batch_size_invariant( + api_v2: BatDetect2API, + example_audio_files: list[Path], +) -> None: + """User story: changing batch size should not change predictions.""" + + preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1) + preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3) + + assert len(preds_batch_1) == len(preds_batch_3) + + by_key_1 = { + ( + str(pred.clip.recording.path), + pred.clip.start_time, + pred.clip.end_time, + ): pred + for pred in preds_batch_1 + } + by_key_3 = { + ( + str(pred.clip.recording.path), + pred.clip.start_time, + pred.clip.end_time, + ): pred + for pred in preds_batch_3 + } + + assert set(by_key_1) == set(by_key_3) + + for key in by_key_1: + pred_1 = by_key_1[key] + pred_3 = by_key_3[key] + assert pred_1.clip.start_time == pred_3.clip.start_time + assert pred_1.clip.end_time == pred_3.clip.end_time + assert len(pred_1.detections) == len(pred_3.detections) + + +def test_process_audio_matches_process_spectrogram( + api_v2: BatDetect2API, + example_audio_files: list[Path], +) -> None: + """User story: users can call either audio or spectrogram entrypoint.""" + + audio = api_v2.load_audio(example_audio_files[0]) + from_audio = api_v2.process_audio(audio) + + spec = api_v2.generate_spectrogram(audio) + from_spec = api_v2.process_spectrogram(spec) + + assert len(from_audio) == len(from_spec) + + for det_audio, det_spec in zip(from_audio, from_spec, strict=True): + bounds_audio = np.array(compute_bounds(det_audio.geometry)) + bounds_spec = np.array(compute_bounds(det_spec.geometry)) + np.testing.assert_allclose(bounds_audio, bounds_spec, atol=1e-6) + assert np.isclose(det_audio.detection_score, det_spec.detection_score) + np.testing.assert_allclose( + det_audio.class_scores, + det_spec.class_scores, + atol=1e-6, + ) + + +def test_process_spectrogram_rejects_batched_input( + api_v2: BatDetect2API, +) -> None: + """User story: invalid batched input gives a clear error.""" + + spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32) + + with pytest.raises(ValueError, match="Batched spectrograms not supported"): + api_v2.process_spectrogram(spec) diff --git a/tests/test_api_v2/test_outputs_io.py b/tests/test_api_v2/test_outputs_io.py new file mode 100644 index 0000000..7cb5062 --- /dev/null +++ b/tests/test_api_v2/test_outputs_io.py @@ -0,0 +1,123 @@ +from pathlib import Path +from typing import cast + +import numpy as np +import pytest + +from batdetect2.api_v2 import BatDetect2API +from batdetect2.config import BatDetect2Config +from batdetect2.outputs import build_output_formatter +from batdetect2.outputs.formats import ( + BatDetect2OutputConfig, + SoundEventOutputConfig, +) +from batdetect2.postprocess.types import ClipDetections + + +@pytest.fixture +def api_v2() -> BatDetect2API: + """User story: API object manages prediction IO formats.""" + + return BatDetect2API.from_config(BatDetect2Config()) + + +@pytest.fixture +def file_prediction(api_v2: BatDetect2API, example_audio_files: list[Path]): + """User story: users save/load predictions produced by API inference.""" + + return api_v2.process_file(example_audio_files[0]) + + +def test_save_and_load_predictions_roundtrip_default_raw( + api_v2: BatDetect2API, + file_prediction, + tmp_path: Path, +) -> None: + output_dir = tmp_path / "raw_preds" + api_v2.save_predictions([file_prediction], path=output_dir) + loaded = cast(list[ClipDetections], api_v2.load_predictions(output_dir)) + + assert len(loaded) == 1 + loaded_prediction = loaded[0] + assert loaded_prediction.clip == file_prediction.clip + assert len(loaded_prediction.detections) == len(file_prediction.detections) + + for loaded_det, det in zip( + loaded_prediction.detections, + file_prediction.detections, + strict=True, + ): + assert loaded_det.geometry == det.geometry + assert np.isclose(loaded_det.detection_score, det.detection_score) + np.testing.assert_allclose( + loaded_det.class_scores, + det.class_scores, + atol=1e-6, + ) + + +def test_save_predictions_with_batdetect2_override( + api_v2: BatDetect2API, + file_prediction, + tmp_path: Path, +) -> None: + output_dir = tmp_path / "batdetect2_preds" + api_v2.save_predictions( + [file_prediction], + path=output_dir, + format="batdetect2", + ) + + formatter = build_output_formatter( + targets=api_v2.targets, + config=BatDetect2OutputConfig(), + ) + loaded = formatter.load(output_dir) + + assert len(loaded) == 1 + assert "annotation" in loaded[0] + assert len(loaded[0]["annotation"]) == len(file_prediction.detections) + + +def test_load_predictions_with_format_override( + api_v2: BatDetect2API, + file_prediction, + tmp_path: Path, +) -> None: + output_dir = tmp_path / "batdetect2_preds_load" + api_v2.save_predictions( + [file_prediction], + path=output_dir, + format="batdetect2", + ) + + loaded = api_v2.load_predictions(output_dir, format="batdetect2") + + assert len(loaded) == 1 + loaded_item = loaded[0] + assert isinstance(loaded_item, dict) + assert "annotation" in loaded_item + + +def test_save_predictions_with_soundevent_override( + api_v2: BatDetect2API, + file_prediction, + tmp_path: Path, +) -> None: + output_path = tmp_path / "soundevent_preds" + api_v2.save_predictions( + [file_prediction], + path=output_path, + format="soundevent", + ) + + formatter = build_output_formatter( + targets=api_v2.targets, + config=SoundEventOutputConfig(), + ) + load_path = output_path.with_suffix(".json") + loaded = formatter.load(load_path) + + assert load_path.exists() + assert len(loaded) == 1 + assert len(loaded[0].sound_events) == len(file_prediction.detections) diff --git a/tests/test_inference/test_clips.py b/tests/test_inference/test_clips.py new file mode 100644 index 0000000..207e823 --- /dev/null +++ b/tests/test_inference/test_clips.py @@ -0,0 +1,25 @@ +from soundevent import data + +from batdetect2.inference.clips import get_recording_clips + + +def test_get_recording_clips_uses_requested_duration(create_recording) -> None: + recording = create_recording(duration=2.0, samplerate=256_000) + + clips = get_recording_clips( + recording, + duration=0.5, + overlap=0.0, + discard_empty=False, + ) + + assert len(clips) == 4 + assert all(isinstance(clip, data.Clip) for clip in clips) + assert clips[0].start_time == 0.0 + assert clips[0].end_time == 0.5 + assert clips[1].start_time == 0.5 + assert clips[1].end_time == 1.0 + assert clips[2].start_time == 1.0 + assert clips[2].end_time == 1.5 + assert clips[3].start_time == 1.5 + assert clips[3].end_time == 2.0