From 541be15c9e428c88bb44b1b6afc2b8fa8a5a4e8c Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 22 Apr 2025 09:00:57 +0100 Subject: [PATCH] Update augmentation tests to new structure --- tests/test_train/test_augmentations.py | 68 ++++++++++++-------------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index f5e04f6..5560d85 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -5,18 +5,19 @@ import pytest import xarray as xr from soundevent import arrays, data +from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.train.augmentations import ( add_echo, mix_examples, select_subclip, ) -from batdetect2.train.preprocess import ( - TrainPreprocessingConfig, - generate_train_example, -) +from batdetect2.train.preprocess import generate_train_example +from batdetect2.train.types import ClipLabeller def test_mix_examples( + sample_preprocessor: PreprocessorProtocol, + sample_labeller: ClipLabeller, create_recording: Callable[..., data.Recording], ): recording1 = create_recording() @@ -28,22 +29,18 @@ def test_mix_examples( clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_2 = data.ClipAnnotation(clip=clip2) - config = TrainPreprocessingConfig() - example1 = generate_train_example( clip_annotation_1, - preprocessing_config=config.preprocessing, - target_config=config.target, - label_config=config.labels, + preprocessor=sample_preprocessor, + labeller=sample_labeller, ) example2 = generate_train_example( clip_annotation_2, - preprocessing_config=config.preprocessing, - target_config=config.target, - label_config=config.labels, + preprocessor=sample_preprocessor, + labeller=sample_labeller, ) - mixed = mix_examples(example1, example2, config=config.preprocessing) + mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor) assert mixed["spectrogram"].shape == example1["spectrogram"].shape assert mixed["detection"].shape == example1["detection"].shape @@ -54,6 +51,8 @@ def test_mix_examples( @pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7]) @pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7]) def test_mix_examples_of_different_durations( + sample_preprocessor: PreprocessorProtocol, + sample_labeller: ClipLabeller, create_recording: Callable[..., data.Recording], duration1: float, duration2: float, @@ -67,22 +66,18 @@ def test_mix_examples_of_different_durations( clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_2 = data.ClipAnnotation(clip=clip2) - config = TrainPreprocessingConfig() - example1 = generate_train_example( clip_annotation_1, - preprocessing_config=config.preprocessing, - target_config=config.target, - label_config=config.labels, + preprocessor=sample_preprocessor, + labeller=sample_labeller, ) example2 = generate_train_example( clip_annotation_2, - preprocessing_config=config.preprocessing, - target_config=config.target, - label_config=config.labels, + preprocessor=sample_preprocessor, + labeller=sample_labeller, ) - mixed = mix_examples(example1, example2, config=config.preprocessing) + mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor) # Check the spectrogram has the expected duration step = arrays.get_dim_step(mixed["spectrogram"], "time") @@ -92,19 +87,20 @@ def test_mix_examples_of_different_durations( def test_add_echo( + sample_preprocessor: PreprocessorProtocol, + sample_labeller: ClipLabeller, create_recording: Callable[..., data.Recording], ): recording1 = create_recording() clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip_annotation_1 = data.ClipAnnotation(clip=clip1) - config = TrainPreprocessingConfig() + original = generate_train_example( clip_annotation_1, - preprocessing_config=config.preprocessing, - target_config=config.target, - label_config=config.labels, + preprocessor=sample_preprocessor, + labeller=sample_labeller, ) - with_echo = add_echo(original, config=config.preprocessing) + with_echo = add_echo(original, preprocessor=sample_preprocessor) assert with_echo["spectrogram"].shape == original["spectrogram"].shape xr.testing.assert_identical(with_echo["size"], original["size"]) @@ -113,17 +109,17 @@ def test_add_echo( def test_selected_random_subclip_has_the_correct_width( + sample_preprocessor: PreprocessorProtocol, + sample_labeller: ClipLabeller, create_recording: Callable[..., data.Recording], ): recording1 = create_recording() clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip_annotation_1 = data.ClipAnnotation(clip=clip1) - config = TrainPreprocessingConfig() original = generate_train_example( clip_annotation_1, - preprocessing_config=config.preprocessing, - target_config=config.target, - label_config=config.labels, + preprocessor=sample_preprocessor, + labeller=sample_labeller, ) subclip = select_subclip(original, width=100) @@ -131,22 +127,22 @@ def test_selected_random_subclip_has_the_correct_width( def test_add_echo_after_subclip( + sample_preprocessor: PreprocessorProtocol, + sample_labeller: ClipLabeller, create_recording: Callable[..., data.Recording], ): recording1 = create_recording(duration=2) clip1 = data.Clip(recording=recording1, start_time=0, end_time=1) clip_annotation_1 = data.ClipAnnotation(clip=clip1) - config = TrainPreprocessingConfig() original = generate_train_example( clip_annotation_1, - preprocessing_config=config.preprocessing, - target_config=config.target, - label_config=config.labels, + preprocessor=sample_preprocessor, + labeller=sample_labeller, ) assert original.sizes["time"] > 512 subclip = select_subclip(original, width=512) - with_echo = add_echo(subclip) + with_echo = add_echo(subclip, preprocessor=sample_preprocessor) assert with_echo.sizes["time"] == 512