mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add roundtrip test for encoding decoding geometries
This commit is contained in:
parent
7b1cb402b4
commit
6276a8884e
72
tests/test_outputs/test_transform/test_roundtrip.py
Normal file
72
tests/test_outputs/test_transform/test_roundtrip.py
Normal file
@ -0,0 +1,72 @@
|
||||
import pytest
|
||||
import torch
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.outputs import build_output_transform
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
|
||||
|
||||
def test_annotation_roundtrip_through_postprocess_and_output_transform(
|
||||
create_recording,
|
||||
create_clip,
|
||||
sample_preprocessor,
|
||||
sample_targets: TargetProtocol,
|
||||
pippip_tag: data.Tag,
|
||||
bat_tag: data.Tag,
|
||||
) -> None:
|
||||
recording = create_recording(duration=30, samplerate=256_000)
|
||||
clip = create_clip(recording=recording, start_time=10.0, end_time=10.5)
|
||||
|
||||
annotation = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[10.2, 40_000, 10.26, 55_000]
|
||||
),
|
||||
),
|
||||
tags=[pippip_tag, bat_tag],
|
||||
)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[annotation])
|
||||
|
||||
height = 128
|
||||
duration = clip.end_time - clip.start_time
|
||||
width = int(duration * sample_preprocessor.output_samplerate)
|
||||
spec = torch.zeros((1, height, width), dtype=torch.float32)
|
||||
|
||||
labeler = build_clip_labeler(targets=sample_targets)
|
||||
heatmaps = labeler(clip_annotation, spec)
|
||||
|
||||
output = ModelOutput(
|
||||
detection_probs=heatmaps.detection.unsqueeze(0),
|
||||
size_preds=heatmaps.size.unsqueeze(0),
|
||||
class_probs=heatmaps.classes.unsqueeze(0),
|
||||
features=torch.zeros((1, 1, height, width), dtype=torch.float32),
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(preprocessor=sample_preprocessor)
|
||||
clip_detection_tensors = postprocessor(output)
|
||||
assert len(clip_detection_tensors) == 1
|
||||
|
||||
transform = build_output_transform(targets=sample_targets)
|
||||
clip_detections = transform.to_clip_detections(
|
||||
detections=clip_detection_tensors[0],
|
||||
clip=clip,
|
||||
)
|
||||
|
||||
assert len(clip_detections.detections) == 1
|
||||
recovered = clip_detections.detections[0]
|
||||
|
||||
recovered_bounds = compute_bounds(recovered.geometry)
|
||||
original_bounds = compute_bounds(annotation.sound_event.geometry)
|
||||
|
||||
# 1 ms of tolerance (spectrogram resolution)
|
||||
assert recovered_bounds[0] == pytest.approx(original_bounds[0], abs=0.001)
|
||||
assert recovered_bounds[2] == pytest.approx(original_bounds[2], abs=0.001)
|
||||
|
||||
# 1000 Hz of tolerance (spectrogram resolution)
|
||||
assert recovered_bounds[1] == pytest.approx(original_bounds[1], abs=1000)
|
||||
assert recovered_bounds[3] == pytest.approx(original_bounds[3], abs=1000)
|
||||
Loading…
Reference in New Issue
Block a user