mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Update tests
This commit is contained in:
parent
46c02962f3
commit
0d590a26cc
@ -83,17 +83,17 @@ class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
super().__init__()
|
||||
|
||||
self.audio_transforms = torch.nn.Sequential(
|
||||
*[
|
||||
*(
|
||||
build_audio_transform(step, samplerate=input_samplerate)
|
||||
for step in config.audio_transforms
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.spectrogram_transforms = torch.nn.Sequential(
|
||||
*[
|
||||
*(
|
||||
build_spectrogram_transform(step, samplerate=input_samplerate)
|
||||
for step in config.spectrogram_transforms
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.spectrogram_builder = build_spectrogram_builder(
|
||||
|
||||
@ -79,11 +79,6 @@ def default_audio_config() -> AudioConfig:
|
||||
return AudioConfig()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# center_tensor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_center_tensor_zero_mean():
|
||||
"""Output tensor should have a mean very close to zero."""
|
||||
wav = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
@ -97,11 +92,6 @@ def test_center_tensor_preserves_shape():
|
||||
assert result.shape == wav.shape
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# peak_normalize
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_peak_normalize_max_is_one():
|
||||
"""After peak normalisation, the maximum absolute value should be 1."""
|
||||
wav = torch.tensor([0.1, -0.4, 0.2, 0.8, -0.3])
|
||||
@ -116,25 +106,12 @@ def test_peak_normalize_zero_tensor_unchanged():
|
||||
assert (result == 0).all()
|
||||
|
||||
|
||||
def test_peak_normalize_preserves_sign():
|
||||
"""Signs of all elements should be preserved after normalisation."""
|
||||
wav = torch.tensor([-2.0, 1.0, -0.5])
|
||||
result = peak_normalize(wav)
|
||||
assert (result < 0).sum() == 2
|
||||
assert result[0].item() < 0
|
||||
|
||||
|
||||
def test_peak_normalize_preserves_shape():
|
||||
wav = torch.randn(2, 512)
|
||||
result = peak_normalize(wav)
|
||||
assert result.shape == wav.shape
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CenterAudio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_center_audio_forward_zero_mean():
|
||||
module = CenterAudio()
|
||||
wav = torch.tensor([1.0, 3.0, 5.0])
|
||||
@ -148,15 +125,10 @@ def test_center_audio_from_config():
|
||||
assert isinstance(module, CenterAudio)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ScaleAudio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_scale_audio_peak_normalises_to_one():
|
||||
"""ScaleAudio.forward should scale the peak absolute value to 1."""
|
||||
module = ScaleAudio()
|
||||
wav = torch.tensor([0.0, 0.25, -0.5, 0.1])
|
||||
wav = torch.tensor([0.0, 0.25, 0.1])
|
||||
result = module(wav)
|
||||
assert abs(result.abs().max().item() - 1.0) < 1e-6
|
||||
|
||||
@ -175,11 +147,6 @@ def test_scale_audio_from_config():
|
||||
assert isinstance(module, ScaleAudio)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FixDuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_fix_duration_truncates_long_input():
|
||||
"""Waveform longer than target should be truncated to the target length."""
|
||||
target_samples = int(SAMPLERATE * 0.5)
|
||||
@ -218,11 +185,6 @@ def test_fix_duration_from_config():
|
||||
assert module.length == int(SAMPLERATE * 0.256)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_audio_transform dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_audio_transform_center_audio():
|
||||
config = CenterAudioConfig()
|
||||
module = build_audio_transform(config, samplerate=SAMPLERATE)
|
||||
|
||||
@ -26,7 +26,6 @@ from batdetect2.preprocess.spectrogram import (
|
||||
)
|
||||
|
||||
SAMPLERATE = 256_000
|
||||
# 0.256 s at 256 kHz = 65536 samples — a convenient power-of-two-sized clip
|
||||
CLIP_SAMPLES = int(SAMPLERATE * 0.256)
|
||||
|
||||
|
||||
@ -40,11 +39,6 @@ def make_sine_wav(
|
||||
return torch.sin(2 * torch.pi * freq * t)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_preprocessor — construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_preprocessor_returns_protocol():
|
||||
"""build_preprocessor should return a Preprocessor instance."""
|
||||
preprocessor = build_preprocessor()
|
||||
@ -68,11 +62,6 @@ def test_build_preprocessor_with_explicit_config():
|
||||
assert isinstance(preprocessor, Preprocessor)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Output shape and dtype
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_preprocessor_output_is_2d():
|
||||
"""The preprocessor output should be a 2-D tensor (freq_bins × time_frames)."""
|
||||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||
@ -108,11 +97,6 @@ def test_preprocessor_output_is_finite():
|
||||
assert torch.isfinite(result).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_numpy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_preprocessor_process_numpy_accepts_ndarray():
|
||||
"""process_numpy should accept a NumPy array and return a NumPy array."""
|
||||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||
@ -130,11 +114,6 @@ def test_preprocessor_process_numpy_matches_forward():
|
||||
np.testing.assert_array_almost_equal(result_pt, result_np)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attributes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_preprocessor_min_max_freq_attributes():
|
||||
"""min_freq and max_freq should match the FrequencyConfig values."""
|
||||
config = PreprocessingConfig(
|
||||
@ -150,11 +129,6 @@ def test_preprocessor_input_samplerate_attribute():
|
||||
assert preprocessor.input_samplerate == SAMPLERATE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_output_samplerate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_compute_output_samplerate_defaults():
|
||||
"""At default settings, output_samplerate should equal 1000 fps."""
|
||||
config = PreprocessingConfig()
|
||||
@ -169,11 +143,6 @@ def test_preprocessor_output_samplerate_attribute_matches_compute():
|
||||
assert abs(preprocessor.output_samplerate - expected) < 1e-6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_spectrogram (raw STFT)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_generate_spectrogram_shape():
|
||||
"""generate_spectrogram should return the full STFT without crop or resize."""
|
||||
config = PreprocessingConfig()
|
||||
@ -193,29 +162,18 @@ def test_generate_spectrogram_larger_than_forward():
|
||||
assert raw.shape[0] > processed.shape[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio transforms pipeline (FixDuration)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_preprocessor_with_fix_duration_audio_transform():
|
||||
"""A FixDuration audio transform should produce consistent output shapes."""
|
||||
config = PreprocessingConfig(
|
||||
audio_transforms=[FixDurationConfig(duration=0.256)],
|
||||
)
|
||||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||
# Feed different lengths; output shape should be the same after fix
|
||||
for n_samples in [CLIP_SAMPLES - 1000, CLIP_SAMPLES, CLIP_SAMPLES + 1000]:
|
||||
wav = torch.randn(n_samples)
|
||||
result = preprocessor(wav)
|
||||
assert result.ndim == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_preprocessor_yaml_roundtrip(tmp_path: pathlib.Path):
|
||||
"""PreprocessingConfig serialised to YAML and reloaded should produce
|
||||
a functionally identical preprocessor."""
|
||||
|
||||
@ -31,11 +31,6 @@ from batdetect2.preprocess.spectrogram import (
|
||||
SAMPLERATE = 256_000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# STFTConfig / _spec_params_from_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stft_config_defaults_give_correct_params():
|
||||
"""Default STFTConfig at 256 kHz should give n_fft=512, hop_length=128."""
|
||||
config = STFTConfig()
|
||||
@ -52,11 +47,6 @@ def test_stft_config_custom_params():
|
||||
assert hop_length == 512
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_spectrogram_builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_spectrogram_builder_output_shape():
|
||||
"""Builder should produce a spectrogram with the expected number of bins."""
|
||||
config = STFTConfig()
|
||||
@ -81,11 +71,6 @@ def test_spectrogram_builder_output_is_nonnegative():
|
||||
assert (spec >= 0).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FrequencyCrop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_frequency_crop_output_shape():
|
||||
"""FrequencyCrop should reduce the number of frequency bins."""
|
||||
config = STFTConfig()
|
||||
@ -128,11 +113,6 @@ def test_frequency_crop_no_crop_when_bounds_are_none():
|
||||
assert cropped.shape == spec.shape
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PCEN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_pcen_output_shape_preserved():
|
||||
"""PCEN should not change the shape of the spectrogram."""
|
||||
config = PcenConfig()
|
||||
@ -160,11 +140,6 @@ def test_pcen_output_dtype_matches_input():
|
||||
assert result.dtype == spec.dtype
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SpectralMeanSubtraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_spectral_mean_subtraction_output_nonnegative():
|
||||
"""SpectralMeanSubtraction clamps output to >= 0."""
|
||||
module = SpectralMeanSubtraction()
|
||||
@ -195,11 +170,6 @@ def test_spectral_mean_subtraction_from_config():
|
||||
assert isinstance(module, SpectralMeanSubtraction)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PeakNormalize (spectrogram-level)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_peak_normalize_spec_max_is_one():
|
||||
"""PeakNormalize should scale the spectrogram peak to 1."""
|
||||
module = PeakNormalize()
|
||||
@ -222,11 +192,6 @@ def test_peak_normalize_from_config():
|
||||
assert isinstance(module, PeakNormalize)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ScaleAmplitude
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_scale_amplitude_db_output_is_finite():
|
||||
"""AmplitudeToDB scaling should produce finite values for positive input."""
|
||||
module = ScaleAmplitude(scale="db")
|
||||
@ -251,11 +216,6 @@ def test_scale_amplitude_from_config():
|
||||
assert module.scale == "db"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ResizeSpec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resize_spec_output_shape():
|
||||
"""ResizeSpec should produce the target height and scaled width."""
|
||||
module = ResizeSpec(height=64, time_factor=0.5)
|
||||
@ -287,11 +247,6 @@ def test_resize_spec_from_config():
|
||||
assert module.time_factor == 0.25
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_spectrogram_transform dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_spectrogram_transform_pcen():
|
||||
config = PcenConfig()
|
||||
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user