Allow loading predictions in different formats

This commit is contained in:
mbsantiago 2026-03-18 16:17:50 +00:00
parent 9fa703b34b
commit a332c5c3bd
4 changed files with 277 additions and 2 deletions

View File

@ -304,8 +304,20 @@ class BatDetect2API:
def load_predictions( def load_predictions(
self, self,
path: data.PathLike, path: data.PathLike,
) -> list[ClipDetections]: format: str | None = None,
return self.formatter.load(path) config: OutputFormatConfig | None = None,
) -> list[object]:
formatter = self.formatter
if format is not None or config is not None:
format = format or config.name # type: ignore
formatter = get_output_formatter(
name=format,
targets=self.targets,
config=config,
)
return formatter.load(path)
@classmethod @classmethod
def from_config( def from_config(

View File

@ -0,0 +1,115 @@
from pathlib import Path
import numpy as np
import pytest
import torch
from soundevent.geometry import compute_bounds
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config
@pytest.fixture
def api_v2() -> BatDetect2API:
"""User story: users can create a ready-to-use API from config."""
config = BatDetect2Config()
config.inference.loader.batch_size = 2
return BatDetect2API.from_config(config)
def test_process_file_returns_recording_level_predictions(
api_v2: BatDetect2API,
example_audio_files: list[Path],
) -> None:
"""User story: process a file and get detections in recording time."""
prediction = api_v2.process_file(example_audio_files[0])
assert prediction.clip.recording.path == example_audio_files[0]
assert prediction.clip.start_time == 0
assert prediction.clip.end_time == prediction.clip.recording.duration
for detection in prediction.detections:
start, low, end, high = compute_bounds(detection.geometry)
assert 0 <= start <= end <= prediction.clip.recording.duration
assert prediction.clip.recording.samplerate > 2 * low
assert prediction.clip.recording.samplerate > 2 * high
assert detection.class_scores.shape[0] == len(
api_v2.targets.class_names
)
def test_process_files_is_batch_size_invariant(
api_v2: BatDetect2API,
example_audio_files: list[Path],
) -> None:
"""User story: changing batch size should not change predictions."""
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
assert len(preds_batch_1) == len(preds_batch_3)
by_key_1 = {
(
str(pred.clip.recording.path),
pred.clip.start_time,
pred.clip.end_time,
): pred
for pred in preds_batch_1
}
by_key_3 = {
(
str(pred.clip.recording.path),
pred.clip.start_time,
pred.clip.end_time,
): pred
for pred in preds_batch_3
}
assert set(by_key_1) == set(by_key_3)
for key in by_key_1:
pred_1 = by_key_1[key]
pred_3 = by_key_3[key]
assert pred_1.clip.start_time == pred_3.clip.start_time
assert pred_1.clip.end_time == pred_3.clip.end_time
assert len(pred_1.detections) == len(pred_3.detections)
def test_process_audio_matches_process_spectrogram(
api_v2: BatDetect2API,
example_audio_files: list[Path],
) -> None:
"""User story: users can call either audio or spectrogram entrypoint."""
audio = api_v2.load_audio(example_audio_files[0])
from_audio = api_v2.process_audio(audio)
spec = api_v2.generate_spectrogram(audio)
from_spec = api_v2.process_spectrogram(spec)
assert len(from_audio) == len(from_spec)
for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
bounds_audio = np.array(compute_bounds(det_audio.geometry))
bounds_spec = np.array(compute_bounds(det_spec.geometry))
np.testing.assert_allclose(bounds_audio, bounds_spec, atol=1e-6)
assert np.isclose(det_audio.detection_score, det_spec.detection_score)
np.testing.assert_allclose(
det_audio.class_scores,
det_spec.class_scores,
atol=1e-6,
)
def test_process_spectrogram_rejects_batched_input(
api_v2: BatDetect2API,
) -> None:
"""User story: invalid batched input gives a clear error."""
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
api_v2.process_spectrogram(spec)

View File

@ -0,0 +1,123 @@
from pathlib import Path
from typing import cast
import numpy as np
import pytest
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config
from batdetect2.outputs import build_output_formatter
from batdetect2.outputs.formats import (
BatDetect2OutputConfig,
SoundEventOutputConfig,
)
from batdetect2.postprocess.types import ClipDetections
@pytest.fixture
def api_v2() -> BatDetect2API:
"""User story: API object manages prediction IO formats."""
return BatDetect2API.from_config(BatDetect2Config())
@pytest.fixture
def file_prediction(api_v2: BatDetect2API, example_audio_files: list[Path]):
"""User story: users save/load predictions produced by API inference."""
return api_v2.process_file(example_audio_files[0])
def test_save_and_load_predictions_roundtrip_default_raw(
api_v2: BatDetect2API,
file_prediction,
tmp_path: Path,
) -> None:
output_dir = tmp_path / "raw_preds"
api_v2.save_predictions([file_prediction], path=output_dir)
loaded = cast(list[ClipDetections], api_v2.load_predictions(output_dir))
assert len(loaded) == 1
loaded_prediction = loaded[0]
assert loaded_prediction.clip == file_prediction.clip
assert len(loaded_prediction.detections) == len(file_prediction.detections)
for loaded_det, det in zip(
loaded_prediction.detections,
file_prediction.detections,
strict=True,
):
assert loaded_det.geometry == det.geometry
assert np.isclose(loaded_det.detection_score, det.detection_score)
np.testing.assert_allclose(
loaded_det.class_scores,
det.class_scores,
atol=1e-6,
)
def test_save_predictions_with_batdetect2_override(
api_v2: BatDetect2API,
file_prediction,
tmp_path: Path,
) -> None:
output_dir = tmp_path / "batdetect2_preds"
api_v2.save_predictions(
[file_prediction],
path=output_dir,
format="batdetect2",
)
formatter = build_output_formatter(
targets=api_v2.targets,
config=BatDetect2OutputConfig(),
)
loaded = formatter.load(output_dir)
assert len(loaded) == 1
assert "annotation" in loaded[0]
assert len(loaded[0]["annotation"]) == len(file_prediction.detections)
def test_load_predictions_with_format_override(
api_v2: BatDetect2API,
file_prediction,
tmp_path: Path,
) -> None:
output_dir = tmp_path / "batdetect2_preds_load"
api_v2.save_predictions(
[file_prediction],
path=output_dir,
format="batdetect2",
)
loaded = api_v2.load_predictions(output_dir, format="batdetect2")
assert len(loaded) == 1
loaded_item = loaded[0]
assert isinstance(loaded_item, dict)
assert "annotation" in loaded_item
def test_save_predictions_with_soundevent_override(
api_v2: BatDetect2API,
file_prediction,
tmp_path: Path,
) -> None:
output_path = tmp_path / "soundevent_preds"
api_v2.save_predictions(
[file_prediction],
path=output_path,
format="soundevent",
)
formatter = build_output_formatter(
targets=api_v2.targets,
config=SoundEventOutputConfig(),
)
load_path = output_path.with_suffix(".json")
loaded = formatter.load(load_path)
assert load_path.exists()
assert len(loaded) == 1
assert len(loaded[0].sound_events) == len(file_prediction.detections)

View File

@ -0,0 +1,25 @@
from soundevent import data
from batdetect2.inference.clips import get_recording_clips
def test_get_recording_clips_uses_requested_duration(create_recording) -> None:
recording = create_recording(duration=2.0, samplerate=256_000)
clips = get_recording_clips(
recording,
duration=0.5,
overlap=0.0,
discard_empty=False,
)
assert len(clips) == 4
assert all(isinstance(clip, data.Clip) for clip in clips)
assert clips[0].start_time == 0.0
assert clips[0].end_time == 0.5
assert clips[1].start_time == 0.5
assert clips[1].end_time == 1.0
assert clips[2].start_time == 1.0
assert clips[2].end_time == 1.5
assert clips[3].start_time == 1.5
assert clips[3].end_time == 2.0