Add test for preprocessing

This commit is contained in:
mbsantiago 2026-03-08 17:11:27 +00:00
parent bfc88a4a0f
commit 46c02962f3
10 changed files with 1538 additions and 77 deletions

View File

@ -21,7 +21,7 @@ preprocess:
gain: 0.98 gain: 0.98
bias: 2 bias: 2
power: 0.5 power: 0.5
- name: spectral_mean_substraction - name: spectral_mean_subtraction
postprocess: postprocess:
nms_kernel_size: 9 nms_kernel_size: 9

View File

@ -1,3 +1,17 @@
"""Audio-level transforms applied to waveforms before spectrogram computation.
This module defines ``torch.nn.Module`` transforms that operate on raw
audio tensors and the Pydantic configuration classes that control them.
Each transform is registered in the ``audio_transforms`` registry so that
the pipeline can be assembled from a configuration object.
The supported transforms are:
* ``CenterAudio`` subtract the DC offset (mean) from the waveform.
* ``ScaleAudio`` peak-normalise the waveform to the range ``[-1, 1]``.
* ``FixDuration`` truncate or zero-pad the waveform to a fixed length.
"""
from typing import Annotated, Literal from typing import Annotated, Literal
import torch import torch
@ -18,14 +32,43 @@ __all__ = [
audio_transforms: Registry[torch.nn.Module, [int]] = Registry( audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
"audio_transform" "audio_transform"
) )
"""Registry mapping audio transform config classes to their builder methods."""
class CenterAudioConfig(BaseConfig): class CenterAudioConfig(BaseConfig):
"""Configuration for the DC-offset removal transform.
Attributes
----------
name : str
Fixed identifier; always ``"center_audio"``.
"""
name: Literal["center_audio"] = "center_audio" name: Literal["center_audio"] = "center_audio"
class CenterAudio(torch.nn.Module): class CenterAudio(torch.nn.Module):
"""Remove the DC offset from an audio waveform.
Subtracts the global mean of the waveform from every sample,
centring the signal around zero. This is useful when an analogue
recording chain introduces a constant voltage bias.
"""
def forward(self, wav: torch.Tensor) -> torch.Tensor: def forward(self, wav: torch.Tensor) -> torch.Tensor:
"""Subtract the mean from the waveform.
Parameters
----------
wav : torch.Tensor
Input waveform tensor of shape ``(samples,)`` or
``(channels, samples)``.
Returns
-------
torch.Tensor
Zero-centred waveform with the same shape as the input.
"""
return center_tensor(wav) return center_tensor(wav)
@audio_transforms.register(CenterAudioConfig) @audio_transforms.register(CenterAudioConfig)
@ -35,11 +78,38 @@ class CenterAudio(torch.nn.Module):
class ScaleAudioConfig(BaseConfig): class ScaleAudioConfig(BaseConfig):
"""Configuration for the peak-normalisation transform.
Attributes
----------
name : str
Fixed identifier; always ``"scale_audio"``.
"""
name: Literal["scale_audio"] = "scale_audio" name: Literal["scale_audio"] = "scale_audio"
class ScaleAudio(torch.nn.Module): class ScaleAudio(torch.nn.Module):
"""Peak-normalise an audio waveform to the range ``[-1, 1]``.
Divides the waveform by its largest absolute sample value. If the
waveform is identically zero it is returned unchanged.
"""
def forward(self, wav: torch.Tensor) -> torch.Tensor: def forward(self, wav: torch.Tensor) -> torch.Tensor:
"""Peak-normalise the waveform.
Parameters
----------
wav : torch.Tensor
Input waveform tensor of any shape.
Returns
-------
torch.Tensor
Normalised waveform with the same shape as the input and
values in the range ``[-1, 1]``.
"""
return peak_normalize(wav) return peak_normalize(wav)
@audio_transforms.register(ScaleAudioConfig) @audio_transforms.register(ScaleAudioConfig)
@ -49,11 +119,36 @@ class ScaleAudio(torch.nn.Module):
class FixDurationConfig(BaseConfig): class FixDurationConfig(BaseConfig):
"""Configuration for the fixed-duration transform.
Attributes
----------
name : str
Fixed identifier; always ``"fix_duration"``.
duration : float, default=0.5
Target duration in seconds. The waveform will be truncated or
zero-padded to match this length.
"""
name: Literal["fix_duration"] = "fix_duration" name: Literal["fix_duration"] = "fix_duration"
duration: float = 0.5 duration: float = 0.5
class FixDuration(torch.nn.Module): class FixDuration(torch.nn.Module):
"""Ensure a waveform has exactly a specified number of samples.
If the input is longer than the target length it is truncated from
the end. If it is shorter, it is zero-padded at the end.
Parameters
----------
samplerate : int
Sample rate of the audio in Hz. Used with ``duration`` to
compute the target number of samples.
duration : float
Target duration in seconds.
"""
def __init__(self, samplerate: int, duration: float): def __init__(self, samplerate: int, duration: float):
super().__init__() super().__init__()
self.samplerate = samplerate self.samplerate = samplerate
@ -61,6 +156,20 @@ class FixDuration(torch.nn.Module):
self.length = int(samplerate * duration) self.length = int(samplerate * duration)
def forward(self, wav: torch.Tensor) -> torch.Tensor: def forward(self, wav: torch.Tensor) -> torch.Tensor:
"""Truncate or pad the waveform to the target length.
Parameters
----------
wav : torch.Tensor
Input waveform tensor of shape ``(samples,)`` or
``(channels, samples)``. The last dimension is adjusted.
Returns
-------
torch.Tensor
Waveform with exactly ``self.length`` samples along the last
dimension.
"""
length = wav.shape[-1] length = wav.shape[-1]
if length == self.length: if length == self.length:
@ -81,10 +190,34 @@ AudioTransform = Annotated[
FixDurationConfig | ScaleAudioConfig | CenterAudioConfig, FixDurationConfig | ScaleAudioConfig | CenterAudioConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Discriminated union of all audio transform configuration types.
Use this type when a field should accept any of the supported audio
transforms. Pydantic will select the correct config class based on the
``name`` field.
"""
def build_audio_transform( def build_audio_transform(
config: AudioTransform, config: AudioTransform,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module: ) -> torch.nn.Module:
"""Build an audio transform module from a configuration object.
Parameters
----------
config : AudioTransform
A configuration object for one of the supported audio transforms
(``CenterAudioConfig``, ``ScaleAudioConfig``, or
``FixDurationConfig``).
samplerate : int, default=256000
Sample rate of the audio in Hz. Passed to the transform builder;
some transforms (e.g. ``FixDuration``) use it to convert seconds
to samples.
Returns
-------
torch.nn.Module
The constructed audio transform module.
"""
return audio_transforms.build(config, samplerate) return audio_transforms.build(config, samplerate)

View File

@ -1,3 +1,10 @@
"""Shared tensor primitives used across the preprocessing pipeline.
This module provides small, stateless helper functions that operate on
PyTorch tensors. They are used by both audio-level and spectrogram-level
transforms, and are kept here to avoid duplication.
"""
import torch import torch
__all__ = [ __all__ = [
@ -7,11 +14,42 @@ __all__ = [
def center_tensor(tensor: torch.Tensor) -> torch.Tensor: def center_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Subtract the mean of a tensor from all of its values.
This centres the signal around zero, removing any constant DC offset.
Parameters
----------
tensor : torch.Tensor
Input tensor of any shape.
Returns
-------
torch.Tensor
A new tensor of the same shape and dtype with the global mean
subtracted from every element.
"""
return tensor - tensor.mean() return tensor - tensor.mean()
def peak_normalize(tensor: torch.Tensor) -> torch.Tensor: def peak_normalize(tensor: torch.Tensor) -> torch.Tensor:
max_value = tensor.abs().min() """Scale a tensor so that its largest absolute value equals one.
Divides the tensor by its peak absolute value. If the tensor is
identically zero, it is returned unchanged (no division by zero).
Parameters
----------
tensor : torch.Tensor
Input tensor of any shape.
Returns
-------
torch.Tensor
A new tensor of the same shape and dtype with values in the range
``[-1, 1]`` (or exactly ``[0, 0]`` for a zero tensor).
"""
max_value = tensor.abs().max()
denominator = torch.where( denominator = torch.where(
max_value == 0, max_value == 0,

View File

@ -1,3 +1,10 @@
"""Configuration for the full batdetect2 preprocessing pipeline.
This module defines :class:`PreprocessingConfig`, which aggregates all
configuration needed to convert a raw audio waveform into a normalised
spectrogram ready for the detection model.
"""
from typing import List from typing import List
from pydantic import Field from pydantic import Field
@ -9,7 +16,7 @@ from batdetect2.preprocess.spectrogram import (
FrequencyConfig, FrequencyConfig,
PcenConfig, PcenConfig,
ResizeConfig, ResizeConfig,
SpectralMeanSubstractionConfig, SpectralMeanSubtractionConfig,
SpectrogramTransform, SpectrogramTransform,
STFTConfig, STFTConfig,
) )
@ -24,19 +31,30 @@ __all__ = [
class PreprocessingConfig(BaseConfig): class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline. """Unified configuration for the audio preprocessing pipeline.
Aggregates the configuration for both the initial audio processing stage Aggregates the parameters for every stage of the pipeline:
and the subsequent spectrogram generation stage. audio-level transforms, STFT computation, frequency cropping,
spectrogram-level transforms, and the final resize step.
Attributes Attributes
---------- ----------
audio : AudioConfig audio_transforms : list of AudioTransform, default=[]
Configuration settings for the audio loading and initial waveform Ordered list of transforms applied to the raw audio waveform
processing steps (e.g., resampling, duration adjustment, scaling). before the STFT is computed. Each entry is a configuration
Defaults to default `AudioConfig` settings if omitted. object for one of the supported audio transforms
spectrogram : SpectrogramConfig (``"center_audio"``, ``"scale_audio"``, or ``"fix_duration"``).
Configuration settings for the spectrogram generation process spectrogram_transforms : list of SpectrogramTransform
(e.g., STFT parameters, frequency cropping, scaling, denoising, Ordered list of transforms applied to the cropped spectrogram
resizing). Defaults to default `SpectrogramConfig` settings if omitted. after the STFT and frequency crop steps. Defaults to
``[PcenConfig(), SpectralMeanSubtractionConfig()]``, which
applies PCEN followed by spectral mean subtraction.
stft : STFTConfig
Parameters for the Short-Time Fourier Transform (window
duration, overlap, and window function).
frequencies : FrequencyConfig
Frequency range (in Hz) to retain after the STFT.
size : ResizeConfig
Target height (number of frequency bins) and time-axis scaling
factor for the final resize step.
""" """
audio_transforms: List[AudioTransform] = Field(default_factory=list) audio_transforms: List[AudioTransform] = Field(default_factory=list)
@ -44,7 +62,7 @@ class PreprocessingConfig(BaseConfig):
spectrogram_transforms: List[SpectrogramTransform] = Field( spectrogram_transforms: List[SpectrogramTransform] = Field(
default_factory=lambda: [ default_factory=lambda: [
PcenConfig(), PcenConfig(),
SpectralMeanSubstractionConfig(), SpectralMeanSubtractionConfig(),
] ]
) )
@ -59,4 +77,20 @@ def load_preprocessing_config(
path: PathLike, path: PathLike,
field: str | None = None, field: str | None = None,
) -> PreprocessingConfig: ) -> PreprocessingConfig:
"""Load a ``PreprocessingConfig`` from a YAML file.
Parameters
----------
path : PathLike
Path to the YAML configuration file.
field : str, optional
If provided, read the config from a nested field within the
YAML document (e.g. ``"preprocessing"`` to read from a top-level
``preprocessing:`` key).
Returns
-------
PreprocessingConfig
The deserialised preprocessing configuration.
"""
return load_config(path, schema=PreprocessingConfig, field=field) return load_config(path, schema=PreprocessingConfig, field=field)

View File

@ -1,3 +1,25 @@
"""Assembles the full batdetect2 preprocessing pipeline.
This module defines :class:`Preprocessor`, the concrete implementation of
:class:`~batdetect2.typing.PreprocessorProtocol`, and the
:func:`build_preprocessor` factory function that constructs it from a
:class:`~batdetect2.preprocess.config.PreprocessingConfig`.
The preprocessing pipeline converts a raw audio waveform (as a
``torch.Tensor``) into a normalised, cropped, and resized spectrogram ready
for the detection model. The stages are applied in this order:
1. **Audio transforms** optional waveform-level operations such as DC
removal, peak normalisation, or duration fixing.
2. **STFT** Short-Time Fourier Transform to produce an amplitude
spectrogram.
3. **Frequency crop** retain only the frequency band of interest.
4. **Spectrogram transforms** normalisation operations such as PCEN and
spectral mean subtraction.
5. **Resize** scale the spectrogram to the model's expected height and
reduce the time resolution.
"""
import torch import torch
from loguru import logger from loguru import logger
@ -20,7 +42,32 @@ __all__ = [
class Preprocessor(torch.nn.Module, PreprocessorProtocol): class Preprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol.""" """Standard implementation of the :class:`~batdetect2.typing.PreprocessorProtocol`.
Wraps all preprocessing stages as ``torch.nn.Module`` submodules so
that parameters (e.g. PCEN filter coefficients) can be tracked and
moved between devices.
Parameters
----------
config : PreprocessingConfig
Full pipeline configuration.
input_samplerate : int
Sample rate of the audio that will be passed to this preprocessor,
in Hz.
Attributes
----------
input_samplerate : int
Sample rate of the input audio in Hz.
output_samplerate : float
Effective frame rate of the output spectrogram in frames per second.
Computed from the STFT hop length and the time-axis resize factor.
min_freq : float
Lower bound of the retained frequency band in Hz.
max_freq : float
Upper bound of the retained frequency band in Hz.
"""
input_samplerate: int input_samplerate: int
output_samplerate: float output_samplerate: float
@ -72,17 +119,75 @@ class Preprocessor(torch.nn.Module, PreprocessorProtocol):
) )
def forward(self, wav: torch.Tensor) -> torch.Tensor: def forward(self, wav: torch.Tensor) -> torch.Tensor:
"""Run the full preprocessing pipeline on a waveform.
Applies audio transforms, then the STFT, then
:meth:`process_spectrogram`.
Parameters
----------
wav : torch.Tensor
Input waveform of shape ``(samples,)``.
Returns
-------
torch.Tensor
Preprocessed spectrogram of shape
``(freq_bins, time_frames)``.
"""
wav = self.audio_transforms(wav) wav = self.audio_transforms(wav)
spec = self.spectrogram_builder(wav) spec = self.spectrogram_builder(wav)
return self.process_spectrogram(spec) return self.process_spectrogram(spec)
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
"""Compute the raw STFT spectrogram without any further processing.
Parameters
----------
wav : torch.Tensor
Input waveform of shape ``(samples,)``.
Returns
-------
torch.Tensor
Amplitude spectrogram of shape ``(n_fft//2 + 1, time_frames)``
with no frequency cropping, normalisation, or resizing applied.
"""
return self.spectrogram_builder(wav) return self.spectrogram_builder(wav)
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: def process_audio(self, wav: torch.Tensor) -> torch.Tensor:
"""Alias for :meth:`forward`.
Parameters
----------
wav : torch.Tensor
Input waveform of shape ``(samples,)``.
Returns
-------
torch.Tensor
Preprocessed spectrogram (same as calling the object directly).
"""
return self(wav) return self(wav)
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
"""Apply the post-STFT processing stages to an existing spectrogram.
Applies frequency cropping, spectrogram-level transforms (e.g.
PCEN, spectral mean subtraction), and the final resize step.
Parameters
----------
spec : torch.Tensor
Raw amplitude spectrogram of shape
``(..., n_fft//2 + 1, time_frames)``.
Returns
-------
torch.Tensor
Normalised and resized spectrogram of shape
``(..., height, scaled_time_frames)``.
"""
spec = self.spectrogram_crop(spec) spec = self.spectrogram_crop(spec)
spec = self.spectrogram_transforms(spec) spec = self.spectrogram_transforms(spec)
return self.spectrogram_resizer(spec) return self.spectrogram_resizer(spec)
@ -92,6 +197,25 @@ def compute_output_samplerate(
config: PreprocessingConfig, config: PreprocessingConfig,
input_samplerate: int = TARGET_SAMPLERATE_HZ, input_samplerate: int = TARGET_SAMPLERATE_HZ,
) -> float: ) -> float:
"""Compute the effective frame rate of the preprocessor's output.
The output frame rate (in frames per second) depends on the STFT hop
length and the time-axis resize factor applied by the final resize step.
Parameters
----------
config : PreprocessingConfig
Pipeline configuration.
input_samplerate : int, default=256000
Sample rate of the input audio in Hz.
Returns
-------
float
Output frame rate in frames per second.
For example, at the default settings (256 kHz, hop=128,
resize_factor=0.5) this equals ``1000.0``.
"""
_, hop_size = _spec_params_from_config( _, hop_size = _spec_params_from_config(
config.stft, samplerate=input_samplerate config.stft, samplerate=input_samplerate
) )
@ -103,7 +227,24 @@ def build_preprocessor(
config: PreprocessingConfig | None = None, config: PreprocessingConfig | None = None,
input_samplerate: int = TARGET_SAMPLERATE_HZ, input_samplerate: int = TARGET_SAMPLERATE_HZ,
) -> PreprocessorProtocol: ) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration.""" """Build the standard preprocessor from a configuration object.
Parameters
----------
config : PreprocessingConfig, optional
Pipeline configuration. If ``None``, the default
``PreprocessingConfig()`` is used (PCEN + spectral mean
subtraction, 256 kHz, standard STFT parameters).
input_samplerate : int, default=256000
Sample rate of the audio that will be fed to the preprocessor,
in Hz.
Returns
-------
PreprocessorProtocol
A :class:`Preprocessor` instance ready to convert waveforms to
spectrograms.
"""
config = config or PreprocessingConfig() config = config or PreprocessingConfig()
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}", "Building preprocessor with config: \n{}",

View File

@ -1,4 +1,14 @@
"""Computes spectrograms from audio waveforms with configurable parameters.""" """Computes spectrograms from audio waveforms with configurable parameters.
This module defines the STFT-based spectrogram builder and a collection of
spectrogram-level transforms (PCEN, spectral mean subtraction, amplitude
scaling, peak normalisation, frequency cropping, and resizing) that form the
signal-processing stage of the batdetect2 preprocessing pipeline.
Each transform is paired with a Pydantic configuration class and registered
in the ``spectrogram_transforms`` registry so that the pipeline can be fully
specified via a YAML or Python configuration object.
"""
from typing import Annotated, Callable, Literal from typing import Annotated, Callable, Literal
@ -32,17 +42,22 @@ class STFTConfig(BaseConfig):
Attributes Attributes
---------- ----------
window_duration : float, default=0.002 window_duration : float, default=0.002
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be Duration of the STFT analysis window in seconds (e.g. 0.002 for
> 0. Determines frequency resolution (longer window = finer frequency 2 ms). Must be > 0. A longer window gives finer frequency resolution
resolution). but coarser time resolution.
window_overlap : float, default=0.75 window_overlap : float, default=0.75
Fraction of overlap between consecutive STFT windows (e.g., 0.75 Fraction of overlap between consecutive windows (e.g. 0.75 for
for 75%). Must be >= 0 and < 1. Determines time resolution 75 %). Must be >= 0 and < 1. Higher overlap gives finer time
(higher overlap = finer time resolution). resolution at the cost of more computation.
window_fn : str, default="hann" window_fn : str, default="hann"
Name of the window function to apply before FFT calculation. Common Name of the tapering window applied to each frame before the FFT.
options include "hann", "hamming", "blackman". See Supported values: ``"hann"``, ``"hamming"``, ``"kaiser"``,
`scipy.signal.get_window`. ``"blackman"``, ``"bartlett"``.
Notes
-----
At the default sample rate of 256 kHz, ``window_duration=0.002`` and
``window_overlap=0.75`` give ``n_fft=512`` and ``hop_length=128``.
""" """
window_duration: float = Field(default=0.002, gt=0) window_duration: float = Field(default=0.002, gt=0)
@ -54,6 +69,23 @@ def build_spectrogram_builder(
config: STFTConfig, config: STFTConfig,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module: ) -> torch.nn.Module:
"""Build a torchaudio STFT spectrogram module from an ``STFTConfig``.
Parameters
----------
config : STFTConfig
STFT parameters (window duration, overlap, and window function).
samplerate : int, default=256000
Sample rate of the input audio in Hz. Used to convert the
window duration into a number of samples.
Returns
-------
torch.nn.Module
A ``torchaudio.transforms.Spectrogram`` module configured to
produce an amplitude (``power=1``) spectrogram with centred
frames.
"""
n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate) n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate)
return torchaudio.transforms.Spectrogram( return torchaudio.transforms.Spectrogram(
n_fft=n_fft, n_fft=n_fft,
@ -65,6 +97,25 @@ def build_spectrogram_builder(
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]: def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
"""Return the PyTorch window function matching the given name.
Parameters
----------
name : str
Name of the window function. One of ``"hann"``, ``"hamming"``,
``"kaiser"``, ``"blackman"``, or ``"bartlett"``.
Returns
-------
Callable[..., torch.Tensor]
A PyTorch window function that accepts a window length and returns
a 1-D tensor of weights.
Raises
------
NotImplementedError
If ``name`` does not match any supported window function.
"""
if name == "hann": if name == "hann":
return torch.hann_window return torch.hann_window
@ -88,7 +139,22 @@ def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
def _spec_params_from_config( def _spec_params_from_config(
config: STFTConfig, config: STFTConfig,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
): ) -> tuple[int, int]:
"""Compute ``n_fft`` and ``hop_length`` from an ``STFTConfig``.
Parameters
----------
config : STFTConfig
STFT parameters.
samplerate : int, default=256000
Sample rate of the input audio in Hz.
Returns
-------
tuple[int, int]
A pair ``(n_fft, hop_length)`` giving the FFT size and the step
between consecutive frames in samples.
"""
n_fft = int(samplerate * config.window_duration) n_fft = int(samplerate * config.window_duration)
hop_length = int(n_fft * (1 - config.window_overlap)) hop_length = int(n_fft * (1 - config.window_overlap))
return n_fft, hop_length return n_fft, hop_length
@ -99,6 +165,24 @@ def _frequency_to_index(
n_fft: int, n_fft: int,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
) -> int | None: ) -> int | None:
"""Convert a frequency in Hz to the nearest STFT frequency bin index.
Parameters
----------
freq : float
Frequency in Hz to convert.
n_fft : int
FFT size used by the STFT.
samplerate : int, default=256000
Sample rate of the audio in Hz.
Returns
-------
int or None
The bin index corresponding to ``freq``, or ``None`` if the
frequency is outside the valid range (i.e. <= 0 Hz or >= the
Nyquist frequency).
"""
alpha = freq * 2 / samplerate alpha = freq * 2 / samplerate
height = np.floor(n_fft / 2) + 1 height = np.floor(n_fft / 2) + 1
index = int(np.floor(alpha * height)) index = int(np.floor(alpha * height))
@ -118,11 +202,11 @@ class FrequencyConfig(BaseConfig):
Attributes Attributes
---------- ----------
max_freq : int, default=120000 max_freq : int, default=120000
Maximum frequency in Hz to retain in the spectrogram after STFT. Maximum frequency in Hz to retain after STFT. Frequency bins
Frequencies above this value will be cropped. Must be > 0. above this value are discarded. Must be >= 0.
min_freq : int, default=10000 min_freq : int, default=10000
Minimum frequency in Hz to retain in the spectrogram after STFT. Minimum frequency in Hz to retain after STFT. Frequency bins
Frequencies below this value will be cropped. Must be >= 0. below this value are discarded. Must be >= 0.
""" """
max_freq: int = Field(default=MAX_FREQ, ge=0) max_freq: int = Field(default=MAX_FREQ, ge=0)
@ -130,6 +214,27 @@ class FrequencyConfig(BaseConfig):
class FrequencyCrop(torch.nn.Module): class FrequencyCrop(torch.nn.Module):
"""Crop a spectrogram to a specified frequency band.
On construction the Hz boundaries are converted to STFT bin indices.
During the forward pass the spectrogram is sliced along its
frequency axis (second-to-last dimension) to retain only the bins
that fall within ``[min_freq, max_freq)``.
Parameters
----------
samplerate : int
Sample rate of the audio in Hz.
n_fft : int
FFT size used by the STFT.
min_freq : int, optional
Lower frequency bound in Hz. If ``None``, no lower crop is
applied and the DC bin is retained.
max_freq : int, optional
Upper frequency bound in Hz. If ``None``, no upper crop is
applied and all bins up to Nyquist are retained.
"""
def __init__( def __init__(
self, self,
samplerate: int, samplerate: int,
@ -162,6 +267,19 @@ class FrequencyCrop(torch.nn.Module):
self.high_index = high_index self.high_index = high_index
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Crop the spectrogram to the configured frequency band.
Parameters
----------
spec : torch.Tensor
Spectrogram tensor of shape ``(..., freq_bins, time_frames)``.
Returns
-------
torch.Tensor
Cropped spectrogram with shape
``(..., n_retained_bins, time_frames)``.
"""
low_index = self.low_index low_index = self.low_index
if low_index is None: if low_index is None:
low_index = 0 low_index = 0
@ -184,6 +302,24 @@ def build_spectrogram_crop(
stft: STFTConfig | None = None, stft: STFTConfig | None = None,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module: ) -> torch.nn.Module:
"""Build a ``FrequencyCrop`` module from configuration objects.
Parameters
----------
config : FrequencyConfig
Frequency boundary configuration specifying ``min_freq`` and
``max_freq`` in Hz.
stft : STFTConfig, optional
STFT configuration used to derive ``n_fft``. Defaults to
``STFTConfig()`` if not provided.
samplerate : int, default=256000
Sample rate of the audio in Hz.
Returns
-------
torch.nn.Module
A ``FrequencyCrop`` module ready to crop spectrograms.
"""
stft = stft or STFTConfig() stft = stft or STFTConfig()
n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate) n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate)
return FrequencyCrop( return FrequencyCrop(
@ -195,18 +331,61 @@ def build_spectrogram_crop(
class ResizeConfig(BaseConfig): class ResizeConfig(BaseConfig):
"""Configuration for the final spectrogram resize step.
Attributes
----------
name : str
Fixed identifier; always ``"resize_spec"``.
height : int, default=128
Target number of frequency bins in the output spectrogram.
The spectrogram is resized to this height using bilinear
interpolation.
resize_factor : float, default=0.5
Fraction by which the time axis is scaled. For example, ``0.5``
halves the number of time frames, reducing computational cost
downstream.
"""
name: Literal["resize_spec"] = "resize_spec" name: Literal["resize_spec"] = "resize_spec"
height: int = 128 height: int = 128
resize_factor: float = 0.5 resize_factor: float = 0.5
class ResizeSpec(torch.nn.Module): class ResizeSpec(torch.nn.Module):
"""Resize a spectrogram to a fixed height and scaled width.
Uses bilinear interpolation so it handles arbitrary input shapes
gracefully. Input tensors with fewer than four dimensions are
temporarily unsqueezed to satisfy ``torch.nn.functional.interpolate``.
Parameters
----------
height : int
Target number of frequency bins (output height).
time_factor : float
Multiplicative scaling applied to the time axis length.
"""
def __init__(self, height: int, time_factor: float): def __init__(self, height: int, time_factor: float):
super().__init__() super().__init__()
self.height = height self.height = height
self.time_factor = time_factor self.time_factor = time_factor
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Resize the spectrogram to the configured output dimensions.
Parameters
----------
spec : torch.Tensor
Input spectrogram of shape ``(..., freq_bins, time_frames)``.
Returns
-------
torch.Tensor
Resized spectrogram with shape
``(..., height, int(time_factor * time_frames))``.
"""
current_length = spec.shape[-1] current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length) target_length = int(self.time_factor * current_length)
@ -227,6 +406,18 @@ class ResizeSpec(torch.nn.Module):
def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module: def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module:
"""Build a ``ResizeSpec`` module from a ``ResizeConfig``.
Parameters
----------
config : ResizeConfig
Resize configuration specifying ``height`` and ``resize_factor``.
Returns
-------
torch.nn.Module
A ``ResizeSpec`` module configured with the given parameters.
"""
return ResizeSpec(height=config.height, time_factor=config.resize_factor) return ResizeSpec(height=config.height, time_factor=config.resize_factor)
@ -236,7 +427,32 @@ spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry(
class PcenConfig(BaseConfig): class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN).""" """Configuration for Per-Channel Energy Normalisation (PCEN).
PCEN is a frontend processing technique that replaces simple log
compression. It applies a learnable automatic gain control followed
by a stabilised root compression, making the representation more
robust to variations in recording level.
Attributes
----------
name : str
Fixed identifier; always ``"pcen"``.
time_constant : float, default=0.4
Time constant (in seconds) of the IIR smoothing filter used
for the background estimate. Larger values produce a slower-
adapting background.
gain : float, default=0.98
Exponent controlling how strongly the background estimate
suppresses the signal.
bias : float, default=2
Stabilisation bias added inside the root-compression step to
avoid division by zero.
power : float, default=0.5
Root-compression exponent. A value of 0.5 gives square-root
compression, similar to log compression but differentiable at
zero.
"""
name: Literal["pcen"] = "pcen" name: Literal["pcen"] = "pcen"
time_constant: float = 0.4 time_constant: float = 0.4
@ -246,6 +462,35 @@ class PcenConfig(BaseConfig):
class PCEN(torch.nn.Module): class PCEN(torch.nn.Module):
"""Per-Channel Energy Normalisation (PCEN) transform.
Applies automatic gain control and root compression to a spectrogram.
The background estimate is computed with a first-order IIR filter
applied along the time axis.
Parameters
----------
smoothing_constant : float
IIR filter coefficient ``alpha``. Derived from the time constant
and sample rate via ``_compute_smoothing_constant``.
gain : float, default=0.98
AGC gain exponent.
bias : float, default=2.0
Root-compression stabilisation bias.
power : float, default=0.5
Root-compression exponent.
eps : float, default=1e-6
Small constant for numerical stability.
dtype : torch.dtype, default=torch.float32
Floating-point precision used for internal computation.
Notes
-----
The smoothing constant is computed to match the original batdetect2
implementation for numerical compatibility. See
``_compute_smoothing_constant`` for details.
"""
def __init__( def __init__(
self, self,
smoothing_constant: float, smoothing_constant: float,
@ -269,6 +514,20 @@ class PCEN(torch.nn.Module):
) )
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Apply PCEN to a spectrogram.
Parameters
----------
spec : torch.Tensor
Input amplitude spectrogram of shape
``(..., freq_bins, time_frames)``.
Returns
-------
torch.Tensor
PCEN-normalised spectrogram with the same shape and dtype as
the input.
"""
S = spec.to(self.dtype) * 2**31 S = spec.to(self.dtype) * 2**31
M = ( M = (
@ -305,21 +564,73 @@ def _compute_smoothing_constant(
samplerate: int, samplerate: int,
time_constant: float, time_constant: float,
) -> float: ) -> float:
# NOTE: These were taken to match the original implementation """Compute the IIR smoothing coefficient for PCEN.
Parameters
----------
samplerate : int
Sample rate of the audio in Hz.
time_constant : float
Desired smoothing time constant in seconds.
Returns
-------
float
IIR filter coefficient ``alpha`` used by ``PCEN``.
Notes
-----
The hop length (512) and the sample-rate divisor (10) are fixed to
reproduce the numerical behaviour of the original batdetect2
implementation, which used ``librosa.pcen`` with ``sr=samplerate/10``
and the default ``hop_length=512``. These values do not reflect the
actual STFT hop length used in the pipeline; they are retained
solely for backward compatibility.
"""
# NOTE: These parameters are fixed to match the original implementation.
hop_length = 512 hop_length = 512
sr = samplerate / 10 sr = samplerate / 10
time_constant = time_constant
t_frames = time_constant * sr / float(hop_length) t_frames = time_constant * sr / float(hop_length)
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
class ScaleAmplitudeConfig(BaseConfig): class ScaleAmplitudeConfig(BaseConfig):
"""Configuration for amplitude scaling of a spectrogram.
Attributes
----------
name : str
Fixed identifier; always ``"scale_amplitude"``.
scale : str, default="db"
Scaling mode. Either ``"db"`` (convert amplitude to decibels
using ``torchaudio.transforms.AmplitudeToDB``) or ``"power"``
(square the amplitude values).
"""
name: Literal["scale_amplitude"] = "scale_amplitude" name: Literal["scale_amplitude"] = "scale_amplitude"
scale: Literal["power", "db"] = "db" scale: Literal["power", "db"] = "db"
class ToPower(torch.nn.Module): class ToPower(torch.nn.Module):
"""Square the values of a spectrogram (amplitude → power).
Raises each element to the power of two, converting an amplitude
spectrogram into a power spectrogram.
"""
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Square all elements of the spectrogram.
Parameters
----------
spec : torch.Tensor
Input amplitude spectrogram.
Returns
-------
torch.Tensor
Power spectrogram (same shape as input).
"""
return spec**2 return spec**2
@ -330,12 +641,34 @@ _scalers = {
class ScaleAmplitude(torch.nn.Module): class ScaleAmplitude(torch.nn.Module):
"""Convert spectrogram amplitude values to a different scale.
Supports conversion to decibels (dB) or to power (squared amplitude).
Parameters
----------
scale : str
Either ``"db"`` or ``"power"``.
"""
def __init__(self, scale: Literal["power", "db"]): def __init__(self, scale: Literal["power", "db"]):
super().__init__() super().__init__()
self.scale = scale self.scale = scale
self.scaler = _scalers[scale]() self.scaler = _scalers[scale]()
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Apply the configured amplitude scaling.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor.
Returns
-------
torch.Tensor
Scaled spectrogram with the same shape as the input.
"""
return self.scaler(spec) return self.scaler(spec)
@spectrogram_transforms.register(ScaleAmplitudeConfig) @spectrogram_transforms.register(ScaleAmplitudeConfig)
@ -344,30 +677,86 @@ class ScaleAmplitude(torch.nn.Module):
return ScaleAmplitude(scale=config.scale) return ScaleAmplitude(scale=config.scale)
class SpectralMeanSubstractionConfig(BaseConfig): class SpectralMeanSubtractionConfig(BaseConfig):
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction" """Configuration for spectral mean subtraction.
Attributes
----------
name : str
Fixed identifier; always ``"spectral_mean_subtraction"``.
"""
name: Literal["spectral_mean_subtraction"] = "spectral_mean_subtraction"
class SpectralMeanSubstraction(torch.nn.Module): class SpectralMeanSubtraction(torch.nn.Module):
"""Remove the time-averaged background noise from a spectrogram.
For each frequency bin, the mean value across all time frames is
computed and subtracted. The result is then clamped to zero so that
no values fall below the baseline. This is a simple form of spectral
denoising that suppresses stationary background noise.
"""
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Subtract the time-axis mean from each frequency bin.
Parameters
----------
spec : torch.Tensor
Input spectrogram of shape ``(..., freq_bins, time_frames)``.
Returns
-------
torch.Tensor
Denoised spectrogram with the same shape as the input. All
values are non-negative (clamped to 0).
"""
mean = spec.mean(-1, keepdim=True) mean = spec.mean(-1, keepdim=True)
return (spec - mean).clamp(min=0) return (spec - mean).clamp(min=0)
@spectrogram_transforms.register(SpectralMeanSubstractionConfig) @spectrogram_transforms.register(SpectralMeanSubtractionConfig)
@staticmethod @staticmethod
def from_config( def from_config(
config: SpectralMeanSubstractionConfig, config: SpectralMeanSubtractionConfig,
samplerate: int, samplerate: int,
): ):
return SpectralMeanSubstraction() return SpectralMeanSubtraction()
class PeakNormalizeConfig(BaseConfig): class PeakNormalizeConfig(BaseConfig):
"""Configuration for peak normalisation of a spectrogram.
Attributes
----------
name : str
Fixed identifier; always ``"peak_normalize"``.
"""
name: Literal["peak_normalize"] = "peak_normalize" name: Literal["peak_normalize"] = "peak_normalize"
class PeakNormalize(torch.nn.Module): class PeakNormalize(torch.nn.Module):
"""Scale a spectrogram so that its largest absolute value equals one.
Wraps :func:`batdetect2.preprocess.common.peak_normalize` as a
``torch.nn.Module`` for use inside a sequential transform pipeline.
"""
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Peak-normalise the spectrogram.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor of any shape.
Returns
-------
torch.Tensor
Normalised spectrogram where the maximum absolute value is 1.
If the input is identically zero, it is returned unchanged.
"""
return peak_normalize(spec) return peak_normalize(spec)
@spectrogram_transforms.register(PeakNormalizeConfig) @spectrogram_transforms.register(PeakNormalizeConfig)
@ -379,14 +768,37 @@ class PeakNormalize(torch.nn.Module):
SpectrogramTransform = Annotated[ SpectrogramTransform = Annotated[
PcenConfig PcenConfig
| ScaleAmplitudeConfig | ScaleAmplitudeConfig
| SpectralMeanSubstractionConfig | SpectralMeanSubtractionConfig
| PeakNormalizeConfig, | PeakNormalizeConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Discriminated union of all spectrogram transform configuration types.
Use this type when a field should accept any of the supported spectrogram
transforms. Pydantic will select the correct config class based on the
``name`` field.
"""
def build_spectrogram_transform( def build_spectrogram_transform(
config: SpectrogramTransform, config: SpectrogramTransform,
samplerate: int, samplerate: int,
) -> torch.nn.Module: ) -> torch.nn.Module:
"""Build a spectrogram transform module from a configuration object.
Parameters
----------
config : SpectrogramTransform
A configuration object for one of the supported spectrogram
transforms (PCEN, amplitude scaling, spectral mean subtraction,
or peak normalisation).
samplerate : int
Sample rate of the audio in Hz. Some transforms (e.g. PCEN) use
this to set internal parameters.
Returns
-------
torch.nn.Module
The constructed transform module.
"""
return spectrogram_transforms.build(config, samplerate) return spectrogram_transforms.build(config, samplerate)

View File

@ -1,12 +1,31 @@
"""Tests for audio-level preprocessing transforms.
Covers :mod:`batdetect2.preprocess.audio` and the shared helper functions
in :mod:`batdetect2.preprocess.common`.
"""
import pathlib import pathlib
import uuid import uuid
import numpy as np import numpy as np
import pytest import pytest
import soundfile as sf import soundfile as sf
import torch
from soundevent import data from soundevent import data
from batdetect2.audio import AudioConfig from batdetect2.audio import AudioConfig
from batdetect2.preprocess.audio import (
CenterAudio,
CenterAudioConfig,
FixDuration,
FixDurationConfig,
ScaleAudio,
ScaleAudioConfig,
build_audio_transform,
)
from batdetect2.preprocess.common import center_tensor, peak_normalize
SAMPLERATE = 256_000
def create_dummy_wave( def create_dummy_wave(
@ -15,9 +34,9 @@ def create_dummy_wave(
num_channels: int = 1, num_channels: int = 1,
freq: float = 440.0, freq: float = 440.0,
amplitude: float = 0.5, amplitude: float = 0.5,
dtype: np.dtype = np.float32, dtype: type = np.float32,
) -> np.ndarray: ) -> np.ndarray:
"""Generates a simple numpy waveform.""" """Generate a simple sine-wave waveform as a NumPy array."""
t = np.linspace( t = np.linspace(
0.0, duration, int(samplerate * duration), endpoint=False, dtype=dtype 0.0, duration, int(samplerate * duration), endpoint=False, dtype=dtype
) )
@ -29,7 +48,7 @@ def create_dummy_wave(
@pytest.fixture @pytest.fixture
def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path: def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
"""Creates a dummy WAV file and returns its path.""" """Create a dummy 2-channel WAV file and return its path."""
samplerate = 48000 samplerate = 48000
duration = 2.0 duration = 2.0
num_channels = 2 num_channels = 2
@ -41,13 +60,13 @@ def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
@pytest.fixture @pytest.fixture
def dummy_recording(dummy_wav_path: pathlib.Path) -> data.Recording: def dummy_recording(dummy_wav_path: pathlib.Path) -> data.Recording:
"""Creates a Recording object pointing to the dummy WAV file.""" """Create a Recording object pointing to the dummy WAV file."""
return data.Recording.from_file(dummy_wav_path) return data.Recording.from_file(dummy_wav_path)
@pytest.fixture @pytest.fixture
def dummy_clip(dummy_recording: data.Recording) -> data.Clip: def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
"""Creates a Clip object from the dummy recording.""" """Create a Clip object from the dummy recording."""
return data.Clip( return data.Clip(
recording=dummy_recording, recording=dummy_recording,
start_time=0.5, start_time=0.5,
@ -58,3 +77,165 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
@pytest.fixture @pytest.fixture
def default_audio_config() -> AudioConfig: def default_audio_config() -> AudioConfig:
return AudioConfig() return AudioConfig()
# ---------------------------------------------------------------------------
# center_tensor
# ---------------------------------------------------------------------------
def test_center_tensor_zero_mean():
"""Output tensor should have a mean very close to zero."""
wav = torch.tensor([1.0, 2.0, 3.0, 4.0])
result = center_tensor(wav)
assert result.mean().abs().item() < 1e-5
def test_center_tensor_preserves_shape():
wav = torch.randn(3, 1000)
result = center_tensor(wav)
assert result.shape == wav.shape
# ---------------------------------------------------------------------------
# peak_normalize
# ---------------------------------------------------------------------------
def test_peak_normalize_max_is_one():
"""After peak normalisation, the maximum absolute value should be 1."""
wav = torch.tensor([0.1, -0.4, 0.2, 0.8, -0.3])
result = peak_normalize(wav)
assert abs(result.abs().max().item() - 1.0) < 1e-6
def test_peak_normalize_zero_tensor_unchanged():
"""A zero tensor should be returned unchanged (no division by zero)."""
wav = torch.zeros(100)
result = peak_normalize(wav)
assert (result == 0).all()
def test_peak_normalize_preserves_sign():
"""Signs of all elements should be preserved after normalisation."""
wav = torch.tensor([-2.0, 1.0, -0.5])
result = peak_normalize(wav)
assert (result < 0).sum() == 2
assert result[0].item() < 0
def test_peak_normalize_preserves_shape():
wav = torch.randn(2, 512)
result = peak_normalize(wav)
assert result.shape == wav.shape
# ---------------------------------------------------------------------------
# CenterAudio
# ---------------------------------------------------------------------------
def test_center_audio_forward_zero_mean():
module = CenterAudio()
wav = torch.tensor([1.0, 3.0, 5.0])
result = module(wav)
assert result.mean().abs().item() < 1e-5
def test_center_audio_from_config():
config = CenterAudioConfig()
module = CenterAudio.from_config(config, samplerate=SAMPLERATE)
assert isinstance(module, CenterAudio)
# ---------------------------------------------------------------------------
# ScaleAudio
# ---------------------------------------------------------------------------
def test_scale_audio_peak_normalises_to_one():
"""ScaleAudio.forward should scale the peak absolute value to 1."""
module = ScaleAudio()
wav = torch.tensor([0.0, 0.25, -0.5, 0.1])
result = module(wav)
assert abs(result.abs().max().item() - 1.0) < 1e-6
def test_scale_audio_handles_zero_tensor():
"""ScaleAudio should not raise on a zero tensor."""
module = ScaleAudio()
wav = torch.zeros(100)
result = module(wav)
assert (result == 0).all()
def test_scale_audio_from_config():
config = ScaleAudioConfig()
module = ScaleAudio.from_config(config, samplerate=SAMPLERATE)
assert isinstance(module, ScaleAudio)
# ---------------------------------------------------------------------------
# FixDuration
# ---------------------------------------------------------------------------
def test_fix_duration_truncates_long_input():
"""Waveform longer than target should be truncated to the target length."""
target_samples = int(SAMPLERATE * 0.5)
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
wav = torch.randn(target_samples + 1000)
result = module(wav)
assert result.shape[-1] == target_samples
def test_fix_duration_pads_short_input():
"""Waveform shorter than target should be zero-padded to the target length."""
target_samples = int(SAMPLERATE * 0.5)
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
short_wav = torch.randn(target_samples - 100)
result = module(short_wav)
assert result.shape[-1] == target_samples
# Padded region should be zero
assert (result[target_samples - 100 :] == 0).all()
def test_fix_duration_passthrough_exact_length():
"""Waveform with exactly the right length should be returned unchanged."""
target_samples = int(SAMPLERATE * 0.5)
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
wav = torch.randn(target_samples)
result = module(wav)
assert result.shape[-1] == target_samples
assert torch.equal(result, wav)
def test_fix_duration_from_config():
"""FixDurationConfig should produce a FixDuration with the correct length."""
config = FixDurationConfig(duration=0.256)
module = FixDuration.from_config(config, samplerate=SAMPLERATE)
assert isinstance(module, FixDuration)
assert module.length == int(SAMPLERATE * 0.256)
# ---------------------------------------------------------------------------
# build_audio_transform dispatch
# ---------------------------------------------------------------------------
def test_build_audio_transform_center_audio():
config = CenterAudioConfig()
module = build_audio_transform(config, samplerate=SAMPLERATE)
assert isinstance(module, CenterAudio)
def test_build_audio_transform_scale_audio():
config = ScaleAudioConfig()
module = build_audio_transform(config, samplerate=SAMPLERATE)
assert isinstance(module, ScaleAudio)
def test_build_audio_transform_fix_duration():
config = FixDurationConfig(duration=0.5)
module = build_audio_transform(config, samplerate=SAMPLERATE)
assert isinstance(module, FixDuration)

View File

@ -0,0 +1,243 @@
"""Integration and unit tests for the Preprocessor pipeline.
Covers :mod:`batdetect2.preprocess.preprocessor` construction,
pipeline output shape/dtype, the ``process_numpy`` helper, attribute
values, output frame rate, and a round-trip YAML config build test.
"""
import pathlib
import numpy as np
import torch
from batdetect2.preprocess.audio import FixDurationConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.preprocessor import (
Preprocessor,
build_preprocessor,
compute_output_samplerate,
)
from batdetect2.preprocess.spectrogram import (
FrequencyConfig,
PcenConfig,
ResizeConfig,
SpectralMeanSubtractionConfig,
STFTConfig,
)
SAMPLERATE = 256_000
# 0.256 s at 256 kHz = 65536 samples — a convenient power-of-two-sized clip
CLIP_SAMPLES = int(SAMPLERATE * 0.256)
def make_sine_wav(
samplerate: int = SAMPLERATE,
duration: float = 0.256,
freq: float = 40_000.0,
) -> torch.Tensor:
"""Return a single-channel sine-wave tensor."""
t = torch.linspace(0.0, duration, int(samplerate * duration))
return torch.sin(2 * torch.pi * freq * t)
# ---------------------------------------------------------------------------
# build_preprocessor — construction
# ---------------------------------------------------------------------------
def test_build_preprocessor_returns_protocol():
"""build_preprocessor should return a Preprocessor instance."""
preprocessor = build_preprocessor()
assert isinstance(preprocessor, Preprocessor)
def test_build_preprocessor_with_default_config():
"""build_preprocessor() with no arguments should not raise."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
assert preprocessor is not None
def test_build_preprocessor_with_explicit_config():
config = PreprocessingConfig(
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
size=ResizeConfig(height=128, resize_factor=0.5),
spectrogram_transforms=[PcenConfig(), SpectralMeanSubtractionConfig()],
)
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
assert isinstance(preprocessor, Preprocessor)
# ---------------------------------------------------------------------------
# Output shape and dtype
# ---------------------------------------------------------------------------
def test_preprocessor_output_is_2d():
"""The preprocessor output should be a 2-D tensor (freq_bins × time_frames)."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav = make_sine_wav()
result = preprocessor(wav)
assert result.ndim == 2
def test_preprocessor_output_height_matches_config():
"""Output height should match the ResizeConfig.height setting."""
config = PreprocessingConfig(
size=ResizeConfig(height=64, resize_factor=0.5)
)
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
wav = make_sine_wav()
result = preprocessor(wav)
assert result.shape[0] == 64
def test_preprocessor_output_dtype_float32():
"""Output tensor should be float32."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav = make_sine_wav()
result = preprocessor(wav)
assert result.dtype == torch.float32
def test_preprocessor_output_is_finite():
"""Output spectrogram should contain no NaN or Inf values."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav = make_sine_wav()
result = preprocessor(wav)
assert torch.isfinite(result).all()
# ---------------------------------------------------------------------------
# process_numpy
# ---------------------------------------------------------------------------
def test_preprocessor_process_numpy_accepts_ndarray():
"""process_numpy should accept a NumPy array and return a NumPy array."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav_np = make_sine_wav().numpy()
result = preprocessor.process_numpy(wav_np)
assert isinstance(result, np.ndarray)
def test_preprocessor_process_numpy_matches_forward():
"""process_numpy and forward should give numerically identical results."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav = make_sine_wav()
result_pt = preprocessor(wav).numpy()
result_np = preprocessor.process_numpy(wav.numpy())
np.testing.assert_array_almost_equal(result_pt, result_np)
# ---------------------------------------------------------------------------
# Attributes
# ---------------------------------------------------------------------------
def test_preprocessor_min_max_freq_attributes():
"""min_freq and max_freq should match the FrequencyConfig values."""
config = PreprocessingConfig(
frequencies=FrequencyConfig(min_freq=15_000, max_freq=100_000)
)
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
assert preprocessor.min_freq == 15_000
assert preprocessor.max_freq == 100_000
def test_preprocessor_input_samplerate_attribute():
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
assert preprocessor.input_samplerate == SAMPLERATE
# ---------------------------------------------------------------------------
# compute_output_samplerate
# ---------------------------------------------------------------------------
def test_compute_output_samplerate_defaults():
"""At default settings, output_samplerate should equal 1000 fps."""
config = PreprocessingConfig()
rate = compute_output_samplerate(config, input_samplerate=SAMPLERATE)
assert abs(rate - 1000.0) < 1e-6
def test_preprocessor_output_samplerate_attribute_matches_compute():
config = PreprocessingConfig()
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
expected = compute_output_samplerate(config, input_samplerate=SAMPLERATE)
assert abs(preprocessor.output_samplerate - expected) < 1e-6
# ---------------------------------------------------------------------------
# generate_spectrogram (raw STFT)
# ---------------------------------------------------------------------------
def test_generate_spectrogram_shape():
"""generate_spectrogram should return the full STFT without crop or resize."""
config = PreprocessingConfig()
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
wav = make_sine_wav()
spec = preprocessor.generate_spectrogram(wav)
# Full STFT: n_fft//2 + 1 = 257 bins at defaults
assert spec.shape[0] == 257
def test_generate_spectrogram_larger_than_forward():
"""Raw spectrogram should have more frequency bins than the processed output."""
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
wav = make_sine_wav()
raw = preprocessor.generate_spectrogram(wav)
processed = preprocessor(wav)
assert raw.shape[0] > processed.shape[0]
# ---------------------------------------------------------------------------
# Audio transforms pipeline (FixDuration)
# ---------------------------------------------------------------------------
def test_preprocessor_with_fix_duration_audio_transform():
"""A FixDuration audio transform should produce consistent output shapes."""
config = PreprocessingConfig(
audio_transforms=[FixDurationConfig(duration=0.256)],
)
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
# Feed different lengths; output shape should be the same after fix
for n_samples in [CLIP_SAMPLES - 1000, CLIP_SAMPLES, CLIP_SAMPLES + 1000]:
wav = torch.randn(n_samples)
result = preprocessor(wav)
assert result.ndim == 2
# ---------------------------------------------------------------------------
# YAML round-trip
# ---------------------------------------------------------------------------
def test_preprocessor_yaml_roundtrip(tmp_path: pathlib.Path):
"""PreprocessingConfig serialised to YAML and reloaded should produce
a functionally identical preprocessor."""
from batdetect2.preprocess.config import load_preprocessing_config
config = PreprocessingConfig(
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
size=ResizeConfig(height=128, resize_factor=0.5),
)
yaml_path = tmp_path / "preprocess_config.yaml"
yaml_path.write_text(config.to_yaml_string())
loaded_config = load_preprocessing_config(yaml_path)
preprocessor = build_preprocessor(
loaded_config, input_samplerate=SAMPLERATE
)
wav = make_sine_wav()
result = preprocessor(wav)
assert result.ndim == 2
assert result.shape[0] == 128
assert torch.isfinite(result).all()

View File

@ -1,37 +1,316 @@
import numpy as np """Tests for spectrogram-level preprocessing transforms.
import pytest
import xarray as xr
SAMPLERATE = 250_000 Covers :mod:`batdetect2.preprocess.spectrogram` STFT configuration,
DURATION = 0.1 frequency cropping, PCEN, spectral mean subtraction, amplitude scaling,
TEST_FREQ = 30_000 peak normalisation, and resizing.
N_SAMPLES = int(SAMPLERATE * DURATION) """
TIME_COORD = np.linspace(
0, DURATION, N_SAMPLES, endpoint=False, dtype=np.float32 import torch
from batdetect2.preprocess.spectrogram import (
PCEN,
FrequencyConfig,
FrequencyCrop,
PcenConfig,
PeakNormalize,
PeakNormalizeConfig,
ResizeConfig,
ResizeSpec,
ScaleAmplitude,
ScaleAmplitudeConfig,
SpectralMeanSubtraction,
SpectralMeanSubtractionConfig,
STFTConfig,
_spec_params_from_config,
build_spectrogram_builder,
build_spectrogram_crop,
build_spectrogram_resizer,
build_spectrogram_transform,
) )
SAMPLERATE = 256_000
@pytest.fixture
def sine_wave_xr() -> xr.DataArray: # ---------------------------------------------------------------------------
"""Generate a single sine wave as an xr.DataArray.""" # STFTConfig / _spec_params_from_config
t = TIME_COORD # ---------------------------------------------------------------------------
wav_data = np.sin(2 * np.pi * TEST_FREQ * t, dtype=np.float32)
return xr.DataArray(
wav_data, def test_stft_config_defaults_give_correct_params():
coords={"time": t}, """Default STFTConfig at 256 kHz should give n_fft=512, hop_length=128."""
dims=["time"], config = STFTConfig()
attrs={"samplerate": SAMPLERATE}, n_fft, hop_length = _spec_params_from_config(config, samplerate=SAMPLERATE)
assert n_fft == 512
assert hop_length == 128
def test_stft_config_custom_params():
"""Custom window duration and overlap should produce the expected sizes."""
config = STFTConfig(window_duration=0.004, window_overlap=0.5)
n_fft, hop_length = _spec_params_from_config(config, samplerate=SAMPLERATE)
assert n_fft == 1024
assert hop_length == 512
# ---------------------------------------------------------------------------
# build_spectrogram_builder
# ---------------------------------------------------------------------------
def test_spectrogram_builder_output_shape():
"""Builder should produce a spectrogram with the expected number of bins."""
config = STFTConfig()
n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE)
expected_freq_bins = n_fft // 2 + 1 # 257 at defaults
builder = build_spectrogram_builder(config, samplerate=SAMPLERATE)
n_samples = SAMPLERATE # 1 second of audio
wav = torch.randn(n_samples)
spec = builder(wav)
assert spec.ndim == 2
assert spec.shape[0] == expected_freq_bins
def test_spectrogram_builder_output_is_nonnegative():
"""Amplitude spectrogram values should all be >= 0."""
config = STFTConfig()
builder = build_spectrogram_builder(config, samplerate=SAMPLERATE)
wav = torch.randn(SAMPLERATE)
spec = builder(wav)
assert (spec >= 0).all()
# ---------------------------------------------------------------------------
# FrequencyCrop
# ---------------------------------------------------------------------------
def test_frequency_crop_output_shape():
"""FrequencyCrop should reduce the number of frequency bins."""
config = STFTConfig()
n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE)
crop = FrequencyCrop(
samplerate=SAMPLERATE,
n_fft=n_fft,
min_freq=10_000,
max_freq=120_000,
) )
spec = torch.ones(n_fft // 2 + 1, 100)
cropped = crop(spec)
assert cropped.ndim == 2
# Must be smaller than the full spectrogram
assert cropped.shape[0] < spec.shape[0]
assert cropped.shape[1] == 100 # time axis unchanged
@pytest.fixture def test_frequency_crop_build_from_config():
def constant_wave_xr() -> xr.DataArray: """build_spectrogram_crop should return a working FrequencyCrop."""
"""Generate a constant signal as an xr.DataArray.""" freq_config = FrequencyConfig(min_freq=10_000, max_freq=120_000)
t = TIME_COORD stft_config = STFTConfig()
wav_data = np.ones(N_SAMPLES, dtype=np.float32) * 0.5 crop = build_spectrogram_crop(
return xr.DataArray( freq_config, stft=stft_config, samplerate=SAMPLERATE
wav_data,
coords={"time": t},
dims=["time"],
attrs={"samplerate": SAMPLERATE},
) )
assert isinstance(crop, FrequencyCrop)
def test_frequency_crop_no_crop_when_bounds_are_none():
"""FrequencyCrop with no bounds should return the full spectrogram."""
config = STFTConfig()
n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE)
crop = FrequencyCrop(samplerate=SAMPLERATE, n_fft=n_fft)
spec = torch.ones(n_fft // 2 + 1, 100)
cropped = crop(spec)
assert cropped.shape == spec.shape
# ---------------------------------------------------------------------------
# PCEN
# ---------------------------------------------------------------------------
def test_pcen_output_shape_preserved():
"""PCEN should not change the shape of the spectrogram."""
config = PcenConfig()
pcen = PCEN.from_config(config, samplerate=SAMPLERATE)
spec = torch.rand(64, 200) + 1e-6 # positive values
result = pcen(spec)
assert result.shape == spec.shape
def test_pcen_output_is_finite():
"""PCEN applied to a well-formed spectrogram should produce no NaN or Inf."""
config = PcenConfig()
pcen = PCEN.from_config(config, samplerate=SAMPLERATE)
spec = torch.rand(64, 200) + 1e-6
result = pcen(spec)
assert torch.isfinite(result).all()
def test_pcen_output_dtype_matches_input():
"""PCEN should return a tensor with the same dtype as the input."""
config = PcenConfig()
pcen = PCEN.from_config(config, samplerate=SAMPLERATE)
spec = torch.rand(64, 200, dtype=torch.float32)
result = pcen(spec)
assert result.dtype == spec.dtype
# ---------------------------------------------------------------------------
# SpectralMeanSubtraction
# ---------------------------------------------------------------------------
def test_spectral_mean_subtraction_output_nonnegative():
"""SpectralMeanSubtraction clamps output to >= 0."""
module = SpectralMeanSubtraction()
spec = torch.rand(64, 200)
result = module(spec)
assert (result >= 0).all()
def test_spectral_mean_subtraction_shape_preserved():
module = SpectralMeanSubtraction()
spec = torch.rand(64, 200)
result = module(spec)
assert result.shape == spec.shape
def test_spectral_mean_subtraction_reduces_time_mean():
"""After subtraction the time-axis mean per bin should be <= 0 (pre-clamp)."""
module = SpectralMeanSubtraction()
# Constant spectrogram: mean subtraction should produce all zeros before clamp
spec = torch.ones(32, 100) * 3.0
result = module(spec)
assert (result == 0).all()
def test_spectral_mean_subtraction_from_config():
config = SpectralMeanSubtractionConfig()
module = SpectralMeanSubtraction.from_config(config, samplerate=SAMPLERATE)
assert isinstance(module, SpectralMeanSubtraction)
# ---------------------------------------------------------------------------
# PeakNormalize (spectrogram-level)
# ---------------------------------------------------------------------------
def test_peak_normalize_spec_max_is_one():
"""PeakNormalize should scale the spectrogram peak to 1."""
module = PeakNormalize()
spec = torch.rand(64, 200) * 5.0
result = module(spec)
assert abs(result.abs().max().item() - 1.0) < 1e-6
def test_peak_normalize_spec_handles_zero():
"""PeakNormalize on a zero spectrogram should not raise."""
module = PeakNormalize()
spec = torch.zeros(64, 200)
result = module(spec)
assert (result == 0).all()
def test_peak_normalize_from_config():
config = PeakNormalizeConfig()
module = PeakNormalize.from_config(config, samplerate=SAMPLERATE)
assert isinstance(module, PeakNormalize)
# ---------------------------------------------------------------------------
# ScaleAmplitude
# ---------------------------------------------------------------------------
def test_scale_amplitude_db_output_is_finite():
"""AmplitudeToDB scaling should produce finite values for positive input."""
module = ScaleAmplitude(scale="db")
spec = torch.rand(64, 200) + 1e-4
result = module(spec)
assert torch.isfinite(result).all()
def test_scale_amplitude_power_output_equals_square():
"""ScaleAmplitude('power') should square every element."""
module = ScaleAmplitude(scale="power")
spec = torch.tensor([[2.0, 3.0], [4.0, 5.0]])
result = module(spec)
expected = spec**2
assert torch.allclose(result, expected)
def test_scale_amplitude_from_config():
config = ScaleAmplitudeConfig(scale="db")
module = ScaleAmplitude.from_config(config, samplerate=SAMPLERATE)
assert isinstance(module, ScaleAmplitude)
assert module.scale == "db"
# ---------------------------------------------------------------------------
# ResizeSpec
# ---------------------------------------------------------------------------
def test_resize_spec_output_shape():
"""ResizeSpec should produce the target height and scaled width."""
module = ResizeSpec(height=64, time_factor=0.5)
spec = torch.rand(1, 128, 200) # (batch, freq, time)
result = module(spec)
assert result.shape == (1, 64, 100)
def test_resize_spec_2d_input():
"""ResizeSpec should handle 2-D input (no batch or channel dimensions)."""
module = ResizeSpec(height=64, time_factor=0.5)
spec = torch.rand(128, 200)
result = module(spec)
assert result.shape == (64, 100)
def test_resize_spec_output_is_finite():
module = ResizeSpec(height=128, time_factor=0.5)
spec = torch.rand(128, 200)
result = module(spec)
assert torch.isfinite(result).all()
def test_resize_spec_from_config():
config = ResizeConfig(height=64, resize_factor=0.25)
module = build_spectrogram_resizer(config)
assert isinstance(module, ResizeSpec)
assert module.height == 64
assert module.time_factor == 0.25
# ---------------------------------------------------------------------------
# build_spectrogram_transform dispatch
# ---------------------------------------------------------------------------
def test_build_spectrogram_transform_pcen():
config = PcenConfig()
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
assert isinstance(module, PCEN)
def test_build_spectrogram_transform_spectral_mean_subtraction():
config = SpectralMeanSubtractionConfig()
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
assert isinstance(module, SpectralMeanSubtraction)
def test_build_spectrogram_transform_scale_amplitude():
config = ScaleAmplitudeConfig(scale="db")
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
assert isinstance(module, ScaleAmplitude)
def test_build_spectrogram_transform_peak_normalize():
config = PeakNormalizeConfig()
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
assert isinstance(module, PeakNormalize)

View File

@ -10,7 +10,7 @@ from batdetect2.preprocess import (
) )
from batdetect2.preprocess.spectrogram import ( from batdetect2.preprocess.spectrogram import (
ScaleAmplitudeConfig, ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig, SpectralMeanSubtractionConfig,
) )
from batdetect2.targets.rois import ( from batdetect2.targets.rois import (
DEFAULT_ANCHOR, DEFAULT_ANCHOR,
@ -597,7 +597,7 @@ def test_build_roi_mapper_for_peak_energy_bbox():
preproc_config = PreprocessingConfig( preproc_config = PreprocessingConfig(
spectrogram_transforms=[ spectrogram_transforms=[
ScaleAmplitudeConfig(scale="db"), ScaleAmplitudeConfig(scale="db"),
SpectralMeanSubstractionConfig(), SpectralMeanSubtractionConfig(),
] ]
) )
config = PeakEnergyBBoxMapperConfig( config = PeakEnergyBBoxMapperConfig(