mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
from soundevent import data
|
|
|
|
from batdetect2.targets import TargetConfig, build_targets
|
|
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
|
from batdetect2.targets.terms import TagInfo
|
|
from batdetect2.train.labels import generate_heatmaps
|
|
|
|
recording = data.Recording(
|
|
samplerate=256_000,
|
|
duration=1,
|
|
channels=1,
|
|
time_expansion=1,
|
|
hash="asdf98sdf",
|
|
path=Path("/path/to/audio.wav"),
|
|
)
|
|
|
|
clip = data.Clip(
|
|
recording=recording,
|
|
start_time=0,
|
|
end_time=100,
|
|
)
|
|
|
|
|
|
def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|
sample_target_config: TargetConfig,
|
|
pippip_tag: TagInfo,
|
|
):
|
|
config = sample_target_config.model_copy(
|
|
update=dict(
|
|
roi=AnchorBBoxMapperConfig(
|
|
time_scale=1,
|
|
frequency_scale=1,
|
|
)
|
|
)
|
|
)
|
|
|
|
targets = build_targets(config)
|
|
|
|
clip_annotation = data.ClipAnnotation(
|
|
clip=clip,
|
|
sound_events=[
|
|
data.SoundEventAnnotation(
|
|
sound_event=data.SoundEvent(
|
|
recording=recording,
|
|
geometry=data.BoundingBox(
|
|
coordinates=[10, 10, 20, 30],
|
|
),
|
|
),
|
|
tags=[data.Tag(key=pippip_tag.key, value=pippip_tag.value)], # type: ignore
|
|
)
|
|
],
|
|
)
|
|
|
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
|
clip_annotation,
|
|
torch.rand([100, 100]),
|
|
min_freq=0,
|
|
max_freq=100,
|
|
targets=targets,
|
|
)
|
|
pippip_index = targets.class_names.index("pippip")
|
|
myomyo_index = targets.class_names.index("myomyo")
|
|
assert size_heatmap[0, 10, 10] == 10
|
|
assert size_heatmap[1, 10, 10] == 20
|
|
assert class_heatmap[pippip_index, 10, 10] == 1.0
|
|
assert class_heatmap[myomyo_index, 10, 10] == 0.0
|
|
assert detection_heatmap[10, 10] == 1.0
|