from pathlib import Path from typing import cast import numpy as np import pytest from batdetect2.api_v2 import BatDetect2API from batdetect2.config import BatDetect2Config from batdetect2.outputs import build_output_formatter from batdetect2.outputs.formats import ( BatDetect2OutputConfig, SoundEventOutputConfig, ) from batdetect2.postprocess.types import ClipDetections @pytest.fixture def api_v2() -> BatDetect2API: """User story: API object manages prediction IO formats.""" return BatDetect2API.from_config(BatDetect2Config()) @pytest.fixture def file_prediction(api_v2: BatDetect2API, example_audio_files: list[Path]): """User story: users save/load predictions produced by API inference.""" return api_v2.process_file(example_audio_files[0]) def test_save_and_load_predictions_roundtrip_default_raw( api_v2: BatDetect2API, file_prediction, tmp_path: Path, ) -> None: output_dir = tmp_path / "raw_preds" api_v2.save_predictions([file_prediction], path=output_dir) loaded = cast(list[ClipDetections], api_v2.load_predictions(output_dir)) assert len(loaded) == 1 loaded_prediction = loaded[0] assert loaded_prediction.clip == file_prediction.clip assert len(loaded_prediction.detections) == len(file_prediction.detections) for loaded_det, det in zip( loaded_prediction.detections, file_prediction.detections, strict=True, ): assert loaded_det.geometry == det.geometry assert np.isclose(loaded_det.detection_score, det.detection_score) np.testing.assert_allclose( loaded_det.class_scores, det.class_scores, atol=1e-6, ) def test_save_predictions_with_batdetect2_override( api_v2: BatDetect2API, file_prediction, tmp_path: Path, ) -> None: output_dir = tmp_path / "batdetect2_preds" api_v2.save_predictions( [file_prediction], path=output_dir, format="batdetect2", ) formatter = build_output_formatter( targets=api_v2.targets, config=BatDetect2OutputConfig(), ) loaded = formatter.load(output_dir) assert len(loaded) == 1 assert "annotation" in loaded[0] assert len(loaded[0]["annotation"]) == len(file_prediction.detections) def test_load_predictions_with_format_override( api_v2: BatDetect2API, file_prediction, tmp_path: Path, ) -> None: output_dir = tmp_path / "batdetect2_preds_load" api_v2.save_predictions( [file_prediction], path=output_dir, format="batdetect2", ) loaded = api_v2.load_predictions(output_dir, format="batdetect2") assert len(loaded) == 1 loaded_item = loaded[0] assert isinstance(loaded_item, dict) assert "annotation" in loaded_item def test_save_predictions_with_soundevent_override( api_v2: BatDetect2API, file_prediction, tmp_path: Path, ) -> None: output_path = tmp_path / "soundevent_preds" api_v2.save_predictions( [file_prediction], path=output_path, format="soundevent", ) formatter = build_output_formatter( targets=api_v2.targets, config=SoundEventOutputConfig(), ) load_path = output_path.with_suffix(".json") loaded = formatter.load(load_path) assert load_path.exists() assert len(loaded) == 1 assert len(loaded[0].sound_events) == len(file_prediction.detections)