batdetect2/tests/test_train/test_labels.py
2025-09-07 11:03:46 +01:00

71 lines
1.9 KiB
Python

from pathlib import Path
import torch
from soundevent import data
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.targets.rois import AnchorBBoxMapperConfig
from batdetect2.targets.terms import TagInfo
from batdetect2.train.labels import generate_heatmaps
recording = data.Recording(
samplerate=256_000,
duration=1,
channels=1,
time_expansion=1,
hash="asdf98sdf",
path=Path("/path/to/audio.wav"),
)
clip = data.Clip(
recording=recording,
start_time=0,
end_time=100,
)
def test_generated_heatmap_are_non_zero_at_correct_positions(
sample_target_config: TargetConfig,
pippip_tag: TagInfo,
):
config = sample_target_config.model_copy(
update=dict(
roi=AnchorBBoxMapperConfig(
time_scale=1,
frequency_scale=1,
)
)
)
targets = build_targets(config)
clip_annotation = data.ClipAnnotation(
clip=clip,
sound_events=[
data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(
coordinates=[10, 10, 20, 30],
),
),
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
)
],
)
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation,
torch.rand([1, 100, 100]),
min_freq=0,
max_freq=100,
targets=targets,
)
pippip_index = targets.class_names.index("pippip")
myomyo_index = targets.class_names.index("myomyo")
assert size_heatmap[0, 10, 10] == 10
assert size_heatmap[1, 10, 10] == 20
assert class_heatmap[pippip_index, 10, 10] == 1.0
assert class_heatmap[myomyo_index, 10, 10] == 0.0
assert detection_heatmap[0, 10, 10] == 1.0