mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
65 lines
1.7 KiB
Python
65 lines
1.7 KiB
Python
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from soundevent import data
|
|
|
|
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
|
|
from batdetect2.typing import (
|
|
BatDetect2Prediction,
|
|
RawPrediction,
|
|
TargetProtocol,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_formatter(sample_targets: TargetProtocol):
|
|
return build_output_formatter(
|
|
config=RawOutputConfig(),
|
|
targets=sample_targets,
|
|
)
|
|
|
|
|
|
def test_roundtrip(
|
|
sample_formatter,
|
|
clip: data.Clip,
|
|
sample_targets: TargetProtocol,
|
|
tmp_path: Path,
|
|
):
|
|
detections = [
|
|
RawPrediction(
|
|
geometry=data.BoundingBox(
|
|
coordinates=list(np.random.uniform(size=[4]))
|
|
),
|
|
detection_score=0.5,
|
|
class_scores=np.random.uniform(
|
|
size=len(sample_targets.class_names)
|
|
),
|
|
features=np.random.uniform(size=32),
|
|
)
|
|
for _ in range(10)
|
|
]
|
|
|
|
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
|
|
|
path = tmp_path / "predictions"
|
|
|
|
sample_formatter.save(predictions=[prediction], path=path)
|
|
|
|
recovered = sample_formatter.load(path=path)
|
|
|
|
assert len(recovered) == 1
|
|
assert recovered[0].clip == prediction.clip
|
|
|
|
for recovered_prediction, detection in zip(
|
|
recovered[0].predictions, detections
|
|
):
|
|
assert (
|
|
recovered_prediction.detection_score == detection.detection_score
|
|
)
|
|
assert (
|
|
recovered_prediction.class_scores == detection.class_scores
|
|
).all()
|
|
assert (recovered_prediction.features == detection.features).all()
|
|
assert recovered_prediction.geometry == detection.geometry
|