batdetect2/tests/test_api_v2/test_outputs_io.py
2026-03-18 16:17:50 +00:00

124 lines
3.4 KiB
Python

from pathlib import Path
from typing import cast
import numpy as np
import pytest
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config
from batdetect2.outputs import build_output_formatter
from batdetect2.outputs.formats import (
BatDetect2OutputConfig,
SoundEventOutputConfig,
)
from batdetect2.postprocess.types import ClipDetections
@pytest.fixture
def api_v2() -> BatDetect2API:
"""User story: API object manages prediction IO formats."""
return BatDetect2API.from_config(BatDetect2Config())
@pytest.fixture
def file_prediction(api_v2: BatDetect2API, example_audio_files: list[Path]):
"""User story: users save/load predictions produced by API inference."""
return api_v2.process_file(example_audio_files[0])
def test_save_and_load_predictions_roundtrip_default_raw(
api_v2: BatDetect2API,
file_prediction,
tmp_path: Path,
) -> None:
output_dir = tmp_path / "raw_preds"
api_v2.save_predictions([file_prediction], path=output_dir)
loaded = cast(list[ClipDetections], api_v2.load_predictions(output_dir))
assert len(loaded) == 1
loaded_prediction = loaded[0]
assert loaded_prediction.clip == file_prediction.clip
assert len(loaded_prediction.detections) == len(file_prediction.detections)
for loaded_det, det in zip(
loaded_prediction.detections,
file_prediction.detections,
strict=True,
):
assert loaded_det.geometry == det.geometry
assert np.isclose(loaded_det.detection_score, det.detection_score)
np.testing.assert_allclose(
loaded_det.class_scores,
det.class_scores,
atol=1e-6,
)
def test_save_predictions_with_batdetect2_override(
api_v2: BatDetect2API,
file_prediction,
tmp_path: Path,
) -> None:
output_dir = tmp_path / "batdetect2_preds"
api_v2.save_predictions(
[file_prediction],
path=output_dir,
format="batdetect2",
)
formatter = build_output_formatter(
targets=api_v2.targets,
config=BatDetect2OutputConfig(),
)
loaded = formatter.load(output_dir)
assert len(loaded) == 1
assert "annotation" in loaded[0]
assert len(loaded[0]["annotation"]) == len(file_prediction.detections)
def test_load_predictions_with_format_override(
api_v2: BatDetect2API,
file_prediction,
tmp_path: Path,
) -> None:
output_dir = tmp_path / "batdetect2_preds_load"
api_v2.save_predictions(
[file_prediction],
path=output_dir,
format="batdetect2",
)
loaded = api_v2.load_predictions(output_dir, format="batdetect2")
assert len(loaded) == 1
loaded_item = loaded[0]
assert isinstance(loaded_item, dict)
assert "annotation" in loaded_item
def test_save_predictions_with_soundevent_override(
api_v2: BatDetect2API,
file_prediction,
tmp_path: Path,
) -> None:
output_path = tmp_path / "soundevent_preds"
api_v2.save_predictions(
[file_prediction],
path=output_path,
format="soundevent",
)
formatter = build_output_formatter(
targets=api_v2.targets,
config=SoundEventOutputConfig(),
)
load_path = output_path.with_suffix(".json")
loaded = formatter.load(load_path)
assert load_path.exists()
assert len(loaded) == 1
assert len(loaded[0].sound_events) == len(file_prediction.detections)