batdetect2/tests/test_preprocessing/test_preprocessor.py
2026-03-08 17:11:27 +00:00

244 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Integration and unit tests for the Preprocessor pipeline.
Covers :mod:`batdetect2.preprocess.preprocessor` — construction,
pipeline output shape/dtype, the ``process_numpy`` helper, attribute
values, output frame rate, and a round-trip YAML → config → build test.
"""
import pathlib
import numpy as np
import torch
from batdetect2.preprocess.audio import FixDurationConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.preprocessor import (
Preprocessor,
build_preprocessor,
compute_output_samplerate,
)
from batdetect2.preprocess.spectrogram import (
FrequencyConfig,
PcenConfig,
ResizeConfig,
SpectralMeanSubtractionConfig,
STFTConfig,
)
SAMPLERATE = 256_000
# 0.256 s at 256 kHz = 65536 samples — a convenient power-of-two-sized clip
CLIP_SAMPLES = int(SAMPLERATE * 0.256)
def make_sine_wav(
samplerate: int = SAMPLERATE,
duration: float = 0.256,
freq: float = 40_000.0,
) -> torch.Tensor:
"""Return a single-channel sine-wave tensor."""
t = torch.linspace(0.0, duration, int(samplerate * duration))
return torch.sin(2 * torch.pi * freq * t)
# ---------------------------------------------------------------------------
# build_preprocessor — construction
# ---------------------------------------------------------------------------
def test_build_preprocessor_returns_protocol():
"""build_preprocessor should return a Preprocessor instance."""
preprocessor = build_preprocessor()
assert isinstance(preprocessor, Preprocessor)
def test_build_preprocessor_with_default_config():
"""build_preprocessor() with no arguments should not raise."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
assert preprocessor is not None
def test_build_preprocessor_with_explicit_config():
config = PreprocessingConfig(
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
size=ResizeConfig(height=128, resize_factor=0.5),
spectrogram_transforms=[PcenConfig(), SpectralMeanSubtractionConfig()],
)
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
assert isinstance(preprocessor, Preprocessor)
# ---------------------------------------------------------------------------
# Output shape and dtype
# ---------------------------------------------------------------------------
def test_preprocessor_output_is_2d():
"""The preprocessor output should be a 2-D tensor (freq_bins × time_frames)."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav = make_sine_wav()
result = preprocessor(wav)
assert result.ndim == 2
def test_preprocessor_output_height_matches_config():
"""Output height should match the ResizeConfig.height setting."""
config = PreprocessingConfig(
size=ResizeConfig(height=64, resize_factor=0.5)
)
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
wav = make_sine_wav()
result = preprocessor(wav)
assert result.shape[0] == 64
def test_preprocessor_output_dtype_float32():
"""Output tensor should be float32."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav = make_sine_wav()
result = preprocessor(wav)
assert result.dtype == torch.float32
def test_preprocessor_output_is_finite():
"""Output spectrogram should contain no NaN or Inf values."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav = make_sine_wav()
result = preprocessor(wav)
assert torch.isfinite(result).all()
# ---------------------------------------------------------------------------
# process_numpy
# ---------------------------------------------------------------------------
def test_preprocessor_process_numpy_accepts_ndarray():
"""process_numpy should accept a NumPy array and return a NumPy array."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav_np = make_sine_wav().numpy()
result = preprocessor.process_numpy(wav_np)
assert isinstance(result, np.ndarray)
def test_preprocessor_process_numpy_matches_forward():
"""process_numpy and forward should give numerically identical results."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav = make_sine_wav()
result_pt = preprocessor(wav).numpy()
result_np = preprocessor.process_numpy(wav.numpy())
np.testing.assert_array_almost_equal(result_pt, result_np)
# ---------------------------------------------------------------------------
# Attributes
# ---------------------------------------------------------------------------
def test_preprocessor_min_max_freq_attributes():
"""min_freq and max_freq should match the FrequencyConfig values."""
config = PreprocessingConfig(
frequencies=FrequencyConfig(min_freq=15_000, max_freq=100_000)
)
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
assert preprocessor.min_freq == 15_000
assert preprocessor.max_freq == 100_000
def test_preprocessor_input_samplerate_attribute():
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
assert preprocessor.input_samplerate == SAMPLERATE
# ---------------------------------------------------------------------------
# compute_output_samplerate
# ---------------------------------------------------------------------------
def test_compute_output_samplerate_defaults():
"""At default settings, output_samplerate should equal 1000 fps."""
config = PreprocessingConfig()
rate = compute_output_samplerate(config, input_samplerate=SAMPLERATE)
assert abs(rate - 1000.0) < 1e-6
def test_preprocessor_output_samplerate_attribute_matches_compute():
config = PreprocessingConfig()
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
expected = compute_output_samplerate(config, input_samplerate=SAMPLERATE)
assert abs(preprocessor.output_samplerate - expected) < 1e-6
# ---------------------------------------------------------------------------
# generate_spectrogram (raw STFT)
# ---------------------------------------------------------------------------
def test_generate_spectrogram_shape():
"""generate_spectrogram should return the full STFT without crop or resize."""
config = PreprocessingConfig()
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
wav = make_sine_wav()
spec = preprocessor.generate_spectrogram(wav)
# Full STFT: n_fft//2 + 1 = 257 bins at defaults
assert spec.shape[0] == 257
def test_generate_spectrogram_larger_than_forward():
"""Raw spectrogram should have more frequency bins than the processed output."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav = make_sine_wav()
raw = preprocessor.generate_spectrogram(wav)
processed = preprocessor(wav)
assert raw.shape[0] > processed.shape[0]
# ---------------------------------------------------------------------------
# Audio transforms pipeline (FixDuration)
# ---------------------------------------------------------------------------
def test_preprocessor_with_fix_duration_audio_transform():
"""A FixDuration audio transform should produce consistent output shapes."""
config = PreprocessingConfig(
audio_transforms=[FixDurationConfig(duration=0.256)],
)
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
# Feed different lengths; output shape should be the same after fix
for n_samples in [CLIP_SAMPLES - 1000, CLIP_SAMPLES, CLIP_SAMPLES + 1000]:
wav = torch.randn(n_samples)
result = preprocessor(wav)
assert result.ndim == 2
# ---------------------------------------------------------------------------
# YAML round-trip
# ---------------------------------------------------------------------------
def test_preprocessor_yaml_roundtrip(tmp_path: pathlib.Path):
"""PreprocessingConfig serialised to YAML and reloaded should produce
a functionally identical preprocessor."""
from batdetect2.preprocess.config import load_preprocessing_config
config = PreprocessingConfig(
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
size=ResizeConfig(height=128, resize_factor=0.5),
)
yaml_path = tmp_path / "preprocess_config.yaml"
yaml_path.write_text(config.to_yaml_string())
loaded_config = load_preprocessing_config(yaml_path)
preprocessor = build_preprocessor(
loaded_config, input_samplerate=SAMPLERATE
)
wav = make_sine_wav()
result = preprocessor(wav)
assert result.ndim == 2
assert result.shape[0] == 128
assert torch.isfinite(result).all()