mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Allow loading predictions in different formats
This commit is contained in:
parent
9fa703b34b
commit
a332c5c3bd
@ -304,8 +304,20 @@ class BatDetect2API:
|
||||
def load_predictions(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
) -> list[ClipDetections]:
|
||||
return self.formatter.load(path)
|
||||
format: str | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
) -> list[object]:
|
||||
formatter = self.formatter
|
||||
|
||||
if format is not None or config is not None:
|
||||
format = format or config.name # type: ignore
|
||||
formatter = get_output_formatter(
|
||||
name=format,
|
||||
targets=self.targets,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return formatter.load(path)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
|
||||
115
tests/test_api_v2/test_api_v2.py
Normal file
115
tests/test_api_v2/test_api_v2.py
Normal file
@ -0,0 +1,115 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_v2() -> BatDetect2API:
|
||||
"""User story: users can create a ready-to-use API from config."""
|
||||
|
||||
config = BatDetect2Config()
|
||||
config.inference.loader.batch_size = 2
|
||||
return BatDetect2API.from_config(config)
|
||||
|
||||
|
||||
def test_process_file_returns_recording_level_predictions(
|
||||
api_v2: BatDetect2API,
|
||||
example_audio_files: list[Path],
|
||||
) -> None:
|
||||
"""User story: process a file and get detections in recording time."""
|
||||
|
||||
prediction = api_v2.process_file(example_audio_files[0])
|
||||
|
||||
assert prediction.clip.recording.path == example_audio_files[0]
|
||||
assert prediction.clip.start_time == 0
|
||||
assert prediction.clip.end_time == prediction.clip.recording.duration
|
||||
|
||||
for detection in prediction.detections:
|
||||
start, low, end, high = compute_bounds(detection.geometry)
|
||||
assert 0 <= start <= end <= prediction.clip.recording.duration
|
||||
assert prediction.clip.recording.samplerate > 2 * low
|
||||
assert prediction.clip.recording.samplerate > 2 * high
|
||||
assert detection.class_scores.shape[0] == len(
|
||||
api_v2.targets.class_names
|
||||
)
|
||||
|
||||
|
||||
def test_process_files_is_batch_size_invariant(
|
||||
api_v2: BatDetect2API,
|
||||
example_audio_files: list[Path],
|
||||
) -> None:
|
||||
"""User story: changing batch size should not change predictions."""
|
||||
|
||||
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
|
||||
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
|
||||
|
||||
assert len(preds_batch_1) == len(preds_batch_3)
|
||||
|
||||
by_key_1 = {
|
||||
(
|
||||
str(pred.clip.recording.path),
|
||||
pred.clip.start_time,
|
||||
pred.clip.end_time,
|
||||
): pred
|
||||
for pred in preds_batch_1
|
||||
}
|
||||
by_key_3 = {
|
||||
(
|
||||
str(pred.clip.recording.path),
|
||||
pred.clip.start_time,
|
||||
pred.clip.end_time,
|
||||
): pred
|
||||
for pred in preds_batch_3
|
||||
}
|
||||
|
||||
assert set(by_key_1) == set(by_key_3)
|
||||
|
||||
for key in by_key_1:
|
||||
pred_1 = by_key_1[key]
|
||||
pred_3 = by_key_3[key]
|
||||
assert pred_1.clip.start_time == pred_3.clip.start_time
|
||||
assert pred_1.clip.end_time == pred_3.clip.end_time
|
||||
assert len(pred_1.detections) == len(pred_3.detections)
|
||||
|
||||
|
||||
def test_process_audio_matches_process_spectrogram(
|
||||
api_v2: BatDetect2API,
|
||||
example_audio_files: list[Path],
|
||||
) -> None:
|
||||
"""User story: users can call either audio or spectrogram entrypoint."""
|
||||
|
||||
audio = api_v2.load_audio(example_audio_files[0])
|
||||
from_audio = api_v2.process_audio(audio)
|
||||
|
||||
spec = api_v2.generate_spectrogram(audio)
|
||||
from_spec = api_v2.process_spectrogram(spec)
|
||||
|
||||
assert len(from_audio) == len(from_spec)
|
||||
|
||||
for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
|
||||
bounds_audio = np.array(compute_bounds(det_audio.geometry))
|
||||
bounds_spec = np.array(compute_bounds(det_spec.geometry))
|
||||
np.testing.assert_allclose(bounds_audio, bounds_spec, atol=1e-6)
|
||||
assert np.isclose(det_audio.detection_score, det_spec.detection_score)
|
||||
np.testing.assert_allclose(
|
||||
det_audio.class_scores,
|
||||
det_spec.class_scores,
|
||||
atol=1e-6,
|
||||
)
|
||||
|
||||
|
||||
def test_process_spectrogram_rejects_batched_input(
|
||||
api_v2: BatDetect2API,
|
||||
) -> None:
|
||||
"""User story: invalid batched input gives a clear error."""
|
||||
|
||||
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
|
||||
|
||||
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
|
||||
api_v2.process_spectrogram(spec)
|
||||
123
tests/test_api_v2/test_outputs_io.py
Normal file
123
tests/test_api_v2/test_outputs_io.py
Normal file
@ -0,0 +1,123 @@
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.outputs import build_output_formatter
|
||||
from batdetect2.outputs.formats import (
|
||||
BatDetect2OutputConfig,
|
||||
SoundEventOutputConfig,
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_v2() -> BatDetect2API:
|
||||
"""User story: API object manages prediction IO formats."""
|
||||
|
||||
return BatDetect2API.from_config(BatDetect2Config())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_prediction(api_v2: BatDetect2API, example_audio_files: list[Path]):
|
||||
"""User story: users save/load predictions produced by API inference."""
|
||||
|
||||
return api_v2.process_file(example_audio_files[0])
|
||||
|
||||
|
||||
def test_save_and_load_predictions_roundtrip_default_raw(
|
||||
api_v2: BatDetect2API,
|
||||
file_prediction,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
output_dir = tmp_path / "raw_preds"
|
||||
api_v2.save_predictions([file_prediction], path=output_dir)
|
||||
loaded = cast(list[ClipDetections], api_v2.load_predictions(output_dir))
|
||||
|
||||
assert len(loaded) == 1
|
||||
loaded_prediction = loaded[0]
|
||||
assert loaded_prediction.clip == file_prediction.clip
|
||||
assert len(loaded_prediction.detections) == len(file_prediction.detections)
|
||||
|
||||
for loaded_det, det in zip(
|
||||
loaded_prediction.detections,
|
||||
file_prediction.detections,
|
||||
strict=True,
|
||||
):
|
||||
assert loaded_det.geometry == det.geometry
|
||||
assert np.isclose(loaded_det.detection_score, det.detection_score)
|
||||
np.testing.assert_allclose(
|
||||
loaded_det.class_scores,
|
||||
det.class_scores,
|
||||
atol=1e-6,
|
||||
)
|
||||
|
||||
|
||||
def test_save_predictions_with_batdetect2_override(
|
||||
api_v2: BatDetect2API,
|
||||
file_prediction,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
output_dir = tmp_path / "batdetect2_preds"
|
||||
api_v2.save_predictions(
|
||||
[file_prediction],
|
||||
path=output_dir,
|
||||
format="batdetect2",
|
||||
)
|
||||
|
||||
formatter = build_output_formatter(
|
||||
targets=api_v2.targets,
|
||||
config=BatDetect2OutputConfig(),
|
||||
)
|
||||
loaded = formatter.load(output_dir)
|
||||
|
||||
assert len(loaded) == 1
|
||||
assert "annotation" in loaded[0]
|
||||
assert len(loaded[0]["annotation"]) == len(file_prediction.detections)
|
||||
|
||||
|
||||
def test_load_predictions_with_format_override(
|
||||
api_v2: BatDetect2API,
|
||||
file_prediction,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
output_dir = tmp_path / "batdetect2_preds_load"
|
||||
api_v2.save_predictions(
|
||||
[file_prediction],
|
||||
path=output_dir,
|
||||
format="batdetect2",
|
||||
)
|
||||
|
||||
loaded = api_v2.load_predictions(output_dir, format="batdetect2")
|
||||
|
||||
assert len(loaded) == 1
|
||||
loaded_item = loaded[0]
|
||||
assert isinstance(loaded_item, dict)
|
||||
assert "annotation" in loaded_item
|
||||
|
||||
|
||||
def test_save_predictions_with_soundevent_override(
|
||||
api_v2: BatDetect2API,
|
||||
file_prediction,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
output_path = tmp_path / "soundevent_preds"
|
||||
api_v2.save_predictions(
|
||||
[file_prediction],
|
||||
path=output_path,
|
||||
format="soundevent",
|
||||
)
|
||||
|
||||
formatter = build_output_formatter(
|
||||
targets=api_v2.targets,
|
||||
config=SoundEventOutputConfig(),
|
||||
)
|
||||
load_path = output_path.with_suffix(".json")
|
||||
loaded = formatter.load(load_path)
|
||||
|
||||
assert load_path.exists()
|
||||
assert len(loaded) == 1
|
||||
assert len(loaded[0].sound_events) == len(file_prediction.detections)
|
||||
25
tests/test_inference/test_clips.py
Normal file
25
tests/test_inference/test_clips.py
Normal file
@ -0,0 +1,25 @@
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.inference.clips import get_recording_clips
|
||||
|
||||
|
||||
def test_get_recording_clips_uses_requested_duration(create_recording) -> None:
|
||||
recording = create_recording(duration=2.0, samplerate=256_000)
|
||||
|
||||
clips = get_recording_clips(
|
||||
recording,
|
||||
duration=0.5,
|
||||
overlap=0.0,
|
||||
discard_empty=False,
|
||||
)
|
||||
|
||||
assert len(clips) == 4
|
||||
assert all(isinstance(clip, data.Clip) for clip in clips)
|
||||
assert clips[0].start_time == 0.0
|
||||
assert clips[0].end_time == 0.5
|
||||
assert clips[1].start_time == 0.5
|
||||
assert clips[1].end_time == 1.0
|
||||
assert clips[2].start_time == 1.0
|
||||
assert clips[2].end_time == 1.5
|
||||
assert clips[3].start_time == 1.5
|
||||
assert clips[3].end_time == 2.0
|
||||
Loading…
Reference in New Issue
Block a user