mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
244 lines
8.7 KiB
Python
244 lines
8.7 KiB
Python
"""Integration and unit tests for the Preprocessor pipeline.
|
||
|
||
Covers :mod:`batdetect2.preprocess.preprocessor` — construction,
|
||
pipeline output shape/dtype, the ``process_numpy`` helper, attribute
|
||
values, output frame rate, and a round-trip YAML → config → build test.
|
||
"""
|
||
|
||
import pathlib
|
||
|
||
import numpy as np
|
||
import torch
|
||
|
||
from batdetect2.preprocess.audio import FixDurationConfig
|
||
from batdetect2.preprocess.config import PreprocessingConfig
|
||
from batdetect2.preprocess.preprocessor import (
|
||
Preprocessor,
|
||
build_preprocessor,
|
||
compute_output_samplerate,
|
||
)
|
||
from batdetect2.preprocess.spectrogram import (
|
||
FrequencyConfig,
|
||
PcenConfig,
|
||
ResizeConfig,
|
||
SpectralMeanSubtractionConfig,
|
||
STFTConfig,
|
||
)
|
||
|
||
SAMPLERATE = 256_000
|
||
# 0.256 s at 256 kHz = 65536 samples — a convenient power-of-two-sized clip
|
||
CLIP_SAMPLES = int(SAMPLERATE * 0.256)
|
||
|
||
|
||
def make_sine_wav(
|
||
samplerate: int = SAMPLERATE,
|
||
duration: float = 0.256,
|
||
freq: float = 40_000.0,
|
||
) -> torch.Tensor:
|
||
"""Return a single-channel sine-wave tensor."""
|
||
t = torch.linspace(0.0, duration, int(samplerate * duration))
|
||
return torch.sin(2 * torch.pi * freq * t)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# build_preprocessor — construction
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def test_build_preprocessor_returns_protocol():
|
||
"""build_preprocessor should return a Preprocessor instance."""
|
||
preprocessor = build_preprocessor()
|
||
assert isinstance(preprocessor, Preprocessor)
|
||
|
||
|
||
def test_build_preprocessor_with_default_config():
|
||
"""build_preprocessor() with no arguments should not raise."""
|
||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||
assert preprocessor is not None
|
||
|
||
|
||
def test_build_preprocessor_with_explicit_config():
|
||
config = PreprocessingConfig(
|
||
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
|
||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||
size=ResizeConfig(height=128, resize_factor=0.5),
|
||
spectrogram_transforms=[PcenConfig(), SpectralMeanSubtractionConfig()],
|
||
)
|
||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||
assert isinstance(preprocessor, Preprocessor)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Output shape and dtype
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def test_preprocessor_output_is_2d():
|
||
"""The preprocessor output should be a 2-D tensor (freq_bins × time_frames)."""
|
||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||
wav = make_sine_wav()
|
||
result = preprocessor(wav)
|
||
assert result.ndim == 2
|
||
|
||
|
||
def test_preprocessor_output_height_matches_config():
|
||
"""Output height should match the ResizeConfig.height setting."""
|
||
config = PreprocessingConfig(
|
||
size=ResizeConfig(height=64, resize_factor=0.5)
|
||
)
|
||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||
wav = make_sine_wav()
|
||
result = preprocessor(wav)
|
||
assert result.shape[0] == 64
|
||
|
||
|
||
def test_preprocessor_output_dtype_float32():
|
||
"""Output tensor should be float32."""
|
||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||
wav = make_sine_wav()
|
||
result = preprocessor(wav)
|
||
assert result.dtype == torch.float32
|
||
|
||
|
||
def test_preprocessor_output_is_finite():
|
||
"""Output spectrogram should contain no NaN or Inf values."""
|
||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||
wav = make_sine_wav()
|
||
result = preprocessor(wav)
|
||
assert torch.isfinite(result).all()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# process_numpy
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def test_preprocessor_process_numpy_accepts_ndarray():
|
||
"""process_numpy should accept a NumPy array and return a NumPy array."""
|
||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||
wav_np = make_sine_wav().numpy()
|
||
result = preprocessor.process_numpy(wav_np)
|
||
assert isinstance(result, np.ndarray)
|
||
|
||
|
||
def test_preprocessor_process_numpy_matches_forward():
|
||
"""process_numpy and forward should give numerically identical results."""
|
||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||
wav = make_sine_wav()
|
||
result_pt = preprocessor(wav).numpy()
|
||
result_np = preprocessor.process_numpy(wav.numpy())
|
||
np.testing.assert_array_almost_equal(result_pt, result_np)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Attributes
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def test_preprocessor_min_max_freq_attributes():
|
||
"""min_freq and max_freq should match the FrequencyConfig values."""
|
||
config = PreprocessingConfig(
|
||
frequencies=FrequencyConfig(min_freq=15_000, max_freq=100_000)
|
||
)
|
||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||
assert preprocessor.min_freq == 15_000
|
||
assert preprocessor.max_freq == 100_000
|
||
|
||
|
||
def test_preprocessor_input_samplerate_attribute():
|
||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||
assert preprocessor.input_samplerate == SAMPLERATE
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# compute_output_samplerate
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def test_compute_output_samplerate_defaults():
|
||
"""At default settings, output_samplerate should equal 1000 fps."""
|
||
config = PreprocessingConfig()
|
||
rate = compute_output_samplerate(config, input_samplerate=SAMPLERATE)
|
||
assert abs(rate - 1000.0) < 1e-6
|
||
|
||
|
||
def test_preprocessor_output_samplerate_attribute_matches_compute():
|
||
config = PreprocessingConfig()
|
||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||
expected = compute_output_samplerate(config, input_samplerate=SAMPLERATE)
|
||
assert abs(preprocessor.output_samplerate - expected) < 1e-6
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# generate_spectrogram (raw STFT)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def test_generate_spectrogram_shape():
|
||
"""generate_spectrogram should return the full STFT without crop or resize."""
|
||
config = PreprocessingConfig()
|
||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||
wav = make_sine_wav()
|
||
spec = preprocessor.generate_spectrogram(wav)
|
||
# Full STFT: n_fft//2 + 1 = 257 bins at defaults
|
||
assert spec.shape[0] == 257
|
||
|
||
|
||
def test_generate_spectrogram_larger_than_forward():
|
||
"""Raw spectrogram should have more frequency bins than the processed output."""
|
||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||
wav = make_sine_wav()
|
||
raw = preprocessor.generate_spectrogram(wav)
|
||
processed = preprocessor(wav)
|
||
assert raw.shape[0] > processed.shape[0]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Audio transforms pipeline (FixDuration)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def test_preprocessor_with_fix_duration_audio_transform():
|
||
"""A FixDuration audio transform should produce consistent output shapes."""
|
||
config = PreprocessingConfig(
|
||
audio_transforms=[FixDurationConfig(duration=0.256)],
|
||
)
|
||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||
# Feed different lengths; output shape should be the same after fix
|
||
for n_samples in [CLIP_SAMPLES - 1000, CLIP_SAMPLES, CLIP_SAMPLES + 1000]:
|
||
wav = torch.randn(n_samples)
|
||
result = preprocessor(wav)
|
||
assert result.ndim == 2
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# YAML round-trip
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def test_preprocessor_yaml_roundtrip(tmp_path: pathlib.Path):
|
||
"""PreprocessingConfig serialised to YAML and reloaded should produce
|
||
a functionally identical preprocessor."""
|
||
from batdetect2.preprocess.config import load_preprocessing_config
|
||
|
||
config = PreprocessingConfig(
|
||
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
|
||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||
size=ResizeConfig(height=128, resize_factor=0.5),
|
||
)
|
||
|
||
yaml_path = tmp_path / "preprocess_config.yaml"
|
||
yaml_path.write_text(config.to_yaml_string())
|
||
|
||
loaded_config = load_preprocessing_config(yaml_path)
|
||
|
||
preprocessor = build_preprocessor(
|
||
loaded_config, input_samplerate=SAMPLERATE
|
||
)
|
||
wav = make_sine_wav()
|
||
result = preprocessor(wav)
|
||
|
||
assert result.ndim == 2
|
||
assert result.shape[0] == 128
|
||
assert torch.isfinite(result).all()
|