mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Allow loading predictions in different formats
This commit is contained in:
parent
9fa703b34b
commit
a332c5c3bd
@ -304,8 +304,20 @@ class BatDetect2API:
|
|||||||
def load_predictions(
|
def load_predictions(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
) -> list[ClipDetections]:
|
format: str | None = None,
|
||||||
return self.formatter.load(path)
|
config: OutputFormatConfig | None = None,
|
||||||
|
) -> list[object]:
|
||||||
|
formatter = self.formatter
|
||||||
|
|
||||||
|
if format is not None or config is not None:
|
||||||
|
format = format or config.name # type: ignore
|
||||||
|
formatter = get_output_formatter(
|
||||||
|
name=format,
|
||||||
|
targets=self.targets,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return formatter.load(path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
|
|||||||
115
tests/test_api_v2/test_api_v2.py
Normal file
115
tests/test_api_v2/test_api_v2.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def api_v2() -> BatDetect2API:
|
||||||
|
"""User story: users can create a ready-to-use API from config."""
|
||||||
|
|
||||||
|
config = BatDetect2Config()
|
||||||
|
config.inference.loader.batch_size = 2
|
||||||
|
return BatDetect2API.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_file_returns_recording_level_predictions(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
) -> None:
|
||||||
|
"""User story: process a file and get detections in recording time."""
|
||||||
|
|
||||||
|
prediction = api_v2.process_file(example_audio_files[0])
|
||||||
|
|
||||||
|
assert prediction.clip.recording.path == example_audio_files[0]
|
||||||
|
assert prediction.clip.start_time == 0
|
||||||
|
assert prediction.clip.end_time == prediction.clip.recording.duration
|
||||||
|
|
||||||
|
for detection in prediction.detections:
|
||||||
|
start, low, end, high = compute_bounds(detection.geometry)
|
||||||
|
assert 0 <= start <= end <= prediction.clip.recording.duration
|
||||||
|
assert prediction.clip.recording.samplerate > 2 * low
|
||||||
|
assert prediction.clip.recording.samplerate > 2 * high
|
||||||
|
assert detection.class_scores.shape[0] == len(
|
||||||
|
api_v2.targets.class_names
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_files_is_batch_size_invariant(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
) -> None:
|
||||||
|
"""User story: changing batch size should not change predictions."""
|
||||||
|
|
||||||
|
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
|
||||||
|
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
|
||||||
|
|
||||||
|
assert len(preds_batch_1) == len(preds_batch_3)
|
||||||
|
|
||||||
|
by_key_1 = {
|
||||||
|
(
|
||||||
|
str(pred.clip.recording.path),
|
||||||
|
pred.clip.start_time,
|
||||||
|
pred.clip.end_time,
|
||||||
|
): pred
|
||||||
|
for pred in preds_batch_1
|
||||||
|
}
|
||||||
|
by_key_3 = {
|
||||||
|
(
|
||||||
|
str(pred.clip.recording.path),
|
||||||
|
pred.clip.start_time,
|
||||||
|
pred.clip.end_time,
|
||||||
|
): pred
|
||||||
|
for pred in preds_batch_3
|
||||||
|
}
|
||||||
|
|
||||||
|
assert set(by_key_1) == set(by_key_3)
|
||||||
|
|
||||||
|
for key in by_key_1:
|
||||||
|
pred_1 = by_key_1[key]
|
||||||
|
pred_3 = by_key_3[key]
|
||||||
|
assert pred_1.clip.start_time == pred_3.clip.start_time
|
||||||
|
assert pred_1.clip.end_time == pred_3.clip.end_time
|
||||||
|
assert len(pred_1.detections) == len(pred_3.detections)
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_audio_matches_process_spectrogram(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
example_audio_files: list[Path],
|
||||||
|
) -> None:
|
||||||
|
"""User story: users can call either audio or spectrogram entrypoint."""
|
||||||
|
|
||||||
|
audio = api_v2.load_audio(example_audio_files[0])
|
||||||
|
from_audio = api_v2.process_audio(audio)
|
||||||
|
|
||||||
|
spec = api_v2.generate_spectrogram(audio)
|
||||||
|
from_spec = api_v2.process_spectrogram(spec)
|
||||||
|
|
||||||
|
assert len(from_audio) == len(from_spec)
|
||||||
|
|
||||||
|
for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
|
||||||
|
bounds_audio = np.array(compute_bounds(det_audio.geometry))
|
||||||
|
bounds_spec = np.array(compute_bounds(det_spec.geometry))
|
||||||
|
np.testing.assert_allclose(bounds_audio, bounds_spec, atol=1e-6)
|
||||||
|
assert np.isclose(det_audio.detection_score, det_spec.detection_score)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
det_audio.class_scores,
|
||||||
|
det_spec.class_scores,
|
||||||
|
atol=1e-6,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_spectrogram_rejects_batched_input(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
) -> None:
|
||||||
|
"""User story: invalid batched input gives a clear error."""
|
||||||
|
|
||||||
|
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
|
||||||
|
api_v2.process_spectrogram(spec)
|
||||||
123
tests/test_api_v2/test_outputs_io.py
Normal file
123
tests/test_api_v2/test_outputs_io.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.outputs import build_output_formatter
|
||||||
|
from batdetect2.outputs.formats import (
|
||||||
|
BatDetect2OutputConfig,
|
||||||
|
SoundEventOutputConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def api_v2() -> BatDetect2API:
|
||||||
|
"""User story: API object manages prediction IO formats."""
|
||||||
|
|
||||||
|
return BatDetect2API.from_config(BatDetect2Config())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def file_prediction(api_v2: BatDetect2API, example_audio_files: list[Path]):
|
||||||
|
"""User story: users save/load predictions produced by API inference."""
|
||||||
|
|
||||||
|
return api_v2.process_file(example_audio_files[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_and_load_predictions_roundtrip_default_raw(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
file_prediction,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
output_dir = tmp_path / "raw_preds"
|
||||||
|
api_v2.save_predictions([file_prediction], path=output_dir)
|
||||||
|
loaded = cast(list[ClipDetections], api_v2.load_predictions(output_dir))
|
||||||
|
|
||||||
|
assert len(loaded) == 1
|
||||||
|
loaded_prediction = loaded[0]
|
||||||
|
assert loaded_prediction.clip == file_prediction.clip
|
||||||
|
assert len(loaded_prediction.detections) == len(file_prediction.detections)
|
||||||
|
|
||||||
|
for loaded_det, det in zip(
|
||||||
|
loaded_prediction.detections,
|
||||||
|
file_prediction.detections,
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
assert loaded_det.geometry == det.geometry
|
||||||
|
assert np.isclose(loaded_det.detection_score, det.detection_score)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
loaded_det.class_scores,
|
||||||
|
det.class_scores,
|
||||||
|
atol=1e-6,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_predictions_with_batdetect2_override(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
file_prediction,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
output_dir = tmp_path / "batdetect2_preds"
|
||||||
|
api_v2.save_predictions(
|
||||||
|
[file_prediction],
|
||||||
|
path=output_dir,
|
||||||
|
format="batdetect2",
|
||||||
|
)
|
||||||
|
|
||||||
|
formatter = build_output_formatter(
|
||||||
|
targets=api_v2.targets,
|
||||||
|
config=BatDetect2OutputConfig(),
|
||||||
|
)
|
||||||
|
loaded = formatter.load(output_dir)
|
||||||
|
|
||||||
|
assert len(loaded) == 1
|
||||||
|
assert "annotation" in loaded[0]
|
||||||
|
assert len(loaded[0]["annotation"]) == len(file_prediction.detections)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_predictions_with_format_override(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
file_prediction,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
output_dir = tmp_path / "batdetect2_preds_load"
|
||||||
|
api_v2.save_predictions(
|
||||||
|
[file_prediction],
|
||||||
|
path=output_dir,
|
||||||
|
format="batdetect2",
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = api_v2.load_predictions(output_dir, format="batdetect2")
|
||||||
|
|
||||||
|
assert len(loaded) == 1
|
||||||
|
loaded_item = loaded[0]
|
||||||
|
assert isinstance(loaded_item, dict)
|
||||||
|
assert "annotation" in loaded_item
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_predictions_with_soundevent_override(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
file_prediction,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
output_path = tmp_path / "soundevent_preds"
|
||||||
|
api_v2.save_predictions(
|
||||||
|
[file_prediction],
|
||||||
|
path=output_path,
|
||||||
|
format="soundevent",
|
||||||
|
)
|
||||||
|
|
||||||
|
formatter = build_output_formatter(
|
||||||
|
targets=api_v2.targets,
|
||||||
|
config=SoundEventOutputConfig(),
|
||||||
|
)
|
||||||
|
load_path = output_path.with_suffix(".json")
|
||||||
|
loaded = formatter.load(load_path)
|
||||||
|
|
||||||
|
assert load_path.exists()
|
||||||
|
assert len(loaded) == 1
|
||||||
|
assert len(loaded[0].sound_events) == len(file_prediction.detections)
|
||||||
25
tests/test_inference/test_clips.py
Normal file
25
tests/test_inference/test_clips.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.inference.clips import get_recording_clips
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_recording_clips_uses_requested_duration(create_recording) -> None:
|
||||||
|
recording = create_recording(duration=2.0, samplerate=256_000)
|
||||||
|
|
||||||
|
clips = get_recording_clips(
|
||||||
|
recording,
|
||||||
|
duration=0.5,
|
||||||
|
overlap=0.0,
|
||||||
|
discard_empty=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(clips) == 4
|
||||||
|
assert all(isinstance(clip, data.Clip) for clip in clips)
|
||||||
|
assert clips[0].start_time == 0.0
|
||||||
|
assert clips[0].end_time == 0.5
|
||||||
|
assert clips[1].start_time == 0.5
|
||||||
|
assert clips[1].end_time == 1.0
|
||||||
|
assert clips[2].start_time == 1.0
|
||||||
|
assert clips[2].end_time == 1.5
|
||||||
|
assert clips[3].start_time == 1.5
|
||||||
|
assert clips[3].end_time == 2.0
|
||||||
Loading…
Reference in New Issue
Block a user