mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add test for preprocessing
This commit is contained in:
parent
bfc88a4a0f
commit
46c02962f3
@ -21,7 +21,7 @@ preprocess:
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_substraction
|
||||
- name: spectral_mean_subtraction
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
|
||||
@ -1,3 +1,17 @@
|
||||
"""Audio-level transforms applied to waveforms before spectrogram computation.
|
||||
|
||||
This module defines ``torch.nn.Module`` transforms that operate on raw
|
||||
audio tensors and the Pydantic configuration classes that control them.
|
||||
Each transform is registered in the ``audio_transforms`` registry so that
|
||||
the pipeline can be assembled from a configuration object.
|
||||
|
||||
The supported transforms are:
|
||||
|
||||
* ``CenterAudio`` — subtract the DC offset (mean) from the waveform.
|
||||
* ``ScaleAudio`` — peak-normalise the waveform to the range ``[-1, 1]``.
|
||||
* ``FixDuration`` — truncate or zero-pad the waveform to a fixed length.
|
||||
"""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
@ -18,14 +32,43 @@ __all__ = [
|
||||
audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||
"audio_transform"
|
||||
)
|
||||
"""Registry mapping audio transform config classes to their builder methods."""
|
||||
|
||||
|
||||
class CenterAudioConfig(BaseConfig):
|
||||
"""Configuration for the DC-offset removal transform.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Fixed identifier; always ``"center_audio"``.
|
||||
"""
|
||||
|
||||
name: Literal["center_audio"] = "center_audio"
|
||||
|
||||
|
||||
class CenterAudio(torch.nn.Module):
|
||||
"""Remove the DC offset from an audio waveform.
|
||||
|
||||
Subtracts the global mean of the waveform from every sample,
|
||||
centring the signal around zero. This is useful when an analogue
|
||||
recording chain introduces a constant voltage bias.
|
||||
"""
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Subtract the mean from the waveform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : torch.Tensor
|
||||
Input waveform tensor of shape ``(samples,)`` or
|
||||
``(channels, samples)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Zero-centred waveform with the same shape as the input.
|
||||
"""
|
||||
return center_tensor(wav)
|
||||
|
||||
@audio_transforms.register(CenterAudioConfig)
|
||||
@ -35,11 +78,38 @@ class CenterAudio(torch.nn.Module):
|
||||
|
||||
|
||||
class ScaleAudioConfig(BaseConfig):
|
||||
"""Configuration for the peak-normalisation transform.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Fixed identifier; always ``"scale_audio"``.
|
||||
"""
|
||||
|
||||
name: Literal["scale_audio"] = "scale_audio"
|
||||
|
||||
|
||||
class ScaleAudio(torch.nn.Module):
|
||||
"""Peak-normalise an audio waveform to the range ``[-1, 1]``.
|
||||
|
||||
Divides the waveform by its largest absolute sample value. If the
|
||||
waveform is identically zero it is returned unchanged.
|
||||
"""
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Peak-normalise the waveform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : torch.Tensor
|
||||
Input waveform tensor of any shape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Normalised waveform with the same shape as the input and
|
||||
values in the range ``[-1, 1]``.
|
||||
"""
|
||||
return peak_normalize(wav)
|
||||
|
||||
@audio_transforms.register(ScaleAudioConfig)
|
||||
@ -49,11 +119,36 @@ class ScaleAudio(torch.nn.Module):
|
||||
|
||||
|
||||
class FixDurationConfig(BaseConfig):
|
||||
"""Configuration for the fixed-duration transform.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Fixed identifier; always ``"fix_duration"``.
|
||||
duration : float, default=0.5
|
||||
Target duration in seconds. The waveform will be truncated or
|
||||
zero-padded to match this length.
|
||||
"""
|
||||
|
||||
name: Literal["fix_duration"] = "fix_duration"
|
||||
duration: float = 0.5
|
||||
|
||||
|
||||
class FixDuration(torch.nn.Module):
|
||||
"""Ensure a waveform has exactly a specified number of samples.
|
||||
|
||||
If the input is longer than the target length it is truncated from
|
||||
the end. If it is shorter, it is zero-padded at the end.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
samplerate : int
|
||||
Sample rate of the audio in Hz. Used with ``duration`` to
|
||||
compute the target number of samples.
|
||||
duration : float
|
||||
Target duration in seconds.
|
||||
"""
|
||||
|
||||
def __init__(self, samplerate: int, duration: float):
|
||||
super().__init__()
|
||||
self.samplerate = samplerate
|
||||
@ -61,6 +156,20 @@ class FixDuration(torch.nn.Module):
|
||||
self.length = int(samplerate * duration)
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Truncate or pad the waveform to the target length.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : torch.Tensor
|
||||
Input waveform tensor of shape ``(samples,)`` or
|
||||
``(channels, samples)``. The last dimension is adjusted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Waveform with exactly ``self.length`` samples along the last
|
||||
dimension.
|
||||
"""
|
||||
length = wav.shape[-1]
|
||||
|
||||
if length == self.length:
|
||||
@ -81,10 +190,34 @@ AudioTransform = Annotated[
|
||||
FixDurationConfig | ScaleAudioConfig | CenterAudioConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Discriminated union of all audio transform configuration types.
|
||||
|
||||
Use this type when a field should accept any of the supported audio
|
||||
transforms. Pydantic will select the correct config class based on the
|
||||
``name`` field.
|
||||
"""
|
||||
|
||||
|
||||
def build_audio_transform(
|
||||
config: AudioTransform,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> torch.nn.Module:
|
||||
"""Build an audio transform module from a configuration object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : AudioTransform
|
||||
A configuration object for one of the supported audio transforms
|
||||
(``CenterAudioConfig``, ``ScaleAudioConfig``, or
|
||||
``FixDurationConfig``).
|
||||
samplerate : int, default=256000
|
||||
Sample rate of the audio in Hz. Passed to the transform builder;
|
||||
some transforms (e.g. ``FixDuration``) use it to convert seconds
|
||||
to samples.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.nn.Module
|
||||
The constructed audio transform module.
|
||||
"""
|
||||
return audio_transforms.build(config, samplerate)
|
||||
|
||||
@ -1,3 +1,10 @@
|
||||
"""Shared tensor primitives used across the preprocessing pipeline.
|
||||
|
||||
This module provides small, stateless helper functions that operate on
|
||||
PyTorch tensors. They are used by both audio-level and spectrogram-level
|
||||
transforms, and are kept here to avoid duplication.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
@ -7,11 +14,42 @@ __all__ = [
|
||||
|
||||
|
||||
def center_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Subtract the mean of a tensor from all of its values.
|
||||
|
||||
This centres the signal around zero, removing any constant DC offset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor of any shape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
A new tensor of the same shape and dtype with the global mean
|
||||
subtracted from every element.
|
||||
"""
|
||||
return tensor - tensor.mean()
|
||||
|
||||
|
||||
def peak_normalize(tensor: torch.Tensor) -> torch.Tensor:
|
||||
max_value = tensor.abs().min()
|
||||
"""Scale a tensor so that its largest absolute value equals one.
|
||||
|
||||
Divides the tensor by its peak absolute value. If the tensor is
|
||||
identically zero, it is returned unchanged (no division by zero).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor of any shape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
A new tensor of the same shape and dtype with values in the range
|
||||
``[-1, 1]`` (or exactly ``[0, 0]`` for a zero tensor).
|
||||
"""
|
||||
max_value = tensor.abs().max()
|
||||
|
||||
denominator = torch.where(
|
||||
max_value == 0,
|
||||
|
||||
@ -1,3 +1,10 @@
|
||||
"""Configuration for the full batdetect2 preprocessing pipeline.
|
||||
|
||||
This module defines :class:`PreprocessingConfig`, which aggregates all
|
||||
configuration needed to convert a raw audio waveform into a normalised
|
||||
spectrogram ready for the detection model.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field
|
||||
@ -9,7 +16,7 @@ from batdetect2.preprocess.spectrogram import (
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
ResizeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
SpectralMeanSubtractionConfig,
|
||||
SpectrogramTransform,
|
||||
STFTConfig,
|
||||
)
|
||||
@ -24,19 +31,30 @@ __all__ = [
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Unified configuration for the audio preprocessing pipeline.
|
||||
|
||||
Aggregates the configuration for both the initial audio processing stage
|
||||
and the subsequent spectrogram generation stage.
|
||||
Aggregates the parameters for every stage of the pipeline:
|
||||
audio-level transforms, STFT computation, frequency cropping,
|
||||
spectrogram-level transforms, and the final resize step.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio : AudioConfig
|
||||
Configuration settings for the audio loading and initial waveform
|
||||
processing steps (e.g., resampling, duration adjustment, scaling).
|
||||
Defaults to default `AudioConfig` settings if omitted.
|
||||
spectrogram : SpectrogramConfig
|
||||
Configuration settings for the spectrogram generation process
|
||||
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||
audio_transforms : list of AudioTransform, default=[]
|
||||
Ordered list of transforms applied to the raw audio waveform
|
||||
before the STFT is computed. Each entry is a configuration
|
||||
object for one of the supported audio transforms
|
||||
(``"center_audio"``, ``"scale_audio"``, or ``"fix_duration"``).
|
||||
spectrogram_transforms : list of SpectrogramTransform
|
||||
Ordered list of transforms applied to the cropped spectrogram
|
||||
after the STFT and frequency crop steps. Defaults to
|
||||
``[PcenConfig(), SpectralMeanSubtractionConfig()]``, which
|
||||
applies PCEN followed by spectral mean subtraction.
|
||||
stft : STFTConfig
|
||||
Parameters for the Short-Time Fourier Transform (window
|
||||
duration, overlap, and window function).
|
||||
frequencies : FrequencyConfig
|
||||
Frequency range (in Hz) to retain after the STFT.
|
||||
size : ResizeConfig
|
||||
Target height (number of frequency bins) and time-axis scaling
|
||||
factor for the final resize step.
|
||||
"""
|
||||
|
||||
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
||||
@ -44,7 +62,7 @@ class PreprocessingConfig(BaseConfig):
|
||||
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
SpectralMeanSubtractionConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
@ -59,4 +77,20 @@ def load_preprocessing_config(
|
||||
path: PathLike,
|
||||
field: str | None = None,
|
||||
) -> PreprocessingConfig:
|
||||
"""Load a ``PreprocessingConfig`` from a YAML file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the YAML configuration file.
|
||||
field : str, optional
|
||||
If provided, read the config from a nested field within the
|
||||
YAML document (e.g. ``"preprocessing"`` to read from a top-level
|
||||
``preprocessing:`` key).
|
||||
|
||||
Returns
|
||||
-------
|
||||
PreprocessingConfig
|
||||
The deserialised preprocessing configuration.
|
||||
"""
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||
|
||||
@ -1,3 +1,25 @@
|
||||
"""Assembles the full batdetect2 preprocessing pipeline.
|
||||
|
||||
This module defines :class:`Preprocessor`, the concrete implementation of
|
||||
:class:`~batdetect2.typing.PreprocessorProtocol`, and the
|
||||
:func:`build_preprocessor` factory function that constructs it from a
|
||||
:class:`~batdetect2.preprocess.config.PreprocessingConfig`.
|
||||
|
||||
The preprocessing pipeline converts a raw audio waveform (as a
|
||||
``torch.Tensor``) into a normalised, cropped, and resized spectrogram ready
|
||||
for the detection model. The stages are applied in this order:
|
||||
|
||||
1. **Audio transforms** — optional waveform-level operations such as DC
|
||||
removal, peak normalisation, or duration fixing.
|
||||
2. **STFT** — Short-Time Fourier Transform to produce an amplitude
|
||||
spectrogram.
|
||||
3. **Frequency crop** — retain only the frequency band of interest.
|
||||
4. **Spectrogram transforms** — normalisation operations such as PCEN and
|
||||
spectral mean subtraction.
|
||||
5. **Resize** — scale the spectrogram to the model's expected height and
|
||||
reduce the time resolution.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
@ -20,7 +42,32 @@ __all__ = [
|
||||
|
||||
|
||||
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol."""
|
||||
"""Standard implementation of the :class:`~batdetect2.typing.PreprocessorProtocol`.
|
||||
|
||||
Wraps all preprocessing stages as ``torch.nn.Module`` submodules so
|
||||
that parameters (e.g. PCEN filter coefficients) can be tracked and
|
||||
moved between devices.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : PreprocessingConfig
|
||||
Full pipeline configuration.
|
||||
input_samplerate : int
|
||||
Sample rate of the audio that will be passed to this preprocessor,
|
||||
in Hz.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_samplerate : int
|
||||
Sample rate of the input audio in Hz.
|
||||
output_samplerate : float
|
||||
Effective frame rate of the output spectrogram in frames per second.
|
||||
Computed from the STFT hop length and the time-axis resize factor.
|
||||
min_freq : float
|
||||
Lower bound of the retained frequency band in Hz.
|
||||
max_freq : float
|
||||
Upper bound of the retained frequency band in Hz.
|
||||
"""
|
||||
|
||||
input_samplerate: int
|
||||
output_samplerate: float
|
||||
@ -72,17 +119,75 @@ class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
)
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the full preprocessing pipeline on a waveform.
|
||||
|
||||
Applies audio transforms, then the STFT, then
|
||||
:meth:`process_spectrogram`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : torch.Tensor
|
||||
Input waveform of shape ``(samples,)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Preprocessed spectrogram of shape
|
||||
``(freq_bins, time_frames)``.
|
||||
"""
|
||||
wav = self.audio_transforms(wav)
|
||||
spec = self.spectrogram_builder(wav)
|
||||
return self.process_spectrogram(spec)
|
||||
|
||||
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute the raw STFT spectrogram without any further processing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : torch.Tensor
|
||||
Input waveform of shape ``(samples,)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Amplitude spectrogram of shape ``(n_fft//2 + 1, time_frames)``
|
||||
with no frequency cropping, normalisation, or resizing applied.
|
||||
"""
|
||||
return self.spectrogram_builder(wav)
|
||||
|
||||
def process_audio(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Alias for :meth:`forward`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : torch.Tensor
|
||||
Input waveform of shape ``(samples,)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Preprocessed spectrogram (same as calling the object directly).
|
||||
"""
|
||||
return self(wav)
|
||||
|
||||
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply the post-STFT processing stages to an existing spectrogram.
|
||||
|
||||
Applies frequency cropping, spectrogram-level transforms (e.g.
|
||||
PCEN, spectral mean subtraction), and the final resize step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Raw amplitude spectrogram of shape
|
||||
``(..., n_fft//2 + 1, time_frames)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Normalised and resized spectrogram of shape
|
||||
``(..., height, scaled_time_frames)``.
|
||||
"""
|
||||
spec = self.spectrogram_crop(spec)
|
||||
spec = self.spectrogram_transforms(spec)
|
||||
return self.spectrogram_resizer(spec)
|
||||
@ -92,6 +197,25 @@ def compute_output_samplerate(
|
||||
config: PreprocessingConfig,
|
||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> float:
|
||||
"""Compute the effective frame rate of the preprocessor's output.
|
||||
|
||||
The output frame rate (in frames per second) depends on the STFT hop
|
||||
length and the time-axis resize factor applied by the final resize step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : PreprocessingConfig
|
||||
Pipeline configuration.
|
||||
input_samplerate : int, default=256000
|
||||
Sample rate of the input audio in Hz.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Output frame rate in frames per second.
|
||||
For example, at the default settings (256 kHz, hop=128,
|
||||
resize_factor=0.5) this equals ``1000.0``.
|
||||
"""
|
||||
_, hop_size = _spec_params_from_config(
|
||||
config.stft, samplerate=input_samplerate
|
||||
)
|
||||
@ -103,7 +227,24 @@ def build_preprocessor(
|
||||
config: PreprocessingConfig | None = None,
|
||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> PreprocessorProtocol:
|
||||
"""Factory function to build the standard preprocessor from configuration."""
|
||||
"""Build the standard preprocessor from a configuration object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : PreprocessingConfig, optional
|
||||
Pipeline configuration. If ``None``, the default
|
||||
``PreprocessingConfig()`` is used (PCEN + spectral mean
|
||||
subtraction, 256 kHz, standard STFT parameters).
|
||||
input_samplerate : int, default=256000
|
||||
Sample rate of the audio that will be fed to the preprocessor,
|
||||
in Hz.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PreprocessorProtocol
|
||||
A :class:`Preprocessor` instance ready to convert waveforms to
|
||||
spectrograms.
|
||||
"""
|
||||
config = config or PreprocessingConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building preprocessor with config: \n{}",
|
||||
|
||||
@ -1,4 +1,14 @@
|
||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||
"""Computes spectrograms from audio waveforms with configurable parameters.
|
||||
|
||||
This module defines the STFT-based spectrogram builder and a collection of
|
||||
spectrogram-level transforms (PCEN, spectral mean subtraction, amplitude
|
||||
scaling, peak normalisation, frequency cropping, and resizing) that form the
|
||||
signal-processing stage of the batdetect2 preprocessing pipeline.
|
||||
|
||||
Each transform is paired with a Pydantic configuration class and registered
|
||||
in the ``spectrogram_transforms`` registry so that the pipeline can be fully
|
||||
specified via a YAML or Python configuration object.
|
||||
"""
|
||||
|
||||
from typing import Annotated, Callable, Literal
|
||||
|
||||
@ -32,17 +42,22 @@ class STFTConfig(BaseConfig):
|
||||
Attributes
|
||||
----------
|
||||
window_duration : float, default=0.002
|
||||
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
|
||||
> 0. Determines frequency resolution (longer window = finer frequency
|
||||
resolution).
|
||||
Duration of the STFT analysis window in seconds (e.g. 0.002 for
|
||||
2 ms). Must be > 0. A longer window gives finer frequency resolution
|
||||
but coarser time resolution.
|
||||
window_overlap : float, default=0.75
|
||||
Fraction of overlap between consecutive STFT windows (e.g., 0.75
|
||||
for 75%). Must be >= 0 and < 1. Determines time resolution
|
||||
(higher overlap = finer time resolution).
|
||||
Fraction of overlap between consecutive windows (e.g. 0.75 for
|
||||
75 %). Must be >= 0 and < 1. Higher overlap gives finer time
|
||||
resolution at the cost of more computation.
|
||||
window_fn : str, default="hann"
|
||||
Name of the window function to apply before FFT calculation. Common
|
||||
options include "hann", "hamming", "blackman". See
|
||||
`scipy.signal.get_window`.
|
||||
Name of the tapering window applied to each frame before the FFT.
|
||||
Supported values: ``"hann"``, ``"hamming"``, ``"kaiser"``,
|
||||
``"blackman"``, ``"bartlett"``.
|
||||
|
||||
Notes
|
||||
-----
|
||||
At the default sample rate of 256 kHz, ``window_duration=0.002`` and
|
||||
``window_overlap=0.75`` give ``n_fft=512`` and ``hop_length=128``.
|
||||
"""
|
||||
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
@ -54,6 +69,23 @@ def build_spectrogram_builder(
|
||||
config: STFTConfig,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> torch.nn.Module:
|
||||
"""Build a torchaudio STFT spectrogram module from an ``STFTConfig``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : STFTConfig
|
||||
STFT parameters (window duration, overlap, and window function).
|
||||
samplerate : int, default=256000
|
||||
Sample rate of the input audio in Hz. Used to convert the
|
||||
window duration into a number of samples.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.nn.Module
|
||||
A ``torchaudio.transforms.Spectrogram`` module configured to
|
||||
produce an amplitude (``power=1``) spectrogram with centred
|
||||
frames.
|
||||
"""
|
||||
n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate)
|
||||
return torchaudio.transforms.Spectrogram(
|
||||
n_fft=n_fft,
|
||||
@ -65,6 +97,25 @@ def build_spectrogram_builder(
|
||||
|
||||
|
||||
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||
"""Return the PyTorch window function matching the given name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Name of the window function. One of ``"hann"``, ``"hamming"``,
|
||||
``"kaiser"``, ``"blackman"``, or ``"bartlett"``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable[..., torch.Tensor]
|
||||
A PyTorch window function that accepts a window length and returns
|
||||
a 1-D tensor of weights.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If ``name`` does not match any supported window function.
|
||||
"""
|
||||
if name == "hann":
|
||||
return torch.hann_window
|
||||
|
||||
@ -88,7 +139,22 @@ def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||
def _spec_params_from_config(
|
||||
config: STFTConfig,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
):
|
||||
) -> tuple[int, int]:
|
||||
"""Compute ``n_fft`` and ``hop_length`` from an ``STFTConfig``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : STFTConfig
|
||||
STFT parameters.
|
||||
samplerate : int, default=256000
|
||||
Sample rate of the input audio in Hz.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[int, int]
|
||||
A pair ``(n_fft, hop_length)`` giving the FFT size and the step
|
||||
between consecutive frames in samples.
|
||||
"""
|
||||
n_fft = int(samplerate * config.window_duration)
|
||||
hop_length = int(n_fft * (1 - config.window_overlap))
|
||||
return n_fft, hop_length
|
||||
@ -99,6 +165,24 @@ def _frequency_to_index(
|
||||
n_fft: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> int | None:
|
||||
"""Convert a frequency in Hz to the nearest STFT frequency bin index.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : float
|
||||
Frequency in Hz to convert.
|
||||
n_fft : int
|
||||
FFT size used by the STFT.
|
||||
samplerate : int, default=256000
|
||||
Sample rate of the audio in Hz.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int or None
|
||||
The bin index corresponding to ``freq``, or ``None`` if the
|
||||
frequency is outside the valid range (i.e. <= 0 Hz or >= the
|
||||
Nyquist frequency).
|
||||
"""
|
||||
alpha = freq * 2 / samplerate
|
||||
height = np.floor(n_fft / 2) + 1
|
||||
index = int(np.floor(alpha * height))
|
||||
@ -118,11 +202,11 @@ class FrequencyConfig(BaseConfig):
|
||||
Attributes
|
||||
----------
|
||||
max_freq : int, default=120000
|
||||
Maximum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies above this value will be cropped. Must be > 0.
|
||||
Maximum frequency in Hz to retain after STFT. Frequency bins
|
||||
above this value are discarded. Must be >= 0.
|
||||
min_freq : int, default=10000
|
||||
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies below this value will be cropped. Must be >= 0.
|
||||
Minimum frequency in Hz to retain after STFT. Frequency bins
|
||||
below this value are discarded. Must be >= 0.
|
||||
"""
|
||||
|
||||
max_freq: int = Field(default=MAX_FREQ, ge=0)
|
||||
@ -130,6 +214,27 @@ class FrequencyConfig(BaseConfig):
|
||||
|
||||
|
||||
class FrequencyCrop(torch.nn.Module):
|
||||
"""Crop a spectrogram to a specified frequency band.
|
||||
|
||||
On construction the Hz boundaries are converted to STFT bin indices.
|
||||
During the forward pass the spectrogram is sliced along its
|
||||
frequency axis (second-to-last dimension) to retain only the bins
|
||||
that fall within ``[min_freq, max_freq)``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
samplerate : int
|
||||
Sample rate of the audio in Hz.
|
||||
n_fft : int
|
||||
FFT size used by the STFT.
|
||||
min_freq : int, optional
|
||||
Lower frequency bound in Hz. If ``None``, no lower crop is
|
||||
applied and the DC bin is retained.
|
||||
max_freq : int, optional
|
||||
Upper frequency bound in Hz. If ``None``, no upper crop is
|
||||
applied and all bins up to Nyquist are retained.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int,
|
||||
@ -162,6 +267,19 @@ class FrequencyCrop(torch.nn.Module):
|
||||
self.high_index = high_index
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Crop the spectrogram to the configured frequency band.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Spectrogram tensor of shape ``(..., freq_bins, time_frames)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Cropped spectrogram with shape
|
||||
``(..., n_retained_bins, time_frames)``.
|
||||
"""
|
||||
low_index = self.low_index
|
||||
if low_index is None:
|
||||
low_index = 0
|
||||
@ -184,6 +302,24 @@ def build_spectrogram_crop(
|
||||
stft: STFTConfig | None = None,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> torch.nn.Module:
|
||||
"""Build a ``FrequencyCrop`` module from configuration objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : FrequencyConfig
|
||||
Frequency boundary configuration specifying ``min_freq`` and
|
||||
``max_freq`` in Hz.
|
||||
stft : STFTConfig, optional
|
||||
STFT configuration used to derive ``n_fft``. Defaults to
|
||||
``STFTConfig()`` if not provided.
|
||||
samplerate : int, default=256000
|
||||
Sample rate of the audio in Hz.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.nn.Module
|
||||
A ``FrequencyCrop`` module ready to crop spectrograms.
|
||||
"""
|
||||
stft = stft or STFTConfig()
|
||||
n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate)
|
||||
return FrequencyCrop(
|
||||
@ -195,18 +331,61 @@ def build_spectrogram_crop(
|
||||
|
||||
|
||||
class ResizeConfig(BaseConfig):
|
||||
"""Configuration for the final spectrogram resize step.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Fixed identifier; always ``"resize_spec"``.
|
||||
height : int, default=128
|
||||
Target number of frequency bins in the output spectrogram.
|
||||
The spectrogram is resized to this height using bilinear
|
||||
interpolation.
|
||||
resize_factor : float, default=0.5
|
||||
Fraction by which the time axis is scaled. For example, ``0.5``
|
||||
halves the number of time frames, reducing computational cost
|
||||
downstream.
|
||||
"""
|
||||
|
||||
name: Literal["resize_spec"] = "resize_spec"
|
||||
height: int = 128
|
||||
resize_factor: float = 0.5
|
||||
|
||||
|
||||
class ResizeSpec(torch.nn.Module):
|
||||
"""Resize a spectrogram to a fixed height and scaled width.
|
||||
|
||||
Uses bilinear interpolation so it handles arbitrary input shapes
|
||||
gracefully. Input tensors with fewer than four dimensions are
|
||||
temporarily unsqueezed to satisfy ``torch.nn.functional.interpolate``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
height : int
|
||||
Target number of frequency bins (output height).
|
||||
time_factor : float
|
||||
Multiplicative scaling applied to the time axis length.
|
||||
"""
|
||||
|
||||
def __init__(self, height: int, time_factor: float):
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.time_factor = time_factor
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Resize the spectrogram to the configured output dimensions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram of shape ``(..., freq_bins, time_frames)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Resized spectrogram with shape
|
||||
``(..., height, int(time_factor * time_frames))``.
|
||||
"""
|
||||
current_length = spec.shape[-1]
|
||||
target_length = int(self.time_factor * current_length)
|
||||
|
||||
@ -227,6 +406,18 @@ class ResizeSpec(torch.nn.Module):
|
||||
|
||||
|
||||
def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module:
|
||||
"""Build a ``ResizeSpec`` module from a ``ResizeConfig``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ResizeConfig
|
||||
Resize configuration specifying ``height`` and ``resize_factor``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.nn.Module
|
||||
A ``ResizeSpec`` module configured with the given parameters.
|
||||
"""
|
||||
return ResizeSpec(height=config.height, time_factor=config.resize_factor)
|
||||
|
||||
|
||||
@ -236,7 +427,32 @@ spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||
|
||||
|
||||
class PcenConfig(BaseConfig):
|
||||
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
||||
"""Configuration for Per-Channel Energy Normalisation (PCEN).
|
||||
|
||||
PCEN is a frontend processing technique that replaces simple log
|
||||
compression. It applies a learnable automatic gain control followed
|
||||
by a stabilised root compression, making the representation more
|
||||
robust to variations in recording level.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Fixed identifier; always ``"pcen"``.
|
||||
time_constant : float, default=0.4
|
||||
Time constant (in seconds) of the IIR smoothing filter used
|
||||
for the background estimate. Larger values produce a slower-
|
||||
adapting background.
|
||||
gain : float, default=0.98
|
||||
Exponent controlling how strongly the background estimate
|
||||
suppresses the signal.
|
||||
bias : float, default=2
|
||||
Stabilisation bias added inside the root-compression step to
|
||||
avoid division by zero.
|
||||
power : float, default=0.5
|
||||
Root-compression exponent. A value of 0.5 gives square-root
|
||||
compression, similar to log compression but differentiable at
|
||||
zero.
|
||||
"""
|
||||
|
||||
name: Literal["pcen"] = "pcen"
|
||||
time_constant: float = 0.4
|
||||
@ -246,6 +462,35 @@ class PcenConfig(BaseConfig):
|
||||
|
||||
|
||||
class PCEN(torch.nn.Module):
|
||||
"""Per-Channel Energy Normalisation (PCEN) transform.
|
||||
|
||||
Applies automatic gain control and root compression to a spectrogram.
|
||||
The background estimate is computed with a first-order IIR filter
|
||||
applied along the time axis.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
smoothing_constant : float
|
||||
IIR filter coefficient ``alpha``. Derived from the time constant
|
||||
and sample rate via ``_compute_smoothing_constant``.
|
||||
gain : float, default=0.98
|
||||
AGC gain exponent.
|
||||
bias : float, default=2.0
|
||||
Root-compression stabilisation bias.
|
||||
power : float, default=0.5
|
||||
Root-compression exponent.
|
||||
eps : float, default=1e-6
|
||||
Small constant for numerical stability.
|
||||
dtype : torch.dtype, default=torch.float32
|
||||
Floating-point precision used for internal computation.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The smoothing constant is computed to match the original batdetect2
|
||||
implementation for numerical compatibility. See
|
||||
``_compute_smoothing_constant`` for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
smoothing_constant: float,
|
||||
@ -269,6 +514,20 @@ class PCEN(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply PCEN to a spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input amplitude spectrogram of shape
|
||||
``(..., freq_bins, time_frames)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
PCEN-normalised spectrogram with the same shape and dtype as
|
||||
the input.
|
||||
"""
|
||||
S = spec.to(self.dtype) * 2**31
|
||||
|
||||
M = (
|
||||
@ -305,21 +564,73 @@ def _compute_smoothing_constant(
|
||||
samplerate: int,
|
||||
time_constant: float,
|
||||
) -> float:
|
||||
# NOTE: These were taken to match the original implementation
|
||||
"""Compute the IIR smoothing coefficient for PCEN.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
samplerate : int
|
||||
Sample rate of the audio in Hz.
|
||||
time_constant : float
|
||||
Desired smoothing time constant in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
IIR filter coefficient ``alpha`` used by ``PCEN``.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The hop length (512) and the sample-rate divisor (10) are fixed to
|
||||
reproduce the numerical behaviour of the original batdetect2
|
||||
implementation, which used ``librosa.pcen`` with ``sr=samplerate/10``
|
||||
and the default ``hop_length=512``. These values do not reflect the
|
||||
actual STFT hop length used in the pipeline; they are retained
|
||||
solely for backward compatibility.
|
||||
"""
|
||||
# NOTE: These parameters are fixed to match the original implementation.
|
||||
hop_length = 512
|
||||
sr = samplerate / 10
|
||||
time_constant = time_constant
|
||||
t_frames = time_constant * sr / float(hop_length)
|
||||
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
|
||||
|
||||
class ScaleAmplitudeConfig(BaseConfig):
|
||||
"""Configuration for amplitude scaling of a spectrogram.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Fixed identifier; always ``"scale_amplitude"``.
|
||||
scale : str, default="db"
|
||||
Scaling mode. Either ``"db"`` (convert amplitude to decibels
|
||||
using ``torchaudio.transforms.AmplitudeToDB``) or ``"power"``
|
||||
(square the amplitude values).
|
||||
"""
|
||||
|
||||
name: Literal["scale_amplitude"] = "scale_amplitude"
|
||||
scale: Literal["power", "db"] = "db"
|
||||
|
||||
|
||||
class ToPower(torch.nn.Module):
|
||||
"""Square the values of a spectrogram (amplitude → power).
|
||||
|
||||
Raises each element to the power of two, converting an amplitude
|
||||
spectrogram into a power spectrogram.
|
||||
"""
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Square all elements of the spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input amplitude spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Power spectrogram (same shape as input).
|
||||
"""
|
||||
return spec**2
|
||||
|
||||
|
||||
@ -330,12 +641,34 @@ _scalers = {
|
||||
|
||||
|
||||
class ScaleAmplitude(torch.nn.Module):
|
||||
"""Convert spectrogram amplitude values to a different scale.
|
||||
|
||||
Supports conversion to decibels (dB) or to power (squared amplitude).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scale : str
|
||||
Either ``"db"`` or ``"power"``.
|
||||
"""
|
||||
|
||||
def __init__(self, scale: Literal["power", "db"]):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.scaler = _scalers[scale]()
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply the configured amplitude scaling.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Scaled spectrogram with the same shape as the input.
|
||||
"""
|
||||
return self.scaler(spec)
|
||||
|
||||
@spectrogram_transforms.register(ScaleAmplitudeConfig)
|
||||
@ -344,30 +677,86 @@ class ScaleAmplitude(torch.nn.Module):
|
||||
return ScaleAmplitude(scale=config.scale)
|
||||
|
||||
|
||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
|
||||
class SpectralMeanSubtractionConfig(BaseConfig):
|
||||
"""Configuration for spectral mean subtraction.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Fixed identifier; always ``"spectral_mean_subtraction"``.
|
||||
"""
|
||||
|
||||
name: Literal["spectral_mean_subtraction"] = "spectral_mean_subtraction"
|
||||
|
||||
|
||||
class SpectralMeanSubstraction(torch.nn.Module):
|
||||
class SpectralMeanSubtraction(torch.nn.Module):
|
||||
"""Remove the time-averaged background noise from a spectrogram.
|
||||
|
||||
For each frequency bin, the mean value across all time frames is
|
||||
computed and subtracted. The result is then clamped to zero so that
|
||||
no values fall below the baseline. This is a simple form of spectral
|
||||
denoising that suppresses stationary background noise.
|
||||
"""
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Subtract the time-axis mean from each frequency bin.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram of shape ``(..., freq_bins, time_frames)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Denoised spectrogram with the same shape as the input. All
|
||||
values are non-negative (clamped to 0).
|
||||
"""
|
||||
mean = spec.mean(-1, keepdim=True)
|
||||
return (spec - mean).clamp(min=0)
|
||||
|
||||
@spectrogram_transforms.register(SpectralMeanSubstractionConfig)
|
||||
@spectrogram_transforms.register(SpectralMeanSubtractionConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: SpectralMeanSubstractionConfig,
|
||||
config: SpectralMeanSubtractionConfig,
|
||||
samplerate: int,
|
||||
):
|
||||
return SpectralMeanSubstraction()
|
||||
return SpectralMeanSubtraction()
|
||||
|
||||
|
||||
class PeakNormalizeConfig(BaseConfig):
|
||||
"""Configuration for peak normalisation of a spectrogram.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Fixed identifier; always ``"peak_normalize"``.
|
||||
"""
|
||||
|
||||
name: Literal["peak_normalize"] = "peak_normalize"
|
||||
|
||||
|
||||
class PeakNormalize(torch.nn.Module):
|
||||
"""Scale a spectrogram so that its largest absolute value equals one.
|
||||
|
||||
Wraps :func:`batdetect2.preprocess.common.peak_normalize` as a
|
||||
``torch.nn.Module`` for use inside a sequential transform pipeline.
|
||||
"""
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Peak-normalise the spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor of any shape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Normalised spectrogram where the maximum absolute value is 1.
|
||||
If the input is identically zero, it is returned unchanged.
|
||||
"""
|
||||
return peak_normalize(spec)
|
||||
|
||||
@spectrogram_transforms.register(PeakNormalizeConfig)
|
||||
@ -379,14 +768,37 @@ class PeakNormalize(torch.nn.Module):
|
||||
SpectrogramTransform = Annotated[
|
||||
PcenConfig
|
||||
| ScaleAmplitudeConfig
|
||||
| SpectralMeanSubstractionConfig
|
||||
| SpectralMeanSubtractionConfig
|
||||
| PeakNormalizeConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Discriminated union of all spectrogram transform configuration types.
|
||||
|
||||
Use this type when a field should accept any of the supported spectrogram
|
||||
transforms. Pydantic will select the correct config class based on the
|
||||
``name`` field.
|
||||
"""
|
||||
|
||||
|
||||
def build_spectrogram_transform(
|
||||
config: SpectrogramTransform,
|
||||
samplerate: int,
|
||||
) -> torch.nn.Module:
|
||||
"""Build a spectrogram transform module from a configuration object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : SpectrogramTransform
|
||||
A configuration object for one of the supported spectrogram
|
||||
transforms (PCEN, amplitude scaling, spectral mean subtraction,
|
||||
or peak normalisation).
|
||||
samplerate : int
|
||||
Sample rate of the audio in Hz. Some transforms (e.g. PCEN) use
|
||||
this to set internal parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.nn.Module
|
||||
The constructed transform module.
|
||||
"""
|
||||
return spectrogram_transforms.build(config, samplerate)
|
||||
|
||||
@ -1,12 +1,31 @@
|
||||
"""Tests for audio-level preprocessing transforms.
|
||||
|
||||
Covers :mod:`batdetect2.preprocess.audio` and the shared helper functions
|
||||
in :mod:`batdetect2.preprocess.common`.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.preprocess.audio import (
|
||||
CenterAudio,
|
||||
CenterAudioConfig,
|
||||
FixDuration,
|
||||
FixDurationConfig,
|
||||
ScaleAudio,
|
||||
ScaleAudioConfig,
|
||||
build_audio_transform,
|
||||
)
|
||||
from batdetect2.preprocess.common import center_tensor, peak_normalize
|
||||
|
||||
SAMPLERATE = 256_000
|
||||
|
||||
|
||||
def create_dummy_wave(
|
||||
@ -15,9 +34,9 @@ def create_dummy_wave(
|
||||
num_channels: int = 1,
|
||||
freq: float = 440.0,
|
||||
amplitude: float = 0.5,
|
||||
dtype: np.dtype = np.float32,
|
||||
dtype: type = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""Generates a simple numpy waveform."""
|
||||
"""Generate a simple sine-wave waveform as a NumPy array."""
|
||||
t = np.linspace(
|
||||
0.0, duration, int(samplerate * duration), endpoint=False, dtype=dtype
|
||||
)
|
||||
@ -29,7 +48,7 @@ def create_dummy_wave(
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||
"""Creates a dummy WAV file and returns its path."""
|
||||
"""Create a dummy 2-channel WAV file and return its path."""
|
||||
samplerate = 48000
|
||||
duration = 2.0
|
||||
num_channels = 2
|
||||
@ -41,13 +60,13 @@ def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_recording(dummy_wav_path: pathlib.Path) -> data.Recording:
|
||||
"""Creates a Recording object pointing to the dummy WAV file."""
|
||||
"""Create a Recording object pointing to the dummy WAV file."""
|
||||
return data.Recording.from_file(dummy_wav_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
|
||||
"""Creates a Clip object from the dummy recording."""
|
||||
"""Create a Clip object from the dummy recording."""
|
||||
return data.Clip(
|
||||
recording=dummy_recording,
|
||||
start_time=0.5,
|
||||
@ -58,3 +77,165 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
|
||||
@pytest.fixture
|
||||
def default_audio_config() -> AudioConfig:
|
||||
return AudioConfig()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# center_tensor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_center_tensor_zero_mean():
|
||||
"""Output tensor should have a mean very close to zero."""
|
||||
wav = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
result = center_tensor(wav)
|
||||
assert result.mean().abs().item() < 1e-5
|
||||
|
||||
|
||||
def test_center_tensor_preserves_shape():
|
||||
wav = torch.randn(3, 1000)
|
||||
result = center_tensor(wav)
|
||||
assert result.shape == wav.shape
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# peak_normalize
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_peak_normalize_max_is_one():
|
||||
"""After peak normalisation, the maximum absolute value should be 1."""
|
||||
wav = torch.tensor([0.1, -0.4, 0.2, 0.8, -0.3])
|
||||
result = peak_normalize(wav)
|
||||
assert abs(result.abs().max().item() - 1.0) < 1e-6
|
||||
|
||||
|
||||
def test_peak_normalize_zero_tensor_unchanged():
|
||||
"""A zero tensor should be returned unchanged (no division by zero)."""
|
||||
wav = torch.zeros(100)
|
||||
result = peak_normalize(wav)
|
||||
assert (result == 0).all()
|
||||
|
||||
|
||||
def test_peak_normalize_preserves_sign():
|
||||
"""Signs of all elements should be preserved after normalisation."""
|
||||
wav = torch.tensor([-2.0, 1.0, -0.5])
|
||||
result = peak_normalize(wav)
|
||||
assert (result < 0).sum() == 2
|
||||
assert result[0].item() < 0
|
||||
|
||||
|
||||
def test_peak_normalize_preserves_shape():
|
||||
wav = torch.randn(2, 512)
|
||||
result = peak_normalize(wav)
|
||||
assert result.shape == wav.shape
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CenterAudio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_center_audio_forward_zero_mean():
|
||||
module = CenterAudio()
|
||||
wav = torch.tensor([1.0, 3.0, 5.0])
|
||||
result = module(wav)
|
||||
assert result.mean().abs().item() < 1e-5
|
||||
|
||||
|
||||
def test_center_audio_from_config():
|
||||
config = CenterAudioConfig()
|
||||
module = CenterAudio.from_config(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, CenterAudio)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ScaleAudio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_scale_audio_peak_normalises_to_one():
|
||||
"""ScaleAudio.forward should scale the peak absolute value to 1."""
|
||||
module = ScaleAudio()
|
||||
wav = torch.tensor([0.0, 0.25, -0.5, 0.1])
|
||||
result = module(wav)
|
||||
assert abs(result.abs().max().item() - 1.0) < 1e-6
|
||||
|
||||
|
||||
def test_scale_audio_handles_zero_tensor():
|
||||
"""ScaleAudio should not raise on a zero tensor."""
|
||||
module = ScaleAudio()
|
||||
wav = torch.zeros(100)
|
||||
result = module(wav)
|
||||
assert (result == 0).all()
|
||||
|
||||
|
||||
def test_scale_audio_from_config():
|
||||
config = ScaleAudioConfig()
|
||||
module = ScaleAudio.from_config(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, ScaleAudio)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FixDuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_fix_duration_truncates_long_input():
|
||||
"""Waveform longer than target should be truncated to the target length."""
|
||||
target_samples = int(SAMPLERATE * 0.5)
|
||||
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
|
||||
wav = torch.randn(target_samples + 1000)
|
||||
result = module(wav)
|
||||
assert result.shape[-1] == target_samples
|
||||
|
||||
|
||||
def test_fix_duration_pads_short_input():
|
||||
"""Waveform shorter than target should be zero-padded to the target length."""
|
||||
target_samples = int(SAMPLERATE * 0.5)
|
||||
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
|
||||
short_wav = torch.randn(target_samples - 100)
|
||||
result = module(short_wav)
|
||||
assert result.shape[-1] == target_samples
|
||||
# Padded region should be zero
|
||||
assert (result[target_samples - 100 :] == 0).all()
|
||||
|
||||
|
||||
def test_fix_duration_passthrough_exact_length():
|
||||
"""Waveform with exactly the right length should be returned unchanged."""
|
||||
target_samples = int(SAMPLERATE * 0.5)
|
||||
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
|
||||
wav = torch.randn(target_samples)
|
||||
result = module(wav)
|
||||
assert result.shape[-1] == target_samples
|
||||
assert torch.equal(result, wav)
|
||||
|
||||
|
||||
def test_fix_duration_from_config():
|
||||
"""FixDurationConfig should produce a FixDuration with the correct length."""
|
||||
config = FixDurationConfig(duration=0.256)
|
||||
module = FixDuration.from_config(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, FixDuration)
|
||||
assert module.length == int(SAMPLERATE * 0.256)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_audio_transform dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_audio_transform_center_audio():
|
||||
config = CenterAudioConfig()
|
||||
module = build_audio_transform(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, CenterAudio)
|
||||
|
||||
|
||||
def test_build_audio_transform_scale_audio():
|
||||
config = ScaleAudioConfig()
|
||||
module = build_audio_transform(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, ScaleAudio)
|
||||
|
||||
|
||||
def test_build_audio_transform_fix_duration():
|
||||
config = FixDurationConfig(duration=0.5)
|
||||
module = build_audio_transform(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, FixDuration)
|
||||
|
||||
243
tests/test_preprocessing/test_preprocessor.py
Normal file
243
tests/test_preprocessing/test_preprocessor.py
Normal file
@ -0,0 +1,243 @@
|
||||
"""Integration and unit tests for the Preprocessor pipeline.
|
||||
|
||||
Covers :mod:`batdetect2.preprocess.preprocessor` — construction,
|
||||
pipeline output shape/dtype, the ``process_numpy`` helper, attribute
|
||||
values, output frame rate, and a round-trip YAML → config → build test.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from batdetect2.preprocess.audio import FixDurationConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.preprocess.preprocessor import (
|
||||
Preprocessor,
|
||||
build_preprocessor,
|
||||
compute_output_samplerate,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
ResizeConfig,
|
||||
SpectralMeanSubtractionConfig,
|
||||
STFTConfig,
|
||||
)
|
||||
|
||||
SAMPLERATE = 256_000
|
||||
# 0.256 s at 256 kHz = 65536 samples — a convenient power-of-two-sized clip
|
||||
CLIP_SAMPLES = int(SAMPLERATE * 0.256)
|
||||
|
||||
|
||||
def make_sine_wav(
|
||||
samplerate: int = SAMPLERATE,
|
||||
duration: float = 0.256,
|
||||
freq: float = 40_000.0,
|
||||
) -> torch.Tensor:
|
||||
"""Return a single-channel sine-wave tensor."""
|
||||
t = torch.linspace(0.0, duration, int(samplerate * duration))
|
||||
return torch.sin(2 * torch.pi * freq * t)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_preprocessor — construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_preprocessor_returns_protocol():
|
||||
"""build_preprocessor should return a Preprocessor instance."""
|
||||
preprocessor = build_preprocessor()
|
||||
assert isinstance(preprocessor, Preprocessor)
|
||||
|
||||
|
||||
def test_build_preprocessor_with_default_config():
|
||||
"""build_preprocessor() with no arguments should not raise."""
|
||||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||
assert preprocessor is not None
|
||||
|
||||
|
||||
def test_build_preprocessor_with_explicit_config():
|
||||
config = PreprocessingConfig(
|
||||
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
|
||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||
size=ResizeConfig(height=128, resize_factor=0.5),
|
||||
spectrogram_transforms=[PcenConfig(), SpectralMeanSubtractionConfig()],
|
||||
)
|
||||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||
assert isinstance(preprocessor, Preprocessor)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Output shape and dtype
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_preprocessor_output_is_2d():
|
||||
"""The preprocessor output should be a 2-D tensor (freq_bins × time_frames)."""
|
||||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||
wav = make_sine_wav()
|
||||
result = preprocessor(wav)
|
||||
assert result.ndim == 2
|
||||
|
||||
|
||||
def test_preprocessor_output_height_matches_config():
|
||||
"""Output height should match the ResizeConfig.height setting."""
|
||||
config = PreprocessingConfig(
|
||||
size=ResizeConfig(height=64, resize_factor=0.5)
|
||||
)
|
||||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||
wav = make_sine_wav()
|
||||
result = preprocessor(wav)
|
||||
assert result.shape[0] == 64
|
||||
|
||||
|
||||
def test_preprocessor_output_dtype_float32():
|
||||
"""Output tensor should be float32."""
|
||||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||
wav = make_sine_wav()
|
||||
result = preprocessor(wav)
|
||||
assert result.dtype == torch.float32
|
||||
|
||||
|
||||
def test_preprocessor_output_is_finite():
|
||||
"""Output spectrogram should contain no NaN or Inf values."""
|
||||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||
wav = make_sine_wav()
|
||||
result = preprocessor(wav)
|
||||
assert torch.isfinite(result).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_numpy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_preprocessor_process_numpy_accepts_ndarray():
|
||||
"""process_numpy should accept a NumPy array and return a NumPy array."""
|
||||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||
wav_np = make_sine_wav().numpy()
|
||||
result = preprocessor.process_numpy(wav_np)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
|
||||
def test_preprocessor_process_numpy_matches_forward():
|
||||
"""process_numpy and forward should give numerically identical results."""
|
||||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||
wav = make_sine_wav()
|
||||
result_pt = preprocessor(wav).numpy()
|
||||
result_np = preprocessor.process_numpy(wav.numpy())
|
||||
np.testing.assert_array_almost_equal(result_pt, result_np)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attributes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_preprocessor_min_max_freq_attributes():
|
||||
"""min_freq and max_freq should match the FrequencyConfig values."""
|
||||
config = PreprocessingConfig(
|
||||
frequencies=FrequencyConfig(min_freq=15_000, max_freq=100_000)
|
||||
)
|
||||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||
assert preprocessor.min_freq == 15_000
|
||||
assert preprocessor.max_freq == 100_000
|
||||
|
||||
|
||||
def test_preprocessor_input_samplerate_attribute():
|
||||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||
assert preprocessor.input_samplerate == SAMPLERATE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_output_samplerate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_compute_output_samplerate_defaults():
|
||||
"""At default settings, output_samplerate should equal 1000 fps."""
|
||||
config = PreprocessingConfig()
|
||||
rate = compute_output_samplerate(config, input_samplerate=SAMPLERATE)
|
||||
assert abs(rate - 1000.0) < 1e-6
|
||||
|
||||
|
||||
def test_preprocessor_output_samplerate_attribute_matches_compute():
|
||||
config = PreprocessingConfig()
|
||||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||
expected = compute_output_samplerate(config, input_samplerate=SAMPLERATE)
|
||||
assert abs(preprocessor.output_samplerate - expected) < 1e-6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_spectrogram (raw STFT)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_generate_spectrogram_shape():
|
||||
"""generate_spectrogram should return the full STFT without crop or resize."""
|
||||
config = PreprocessingConfig()
|
||||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||
wav = make_sine_wav()
|
||||
spec = preprocessor.generate_spectrogram(wav)
|
||||
# Full STFT: n_fft//2 + 1 = 257 bins at defaults
|
||||
assert spec.shape[0] == 257
|
||||
|
||||
|
||||
def test_generate_spectrogram_larger_than_forward():
|
||||
"""Raw spectrogram should have more frequency bins than the processed output."""
|
||||
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||
wav = make_sine_wav()
|
||||
raw = preprocessor.generate_spectrogram(wav)
|
||||
processed = preprocessor(wav)
|
||||
assert raw.shape[0] > processed.shape[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio transforms pipeline (FixDuration)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_preprocessor_with_fix_duration_audio_transform():
|
||||
"""A FixDuration audio transform should produce consistent output shapes."""
|
||||
config = PreprocessingConfig(
|
||||
audio_transforms=[FixDurationConfig(duration=0.256)],
|
||||
)
|
||||
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||
# Feed different lengths; output shape should be the same after fix
|
||||
for n_samples in [CLIP_SAMPLES - 1000, CLIP_SAMPLES, CLIP_SAMPLES + 1000]:
|
||||
wav = torch.randn(n_samples)
|
||||
result = preprocessor(wav)
|
||||
assert result.ndim == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_preprocessor_yaml_roundtrip(tmp_path: pathlib.Path):
|
||||
"""PreprocessingConfig serialised to YAML and reloaded should produce
|
||||
a functionally identical preprocessor."""
|
||||
from batdetect2.preprocess.config import load_preprocessing_config
|
||||
|
||||
config = PreprocessingConfig(
|
||||
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
|
||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||
size=ResizeConfig(height=128, resize_factor=0.5),
|
||||
)
|
||||
|
||||
yaml_path = tmp_path / "preprocess_config.yaml"
|
||||
yaml_path.write_text(config.to_yaml_string())
|
||||
|
||||
loaded_config = load_preprocessing_config(yaml_path)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
loaded_config, input_samplerate=SAMPLERATE
|
||||
)
|
||||
wav = make_sine_wav()
|
||||
result = preprocessor(wav)
|
||||
|
||||
assert result.ndim == 2
|
||||
assert result.shape[0] == 128
|
||||
assert torch.isfinite(result).all()
|
||||
@ -1,37 +1,316 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
"""Tests for spectrogram-level preprocessing transforms.
|
||||
|
||||
SAMPLERATE = 250_000
|
||||
DURATION = 0.1
|
||||
TEST_FREQ = 30_000
|
||||
N_SAMPLES = int(SAMPLERATE * DURATION)
|
||||
TIME_COORD = np.linspace(
|
||||
0, DURATION, N_SAMPLES, endpoint=False, dtype=np.float32
|
||||
Covers :mod:`batdetect2.preprocess.spectrogram` — STFT configuration,
|
||||
frequency cropping, PCEN, spectral mean subtraction, amplitude scaling,
|
||||
peak normalisation, and resizing.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
PCEN,
|
||||
FrequencyConfig,
|
||||
FrequencyCrop,
|
||||
PcenConfig,
|
||||
PeakNormalize,
|
||||
PeakNormalizeConfig,
|
||||
ResizeConfig,
|
||||
ResizeSpec,
|
||||
ScaleAmplitude,
|
||||
ScaleAmplitudeConfig,
|
||||
SpectralMeanSubtraction,
|
||||
SpectralMeanSubtractionConfig,
|
||||
STFTConfig,
|
||||
_spec_params_from_config,
|
||||
build_spectrogram_builder,
|
||||
build_spectrogram_crop,
|
||||
build_spectrogram_resizer,
|
||||
build_spectrogram_transform,
|
||||
)
|
||||
|
||||
SAMPLERATE = 256_000
|
||||
|
||||
@pytest.fixture
|
||||
def sine_wave_xr() -> xr.DataArray:
|
||||
"""Generate a single sine wave as an xr.DataArray."""
|
||||
t = TIME_COORD
|
||||
wav_data = np.sin(2 * np.pi * TEST_FREQ * t, dtype=np.float32)
|
||||
return xr.DataArray(
|
||||
wav_data,
|
||||
coords={"time": t},
|
||||
dims=["time"],
|
||||
attrs={"samplerate": SAMPLERATE},
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# STFTConfig / _spec_params_from_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stft_config_defaults_give_correct_params():
|
||||
"""Default STFTConfig at 256 kHz should give n_fft=512, hop_length=128."""
|
||||
config = STFTConfig()
|
||||
n_fft, hop_length = _spec_params_from_config(config, samplerate=SAMPLERATE)
|
||||
assert n_fft == 512
|
||||
assert hop_length == 128
|
||||
|
||||
|
||||
def test_stft_config_custom_params():
|
||||
"""Custom window duration and overlap should produce the expected sizes."""
|
||||
config = STFTConfig(window_duration=0.004, window_overlap=0.5)
|
||||
n_fft, hop_length = _spec_params_from_config(config, samplerate=SAMPLERATE)
|
||||
assert n_fft == 1024
|
||||
assert hop_length == 512
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_spectrogram_builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_spectrogram_builder_output_shape():
|
||||
"""Builder should produce a spectrogram with the expected number of bins."""
|
||||
config = STFTConfig()
|
||||
n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE)
|
||||
expected_freq_bins = n_fft // 2 + 1 # 257 at defaults
|
||||
|
||||
builder = build_spectrogram_builder(config, samplerate=SAMPLERATE)
|
||||
n_samples = SAMPLERATE # 1 second of audio
|
||||
wav = torch.randn(n_samples)
|
||||
spec = builder(wav)
|
||||
|
||||
assert spec.ndim == 2
|
||||
assert spec.shape[0] == expected_freq_bins
|
||||
|
||||
|
||||
def test_spectrogram_builder_output_is_nonnegative():
|
||||
"""Amplitude spectrogram values should all be >= 0."""
|
||||
config = STFTConfig()
|
||||
builder = build_spectrogram_builder(config, samplerate=SAMPLERATE)
|
||||
wav = torch.randn(SAMPLERATE)
|
||||
spec = builder(wav)
|
||||
assert (spec >= 0).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FrequencyCrop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_frequency_crop_output_shape():
|
||||
"""FrequencyCrop should reduce the number of frequency bins."""
|
||||
config = STFTConfig()
|
||||
n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE)
|
||||
|
||||
crop = FrequencyCrop(
|
||||
samplerate=SAMPLERATE,
|
||||
n_fft=n_fft,
|
||||
min_freq=10_000,
|
||||
max_freq=120_000,
|
||||
)
|
||||
spec = torch.ones(n_fft // 2 + 1, 100)
|
||||
cropped = crop(spec)
|
||||
|
||||
assert cropped.ndim == 2
|
||||
# Must be smaller than the full spectrogram
|
||||
assert cropped.shape[0] < spec.shape[0]
|
||||
assert cropped.shape[1] == 100 # time axis unchanged
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def constant_wave_xr() -> xr.DataArray:
|
||||
"""Generate a constant signal as an xr.DataArray."""
|
||||
t = TIME_COORD
|
||||
wav_data = np.ones(N_SAMPLES, dtype=np.float32) * 0.5
|
||||
return xr.DataArray(
|
||||
wav_data,
|
||||
coords={"time": t},
|
||||
dims=["time"],
|
||||
attrs={"samplerate": SAMPLERATE},
|
||||
def test_frequency_crop_build_from_config():
|
||||
"""build_spectrogram_crop should return a working FrequencyCrop."""
|
||||
freq_config = FrequencyConfig(min_freq=10_000, max_freq=120_000)
|
||||
stft_config = STFTConfig()
|
||||
crop = build_spectrogram_crop(
|
||||
freq_config, stft=stft_config, samplerate=SAMPLERATE
|
||||
)
|
||||
assert isinstance(crop, FrequencyCrop)
|
||||
|
||||
|
||||
def test_frequency_crop_no_crop_when_bounds_are_none():
|
||||
"""FrequencyCrop with no bounds should return the full spectrogram."""
|
||||
config = STFTConfig()
|
||||
n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE)
|
||||
|
||||
crop = FrequencyCrop(samplerate=SAMPLERATE, n_fft=n_fft)
|
||||
spec = torch.ones(n_fft // 2 + 1, 100)
|
||||
cropped = crop(spec)
|
||||
|
||||
assert cropped.shape == spec.shape
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PCEN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_pcen_output_shape_preserved():
|
||||
"""PCEN should not change the shape of the spectrogram."""
|
||||
config = PcenConfig()
|
||||
pcen = PCEN.from_config(config, samplerate=SAMPLERATE)
|
||||
spec = torch.rand(64, 200) + 1e-6 # positive values
|
||||
result = pcen(spec)
|
||||
assert result.shape == spec.shape
|
||||
|
||||
|
||||
def test_pcen_output_is_finite():
|
||||
"""PCEN applied to a well-formed spectrogram should produce no NaN or Inf."""
|
||||
config = PcenConfig()
|
||||
pcen = PCEN.from_config(config, samplerate=SAMPLERATE)
|
||||
spec = torch.rand(64, 200) + 1e-6
|
||||
result = pcen(spec)
|
||||
assert torch.isfinite(result).all()
|
||||
|
||||
|
||||
def test_pcen_output_dtype_matches_input():
|
||||
"""PCEN should return a tensor with the same dtype as the input."""
|
||||
config = PcenConfig()
|
||||
pcen = PCEN.from_config(config, samplerate=SAMPLERATE)
|
||||
spec = torch.rand(64, 200, dtype=torch.float32)
|
||||
result = pcen(spec)
|
||||
assert result.dtype == spec.dtype
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SpectralMeanSubtraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_spectral_mean_subtraction_output_nonnegative():
|
||||
"""SpectralMeanSubtraction clamps output to >= 0."""
|
||||
module = SpectralMeanSubtraction()
|
||||
spec = torch.rand(64, 200)
|
||||
result = module(spec)
|
||||
assert (result >= 0).all()
|
||||
|
||||
|
||||
def test_spectral_mean_subtraction_shape_preserved():
|
||||
module = SpectralMeanSubtraction()
|
||||
spec = torch.rand(64, 200)
|
||||
result = module(spec)
|
||||
assert result.shape == spec.shape
|
||||
|
||||
|
||||
def test_spectral_mean_subtraction_reduces_time_mean():
|
||||
"""After subtraction the time-axis mean per bin should be <= 0 (pre-clamp)."""
|
||||
module = SpectralMeanSubtraction()
|
||||
# Constant spectrogram: mean subtraction should produce all zeros before clamp
|
||||
spec = torch.ones(32, 100) * 3.0
|
||||
result = module(spec)
|
||||
assert (result == 0).all()
|
||||
|
||||
|
||||
def test_spectral_mean_subtraction_from_config():
|
||||
config = SpectralMeanSubtractionConfig()
|
||||
module = SpectralMeanSubtraction.from_config(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, SpectralMeanSubtraction)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PeakNormalize (spectrogram-level)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_peak_normalize_spec_max_is_one():
|
||||
"""PeakNormalize should scale the spectrogram peak to 1."""
|
||||
module = PeakNormalize()
|
||||
spec = torch.rand(64, 200) * 5.0
|
||||
result = module(spec)
|
||||
assert abs(result.abs().max().item() - 1.0) < 1e-6
|
||||
|
||||
|
||||
def test_peak_normalize_spec_handles_zero():
|
||||
"""PeakNormalize on a zero spectrogram should not raise."""
|
||||
module = PeakNormalize()
|
||||
spec = torch.zeros(64, 200)
|
||||
result = module(spec)
|
||||
assert (result == 0).all()
|
||||
|
||||
|
||||
def test_peak_normalize_from_config():
|
||||
config = PeakNormalizeConfig()
|
||||
module = PeakNormalize.from_config(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, PeakNormalize)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ScaleAmplitude
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_scale_amplitude_db_output_is_finite():
|
||||
"""AmplitudeToDB scaling should produce finite values for positive input."""
|
||||
module = ScaleAmplitude(scale="db")
|
||||
spec = torch.rand(64, 200) + 1e-4
|
||||
result = module(spec)
|
||||
assert torch.isfinite(result).all()
|
||||
|
||||
|
||||
def test_scale_amplitude_power_output_equals_square():
|
||||
"""ScaleAmplitude('power') should square every element."""
|
||||
module = ScaleAmplitude(scale="power")
|
||||
spec = torch.tensor([[2.0, 3.0], [4.0, 5.0]])
|
||||
result = module(spec)
|
||||
expected = spec**2
|
||||
assert torch.allclose(result, expected)
|
||||
|
||||
|
||||
def test_scale_amplitude_from_config():
|
||||
config = ScaleAmplitudeConfig(scale="db")
|
||||
module = ScaleAmplitude.from_config(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, ScaleAmplitude)
|
||||
assert module.scale == "db"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ResizeSpec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resize_spec_output_shape():
|
||||
"""ResizeSpec should produce the target height and scaled width."""
|
||||
module = ResizeSpec(height=64, time_factor=0.5)
|
||||
spec = torch.rand(1, 128, 200) # (batch, freq, time)
|
||||
result = module(spec)
|
||||
assert result.shape == (1, 64, 100)
|
||||
|
||||
|
||||
def test_resize_spec_2d_input():
|
||||
"""ResizeSpec should handle 2-D input (no batch or channel dimensions)."""
|
||||
module = ResizeSpec(height=64, time_factor=0.5)
|
||||
spec = torch.rand(128, 200)
|
||||
result = module(spec)
|
||||
assert result.shape == (64, 100)
|
||||
|
||||
|
||||
def test_resize_spec_output_is_finite():
|
||||
module = ResizeSpec(height=128, time_factor=0.5)
|
||||
spec = torch.rand(128, 200)
|
||||
result = module(spec)
|
||||
assert torch.isfinite(result).all()
|
||||
|
||||
|
||||
def test_resize_spec_from_config():
|
||||
config = ResizeConfig(height=64, resize_factor=0.25)
|
||||
module = build_spectrogram_resizer(config)
|
||||
assert isinstance(module, ResizeSpec)
|
||||
assert module.height == 64
|
||||
assert module.time_factor == 0.25
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_spectrogram_transform dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_spectrogram_transform_pcen():
|
||||
config = PcenConfig()
|
||||
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, PCEN)
|
||||
|
||||
|
||||
def test_build_spectrogram_transform_spectral_mean_subtraction():
|
||||
config = SpectralMeanSubtractionConfig()
|
||||
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, SpectralMeanSubtraction)
|
||||
|
||||
|
||||
def test_build_spectrogram_transform_scale_amplitude():
|
||||
config = ScaleAmplitudeConfig(scale="db")
|
||||
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, ScaleAmplitude)
|
||||
|
||||
|
||||
def test_build_spectrogram_transform_peak_normalize():
|
||||
config = PeakNormalizeConfig()
|
||||
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
|
||||
assert isinstance(module, PeakNormalize)
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.preprocess import (
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
ScaleAmplitudeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
SpectralMeanSubtractionConfig,
|
||||
)
|
||||
from batdetect2.targets.rois import (
|
||||
DEFAULT_ANCHOR,
|
||||
@ -597,7 +597,7 @@ def test_build_roi_mapper_for_peak_energy_bbox():
|
||||
preproc_config = PreprocessingConfig(
|
||||
spectrogram_transforms=[
|
||||
ScaleAmplitudeConfig(scale="db"),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
SpectralMeanSubtractionConfig(),
|
||||
]
|
||||
)
|
||||
config = PeakEnergyBBoxMapperConfig(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user