batdetect2/tests/test_preprocessing/test_audio.py
2026-03-08 18:41:05 +00:00

204 lines
6.1 KiB
Python

"""Tests for audio-level preprocessing transforms.
Covers :mod:`batdetect2.preprocess.audio` and the shared helper functions
in :mod:`batdetect2.preprocess.common`.
"""
import pathlib
import uuid
import numpy as np
import pytest
import soundfile as sf
import torch
from soundevent import data
from batdetect2.audio import AudioConfig
from batdetect2.preprocess.audio import (
CenterAudio,
CenterAudioConfig,
FixDuration,
FixDurationConfig,
ScaleAudio,
ScaleAudioConfig,
build_audio_transform,
)
from batdetect2.preprocess.common import center_tensor, peak_normalize
SAMPLERATE = 256_000
def create_dummy_wave(
samplerate: int,
duration: float,
num_channels: int = 1,
freq: float = 440.0,
amplitude: float = 0.5,
dtype: type = np.float32,
) -> np.ndarray:
"""Generate a simple sine-wave waveform as a NumPy array."""
t = np.linspace(
0.0, duration, int(samplerate * duration), endpoint=False, dtype=dtype
)
wave = amplitude * np.sin(2 * np.pi * freq * t)
if num_channels > 1:
wave = np.stack([wave] * num_channels, axis=0)
return wave.astype(dtype)
@pytest.fixture
def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
"""Create a dummy 2-channel WAV file and return its path."""
samplerate = 48000
duration = 2.0
num_channels = 2
wave_data = create_dummy_wave(samplerate, duration, num_channels)
file_path = tmp_path / f"{uuid.uuid4()}.wav"
sf.write(file_path, wave_data.T, samplerate, format="WAV", subtype="FLOAT")
return file_path
@pytest.fixture
def dummy_recording(dummy_wav_path: pathlib.Path) -> data.Recording:
"""Create a Recording object pointing to the dummy WAV file."""
return data.Recording.from_file(dummy_wav_path)
@pytest.fixture
def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
"""Create a Clip object from the dummy recording."""
return data.Clip(
recording=dummy_recording,
start_time=0.5,
end_time=1.5,
)
@pytest.fixture
def default_audio_config() -> AudioConfig:
return AudioConfig()
def test_center_tensor_zero_mean():
"""Output tensor should have a mean very close to zero."""
wav = torch.tensor([1.0, 2.0, 3.0, 4.0])
result = center_tensor(wav)
assert result.mean().abs().item() < 1e-5
def test_center_tensor_preserves_shape():
wav = torch.randn(3, 1000)
result = center_tensor(wav)
assert result.shape == wav.shape
def test_peak_normalize_max_is_one():
"""After peak normalisation, the maximum absolute value should be 1."""
wav = torch.tensor([0.1, -0.4, 0.2, 0.8, -0.3])
result = peak_normalize(wav)
assert abs(result.abs().max().item() - 1.0) < 1e-6
def test_peak_normalize_zero_tensor_unchanged():
"""A zero tensor should be returned unchanged (no division by zero)."""
wav = torch.zeros(100)
result = peak_normalize(wav)
assert (result == 0).all()
def test_peak_normalize_preserves_shape():
wav = torch.randn(2, 512)
result = peak_normalize(wav)
assert result.shape == wav.shape
def test_center_audio_forward_zero_mean():
module = CenterAudio()
wav = torch.tensor([1.0, 3.0, 5.0])
result = module(wav)
assert result.mean().abs().item() < 1e-5
def test_center_audio_from_config():
config = CenterAudioConfig()
module = CenterAudio.from_config(config, samplerate=SAMPLERATE)
assert isinstance(module, CenterAudio)
def test_scale_audio_peak_normalises_to_one():
"""ScaleAudio.forward should scale the peak absolute value to 1."""
module = ScaleAudio()
wav = torch.tensor([0.0, 0.25, 0.1])
result = module(wav)
assert abs(result.abs().max().item() - 1.0) < 1e-6
def test_scale_audio_handles_zero_tensor():
"""ScaleAudio should not raise on a zero tensor."""
module = ScaleAudio()
wav = torch.zeros(100)
result = module(wav)
assert (result == 0).all()
def test_scale_audio_from_config():
config = ScaleAudioConfig()
module = ScaleAudio.from_config(config, samplerate=SAMPLERATE)
assert isinstance(module, ScaleAudio)
def test_fix_duration_truncates_long_input():
"""Waveform longer than target should be truncated to the target length."""
target_samples = int(SAMPLERATE * 0.5)
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
wav = torch.randn(target_samples + 1000)
result = module(wav)
assert result.shape[-1] == target_samples
def test_fix_duration_pads_short_input():
"""Waveform shorter than target should be zero-padded to the target length."""
target_samples = int(SAMPLERATE * 0.5)
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
short_wav = torch.randn(target_samples - 100)
result = module(short_wav)
assert result.shape[-1] == target_samples
# Padded region should be zero
assert (result[target_samples - 100 :] == 0).all()
def test_fix_duration_passthrough_exact_length():
"""Waveform with exactly the right length should be returned unchanged."""
target_samples = int(SAMPLERATE * 0.5)
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
wav = torch.randn(target_samples)
result = module(wav)
assert result.shape[-1] == target_samples
assert torch.equal(result, wav)
def test_fix_duration_from_config():
"""FixDurationConfig should produce a FixDuration with the correct length."""
config = FixDurationConfig(duration=0.256)
module = FixDuration.from_config(config, samplerate=SAMPLERATE)
assert isinstance(module, FixDuration)
assert module.length == int(SAMPLERATE * 0.256)
def test_build_audio_transform_center_audio():
config = CenterAudioConfig()
module = build_audio_transform(config, samplerate=SAMPLERATE)
assert isinstance(module, CenterAudio)
def test_build_audio_transform_scale_audio():
config = ScaleAudioConfig()
module = build_audio_transform(config, samplerate=SAMPLERATE)
assert isinstance(module, ScaleAudio)
def test_build_audio_transform_fix_duration():
config = FixDurationConfig(duration=0.5)
module = build_audio_transform(config, samplerate=SAMPLERATE)
assert isinstance(module, FixDuration)