batdetect2/tests/test_train/test_augmentations.py
2025-04-22 09:00:57 +01:00

149 lines
4.9 KiB
Python

from collections.abc import Callable
import numpy as np
import pytest
import xarray as xr
from soundevent import arrays, data
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.train.augmentations import (
add_echo,
mix_examples,
select_subclip,
)
from batdetect2.train.preprocess import generate_train_example
from batdetect2.train.types import ClipLabeller
def test_mix_examples(
sample_preprocessor: PreprocessorProtocol,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
recording1 = create_recording()
recording2 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
example1 = generate_train_example(
clip_annotation_1,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
example2 = generate_train_example(
clip_annotation_2,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
assert mixed["spectrogram"].shape == example1["spectrogram"].shape
assert mixed["detection"].shape == example1["detection"].shape
assert mixed["size"].shape == example1["size"].shape
assert mixed["class"].shape == example1["class"].shape
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
def test_mix_examples_of_different_durations(
sample_preprocessor: PreprocessorProtocol,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
duration1: float,
duration2: float,
):
recording1 = create_recording()
recording2 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
example1 = generate_train_example(
clip_annotation_1,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
example2 = generate_train_example(
clip_annotation_2,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
# Check the spectrogram has the expected duration
step = arrays.get_dim_step(mixed["spectrogram"], "time")
start, stop = arrays.get_dim_range(mixed["spectrogram"], "time")
assert start == 0
assert np.isclose(stop + step, duration1, atol=2 * step)
def test_add_echo(
sample_preprocessor: PreprocessorProtocol,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
original = generate_train_example(
clip_annotation_1,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
with_echo = add_echo(original, preprocessor=sample_preprocessor)
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
xr.testing.assert_identical(with_echo["size"], original["size"])
xr.testing.assert_identical(with_echo["class"], original["class"])
xr.testing.assert_identical(with_echo["detection"], original["detection"])
def test_selected_random_subclip_has_the_correct_width(
sample_preprocessor: PreprocessorProtocol,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
original = generate_train_example(
clip_annotation_1,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
subclip = select_subclip(original, width=100)
assert subclip["spectrogram"].shape[1] == 100
def test_add_echo_after_subclip(
sample_preprocessor: PreprocessorProtocol,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
recording1 = create_recording(duration=2)
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
original = generate_train_example(
clip_annotation_1,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
assert original.sizes["time"] > 512
subclip = select_subclip(original, width=512)
with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
assert with_echo.sizes["time"] == 512