Update augmentation tests to new structure

This commit is contained in:
mbsantiago 2025-04-22 09:00:57 +01:00
parent 8a463e3942
commit 541be15c9e

View File

@ -5,18 +5,19 @@ import pytest
import xarray as xr import xarray as xr
from soundevent import arrays, data from soundevent import arrays, data
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
add_echo, add_echo,
mix_examples, mix_examples,
select_subclip, select_subclip,
) )
from batdetect2.train.preprocess import ( from batdetect2.train.preprocess import generate_train_example
TrainPreprocessingConfig, from batdetect2.train.types import ClipLabeller
generate_train_example,
)
def test_mix_examples( def test_mix_examples(
sample_preprocessor: PreprocessorProtocol,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
): ):
recording1 = create_recording() recording1 = create_recording()
@ -28,22 +29,18 @@ def test_mix_examples(
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2) clip_annotation_2 = data.ClipAnnotation(clip=clip2)
config = TrainPreprocessingConfig()
example1 = generate_train_example( example1 = generate_train_example(
clip_annotation_1, clip_annotation_1,
preprocessing_config=config.preprocessing, preprocessor=sample_preprocessor,
target_config=config.target, labeller=sample_labeller,
label_config=config.labels,
) )
example2 = generate_train_example( example2 = generate_train_example(
clip_annotation_2, clip_annotation_2,
preprocessing_config=config.preprocessing, preprocessor=sample_preprocessor,
target_config=config.target, labeller=sample_labeller,
label_config=config.labels,
) )
mixed = mix_examples(example1, example2, config=config.preprocessing) mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
assert mixed["spectrogram"].shape == example1["spectrogram"].shape assert mixed["spectrogram"].shape == example1["spectrogram"].shape
assert mixed["detection"].shape == example1["detection"].shape assert mixed["detection"].shape == example1["detection"].shape
@ -54,6 +51,8 @@ def test_mix_examples(
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7]) @pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7]) @pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
def test_mix_examples_of_different_durations( def test_mix_examples_of_different_durations(
sample_preprocessor: PreprocessorProtocol,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
duration1: float, duration1: float,
duration2: float, duration2: float,
@ -67,22 +66,18 @@ def test_mix_examples_of_different_durations(
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
clip_annotation_2 = data.ClipAnnotation(clip=clip2) clip_annotation_2 = data.ClipAnnotation(clip=clip2)
config = TrainPreprocessingConfig()
example1 = generate_train_example( example1 = generate_train_example(
clip_annotation_1, clip_annotation_1,
preprocessing_config=config.preprocessing, preprocessor=sample_preprocessor,
target_config=config.target, labeller=sample_labeller,
label_config=config.labels,
) )
example2 = generate_train_example( example2 = generate_train_example(
clip_annotation_2, clip_annotation_2,
preprocessing_config=config.preprocessing, preprocessor=sample_preprocessor,
target_config=config.target, labeller=sample_labeller,
label_config=config.labels,
) )
mixed = mix_examples(example1, example2, config=config.preprocessing) mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
# Check the spectrogram has the expected duration # Check the spectrogram has the expected duration
step = arrays.get_dim_step(mixed["spectrogram"], "time") step = arrays.get_dim_step(mixed["spectrogram"], "time")
@ -92,19 +87,20 @@ def test_mix_examples_of_different_durations(
def test_add_echo( def test_add_echo(
sample_preprocessor: PreprocessorProtocol,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
): ):
recording1 = create_recording() recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig()
original = generate_train_example( original = generate_train_example(
clip_annotation_1, clip_annotation_1,
preprocessing_config=config.preprocessing, preprocessor=sample_preprocessor,
target_config=config.target, labeller=sample_labeller,
label_config=config.labels,
) )
with_echo = add_echo(original, config=config.preprocessing) with_echo = add_echo(original, preprocessor=sample_preprocessor)
assert with_echo["spectrogram"].shape == original["spectrogram"].shape assert with_echo["spectrogram"].shape == original["spectrogram"].shape
xr.testing.assert_identical(with_echo["size"], original["size"]) xr.testing.assert_identical(with_echo["size"], original["size"])
@ -113,17 +109,17 @@ def test_add_echo(
def test_selected_random_subclip_has_the_correct_width( def test_selected_random_subclip_has_the_correct_width(
sample_preprocessor: PreprocessorProtocol,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
): ):
recording1 = create_recording() recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig()
original = generate_train_example( original = generate_train_example(
clip_annotation_1, clip_annotation_1,
preprocessing_config=config.preprocessing, preprocessor=sample_preprocessor,
target_config=config.target, labeller=sample_labeller,
label_config=config.labels,
) )
subclip = select_subclip(original, width=100) subclip = select_subclip(original, width=100)
@ -131,22 +127,22 @@ def test_selected_random_subclip_has_the_correct_width(
def test_add_echo_after_subclip( def test_add_echo_after_subclip(
sample_preprocessor: PreprocessorProtocol,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
): ):
recording1 = create_recording(duration=2) recording1 = create_recording(duration=2)
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1) clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig()
original = generate_train_example( original = generate_train_example(
clip_annotation_1, clip_annotation_1,
preprocessing_config=config.preprocessing, preprocessor=sample_preprocessor,
target_config=config.target, labeller=sample_labeller,
label_config=config.labels,
) )
assert original.sizes["time"] > 512 assert original.sizes["time"] > 512
subclip = select_subclip(original, width=512) subclip = select_subclip(original, width=512)
with_echo = add_echo(subclip) with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
assert with_echo.sizes["time"] == 512 assert with_echo.sizes["time"] == 512