mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
124 lines
3.4 KiB
Python
124 lines
3.4 KiB
Python
from pathlib import Path
|
|
from typing import cast
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from batdetect2.api_v2 import BatDetect2API
|
|
from batdetect2.config import BatDetect2Config
|
|
from batdetect2.outputs import build_output_formatter
|
|
from batdetect2.outputs.formats import (
|
|
BatDetect2OutputConfig,
|
|
SoundEventOutputConfig,
|
|
)
|
|
from batdetect2.postprocess.types import ClipDetections
|
|
|
|
|
|
@pytest.fixture
|
|
def api_v2() -> BatDetect2API:
|
|
"""User story: API object manages prediction IO formats."""
|
|
|
|
return BatDetect2API.from_config(BatDetect2Config())
|
|
|
|
|
|
@pytest.fixture
|
|
def file_prediction(api_v2: BatDetect2API, example_audio_files: list[Path]):
|
|
"""User story: users save/load predictions produced by API inference."""
|
|
|
|
return api_v2.process_file(example_audio_files[0])
|
|
|
|
|
|
def test_save_and_load_predictions_roundtrip_default_raw(
|
|
api_v2: BatDetect2API,
|
|
file_prediction,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
output_dir = tmp_path / "raw_preds"
|
|
api_v2.save_predictions([file_prediction], path=output_dir)
|
|
loaded = cast(list[ClipDetections], api_v2.load_predictions(output_dir))
|
|
|
|
assert len(loaded) == 1
|
|
loaded_prediction = loaded[0]
|
|
assert loaded_prediction.clip == file_prediction.clip
|
|
assert len(loaded_prediction.detections) == len(file_prediction.detections)
|
|
|
|
for loaded_det, det in zip(
|
|
loaded_prediction.detections,
|
|
file_prediction.detections,
|
|
strict=True,
|
|
):
|
|
assert loaded_det.geometry == det.geometry
|
|
assert np.isclose(loaded_det.detection_score, det.detection_score)
|
|
np.testing.assert_allclose(
|
|
loaded_det.class_scores,
|
|
det.class_scores,
|
|
atol=1e-6,
|
|
)
|
|
|
|
|
|
def test_save_predictions_with_batdetect2_override(
|
|
api_v2: BatDetect2API,
|
|
file_prediction,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
output_dir = tmp_path / "batdetect2_preds"
|
|
api_v2.save_predictions(
|
|
[file_prediction],
|
|
path=output_dir,
|
|
format="batdetect2",
|
|
)
|
|
|
|
formatter = build_output_formatter(
|
|
targets=api_v2.targets,
|
|
config=BatDetect2OutputConfig(),
|
|
)
|
|
loaded = formatter.load(output_dir)
|
|
|
|
assert len(loaded) == 1
|
|
assert "annotation" in loaded[0]
|
|
assert len(loaded[0]["annotation"]) == len(file_prediction.detections)
|
|
|
|
|
|
def test_load_predictions_with_format_override(
|
|
api_v2: BatDetect2API,
|
|
file_prediction,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
output_dir = tmp_path / "batdetect2_preds_load"
|
|
api_v2.save_predictions(
|
|
[file_prediction],
|
|
path=output_dir,
|
|
format="batdetect2",
|
|
)
|
|
|
|
loaded = api_v2.load_predictions(output_dir, format="batdetect2")
|
|
|
|
assert len(loaded) == 1
|
|
loaded_item = loaded[0]
|
|
assert isinstance(loaded_item, dict)
|
|
assert "annotation" in loaded_item
|
|
|
|
|
|
def test_save_predictions_with_soundevent_override(
|
|
api_v2: BatDetect2API,
|
|
file_prediction,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
output_path = tmp_path / "soundevent_preds"
|
|
api_v2.save_predictions(
|
|
[file_prediction],
|
|
path=output_path,
|
|
format="soundevent",
|
|
)
|
|
|
|
formatter = build_output_formatter(
|
|
targets=api_v2.targets,
|
|
config=SoundEventOutputConfig(),
|
|
)
|
|
load_path = output_path.with_suffix(".json")
|
|
loaded = formatter.load(load_path)
|
|
|
|
assert load_path.exists()
|
|
assert len(loaded) == 1
|
|
assert len(loaded[0].sound_events) == len(file_prediction.detections)
|