From 1cec332dd57291fef061f2e5fb36822e6238c435 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 30 Aug 2025 14:08:00 +0100 Subject: [PATCH] Change default train duration to 0.256 instead of 0.512 --- src/batdetect2/train/clips.py | 2 +- tests/conftest.py | 9 +++++++++ tests/test_train/test_clips.py | 28 ++++++++++++++++++++++++++-- 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index 9ed9147..04da59f 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -10,7 +10,7 @@ from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.train import PreprocessedExample from batdetect2.utils.arrays import adjust_width, slice_tensor -DEFAULT_TRAIN_CLIP_DURATION = 0.512 +DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_MAX_EMPTY_CLIP = 0.1 diff --git a/tests/conftest.py b/tests/conftest.py index 49bea5b..d036c62 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ from batdetect2.targets import ( from batdetect2.targets.classes import ClassesConfig, TargetClass from batdetect2.targets.filtering import FilterConfig, FilterRule from batdetect2.targets.terms import TagInfo +from batdetect2.train.clips import build_clipper from batdetect2.train.labels import build_clip_labeler from batdetect2.typing import ( ClipLabeller, @@ -29,6 +30,7 @@ from batdetect2.typing import ( TargetProtocol, ) from batdetect2.typing.preprocess import AudioLoader +from batdetect2.typing.train import ClipperProtocol @pytest.fixture @@ -440,6 +442,13 @@ def sample_labeller( ) +@pytest.fixture +def sample_clipper( + sample_preprocessor: PreprocessorProtocol, +) -> ClipperProtocol: + return build_clipper(preprocessor=sample_preprocessor) + + @pytest.fixture def example_dataset(example_data_dir: Path) -> DatasetConfig: return DatasetConfig( diff --git a/tests/test_train/test_clips.py b/tests/test_train/test_clips.py index 0f4f4f4..b6f7953 100644 --- a/tests/test_train/test_clips.py +++ b/tests/test_train/test_clips.py @@ -1,3 +1,27 @@ -import numpy as np +from soundevent import data -from batdetect2.train.clips import select_subclip +from batdetect2.train import generate_train_example +from batdetect2.typing import ( + AudioLoader, + ClipLabeller, + ClipperProtocol, + PreprocessorProtocol, +) + + +def test_default_clip_size_is_correct( + sample_clipper: ClipperProtocol, + sample_labeller: ClipLabeller, + sample_audio_loader: AudioLoader, + clip_annotation: data.ClipAnnotation, + sample_preprocessor: PreprocessorProtocol, +): + example = generate_train_example( + clip_annotation=clip_annotation, + audio_loader=sample_audio_loader, + preprocessor=sample_preprocessor, + labeller=sample_labeller, + ) + + clip, _, _ = sample_clipper(example) + assert clip.spectrogram.shape == (1, 128, 256)