mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Update augmentation tests to new structure
This commit is contained in:
parent
8a463e3942
commit
541be15c9e
@ -5,18 +5,19 @@ import pytest
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import arrays, data
|
from soundevent import arrays, data
|
||||||
|
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
add_echo,
|
add_echo,
|
||||||
mix_examples,
|
mix_examples,
|
||||||
select_subclip,
|
select_subclip,
|
||||||
)
|
)
|
||||||
from batdetect2.train.preprocess import (
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
TrainPreprocessingConfig,
|
from batdetect2.train.types import ClipLabeller
|
||||||
generate_train_example,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mix_examples(
|
def test_mix_examples(
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_labeller: ClipLabeller,
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
):
|
):
|
||||||
recording1 = create_recording()
|
recording1 = create_recording()
|
||||||
@ -28,22 +29,18 @@ def test_mix_examples(
|
|||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||||
|
|
||||||
config = TrainPreprocessingConfig()
|
|
||||||
|
|
||||||
example1 = generate_train_example(
|
example1 = generate_train_example(
|
||||||
clip_annotation_1,
|
clip_annotation_1,
|
||||||
preprocessing_config=config.preprocessing,
|
preprocessor=sample_preprocessor,
|
||||||
target_config=config.target,
|
labeller=sample_labeller,
|
||||||
label_config=config.labels,
|
|
||||||
)
|
)
|
||||||
example2 = generate_train_example(
|
example2 = generate_train_example(
|
||||||
clip_annotation_2,
|
clip_annotation_2,
|
||||||
preprocessing_config=config.preprocessing,
|
preprocessor=sample_preprocessor,
|
||||||
target_config=config.target,
|
labeller=sample_labeller,
|
||||||
label_config=config.labels,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
|
||||||
|
|
||||||
assert mixed["spectrogram"].shape == example1["spectrogram"].shape
|
assert mixed["spectrogram"].shape == example1["spectrogram"].shape
|
||||||
assert mixed["detection"].shape == example1["detection"].shape
|
assert mixed["detection"].shape == example1["detection"].shape
|
||||||
@ -54,6 +51,8 @@ def test_mix_examples(
|
|||||||
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
|
||||||
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
||||||
def test_mix_examples_of_different_durations(
|
def test_mix_examples_of_different_durations(
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_labeller: ClipLabeller,
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
duration1: float,
|
duration1: float,
|
||||||
duration2: float,
|
duration2: float,
|
||||||
@ -67,22 +66,18 @@ def test_mix_examples_of_different_durations(
|
|||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
clip_annotation_2 = data.ClipAnnotation(clip=clip2)
|
||||||
|
|
||||||
config = TrainPreprocessingConfig()
|
|
||||||
|
|
||||||
example1 = generate_train_example(
|
example1 = generate_train_example(
|
||||||
clip_annotation_1,
|
clip_annotation_1,
|
||||||
preprocessing_config=config.preprocessing,
|
preprocessor=sample_preprocessor,
|
||||||
target_config=config.target,
|
labeller=sample_labeller,
|
||||||
label_config=config.labels,
|
|
||||||
)
|
)
|
||||||
example2 = generate_train_example(
|
example2 = generate_train_example(
|
||||||
clip_annotation_2,
|
clip_annotation_2,
|
||||||
preprocessing_config=config.preprocessing,
|
preprocessor=sample_preprocessor,
|
||||||
target_config=config.target,
|
labeller=sample_labeller,
|
||||||
label_config=config.labels,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed = mix_examples(example1, example2, config=config.preprocessing)
|
mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor)
|
||||||
|
|
||||||
# Check the spectrogram has the expected duration
|
# Check the spectrogram has the expected duration
|
||||||
step = arrays.get_dim_step(mixed["spectrogram"], "time")
|
step = arrays.get_dim_step(mixed["spectrogram"], "time")
|
||||||
@ -92,19 +87,20 @@ def test_mix_examples_of_different_durations(
|
|||||||
|
|
||||||
|
|
||||||
def test_add_echo(
|
def test_add_echo(
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_labeller: ClipLabeller,
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
):
|
):
|
||||||
recording1 = create_recording()
|
recording1 = create_recording()
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
config = TrainPreprocessingConfig()
|
|
||||||
original = generate_train_example(
|
original = generate_train_example(
|
||||||
clip_annotation_1,
|
clip_annotation_1,
|
||||||
preprocessing_config=config.preprocessing,
|
preprocessor=sample_preprocessor,
|
||||||
target_config=config.target,
|
labeller=sample_labeller,
|
||||||
label_config=config.labels,
|
|
||||||
)
|
)
|
||||||
with_echo = add_echo(original, config=config.preprocessing)
|
with_echo = add_echo(original, preprocessor=sample_preprocessor)
|
||||||
|
|
||||||
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
|
assert with_echo["spectrogram"].shape == original["spectrogram"].shape
|
||||||
xr.testing.assert_identical(with_echo["size"], original["size"])
|
xr.testing.assert_identical(with_echo["size"], original["size"])
|
||||||
@ -113,17 +109,17 @@ def test_add_echo(
|
|||||||
|
|
||||||
|
|
||||||
def test_selected_random_subclip_has_the_correct_width(
|
def test_selected_random_subclip_has_the_correct_width(
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_labeller: ClipLabeller,
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
):
|
):
|
||||||
recording1 = create_recording()
|
recording1 = create_recording()
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
config = TrainPreprocessingConfig()
|
|
||||||
original = generate_train_example(
|
original = generate_train_example(
|
||||||
clip_annotation_1,
|
clip_annotation_1,
|
||||||
preprocessing_config=config.preprocessing,
|
preprocessor=sample_preprocessor,
|
||||||
target_config=config.target,
|
labeller=sample_labeller,
|
||||||
label_config=config.labels,
|
|
||||||
)
|
)
|
||||||
subclip = select_subclip(original, width=100)
|
subclip = select_subclip(original, width=100)
|
||||||
|
|
||||||
@ -131,22 +127,22 @@ def test_selected_random_subclip_has_the_correct_width(
|
|||||||
|
|
||||||
|
|
||||||
def test_add_echo_after_subclip(
|
def test_add_echo_after_subclip(
|
||||||
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_labeller: ClipLabeller,
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
):
|
):
|
||||||
recording1 = create_recording(duration=2)
|
recording1 = create_recording(duration=2)
|
||||||
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
|
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
|
||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
config = TrainPreprocessingConfig()
|
|
||||||
original = generate_train_example(
|
original = generate_train_example(
|
||||||
clip_annotation_1,
|
clip_annotation_1,
|
||||||
preprocessing_config=config.preprocessing,
|
preprocessor=sample_preprocessor,
|
||||||
target_config=config.target,
|
labeller=sample_labeller,
|
||||||
label_config=config.labels,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert original.sizes["time"] > 512
|
assert original.sizes["time"] > 512
|
||||||
|
|
||||||
subclip = select_subclip(original, width=512)
|
subclip = select_subclip(original, width=512)
|
||||||
with_echo = add_echo(subclip)
|
with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
|
||||||
|
|
||||||
assert with_echo.sizes["time"] == 512
|
assert with_echo.sizes["time"] == 512
|
||||||
|
Loading…
Reference in New Issue
Block a user