mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add test for preprocessing
This commit is contained in:
parent
bfc88a4a0f
commit
46c02962f3
@ -21,7 +21,7 @@ preprocess:
|
|||||||
gain: 0.98
|
gain: 0.98
|
||||||
bias: 2
|
bias: 2
|
||||||
power: 0.5
|
power: 0.5
|
||||||
- name: spectral_mean_substraction
|
- name: spectral_mean_subtraction
|
||||||
|
|
||||||
postprocess:
|
postprocess:
|
||||||
nms_kernel_size: 9
|
nms_kernel_size: 9
|
||||||
|
|||||||
@ -1,3 +1,17 @@
|
|||||||
|
"""Audio-level transforms applied to waveforms before spectrogram computation.
|
||||||
|
|
||||||
|
This module defines ``torch.nn.Module`` transforms that operate on raw
|
||||||
|
audio tensors and the Pydantic configuration classes that control them.
|
||||||
|
Each transform is registered in the ``audio_transforms`` registry so that
|
||||||
|
the pipeline can be assembled from a configuration object.
|
||||||
|
|
||||||
|
The supported transforms are:
|
||||||
|
|
||||||
|
* ``CenterAudio`` — subtract the DC offset (mean) from the waveform.
|
||||||
|
* ``ScaleAudio`` — peak-normalise the waveform to the range ``[-1, 1]``.
|
||||||
|
* ``FixDuration`` — truncate or zero-pad the waveform to a fixed length.
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -18,14 +32,43 @@ __all__ = [
|
|||||||
audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||||
"audio_transform"
|
"audio_transform"
|
||||||
)
|
)
|
||||||
|
"""Registry mapping audio transform config classes to their builder methods."""
|
||||||
|
|
||||||
|
|
||||||
class CenterAudioConfig(BaseConfig):
|
class CenterAudioConfig(BaseConfig):
|
||||||
|
"""Configuration for the DC-offset removal transform.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Fixed identifier; always ``"center_audio"``.
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["center_audio"] = "center_audio"
|
name: Literal["center_audio"] = "center_audio"
|
||||||
|
|
||||||
|
|
||||||
class CenterAudio(torch.nn.Module):
|
class CenterAudio(torch.nn.Module):
|
||||||
|
"""Remove the DC offset from an audio waveform.
|
||||||
|
|
||||||
|
Subtracts the global mean of the waveform from every sample,
|
||||||
|
centring the signal around zero. This is useful when an analogue
|
||||||
|
recording chain introduces a constant voltage bias.
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Subtract the mean from the waveform.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
wav : torch.Tensor
|
||||||
|
Input waveform tensor of shape ``(samples,)`` or
|
||||||
|
``(channels, samples)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Zero-centred waveform with the same shape as the input.
|
||||||
|
"""
|
||||||
return center_tensor(wav)
|
return center_tensor(wav)
|
||||||
|
|
||||||
@audio_transforms.register(CenterAudioConfig)
|
@audio_transforms.register(CenterAudioConfig)
|
||||||
@ -35,11 +78,38 @@ class CenterAudio(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ScaleAudioConfig(BaseConfig):
|
class ScaleAudioConfig(BaseConfig):
|
||||||
|
"""Configuration for the peak-normalisation transform.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Fixed identifier; always ``"scale_audio"``.
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["scale_audio"] = "scale_audio"
|
name: Literal["scale_audio"] = "scale_audio"
|
||||||
|
|
||||||
|
|
||||||
class ScaleAudio(torch.nn.Module):
|
class ScaleAudio(torch.nn.Module):
|
||||||
|
"""Peak-normalise an audio waveform to the range ``[-1, 1]``.
|
||||||
|
|
||||||
|
Divides the waveform by its largest absolute sample value. If the
|
||||||
|
waveform is identically zero it is returned unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Peak-normalise the waveform.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
wav : torch.Tensor
|
||||||
|
Input waveform tensor of any shape.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Normalised waveform with the same shape as the input and
|
||||||
|
values in the range ``[-1, 1]``.
|
||||||
|
"""
|
||||||
return peak_normalize(wav)
|
return peak_normalize(wav)
|
||||||
|
|
||||||
@audio_transforms.register(ScaleAudioConfig)
|
@audio_transforms.register(ScaleAudioConfig)
|
||||||
@ -49,11 +119,36 @@ class ScaleAudio(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FixDurationConfig(BaseConfig):
|
class FixDurationConfig(BaseConfig):
|
||||||
|
"""Configuration for the fixed-duration transform.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Fixed identifier; always ``"fix_duration"``.
|
||||||
|
duration : float, default=0.5
|
||||||
|
Target duration in seconds. The waveform will be truncated or
|
||||||
|
zero-padded to match this length.
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["fix_duration"] = "fix_duration"
|
name: Literal["fix_duration"] = "fix_duration"
|
||||||
duration: float = 0.5
|
duration: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
class FixDuration(torch.nn.Module):
|
class FixDuration(torch.nn.Module):
|
||||||
|
"""Ensure a waveform has exactly a specified number of samples.
|
||||||
|
|
||||||
|
If the input is longer than the target length it is truncated from
|
||||||
|
the end. If it is shorter, it is zero-padded at the end.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
samplerate : int
|
||||||
|
Sample rate of the audio in Hz. Used with ``duration`` to
|
||||||
|
compute the target number of samples.
|
||||||
|
duration : float
|
||||||
|
Target duration in seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, samplerate: int, duration: float):
|
def __init__(self, samplerate: int, duration: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.samplerate = samplerate
|
self.samplerate = samplerate
|
||||||
@ -61,6 +156,20 @@ class FixDuration(torch.nn.Module):
|
|||||||
self.length = int(samplerate * duration)
|
self.length = int(samplerate * duration)
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Truncate or pad the waveform to the target length.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
wav : torch.Tensor
|
||||||
|
Input waveform tensor of shape ``(samples,)`` or
|
||||||
|
``(channels, samples)``. The last dimension is adjusted.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Waveform with exactly ``self.length`` samples along the last
|
||||||
|
dimension.
|
||||||
|
"""
|
||||||
length = wav.shape[-1]
|
length = wav.shape[-1]
|
||||||
|
|
||||||
if length == self.length:
|
if length == self.length:
|
||||||
@ -81,10 +190,34 @@ AudioTransform = Annotated[
|
|||||||
FixDurationConfig | ScaleAudioConfig | CenterAudioConfig,
|
FixDurationConfig | ScaleAudioConfig | CenterAudioConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
"""Discriminated union of all audio transform configuration types.
|
||||||
|
|
||||||
|
Use this type when a field should accept any of the supported audio
|
||||||
|
transforms. Pydantic will select the correct config class based on the
|
||||||
|
``name`` field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_audio_transform(
|
def build_audio_transform(
|
||||||
config: AudioTransform,
|
config: AudioTransform,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
|
"""Build an audio transform module from a configuration object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : AudioTransform
|
||||||
|
A configuration object for one of the supported audio transforms
|
||||||
|
(``CenterAudioConfig``, ``ScaleAudioConfig``, or
|
||||||
|
``FixDurationConfig``).
|
||||||
|
samplerate : int, default=256000
|
||||||
|
Sample rate of the audio in Hz. Passed to the transform builder;
|
||||||
|
some transforms (e.g. ``FixDuration``) use it to convert seconds
|
||||||
|
to samples.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.nn.Module
|
||||||
|
The constructed audio transform module.
|
||||||
|
"""
|
||||||
return audio_transforms.build(config, samplerate)
|
return audio_transforms.build(config, samplerate)
|
||||||
|
|||||||
@ -1,3 +1,10 @@
|
|||||||
|
"""Shared tensor primitives used across the preprocessing pipeline.
|
||||||
|
|
||||||
|
This module provides small, stateless helper functions that operate on
|
||||||
|
PyTorch tensors. They are used by both audio-level and spectrogram-level
|
||||||
|
transforms, and are kept here to avoid duplication.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -7,11 +14,42 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def center_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
def center_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Subtract the mean of a tensor from all of its values.
|
||||||
|
|
||||||
|
This centres the signal around zero, removing any constant DC offset.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tensor : torch.Tensor
|
||||||
|
Input tensor of any shape.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
A new tensor of the same shape and dtype with the global mean
|
||||||
|
subtracted from every element.
|
||||||
|
"""
|
||||||
return tensor - tensor.mean()
|
return tensor - tensor.mean()
|
||||||
|
|
||||||
|
|
||||||
def peak_normalize(tensor: torch.Tensor) -> torch.Tensor:
|
def peak_normalize(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
max_value = tensor.abs().min()
|
"""Scale a tensor so that its largest absolute value equals one.
|
||||||
|
|
||||||
|
Divides the tensor by its peak absolute value. If the tensor is
|
||||||
|
identically zero, it is returned unchanged (no division by zero).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tensor : torch.Tensor
|
||||||
|
Input tensor of any shape.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
A new tensor of the same shape and dtype with values in the range
|
||||||
|
``[-1, 1]`` (or exactly ``[0, 0]`` for a zero tensor).
|
||||||
|
"""
|
||||||
|
max_value = tensor.abs().max()
|
||||||
|
|
||||||
denominator = torch.where(
|
denominator = torch.where(
|
||||||
max_value == 0,
|
max_value == 0,
|
||||||
|
|||||||
@ -1,3 +1,10 @@
|
|||||||
|
"""Configuration for the full batdetect2 preprocessing pipeline.
|
||||||
|
|
||||||
|
This module defines :class:`PreprocessingConfig`, which aggregates all
|
||||||
|
configuration needed to convert a raw audio waveform into a normalised
|
||||||
|
spectrogram ready for the detection model.
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -9,7 +16,7 @@ from batdetect2.preprocess.spectrogram import (
|
|||||||
FrequencyConfig,
|
FrequencyConfig,
|
||||||
PcenConfig,
|
PcenConfig,
|
||||||
ResizeConfig,
|
ResizeConfig,
|
||||||
SpectralMeanSubstractionConfig,
|
SpectralMeanSubtractionConfig,
|
||||||
SpectrogramTransform,
|
SpectrogramTransform,
|
||||||
STFTConfig,
|
STFTConfig,
|
||||||
)
|
)
|
||||||
@ -24,19 +31,30 @@ __all__ = [
|
|||||||
class PreprocessingConfig(BaseConfig):
|
class PreprocessingConfig(BaseConfig):
|
||||||
"""Unified configuration for the audio preprocessing pipeline.
|
"""Unified configuration for the audio preprocessing pipeline.
|
||||||
|
|
||||||
Aggregates the configuration for both the initial audio processing stage
|
Aggregates the parameters for every stage of the pipeline:
|
||||||
and the subsequent spectrogram generation stage.
|
audio-level transforms, STFT computation, frequency cropping,
|
||||||
|
spectrogram-level transforms, and the final resize step.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
audio : AudioConfig
|
audio_transforms : list of AudioTransform, default=[]
|
||||||
Configuration settings for the audio loading and initial waveform
|
Ordered list of transforms applied to the raw audio waveform
|
||||||
processing steps (e.g., resampling, duration adjustment, scaling).
|
before the STFT is computed. Each entry is a configuration
|
||||||
Defaults to default `AudioConfig` settings if omitted.
|
object for one of the supported audio transforms
|
||||||
spectrogram : SpectrogramConfig
|
(``"center_audio"``, ``"scale_audio"``, or ``"fix_duration"``).
|
||||||
Configuration settings for the spectrogram generation process
|
spectrogram_transforms : list of SpectrogramTransform
|
||||||
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
Ordered list of transforms applied to the cropped spectrogram
|
||||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
after the STFT and frequency crop steps. Defaults to
|
||||||
|
``[PcenConfig(), SpectralMeanSubtractionConfig()]``, which
|
||||||
|
applies PCEN followed by spectral mean subtraction.
|
||||||
|
stft : STFTConfig
|
||||||
|
Parameters for the Short-Time Fourier Transform (window
|
||||||
|
duration, overlap, and window function).
|
||||||
|
frequencies : FrequencyConfig
|
||||||
|
Frequency range (in Hz) to retain after the STFT.
|
||||||
|
size : ResizeConfig
|
||||||
|
Target height (number of frequency bins) and time-axis scaling
|
||||||
|
factor for the final resize step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
||||||
@ -44,7 +62,7 @@ class PreprocessingConfig(BaseConfig):
|
|||||||
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
PcenConfig(),
|
PcenConfig(),
|
||||||
SpectralMeanSubstractionConfig(),
|
SpectralMeanSubtractionConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -59,4 +77,20 @@ def load_preprocessing_config(
|
|||||||
path: PathLike,
|
path: PathLike,
|
||||||
field: str | None = None,
|
field: str | None = None,
|
||||||
) -> PreprocessingConfig:
|
) -> PreprocessingConfig:
|
||||||
|
"""Load a ``PreprocessingConfig`` from a YAML file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the YAML configuration file.
|
||||||
|
field : str, optional
|
||||||
|
If provided, read the config from a nested field within the
|
||||||
|
YAML document (e.g. ``"preprocessing"`` to read from a top-level
|
||||||
|
``preprocessing:`` key).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
PreprocessingConfig
|
||||||
|
The deserialised preprocessing configuration.
|
||||||
|
"""
|
||||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||||
|
|||||||
@ -1,3 +1,25 @@
|
|||||||
|
"""Assembles the full batdetect2 preprocessing pipeline.
|
||||||
|
|
||||||
|
This module defines :class:`Preprocessor`, the concrete implementation of
|
||||||
|
:class:`~batdetect2.typing.PreprocessorProtocol`, and the
|
||||||
|
:func:`build_preprocessor` factory function that constructs it from a
|
||||||
|
:class:`~batdetect2.preprocess.config.PreprocessingConfig`.
|
||||||
|
|
||||||
|
The preprocessing pipeline converts a raw audio waveform (as a
|
||||||
|
``torch.Tensor``) into a normalised, cropped, and resized spectrogram ready
|
||||||
|
for the detection model. The stages are applied in this order:
|
||||||
|
|
||||||
|
1. **Audio transforms** — optional waveform-level operations such as DC
|
||||||
|
removal, peak normalisation, or duration fixing.
|
||||||
|
2. **STFT** — Short-Time Fourier Transform to produce an amplitude
|
||||||
|
spectrogram.
|
||||||
|
3. **Frequency crop** — retain only the frequency band of interest.
|
||||||
|
4. **Spectrogram transforms** — normalisation operations such as PCEN and
|
||||||
|
spectral mean subtraction.
|
||||||
|
5. **Resize** — scale the spectrogram to the model's expected height and
|
||||||
|
reduce the time resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@ -20,7 +42,32 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||||
"""Standard implementation of the `Preprocessor` protocol."""
|
"""Standard implementation of the :class:`~batdetect2.typing.PreprocessorProtocol`.
|
||||||
|
|
||||||
|
Wraps all preprocessing stages as ``torch.nn.Module`` submodules so
|
||||||
|
that parameters (e.g. PCEN filter coefficients) can be tracked and
|
||||||
|
moved between devices.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : PreprocessingConfig
|
||||||
|
Full pipeline configuration.
|
||||||
|
input_samplerate : int
|
||||||
|
Sample rate of the audio that will be passed to this preprocessor,
|
||||||
|
in Hz.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
input_samplerate : int
|
||||||
|
Sample rate of the input audio in Hz.
|
||||||
|
output_samplerate : float
|
||||||
|
Effective frame rate of the output spectrogram in frames per second.
|
||||||
|
Computed from the STFT hop length and the time-axis resize factor.
|
||||||
|
min_freq : float
|
||||||
|
Lower bound of the retained frequency band in Hz.
|
||||||
|
max_freq : float
|
||||||
|
Upper bound of the retained frequency band in Hz.
|
||||||
|
"""
|
||||||
|
|
||||||
input_samplerate: int
|
input_samplerate: int
|
||||||
output_samplerate: float
|
output_samplerate: float
|
||||||
@ -72,17 +119,75 @@ class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Run the full preprocessing pipeline on a waveform.
|
||||||
|
|
||||||
|
Applies audio transforms, then the STFT, then
|
||||||
|
:meth:`process_spectrogram`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
wav : torch.Tensor
|
||||||
|
Input waveform of shape ``(samples,)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Preprocessed spectrogram of shape
|
||||||
|
``(freq_bins, time_frames)``.
|
||||||
|
"""
|
||||||
wav = self.audio_transforms(wav)
|
wav = self.audio_transforms(wav)
|
||||||
spec = self.spectrogram_builder(wav)
|
spec = self.spectrogram_builder(wav)
|
||||||
return self.process_spectrogram(spec)
|
return self.process_spectrogram(spec)
|
||||||
|
|
||||||
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute the raw STFT spectrogram without any further processing.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
wav : torch.Tensor
|
||||||
|
Input waveform of shape ``(samples,)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Amplitude spectrogram of shape ``(n_fft//2 + 1, time_frames)``
|
||||||
|
with no frequency cropping, normalisation, or resizing applied.
|
||||||
|
"""
|
||||||
return self.spectrogram_builder(wav)
|
return self.spectrogram_builder(wav)
|
||||||
|
|
||||||
def process_audio(self, wav: torch.Tensor) -> torch.Tensor:
|
def process_audio(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Alias for :meth:`forward`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
wav : torch.Tensor
|
||||||
|
Input waveform of shape ``(samples,)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Preprocessed spectrogram (same as calling the object directly).
|
||||||
|
"""
|
||||||
return self(wav)
|
return self(wav)
|
||||||
|
|
||||||
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply the post-STFT processing stages to an existing spectrogram.
|
||||||
|
|
||||||
|
Applies frequency cropping, spectrogram-level transforms (e.g.
|
||||||
|
PCEN, spectral mean subtraction), and the final resize step.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Raw amplitude spectrogram of shape
|
||||||
|
``(..., n_fft//2 + 1, time_frames)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Normalised and resized spectrogram of shape
|
||||||
|
``(..., height, scaled_time_frames)``.
|
||||||
|
"""
|
||||||
spec = self.spectrogram_crop(spec)
|
spec = self.spectrogram_crop(spec)
|
||||||
spec = self.spectrogram_transforms(spec)
|
spec = self.spectrogram_transforms(spec)
|
||||||
return self.spectrogram_resizer(spec)
|
return self.spectrogram_resizer(spec)
|
||||||
@ -92,6 +197,25 @@ def compute_output_samplerate(
|
|||||||
config: PreprocessingConfig,
|
config: PreprocessingConfig,
|
||||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
) -> float:
|
) -> float:
|
||||||
|
"""Compute the effective frame rate of the preprocessor's output.
|
||||||
|
|
||||||
|
The output frame rate (in frames per second) depends on the STFT hop
|
||||||
|
length and the time-axis resize factor applied by the final resize step.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : PreprocessingConfig
|
||||||
|
Pipeline configuration.
|
||||||
|
input_samplerate : int, default=256000
|
||||||
|
Sample rate of the input audio in Hz.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
Output frame rate in frames per second.
|
||||||
|
For example, at the default settings (256 kHz, hop=128,
|
||||||
|
resize_factor=0.5) this equals ``1000.0``.
|
||||||
|
"""
|
||||||
_, hop_size = _spec_params_from_config(
|
_, hop_size = _spec_params_from_config(
|
||||||
config.stft, samplerate=input_samplerate
|
config.stft, samplerate=input_samplerate
|
||||||
)
|
)
|
||||||
@ -103,7 +227,24 @@ def build_preprocessor(
|
|||||||
config: PreprocessingConfig | None = None,
|
config: PreprocessingConfig | None = None,
|
||||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
) -> PreprocessorProtocol:
|
) -> PreprocessorProtocol:
|
||||||
"""Factory function to build the standard preprocessor from configuration."""
|
"""Build the standard preprocessor from a configuration object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : PreprocessingConfig, optional
|
||||||
|
Pipeline configuration. If ``None``, the default
|
||||||
|
``PreprocessingConfig()`` is used (PCEN + spectral mean
|
||||||
|
subtraction, 256 kHz, standard STFT parameters).
|
||||||
|
input_samplerate : int, default=256000
|
||||||
|
Sample rate of the audio that will be fed to the preprocessor,
|
||||||
|
in Hz.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
PreprocessorProtocol
|
||||||
|
A :class:`Preprocessor` instance ready to convert waveforms to
|
||||||
|
spectrograms.
|
||||||
|
"""
|
||||||
config = config or PreprocessingConfig()
|
config = config or PreprocessingConfig()
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building preprocessor with config: \n{}",
|
"Building preprocessor with config: \n{}",
|
||||||
|
|||||||
@ -1,4 +1,14 @@
|
|||||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
"""Computes spectrograms from audio waveforms with configurable parameters.
|
||||||
|
|
||||||
|
This module defines the STFT-based spectrogram builder and a collection of
|
||||||
|
spectrogram-level transforms (PCEN, spectral mean subtraction, amplitude
|
||||||
|
scaling, peak normalisation, frequency cropping, and resizing) that form the
|
||||||
|
signal-processing stage of the batdetect2 preprocessing pipeline.
|
||||||
|
|
||||||
|
Each transform is paired with a Pydantic configuration class and registered
|
||||||
|
in the ``spectrogram_transforms`` registry so that the pipeline can be fully
|
||||||
|
specified via a YAML or Python configuration object.
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Callable, Literal
|
from typing import Annotated, Callable, Literal
|
||||||
|
|
||||||
@ -32,17 +42,22 @@ class STFTConfig(BaseConfig):
|
|||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
window_duration : float, default=0.002
|
window_duration : float, default=0.002
|
||||||
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
|
Duration of the STFT analysis window in seconds (e.g. 0.002 for
|
||||||
> 0. Determines frequency resolution (longer window = finer frequency
|
2 ms). Must be > 0. A longer window gives finer frequency resolution
|
||||||
resolution).
|
but coarser time resolution.
|
||||||
window_overlap : float, default=0.75
|
window_overlap : float, default=0.75
|
||||||
Fraction of overlap between consecutive STFT windows (e.g., 0.75
|
Fraction of overlap between consecutive windows (e.g. 0.75 for
|
||||||
for 75%). Must be >= 0 and < 1. Determines time resolution
|
75 %). Must be >= 0 and < 1. Higher overlap gives finer time
|
||||||
(higher overlap = finer time resolution).
|
resolution at the cost of more computation.
|
||||||
window_fn : str, default="hann"
|
window_fn : str, default="hann"
|
||||||
Name of the window function to apply before FFT calculation. Common
|
Name of the tapering window applied to each frame before the FFT.
|
||||||
options include "hann", "hamming", "blackman". See
|
Supported values: ``"hann"``, ``"hamming"``, ``"kaiser"``,
|
||||||
`scipy.signal.get_window`.
|
``"blackman"``, ``"bartlett"``.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
At the default sample rate of 256 kHz, ``window_duration=0.002`` and
|
||||||
|
``window_overlap=0.75`` give ``n_fft=512`` and ``hop_length=128``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
window_duration: float = Field(default=0.002, gt=0)
|
window_duration: float = Field(default=0.002, gt=0)
|
||||||
@ -54,6 +69,23 @@ def build_spectrogram_builder(
|
|||||||
config: STFTConfig,
|
config: STFTConfig,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
|
"""Build a torchaudio STFT spectrogram module from an ``STFTConfig``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : STFTConfig
|
||||||
|
STFT parameters (window duration, overlap, and window function).
|
||||||
|
samplerate : int, default=256000
|
||||||
|
Sample rate of the input audio in Hz. Used to convert the
|
||||||
|
window duration into a number of samples.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.nn.Module
|
||||||
|
A ``torchaudio.transforms.Spectrogram`` module configured to
|
||||||
|
produce an amplitude (``power=1``) spectrogram with centred
|
||||||
|
frames.
|
||||||
|
"""
|
||||||
n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate)
|
n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate)
|
||||||
return torchaudio.transforms.Spectrogram(
|
return torchaudio.transforms.Spectrogram(
|
||||||
n_fft=n_fft,
|
n_fft=n_fft,
|
||||||
@ -65,6 +97,25 @@ def build_spectrogram_builder(
|
|||||||
|
|
||||||
|
|
||||||
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||||
|
"""Return the PyTorch window function matching the given name.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Name of the window function. One of ``"hann"``, ``"hamming"``,
|
||||||
|
``"kaiser"``, ``"blackman"``, or ``"bartlett"``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Callable[..., torch.Tensor]
|
||||||
|
A PyTorch window function that accepts a window length and returns
|
||||||
|
a 1-D tensor of weights.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
NotImplementedError
|
||||||
|
If ``name`` does not match any supported window function.
|
||||||
|
"""
|
||||||
if name == "hann":
|
if name == "hann":
|
||||||
return torch.hann_window
|
return torch.hann_window
|
||||||
|
|
||||||
@ -88,7 +139,22 @@ def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
|||||||
def _spec_params_from_config(
|
def _spec_params_from_config(
|
||||||
config: STFTConfig,
|
config: STFTConfig,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
):
|
) -> tuple[int, int]:
|
||||||
|
"""Compute ``n_fft`` and ``hop_length`` from an ``STFTConfig``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : STFTConfig
|
||||||
|
STFT parameters.
|
||||||
|
samplerate : int, default=256000
|
||||||
|
Sample rate of the input audio in Hz.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[int, int]
|
||||||
|
A pair ``(n_fft, hop_length)`` giving the FFT size and the step
|
||||||
|
between consecutive frames in samples.
|
||||||
|
"""
|
||||||
n_fft = int(samplerate * config.window_duration)
|
n_fft = int(samplerate * config.window_duration)
|
||||||
hop_length = int(n_fft * (1 - config.window_overlap))
|
hop_length = int(n_fft * (1 - config.window_overlap))
|
||||||
return n_fft, hop_length
|
return n_fft, hop_length
|
||||||
@ -99,6 +165,24 @@ def _frequency_to_index(
|
|||||||
n_fft: int,
|
n_fft: int,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
|
"""Convert a frequency in Hz to the nearest STFT frequency bin index.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
freq : float
|
||||||
|
Frequency in Hz to convert.
|
||||||
|
n_fft : int
|
||||||
|
FFT size used by the STFT.
|
||||||
|
samplerate : int, default=256000
|
||||||
|
Sample rate of the audio in Hz.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int or None
|
||||||
|
The bin index corresponding to ``freq``, or ``None`` if the
|
||||||
|
frequency is outside the valid range (i.e. <= 0 Hz or >= the
|
||||||
|
Nyquist frequency).
|
||||||
|
"""
|
||||||
alpha = freq * 2 / samplerate
|
alpha = freq * 2 / samplerate
|
||||||
height = np.floor(n_fft / 2) + 1
|
height = np.floor(n_fft / 2) + 1
|
||||||
index = int(np.floor(alpha * height))
|
index = int(np.floor(alpha * height))
|
||||||
@ -118,11 +202,11 @@ class FrequencyConfig(BaseConfig):
|
|||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
max_freq : int, default=120000
|
max_freq : int, default=120000
|
||||||
Maximum frequency in Hz to retain in the spectrogram after STFT.
|
Maximum frequency in Hz to retain after STFT. Frequency bins
|
||||||
Frequencies above this value will be cropped. Must be > 0.
|
above this value are discarded. Must be >= 0.
|
||||||
min_freq : int, default=10000
|
min_freq : int, default=10000
|
||||||
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
Minimum frequency in Hz to retain after STFT. Frequency bins
|
||||||
Frequencies below this value will be cropped. Must be >= 0.
|
below this value are discarded. Must be >= 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_freq: int = Field(default=MAX_FREQ, ge=0)
|
max_freq: int = Field(default=MAX_FREQ, ge=0)
|
||||||
@ -130,6 +214,27 @@ class FrequencyConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class FrequencyCrop(torch.nn.Module):
|
class FrequencyCrop(torch.nn.Module):
|
||||||
|
"""Crop a spectrogram to a specified frequency band.
|
||||||
|
|
||||||
|
On construction the Hz boundaries are converted to STFT bin indices.
|
||||||
|
During the forward pass the spectrogram is sliced along its
|
||||||
|
frequency axis (second-to-last dimension) to retain only the bins
|
||||||
|
that fall within ``[min_freq, max_freq)``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
samplerate : int
|
||||||
|
Sample rate of the audio in Hz.
|
||||||
|
n_fft : int
|
||||||
|
FFT size used by the STFT.
|
||||||
|
min_freq : int, optional
|
||||||
|
Lower frequency bound in Hz. If ``None``, no lower crop is
|
||||||
|
applied and the DC bin is retained.
|
||||||
|
max_freq : int, optional
|
||||||
|
Upper frequency bound in Hz. If ``None``, no upper crop is
|
||||||
|
applied and all bins up to Nyquist are retained.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
@ -162,6 +267,19 @@ class FrequencyCrop(torch.nn.Module):
|
|||||||
self.high_index = high_index
|
self.high_index = high_index
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Crop the spectrogram to the configured frequency band.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Spectrogram tensor of shape ``(..., freq_bins, time_frames)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Cropped spectrogram with shape
|
||||||
|
``(..., n_retained_bins, time_frames)``.
|
||||||
|
"""
|
||||||
low_index = self.low_index
|
low_index = self.low_index
|
||||||
if low_index is None:
|
if low_index is None:
|
||||||
low_index = 0
|
low_index = 0
|
||||||
@ -184,6 +302,24 @@ def build_spectrogram_crop(
|
|||||||
stft: STFTConfig | None = None,
|
stft: STFTConfig | None = None,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
|
"""Build a ``FrequencyCrop`` module from configuration objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : FrequencyConfig
|
||||||
|
Frequency boundary configuration specifying ``min_freq`` and
|
||||||
|
``max_freq`` in Hz.
|
||||||
|
stft : STFTConfig, optional
|
||||||
|
STFT configuration used to derive ``n_fft``. Defaults to
|
||||||
|
``STFTConfig()`` if not provided.
|
||||||
|
samplerate : int, default=256000
|
||||||
|
Sample rate of the audio in Hz.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.nn.Module
|
||||||
|
A ``FrequencyCrop`` module ready to crop spectrograms.
|
||||||
|
"""
|
||||||
stft = stft or STFTConfig()
|
stft = stft or STFTConfig()
|
||||||
n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate)
|
n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate)
|
||||||
return FrequencyCrop(
|
return FrequencyCrop(
|
||||||
@ -195,18 +331,61 @@ def build_spectrogram_crop(
|
|||||||
|
|
||||||
|
|
||||||
class ResizeConfig(BaseConfig):
|
class ResizeConfig(BaseConfig):
|
||||||
|
"""Configuration for the final spectrogram resize step.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Fixed identifier; always ``"resize_spec"``.
|
||||||
|
height : int, default=128
|
||||||
|
Target number of frequency bins in the output spectrogram.
|
||||||
|
The spectrogram is resized to this height using bilinear
|
||||||
|
interpolation.
|
||||||
|
resize_factor : float, default=0.5
|
||||||
|
Fraction by which the time axis is scaled. For example, ``0.5``
|
||||||
|
halves the number of time frames, reducing computational cost
|
||||||
|
downstream.
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["resize_spec"] = "resize_spec"
|
name: Literal["resize_spec"] = "resize_spec"
|
||||||
height: int = 128
|
height: int = 128
|
||||||
resize_factor: float = 0.5
|
resize_factor: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
class ResizeSpec(torch.nn.Module):
|
class ResizeSpec(torch.nn.Module):
|
||||||
|
"""Resize a spectrogram to a fixed height and scaled width.
|
||||||
|
|
||||||
|
Uses bilinear interpolation so it handles arbitrary input shapes
|
||||||
|
gracefully. Input tensors with fewer than four dimensions are
|
||||||
|
temporarily unsqueezed to satisfy ``torch.nn.functional.interpolate``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
height : int
|
||||||
|
Target number of frequency bins (output height).
|
||||||
|
time_factor : float
|
||||||
|
Multiplicative scaling applied to the time axis length.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, height: int, time_factor: float):
|
def __init__(self, height: int, time_factor: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.height = height
|
self.height = height
|
||||||
self.time_factor = time_factor
|
self.time_factor = time_factor
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Resize the spectrogram to the configured output dimensions.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input spectrogram of shape ``(..., freq_bins, time_frames)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Resized spectrogram with shape
|
||||||
|
``(..., height, int(time_factor * time_frames))``.
|
||||||
|
"""
|
||||||
current_length = spec.shape[-1]
|
current_length = spec.shape[-1]
|
||||||
target_length = int(self.time_factor * current_length)
|
target_length = int(self.time_factor * current_length)
|
||||||
|
|
||||||
@ -227,6 +406,18 @@ class ResizeSpec(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module:
|
def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module:
|
||||||
|
"""Build a ``ResizeSpec`` module from a ``ResizeConfig``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : ResizeConfig
|
||||||
|
Resize configuration specifying ``height`` and ``resize_factor``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.nn.Module
|
||||||
|
A ``ResizeSpec`` module configured with the given parameters.
|
||||||
|
"""
|
||||||
return ResizeSpec(height=config.height, time_factor=config.resize_factor)
|
return ResizeSpec(height=config.height, time_factor=config.resize_factor)
|
||||||
|
|
||||||
|
|
||||||
@ -236,7 +427,32 @@ spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
|||||||
|
|
||||||
|
|
||||||
class PcenConfig(BaseConfig):
|
class PcenConfig(BaseConfig):
|
||||||
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
"""Configuration for Per-Channel Energy Normalisation (PCEN).
|
||||||
|
|
||||||
|
PCEN is a frontend processing technique that replaces simple log
|
||||||
|
compression. It applies a learnable automatic gain control followed
|
||||||
|
by a stabilised root compression, making the representation more
|
||||||
|
robust to variations in recording level.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Fixed identifier; always ``"pcen"``.
|
||||||
|
time_constant : float, default=0.4
|
||||||
|
Time constant (in seconds) of the IIR smoothing filter used
|
||||||
|
for the background estimate. Larger values produce a slower-
|
||||||
|
adapting background.
|
||||||
|
gain : float, default=0.98
|
||||||
|
Exponent controlling how strongly the background estimate
|
||||||
|
suppresses the signal.
|
||||||
|
bias : float, default=2
|
||||||
|
Stabilisation bias added inside the root-compression step to
|
||||||
|
avoid division by zero.
|
||||||
|
power : float, default=0.5
|
||||||
|
Root-compression exponent. A value of 0.5 gives square-root
|
||||||
|
compression, similar to log compression but differentiable at
|
||||||
|
zero.
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["pcen"] = "pcen"
|
name: Literal["pcen"] = "pcen"
|
||||||
time_constant: float = 0.4
|
time_constant: float = 0.4
|
||||||
@ -246,6 +462,35 @@ class PcenConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class PCEN(torch.nn.Module):
|
class PCEN(torch.nn.Module):
|
||||||
|
"""Per-Channel Energy Normalisation (PCEN) transform.
|
||||||
|
|
||||||
|
Applies automatic gain control and root compression to a spectrogram.
|
||||||
|
The background estimate is computed with a first-order IIR filter
|
||||||
|
applied along the time axis.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
smoothing_constant : float
|
||||||
|
IIR filter coefficient ``alpha``. Derived from the time constant
|
||||||
|
and sample rate via ``_compute_smoothing_constant``.
|
||||||
|
gain : float, default=0.98
|
||||||
|
AGC gain exponent.
|
||||||
|
bias : float, default=2.0
|
||||||
|
Root-compression stabilisation bias.
|
||||||
|
power : float, default=0.5
|
||||||
|
Root-compression exponent.
|
||||||
|
eps : float, default=1e-6
|
||||||
|
Small constant for numerical stability.
|
||||||
|
dtype : torch.dtype, default=torch.float32
|
||||||
|
Floating-point precision used for internal computation.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
The smoothing constant is computed to match the original batdetect2
|
||||||
|
implementation for numerical compatibility. See
|
||||||
|
``_compute_smoothing_constant`` for details.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
smoothing_constant: float,
|
smoothing_constant: float,
|
||||||
@ -269,6 +514,20 @@ class PCEN(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply PCEN to a spectrogram.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input amplitude spectrogram of shape
|
||||||
|
``(..., freq_bins, time_frames)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
PCEN-normalised spectrogram with the same shape and dtype as
|
||||||
|
the input.
|
||||||
|
"""
|
||||||
S = spec.to(self.dtype) * 2**31
|
S = spec.to(self.dtype) * 2**31
|
||||||
|
|
||||||
M = (
|
M = (
|
||||||
@ -305,21 +564,73 @@ def _compute_smoothing_constant(
|
|||||||
samplerate: int,
|
samplerate: int,
|
||||||
time_constant: float,
|
time_constant: float,
|
||||||
) -> float:
|
) -> float:
|
||||||
# NOTE: These were taken to match the original implementation
|
"""Compute the IIR smoothing coefficient for PCEN.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
samplerate : int
|
||||||
|
Sample rate of the audio in Hz.
|
||||||
|
time_constant : float
|
||||||
|
Desired smoothing time constant in seconds.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
IIR filter coefficient ``alpha`` used by ``PCEN``.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
The hop length (512) and the sample-rate divisor (10) are fixed to
|
||||||
|
reproduce the numerical behaviour of the original batdetect2
|
||||||
|
implementation, which used ``librosa.pcen`` with ``sr=samplerate/10``
|
||||||
|
and the default ``hop_length=512``. These values do not reflect the
|
||||||
|
actual STFT hop length used in the pipeline; they are retained
|
||||||
|
solely for backward compatibility.
|
||||||
|
"""
|
||||||
|
# NOTE: These parameters are fixed to match the original implementation.
|
||||||
hop_length = 512
|
hop_length = 512
|
||||||
sr = samplerate / 10
|
sr = samplerate / 10
|
||||||
time_constant = time_constant
|
|
||||||
t_frames = time_constant * sr / float(hop_length)
|
t_frames = time_constant * sr / float(hop_length)
|
||||||
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||||
|
|
||||||
|
|
||||||
class ScaleAmplitudeConfig(BaseConfig):
|
class ScaleAmplitudeConfig(BaseConfig):
|
||||||
|
"""Configuration for amplitude scaling of a spectrogram.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Fixed identifier; always ``"scale_amplitude"``.
|
||||||
|
scale : str, default="db"
|
||||||
|
Scaling mode. Either ``"db"`` (convert amplitude to decibels
|
||||||
|
using ``torchaudio.transforms.AmplitudeToDB``) or ``"power"``
|
||||||
|
(square the amplitude values).
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["scale_amplitude"] = "scale_amplitude"
|
name: Literal["scale_amplitude"] = "scale_amplitude"
|
||||||
scale: Literal["power", "db"] = "db"
|
scale: Literal["power", "db"] = "db"
|
||||||
|
|
||||||
|
|
||||||
class ToPower(torch.nn.Module):
|
class ToPower(torch.nn.Module):
|
||||||
|
"""Square the values of a spectrogram (amplitude → power).
|
||||||
|
|
||||||
|
Raises each element to the power of two, converting an amplitude
|
||||||
|
spectrogram into a power spectrogram.
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Square all elements of the spectrogram.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input amplitude spectrogram.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Power spectrogram (same shape as input).
|
||||||
|
"""
|
||||||
return spec**2
|
return spec**2
|
||||||
|
|
||||||
|
|
||||||
@ -330,12 +641,34 @@ _scalers = {
|
|||||||
|
|
||||||
|
|
||||||
class ScaleAmplitude(torch.nn.Module):
|
class ScaleAmplitude(torch.nn.Module):
|
||||||
|
"""Convert spectrogram amplitude values to a different scale.
|
||||||
|
|
||||||
|
Supports conversion to decibels (dB) or to power (squared amplitude).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
scale : str
|
||||||
|
Either ``"db"`` or ``"power"``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, scale: Literal["power", "db"]):
|
def __init__(self, scale: Literal["power", "db"]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.scaler = _scalers[scale]()
|
self.scaler = _scalers[scale]()
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply the configured amplitude scaling.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input spectrogram tensor.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Scaled spectrogram with the same shape as the input.
|
||||||
|
"""
|
||||||
return self.scaler(spec)
|
return self.scaler(spec)
|
||||||
|
|
||||||
@spectrogram_transforms.register(ScaleAmplitudeConfig)
|
@spectrogram_transforms.register(ScaleAmplitudeConfig)
|
||||||
@ -344,30 +677,86 @@ class ScaleAmplitude(torch.nn.Module):
|
|||||||
return ScaleAmplitude(scale=config.scale)
|
return ScaleAmplitude(scale=config.scale)
|
||||||
|
|
||||||
|
|
||||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
class SpectralMeanSubtractionConfig(BaseConfig):
|
||||||
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
|
"""Configuration for spectral mean subtraction.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Fixed identifier; always ``"spectral_mean_subtraction"``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["spectral_mean_subtraction"] = "spectral_mean_subtraction"
|
||||||
|
|
||||||
|
|
||||||
class SpectralMeanSubstraction(torch.nn.Module):
|
class SpectralMeanSubtraction(torch.nn.Module):
|
||||||
|
"""Remove the time-averaged background noise from a spectrogram.
|
||||||
|
|
||||||
|
For each frequency bin, the mean value across all time frames is
|
||||||
|
computed and subtracted. The result is then clamped to zero so that
|
||||||
|
no values fall below the baseline. This is a simple form of spectral
|
||||||
|
denoising that suppresses stationary background noise.
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Subtract the time-axis mean from each frequency bin.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input spectrogram of shape ``(..., freq_bins, time_frames)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Denoised spectrogram with the same shape as the input. All
|
||||||
|
values are non-negative (clamped to 0).
|
||||||
|
"""
|
||||||
mean = spec.mean(-1, keepdim=True)
|
mean = spec.mean(-1, keepdim=True)
|
||||||
return (spec - mean).clamp(min=0)
|
return (spec - mean).clamp(min=0)
|
||||||
|
|
||||||
@spectrogram_transforms.register(SpectralMeanSubstractionConfig)
|
@spectrogram_transforms.register(SpectralMeanSubtractionConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
config: SpectralMeanSubstractionConfig,
|
config: SpectralMeanSubtractionConfig,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
):
|
):
|
||||||
return SpectralMeanSubstraction()
|
return SpectralMeanSubtraction()
|
||||||
|
|
||||||
|
|
||||||
class PeakNormalizeConfig(BaseConfig):
|
class PeakNormalizeConfig(BaseConfig):
|
||||||
|
"""Configuration for peak normalisation of a spectrogram.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Fixed identifier; always ``"peak_normalize"``.
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["peak_normalize"] = "peak_normalize"
|
name: Literal["peak_normalize"] = "peak_normalize"
|
||||||
|
|
||||||
|
|
||||||
class PeakNormalize(torch.nn.Module):
|
class PeakNormalize(torch.nn.Module):
|
||||||
|
"""Scale a spectrogram so that its largest absolute value equals one.
|
||||||
|
|
||||||
|
Wraps :func:`batdetect2.preprocess.common.peak_normalize` as a
|
||||||
|
``torch.nn.Module`` for use inside a sequential transform pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Peak-normalise the spectrogram.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input spectrogram tensor of any shape.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Normalised spectrogram where the maximum absolute value is 1.
|
||||||
|
If the input is identically zero, it is returned unchanged.
|
||||||
|
"""
|
||||||
return peak_normalize(spec)
|
return peak_normalize(spec)
|
||||||
|
|
||||||
@spectrogram_transforms.register(PeakNormalizeConfig)
|
@spectrogram_transforms.register(PeakNormalizeConfig)
|
||||||
@ -379,14 +768,37 @@ class PeakNormalize(torch.nn.Module):
|
|||||||
SpectrogramTransform = Annotated[
|
SpectrogramTransform = Annotated[
|
||||||
PcenConfig
|
PcenConfig
|
||||||
| ScaleAmplitudeConfig
|
| ScaleAmplitudeConfig
|
||||||
| SpectralMeanSubstractionConfig
|
| SpectralMeanSubtractionConfig
|
||||||
| PeakNormalizeConfig,
|
| PeakNormalizeConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
"""Discriminated union of all spectrogram transform configuration types.
|
||||||
|
|
||||||
|
Use this type when a field should accept any of the supported spectrogram
|
||||||
|
transforms. Pydantic will select the correct config class based on the
|
||||||
|
``name`` field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_spectrogram_transform(
|
def build_spectrogram_transform(
|
||||||
config: SpectrogramTransform,
|
config: SpectrogramTransform,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
|
"""Build a spectrogram transform module from a configuration object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : SpectrogramTransform
|
||||||
|
A configuration object for one of the supported spectrogram
|
||||||
|
transforms (PCEN, amplitude scaling, spectral mean subtraction,
|
||||||
|
or peak normalisation).
|
||||||
|
samplerate : int
|
||||||
|
Sample rate of the audio in Hz. Some transforms (e.g. PCEN) use
|
||||||
|
this to set internal parameters.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.nn.Module
|
||||||
|
The constructed transform module.
|
||||||
|
"""
|
||||||
return spectrogram_transforms.build(config, samplerate)
|
return spectrogram_transforms.build(config, samplerate)
|
||||||
|
|||||||
@ -1,12 +1,31 @@
|
|||||||
|
"""Tests for audio-level preprocessing transforms.
|
||||||
|
|
||||||
|
Covers :mod:`batdetect2.preprocess.audio` and the shared helper functions
|
||||||
|
in :mod:`batdetect2.preprocess.common`.
|
||||||
|
"""
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
|
from batdetect2.preprocess.audio import (
|
||||||
|
CenterAudio,
|
||||||
|
CenterAudioConfig,
|
||||||
|
FixDuration,
|
||||||
|
FixDurationConfig,
|
||||||
|
ScaleAudio,
|
||||||
|
ScaleAudioConfig,
|
||||||
|
build_audio_transform,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess.common import center_tensor, peak_normalize
|
||||||
|
|
||||||
|
SAMPLERATE = 256_000
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_wave(
|
def create_dummy_wave(
|
||||||
@ -15,9 +34,9 @@ def create_dummy_wave(
|
|||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
freq: float = 440.0,
|
freq: float = 440.0,
|
||||||
amplitude: float = 0.5,
|
amplitude: float = 0.5,
|
||||||
dtype: np.dtype = np.float32,
|
dtype: type = np.float32,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Generates a simple numpy waveform."""
|
"""Generate a simple sine-wave waveform as a NumPy array."""
|
||||||
t = np.linspace(
|
t = np.linspace(
|
||||||
0.0, duration, int(samplerate * duration), endpoint=False, dtype=dtype
|
0.0, duration, int(samplerate * duration), endpoint=False, dtype=dtype
|
||||||
)
|
)
|
||||||
@ -29,7 +48,7 @@ def create_dummy_wave(
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
|
def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||||
"""Creates a dummy WAV file and returns its path."""
|
"""Create a dummy 2-channel WAV file and return its path."""
|
||||||
samplerate = 48000
|
samplerate = 48000
|
||||||
duration = 2.0
|
duration = 2.0
|
||||||
num_channels = 2
|
num_channels = 2
|
||||||
@ -41,13 +60,13 @@ def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dummy_recording(dummy_wav_path: pathlib.Path) -> data.Recording:
|
def dummy_recording(dummy_wav_path: pathlib.Path) -> data.Recording:
|
||||||
"""Creates a Recording object pointing to the dummy WAV file."""
|
"""Create a Recording object pointing to the dummy WAV file."""
|
||||||
return data.Recording.from_file(dummy_wav_path)
|
return data.Recording.from_file(dummy_wav_path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
|
def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
|
||||||
"""Creates a Clip object from the dummy recording."""
|
"""Create a Clip object from the dummy recording."""
|
||||||
return data.Clip(
|
return data.Clip(
|
||||||
recording=dummy_recording,
|
recording=dummy_recording,
|
||||||
start_time=0.5,
|
start_time=0.5,
|
||||||
@ -58,3 +77,165 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_audio_config() -> AudioConfig:
|
def default_audio_config() -> AudioConfig:
|
||||||
return AudioConfig()
|
return AudioConfig()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# center_tensor
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_center_tensor_zero_mean():
|
||||||
|
"""Output tensor should have a mean very close to zero."""
|
||||||
|
wav = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||||
|
result = center_tensor(wav)
|
||||||
|
assert result.mean().abs().item() < 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
def test_center_tensor_preserves_shape():
|
||||||
|
wav = torch.randn(3, 1000)
|
||||||
|
result = center_tensor(wav)
|
||||||
|
assert result.shape == wav.shape
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# peak_normalize
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_peak_normalize_max_is_one():
|
||||||
|
"""After peak normalisation, the maximum absolute value should be 1."""
|
||||||
|
wav = torch.tensor([0.1, -0.4, 0.2, 0.8, -0.3])
|
||||||
|
result = peak_normalize(wav)
|
||||||
|
assert abs(result.abs().max().item() - 1.0) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_peak_normalize_zero_tensor_unchanged():
|
||||||
|
"""A zero tensor should be returned unchanged (no division by zero)."""
|
||||||
|
wav = torch.zeros(100)
|
||||||
|
result = peak_normalize(wav)
|
||||||
|
assert (result == 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_peak_normalize_preserves_sign():
|
||||||
|
"""Signs of all elements should be preserved after normalisation."""
|
||||||
|
wav = torch.tensor([-2.0, 1.0, -0.5])
|
||||||
|
result = peak_normalize(wav)
|
||||||
|
assert (result < 0).sum() == 2
|
||||||
|
assert result[0].item() < 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_peak_normalize_preserves_shape():
|
||||||
|
wav = torch.randn(2, 512)
|
||||||
|
result = peak_normalize(wav)
|
||||||
|
assert result.shape == wav.shape
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CenterAudio
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_center_audio_forward_zero_mean():
|
||||||
|
module = CenterAudio()
|
||||||
|
wav = torch.tensor([1.0, 3.0, 5.0])
|
||||||
|
result = module(wav)
|
||||||
|
assert result.mean().abs().item() < 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
def test_center_audio_from_config():
|
||||||
|
config = CenterAudioConfig()
|
||||||
|
module = CenterAudio.from_config(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, CenterAudio)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ScaleAudio
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_audio_peak_normalises_to_one():
|
||||||
|
"""ScaleAudio.forward should scale the peak absolute value to 1."""
|
||||||
|
module = ScaleAudio()
|
||||||
|
wav = torch.tensor([0.0, 0.25, -0.5, 0.1])
|
||||||
|
result = module(wav)
|
||||||
|
assert abs(result.abs().max().item() - 1.0) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_audio_handles_zero_tensor():
|
||||||
|
"""ScaleAudio should not raise on a zero tensor."""
|
||||||
|
module = ScaleAudio()
|
||||||
|
wav = torch.zeros(100)
|
||||||
|
result = module(wav)
|
||||||
|
assert (result == 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_audio_from_config():
|
||||||
|
config = ScaleAudioConfig()
|
||||||
|
module = ScaleAudio.from_config(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, ScaleAudio)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# FixDuration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_fix_duration_truncates_long_input():
|
||||||
|
"""Waveform longer than target should be truncated to the target length."""
|
||||||
|
target_samples = int(SAMPLERATE * 0.5)
|
||||||
|
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
|
||||||
|
wav = torch.randn(target_samples + 1000)
|
||||||
|
result = module(wav)
|
||||||
|
assert result.shape[-1] == target_samples
|
||||||
|
|
||||||
|
|
||||||
|
def test_fix_duration_pads_short_input():
|
||||||
|
"""Waveform shorter than target should be zero-padded to the target length."""
|
||||||
|
target_samples = int(SAMPLERATE * 0.5)
|
||||||
|
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
|
||||||
|
short_wav = torch.randn(target_samples - 100)
|
||||||
|
result = module(short_wav)
|
||||||
|
assert result.shape[-1] == target_samples
|
||||||
|
# Padded region should be zero
|
||||||
|
assert (result[target_samples - 100 :] == 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_fix_duration_passthrough_exact_length():
|
||||||
|
"""Waveform with exactly the right length should be returned unchanged."""
|
||||||
|
target_samples = int(SAMPLERATE * 0.5)
|
||||||
|
module = FixDuration(samplerate=SAMPLERATE, duration=0.5)
|
||||||
|
wav = torch.randn(target_samples)
|
||||||
|
result = module(wav)
|
||||||
|
assert result.shape[-1] == target_samples
|
||||||
|
assert torch.equal(result, wav)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fix_duration_from_config():
|
||||||
|
"""FixDurationConfig should produce a FixDuration with the correct length."""
|
||||||
|
config = FixDurationConfig(duration=0.256)
|
||||||
|
module = FixDuration.from_config(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, FixDuration)
|
||||||
|
assert module.length == int(SAMPLERATE * 0.256)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_audio_transform dispatch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_audio_transform_center_audio():
|
||||||
|
config = CenterAudioConfig()
|
||||||
|
module = build_audio_transform(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, CenterAudio)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_audio_transform_scale_audio():
|
||||||
|
config = ScaleAudioConfig()
|
||||||
|
module = build_audio_transform(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, ScaleAudio)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_audio_transform_fix_duration():
|
||||||
|
config = FixDurationConfig(duration=0.5)
|
||||||
|
module = build_audio_transform(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, FixDuration)
|
||||||
|
|||||||
243
tests/test_preprocessing/test_preprocessor.py
Normal file
243
tests/test_preprocessing/test_preprocessor.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
"""Integration and unit tests for the Preprocessor pipeline.
|
||||||
|
|
||||||
|
Covers :mod:`batdetect2.preprocess.preprocessor` — construction,
|
||||||
|
pipeline output shape/dtype, the ``process_numpy`` helper, attribute
|
||||||
|
values, output frame rate, and a round-trip YAML → config → build test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.preprocess.audio import FixDurationConfig
|
||||||
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
|
from batdetect2.preprocess.preprocessor import (
|
||||||
|
Preprocessor,
|
||||||
|
build_preprocessor,
|
||||||
|
compute_output_samplerate,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
FrequencyConfig,
|
||||||
|
PcenConfig,
|
||||||
|
ResizeConfig,
|
||||||
|
SpectralMeanSubtractionConfig,
|
||||||
|
STFTConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
SAMPLERATE = 256_000
|
||||||
|
# 0.256 s at 256 kHz = 65536 samples — a convenient power-of-two-sized clip
|
||||||
|
CLIP_SAMPLES = int(SAMPLERATE * 0.256)
|
||||||
|
|
||||||
|
|
||||||
|
def make_sine_wav(
|
||||||
|
samplerate: int = SAMPLERATE,
|
||||||
|
duration: float = 0.256,
|
||||||
|
freq: float = 40_000.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Return a single-channel sine-wave tensor."""
|
||||||
|
t = torch.linspace(0.0, duration, int(samplerate * duration))
|
||||||
|
return torch.sin(2 * torch.pi * freq * t)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_preprocessor — construction
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_preprocessor_returns_protocol():
|
||||||
|
"""build_preprocessor should return a Preprocessor instance."""
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
assert isinstance(preprocessor, Preprocessor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_preprocessor_with_default_config():
|
||||||
|
"""build_preprocessor() with no arguments should not raise."""
|
||||||
|
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||||
|
assert preprocessor is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_preprocessor_with_explicit_config():
|
||||||
|
config = PreprocessingConfig(
|
||||||
|
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
|
||||||
|
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||||
|
size=ResizeConfig(height=128, resize_factor=0.5),
|
||||||
|
spectrogram_transforms=[PcenConfig(), SpectralMeanSubtractionConfig()],
|
||||||
|
)
|
||||||
|
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(preprocessor, Preprocessor)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Output shape and dtype
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_output_is_2d():
|
||||||
|
"""The preprocessor output should be a 2-D tensor (freq_bins × time_frames)."""
|
||||||
|
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||||
|
wav = make_sine_wav()
|
||||||
|
result = preprocessor(wav)
|
||||||
|
assert result.ndim == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_output_height_matches_config():
|
||||||
|
"""Output height should match the ResizeConfig.height setting."""
|
||||||
|
config = PreprocessingConfig(
|
||||||
|
size=ResizeConfig(height=64, resize_factor=0.5)
|
||||||
|
)
|
||||||
|
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||||
|
wav = make_sine_wav()
|
||||||
|
result = preprocessor(wav)
|
||||||
|
assert result.shape[0] == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_output_dtype_float32():
|
||||||
|
"""Output tensor should be float32."""
|
||||||
|
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||||
|
wav = make_sine_wav()
|
||||||
|
result = preprocessor(wav)
|
||||||
|
assert result.dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_output_is_finite():
|
||||||
|
"""Output spectrogram should contain no NaN or Inf values."""
|
||||||
|
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||||
|
wav = make_sine_wav()
|
||||||
|
result = preprocessor(wav)
|
||||||
|
assert torch.isfinite(result).all()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# process_numpy
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_process_numpy_accepts_ndarray():
|
||||||
|
"""process_numpy should accept a NumPy array and return a NumPy array."""
|
||||||
|
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||||
|
wav_np = make_sine_wav().numpy()
|
||||||
|
result = preprocessor.process_numpy(wav_np)
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_process_numpy_matches_forward():
|
||||||
|
"""process_numpy and forward should give numerically identical results."""
|
||||||
|
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||||
|
wav = make_sine_wav()
|
||||||
|
result_pt = preprocessor(wav).numpy()
|
||||||
|
result_np = preprocessor.process_numpy(wav.numpy())
|
||||||
|
np.testing.assert_array_almost_equal(result_pt, result_np)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Attributes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_min_max_freq_attributes():
|
||||||
|
"""min_freq and max_freq should match the FrequencyConfig values."""
|
||||||
|
config = PreprocessingConfig(
|
||||||
|
frequencies=FrequencyConfig(min_freq=15_000, max_freq=100_000)
|
||||||
|
)
|
||||||
|
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||||
|
assert preprocessor.min_freq == 15_000
|
||||||
|
assert preprocessor.max_freq == 100_000
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_input_samplerate_attribute():
|
||||||
|
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||||
|
assert preprocessor.input_samplerate == SAMPLERATE
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# compute_output_samplerate
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_output_samplerate_defaults():
|
||||||
|
"""At default settings, output_samplerate should equal 1000 fps."""
|
||||||
|
config = PreprocessingConfig()
|
||||||
|
rate = compute_output_samplerate(config, input_samplerate=SAMPLERATE)
|
||||||
|
assert abs(rate - 1000.0) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_output_samplerate_attribute_matches_compute():
|
||||||
|
config = PreprocessingConfig()
|
||||||
|
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||||
|
expected = compute_output_samplerate(config, input_samplerate=SAMPLERATE)
|
||||||
|
assert abs(preprocessor.output_samplerate - expected) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# generate_spectrogram (raw STFT)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_spectrogram_shape():
|
||||||
|
"""generate_spectrogram should return the full STFT without crop or resize."""
|
||||||
|
config = PreprocessingConfig()
|
||||||
|
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||||
|
wav = make_sine_wav()
|
||||||
|
spec = preprocessor.generate_spectrogram(wav)
|
||||||
|
# Full STFT: n_fft//2 + 1 = 257 bins at defaults
|
||||||
|
assert spec.shape[0] == 257
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_spectrogram_larger_than_forward():
|
||||||
|
"""Raw spectrogram should have more frequency bins than the processed output."""
|
||||||
|
preprocessor = build_preprocessor(input_samplerate=SAMPLERATE)
|
||||||
|
wav = make_sine_wav()
|
||||||
|
raw = preprocessor.generate_spectrogram(wav)
|
||||||
|
processed = preprocessor(wav)
|
||||||
|
assert raw.shape[0] > processed.shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Audio transforms pipeline (FixDuration)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_with_fix_duration_audio_transform():
|
||||||
|
"""A FixDuration audio transform should produce consistent output shapes."""
|
||||||
|
config = PreprocessingConfig(
|
||||||
|
audio_transforms=[FixDurationConfig(duration=0.256)],
|
||||||
|
)
|
||||||
|
preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE)
|
||||||
|
# Feed different lengths; output shape should be the same after fix
|
||||||
|
for n_samples in [CLIP_SAMPLES - 1000, CLIP_SAMPLES, CLIP_SAMPLES + 1000]:
|
||||||
|
wav = torch.randn(n_samples)
|
||||||
|
result = preprocessor(wav)
|
||||||
|
assert result.ndim == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# YAML round-trip
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocessor_yaml_roundtrip(tmp_path: pathlib.Path):
|
||||||
|
"""PreprocessingConfig serialised to YAML and reloaded should produce
|
||||||
|
a functionally identical preprocessor."""
|
||||||
|
from batdetect2.preprocess.config import load_preprocessing_config
|
||||||
|
|
||||||
|
config = PreprocessingConfig(
|
||||||
|
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
|
||||||
|
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||||
|
size=ResizeConfig(height=128, resize_factor=0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
yaml_path = tmp_path / "preprocess_config.yaml"
|
||||||
|
yaml_path.write_text(config.to_yaml_string())
|
||||||
|
|
||||||
|
loaded_config = load_preprocessing_config(yaml_path)
|
||||||
|
|
||||||
|
preprocessor = build_preprocessor(
|
||||||
|
loaded_config, input_samplerate=SAMPLERATE
|
||||||
|
)
|
||||||
|
wav = make_sine_wav()
|
||||||
|
result = preprocessor(wav)
|
||||||
|
|
||||||
|
assert result.ndim == 2
|
||||||
|
assert result.shape[0] == 128
|
||||||
|
assert torch.isfinite(result).all()
|
||||||
@ -1,37 +1,316 @@
|
|||||||
import numpy as np
|
"""Tests for spectrogram-level preprocessing transforms.
|
||||||
import pytest
|
|
||||||
import xarray as xr
|
|
||||||
|
|
||||||
SAMPLERATE = 250_000
|
Covers :mod:`batdetect2.preprocess.spectrogram` — STFT configuration,
|
||||||
DURATION = 0.1
|
frequency cropping, PCEN, spectral mean subtraction, amplitude scaling,
|
||||||
TEST_FREQ = 30_000
|
peak normalisation, and resizing.
|
||||||
N_SAMPLES = int(SAMPLERATE * DURATION)
|
"""
|
||||||
TIME_COORD = np.linspace(
|
|
||||||
0, DURATION, N_SAMPLES, endpoint=False, dtype=np.float32
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
PCEN,
|
||||||
|
FrequencyConfig,
|
||||||
|
FrequencyCrop,
|
||||||
|
PcenConfig,
|
||||||
|
PeakNormalize,
|
||||||
|
PeakNormalizeConfig,
|
||||||
|
ResizeConfig,
|
||||||
|
ResizeSpec,
|
||||||
|
ScaleAmplitude,
|
||||||
|
ScaleAmplitudeConfig,
|
||||||
|
SpectralMeanSubtraction,
|
||||||
|
SpectralMeanSubtractionConfig,
|
||||||
|
STFTConfig,
|
||||||
|
_spec_params_from_config,
|
||||||
|
build_spectrogram_builder,
|
||||||
|
build_spectrogram_crop,
|
||||||
|
build_spectrogram_resizer,
|
||||||
|
build_spectrogram_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
SAMPLERATE = 256_000
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sine_wave_xr() -> xr.DataArray:
|
# ---------------------------------------------------------------------------
|
||||||
"""Generate a single sine wave as an xr.DataArray."""
|
# STFTConfig / _spec_params_from_config
|
||||||
t = TIME_COORD
|
# ---------------------------------------------------------------------------
|
||||||
wav_data = np.sin(2 * np.pi * TEST_FREQ * t, dtype=np.float32)
|
|
||||||
return xr.DataArray(
|
|
||||||
wav_data,
|
def test_stft_config_defaults_give_correct_params():
|
||||||
coords={"time": t},
|
"""Default STFTConfig at 256 kHz should give n_fft=512, hop_length=128."""
|
||||||
dims=["time"],
|
config = STFTConfig()
|
||||||
attrs={"samplerate": SAMPLERATE},
|
n_fft, hop_length = _spec_params_from_config(config, samplerate=SAMPLERATE)
|
||||||
|
assert n_fft == 512
|
||||||
|
assert hop_length == 128
|
||||||
|
|
||||||
|
|
||||||
|
def test_stft_config_custom_params():
|
||||||
|
"""Custom window duration and overlap should produce the expected sizes."""
|
||||||
|
config = STFTConfig(window_duration=0.004, window_overlap=0.5)
|
||||||
|
n_fft, hop_length = _spec_params_from_config(config, samplerate=SAMPLERATE)
|
||||||
|
assert n_fft == 1024
|
||||||
|
assert hop_length == 512
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_spectrogram_builder
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_spectrogram_builder_output_shape():
|
||||||
|
"""Builder should produce a spectrogram with the expected number of bins."""
|
||||||
|
config = STFTConfig()
|
||||||
|
n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE)
|
||||||
|
expected_freq_bins = n_fft // 2 + 1 # 257 at defaults
|
||||||
|
|
||||||
|
builder = build_spectrogram_builder(config, samplerate=SAMPLERATE)
|
||||||
|
n_samples = SAMPLERATE # 1 second of audio
|
||||||
|
wav = torch.randn(n_samples)
|
||||||
|
spec = builder(wav)
|
||||||
|
|
||||||
|
assert spec.ndim == 2
|
||||||
|
assert spec.shape[0] == expected_freq_bins
|
||||||
|
|
||||||
|
|
||||||
|
def test_spectrogram_builder_output_is_nonnegative():
|
||||||
|
"""Amplitude spectrogram values should all be >= 0."""
|
||||||
|
config = STFTConfig()
|
||||||
|
builder = build_spectrogram_builder(config, samplerate=SAMPLERATE)
|
||||||
|
wav = torch.randn(SAMPLERATE)
|
||||||
|
spec = builder(wav)
|
||||||
|
assert (spec >= 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# FrequencyCrop
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_frequency_crop_output_shape():
|
||||||
|
"""FrequencyCrop should reduce the number of frequency bins."""
|
||||||
|
config = STFTConfig()
|
||||||
|
n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE)
|
||||||
|
|
||||||
|
crop = FrequencyCrop(
|
||||||
|
samplerate=SAMPLERATE,
|
||||||
|
n_fft=n_fft,
|
||||||
|
min_freq=10_000,
|
||||||
|
max_freq=120_000,
|
||||||
)
|
)
|
||||||
|
spec = torch.ones(n_fft // 2 + 1, 100)
|
||||||
|
cropped = crop(spec)
|
||||||
|
|
||||||
|
assert cropped.ndim == 2
|
||||||
|
# Must be smaller than the full spectrogram
|
||||||
|
assert cropped.shape[0] < spec.shape[0]
|
||||||
|
assert cropped.shape[1] == 100 # time axis unchanged
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def test_frequency_crop_build_from_config():
|
||||||
def constant_wave_xr() -> xr.DataArray:
|
"""build_spectrogram_crop should return a working FrequencyCrop."""
|
||||||
"""Generate a constant signal as an xr.DataArray."""
|
freq_config = FrequencyConfig(min_freq=10_000, max_freq=120_000)
|
||||||
t = TIME_COORD
|
stft_config = STFTConfig()
|
||||||
wav_data = np.ones(N_SAMPLES, dtype=np.float32) * 0.5
|
crop = build_spectrogram_crop(
|
||||||
return xr.DataArray(
|
freq_config, stft=stft_config, samplerate=SAMPLERATE
|
||||||
wav_data,
|
|
||||||
coords={"time": t},
|
|
||||||
dims=["time"],
|
|
||||||
attrs={"samplerate": SAMPLERATE},
|
|
||||||
)
|
)
|
||||||
|
assert isinstance(crop, FrequencyCrop)
|
||||||
|
|
||||||
|
|
||||||
|
def test_frequency_crop_no_crop_when_bounds_are_none():
|
||||||
|
"""FrequencyCrop with no bounds should return the full spectrogram."""
|
||||||
|
config = STFTConfig()
|
||||||
|
n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE)
|
||||||
|
|
||||||
|
crop = FrequencyCrop(samplerate=SAMPLERATE, n_fft=n_fft)
|
||||||
|
spec = torch.ones(n_fft // 2 + 1, 100)
|
||||||
|
cropped = crop(spec)
|
||||||
|
|
||||||
|
assert cropped.shape == spec.shape
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PCEN
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_pcen_output_shape_preserved():
|
||||||
|
"""PCEN should not change the shape of the spectrogram."""
|
||||||
|
config = PcenConfig()
|
||||||
|
pcen = PCEN.from_config(config, samplerate=SAMPLERATE)
|
||||||
|
spec = torch.rand(64, 200) + 1e-6 # positive values
|
||||||
|
result = pcen(spec)
|
||||||
|
assert result.shape == spec.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_pcen_output_is_finite():
|
||||||
|
"""PCEN applied to a well-formed spectrogram should produce no NaN or Inf."""
|
||||||
|
config = PcenConfig()
|
||||||
|
pcen = PCEN.from_config(config, samplerate=SAMPLERATE)
|
||||||
|
spec = torch.rand(64, 200) + 1e-6
|
||||||
|
result = pcen(spec)
|
||||||
|
assert torch.isfinite(result).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_pcen_output_dtype_matches_input():
|
||||||
|
"""PCEN should return a tensor with the same dtype as the input."""
|
||||||
|
config = PcenConfig()
|
||||||
|
pcen = PCEN.from_config(config, samplerate=SAMPLERATE)
|
||||||
|
spec = torch.rand(64, 200, dtype=torch.float32)
|
||||||
|
result = pcen(spec)
|
||||||
|
assert result.dtype == spec.dtype
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SpectralMeanSubtraction
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_spectral_mean_subtraction_output_nonnegative():
|
||||||
|
"""SpectralMeanSubtraction clamps output to >= 0."""
|
||||||
|
module = SpectralMeanSubtraction()
|
||||||
|
spec = torch.rand(64, 200)
|
||||||
|
result = module(spec)
|
||||||
|
assert (result >= 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_spectral_mean_subtraction_shape_preserved():
|
||||||
|
module = SpectralMeanSubtraction()
|
||||||
|
spec = torch.rand(64, 200)
|
||||||
|
result = module(spec)
|
||||||
|
assert result.shape == spec.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_spectral_mean_subtraction_reduces_time_mean():
|
||||||
|
"""After subtraction the time-axis mean per bin should be <= 0 (pre-clamp)."""
|
||||||
|
module = SpectralMeanSubtraction()
|
||||||
|
# Constant spectrogram: mean subtraction should produce all zeros before clamp
|
||||||
|
spec = torch.ones(32, 100) * 3.0
|
||||||
|
result = module(spec)
|
||||||
|
assert (result == 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_spectral_mean_subtraction_from_config():
|
||||||
|
config = SpectralMeanSubtractionConfig()
|
||||||
|
module = SpectralMeanSubtraction.from_config(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, SpectralMeanSubtraction)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PeakNormalize (spectrogram-level)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_peak_normalize_spec_max_is_one():
|
||||||
|
"""PeakNormalize should scale the spectrogram peak to 1."""
|
||||||
|
module = PeakNormalize()
|
||||||
|
spec = torch.rand(64, 200) * 5.0
|
||||||
|
result = module(spec)
|
||||||
|
assert abs(result.abs().max().item() - 1.0) < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_peak_normalize_spec_handles_zero():
|
||||||
|
"""PeakNormalize on a zero spectrogram should not raise."""
|
||||||
|
module = PeakNormalize()
|
||||||
|
spec = torch.zeros(64, 200)
|
||||||
|
result = module(spec)
|
||||||
|
assert (result == 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_peak_normalize_from_config():
|
||||||
|
config = PeakNormalizeConfig()
|
||||||
|
module = PeakNormalize.from_config(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, PeakNormalize)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ScaleAmplitude
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_amplitude_db_output_is_finite():
|
||||||
|
"""AmplitudeToDB scaling should produce finite values for positive input."""
|
||||||
|
module = ScaleAmplitude(scale="db")
|
||||||
|
spec = torch.rand(64, 200) + 1e-4
|
||||||
|
result = module(spec)
|
||||||
|
assert torch.isfinite(result).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_amplitude_power_output_equals_square():
|
||||||
|
"""ScaleAmplitude('power') should square every element."""
|
||||||
|
module = ScaleAmplitude(scale="power")
|
||||||
|
spec = torch.tensor([[2.0, 3.0], [4.0, 5.0]])
|
||||||
|
result = module(spec)
|
||||||
|
expected = spec**2
|
||||||
|
assert torch.allclose(result, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_amplitude_from_config():
|
||||||
|
config = ScaleAmplitudeConfig(scale="db")
|
||||||
|
module = ScaleAmplitude.from_config(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, ScaleAmplitude)
|
||||||
|
assert module.scale == "db"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ResizeSpec
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_spec_output_shape():
|
||||||
|
"""ResizeSpec should produce the target height and scaled width."""
|
||||||
|
module = ResizeSpec(height=64, time_factor=0.5)
|
||||||
|
spec = torch.rand(1, 128, 200) # (batch, freq, time)
|
||||||
|
result = module(spec)
|
||||||
|
assert result.shape == (1, 64, 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_spec_2d_input():
|
||||||
|
"""ResizeSpec should handle 2-D input (no batch or channel dimensions)."""
|
||||||
|
module = ResizeSpec(height=64, time_factor=0.5)
|
||||||
|
spec = torch.rand(128, 200)
|
||||||
|
result = module(spec)
|
||||||
|
assert result.shape == (64, 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_spec_output_is_finite():
|
||||||
|
module = ResizeSpec(height=128, time_factor=0.5)
|
||||||
|
spec = torch.rand(128, 200)
|
||||||
|
result = module(spec)
|
||||||
|
assert torch.isfinite(result).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize_spec_from_config():
|
||||||
|
config = ResizeConfig(height=64, resize_factor=0.25)
|
||||||
|
module = build_spectrogram_resizer(config)
|
||||||
|
assert isinstance(module, ResizeSpec)
|
||||||
|
assert module.height == 64
|
||||||
|
assert module.time_factor == 0.25
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_spectrogram_transform dispatch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_spectrogram_transform_pcen():
|
||||||
|
config = PcenConfig()
|
||||||
|
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, PCEN)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_spectrogram_transform_spectral_mean_subtraction():
|
||||||
|
config = SpectralMeanSubtractionConfig()
|
||||||
|
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, SpectralMeanSubtraction)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_spectrogram_transform_scale_amplitude():
|
||||||
|
config = ScaleAmplitudeConfig(scale="db")
|
||||||
|
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, ScaleAmplitude)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_spectrogram_transform_peak_normalize():
|
||||||
|
config = PeakNormalizeConfig()
|
||||||
|
module = build_spectrogram_transform(config, samplerate=SAMPLERATE)
|
||||||
|
assert isinstance(module, PeakNormalize)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from batdetect2.preprocess import (
|
|||||||
)
|
)
|
||||||
from batdetect2.preprocess.spectrogram import (
|
from batdetect2.preprocess.spectrogram import (
|
||||||
ScaleAmplitudeConfig,
|
ScaleAmplitudeConfig,
|
||||||
SpectralMeanSubstractionConfig,
|
SpectralMeanSubtractionConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
DEFAULT_ANCHOR,
|
DEFAULT_ANCHOR,
|
||||||
@ -597,7 +597,7 @@ def test_build_roi_mapper_for_peak_energy_bbox():
|
|||||||
preproc_config = PreprocessingConfig(
|
preproc_config = PreprocessingConfig(
|
||||||
spectrogram_transforms=[
|
spectrogram_transforms=[
|
||||||
ScaleAmplitudeConfig(scale="db"),
|
ScaleAmplitudeConfig(scale="db"),
|
||||||
SpectralMeanSubstractionConfig(),
|
SpectralMeanSubtractionConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
config = PeakEnergyBBoxMapperConfig(
|
config = PeakEnergyBBoxMapperConfig(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user