batdetect2/tests/test_train/test_preprocessing.py
2025-08-25 17:06:17 +01:00

230 lines
7.7 KiB
Python

import pytest
from soundevent import data
from soundevent.terms import get_term
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example
from batdetect2.typing import ModelOutput
from batdetect2.typing.preprocess import AudioLoader
@pytest.fixture
def build_from_config(
create_temp_yaml,
):
def build(yaml_content):
config_path = create_temp_yaml(yaml_content)
targets_config = load_target_config(config_path, field="targets")
preprocessing_config = load_preprocessing_config(
config_path,
field="preprocessing",
)
labels_config = load_label_config(config_path, field="labels")
postprocessing_config = load_postprocess_config(
config_path,
field="postprocessing",
)
targets = build_targets(targets_config)
preprocessor = build_preprocessor(preprocessing_config)
labeller = build_clip_labeler(
targets=targets,
config=labels_config,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
postprocessor = build_postprocessor(
targets,
config=postprocessing_config,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
return targets, preprocessor, labeller, postprocessor
return build
def test_encoding_decoding_roundtrip_recovers_object(
sample_audio_loader: AudioLoader,
build_from_config,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
generic_class:
- key: order
value: Chiroptera
preprocessing:
"""
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[
data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore
],
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(
clip_annotation, sample_audio_loader, preprocessor, labeller
)
predictions = postprocessor.get_predictions(
ModelOutput(
detection_probs=encoded["detection_heatmap"]
.unsqueeze(0)
.unsqueeze(0),
size_preds=encoded["size_heatmap"].unsqueeze(0),
class_probs=encoded["class_heatmap"].unsqueeze(0),
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
),
[clip],
)[0]
assert isinstance(predictions, data.ClipPrediction)
assert len(predictions.sound_events) == 1
recovered = predictions.sound_events[0]
assert recovered.sound_event.geometry is not None
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
recovered.sound_event.geometry.coordinates
)
start_time_or, low_freq_or, end_time_or, high_freq_or = (
geometry.coordinates
)
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1
assert predicted_order_tag.tag.value == "Chiroptera"
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
sample_audio_loader: AudioLoader,
build_from_config,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
generic_class:
- key: order
value: Chiroptera
preprocessing:
"""
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(key="species", value="Myotis myotis")], # type: ignore
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(
clip_annotation,
sample_audio_loader,
preprocessor,
labeller,
)
predictions = postprocessor.get_predictions(
ModelOutput(
detection_probs=encoded["detection_heatmap"]
.unsqueeze(0)
.unsqueeze(0),
size_preds=encoded["size_heatmap"].unsqueeze(0),
class_probs=encoded["class_heatmap"].unsqueeze(0),
features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0),
),
[clip],
)[0]
assert isinstance(predictions, data.ClipPrediction)
assert len(predictions.sound_events) == 1
recovered = predictions.sound_events[0]
assert recovered.sound_event.geometry is not None
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
recovered.sound_event.geometry.coordinates
)
start_time_or, low_freq_or, end_time_or, high_freq_or = (
geometry.coordinates
)
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("species")),
None,
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Myotis myotis"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term == get_term("order")),
None,
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1
assert predicted_order_tag.tag.value == "Chiroptera"