diff --git a/example_data/config.yaml b/example_data/config.yaml index 9a95182..42a35d6 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -21,7 +21,7 @@ preprocess: gain: 0.98 bias: 2 power: 0.5 - - name: spectral_mean_substraction + - name: spectral_mean_subtraction postprocess: nms_kernel_size: 9 diff --git a/src/batdetect2/preprocess/audio.py b/src/batdetect2/preprocess/audio.py index 89afddc..a872e38 100644 --- a/src/batdetect2/preprocess/audio.py +++ b/src/batdetect2/preprocess/audio.py @@ -1,3 +1,17 @@ +"""Audio-level transforms applied to waveforms before spectrogram computation. + +This module defines ``torch.nn.Module`` transforms that operate on raw +audio tensors and the Pydantic configuration classes that control them. +Each transform is registered in the ``audio_transforms`` registry so that +the pipeline can be assembled from a configuration object. + +The supported transforms are: + +* ``CenterAudio`` — subtract the DC offset (mean) from the waveform. +* ``ScaleAudio`` — peak-normalise the waveform to the range ``[-1, 1]``. +* ``FixDuration`` — truncate or zero-pad the waveform to a fixed length. +""" + from typing import Annotated, Literal import torch @@ -18,14 +32,43 @@ __all__ = [ audio_transforms: Registry[torch.nn.Module, [int]] = Registry( "audio_transform" ) +"""Registry mapping audio transform config classes to their builder methods.""" class CenterAudioConfig(BaseConfig): + """Configuration for the DC-offset removal transform. + + Attributes + ---------- + name : str + Fixed identifier; always ``"center_audio"``. + """ + name: Literal["center_audio"] = "center_audio" class CenterAudio(torch.nn.Module): + """Remove the DC offset from an audio waveform. + + Subtracts the global mean of the waveform from every sample, + centring the signal around zero. This is useful when an analogue + recording chain introduces a constant voltage bias. + """ + def forward(self, wav: torch.Tensor) -> torch.Tensor: + """Subtract the mean from the waveform. + + Parameters + ---------- + wav : torch.Tensor + Input waveform tensor of shape ``(samples,)`` or + ``(channels, samples)``. + + Returns + ------- + torch.Tensor + Zero-centred waveform with the same shape as the input. + """ return center_tensor(wav) @audio_transforms.register(CenterAudioConfig) @@ -35,11 +78,38 @@ class CenterAudio(torch.nn.Module): class ScaleAudioConfig(BaseConfig): + """Configuration for the peak-normalisation transform. + + Attributes + ---------- + name : str + Fixed identifier; always ``"scale_audio"``. + """ + name: Literal["scale_audio"] = "scale_audio" class ScaleAudio(torch.nn.Module): + """Peak-normalise an audio waveform to the range ``[-1, 1]``. + + Divides the waveform by its largest absolute sample value. If the + waveform is identically zero it is returned unchanged. + """ + def forward(self, wav: torch.Tensor) -> torch.Tensor: + """Peak-normalise the waveform. + + Parameters + ---------- + wav : torch.Tensor + Input waveform tensor of any shape. + + Returns + ------- + torch.Tensor + Normalised waveform with the same shape as the input and + values in the range ``[-1, 1]``. + """ return peak_normalize(wav) @audio_transforms.register(ScaleAudioConfig) @@ -49,11 +119,36 @@ class ScaleAudio(torch.nn.Module): class FixDurationConfig(BaseConfig): + """Configuration for the fixed-duration transform. + + Attributes + ---------- + name : str + Fixed identifier; always ``"fix_duration"``. + duration : float, default=0.5 + Target duration in seconds. The waveform will be truncated or + zero-padded to match this length. + """ + name: Literal["fix_duration"] = "fix_duration" duration: float = 0.5 class FixDuration(torch.nn.Module): + """Ensure a waveform has exactly a specified number of samples. + + If the input is longer than the target length it is truncated from + the end. If it is shorter, it is zero-padded at the end. + + Parameters + ---------- + samplerate : int + Sample rate of the audio in Hz. Used with ``duration`` to + compute the target number of samples. + duration : float + Target duration in seconds. + """ + def __init__(self, samplerate: int, duration: float): super().__init__() self.samplerate = samplerate @@ -61,6 +156,20 @@ class FixDuration(torch.nn.Module): self.length = int(samplerate * duration) def forward(self, wav: torch.Tensor) -> torch.Tensor: + """Truncate or pad the waveform to the target length. + + Parameters + ---------- + wav : torch.Tensor + Input waveform tensor of shape ``(samples,)`` or + ``(channels, samples)``. The last dimension is adjusted. + + Returns + ------- + torch.Tensor + Waveform with exactly ``self.length`` samples along the last + dimension. + """ length = wav.shape[-1] if length == self.length: @@ -81,10 +190,34 @@ AudioTransform = Annotated[ FixDurationConfig | ScaleAudioConfig | CenterAudioConfig, Field(discriminator="name"), ] +"""Discriminated union of all audio transform configuration types. + +Use this type when a field should accept any of the supported audio +transforms. Pydantic will select the correct config class based on the +``name`` field. +""" def build_audio_transform( config: AudioTransform, samplerate: int = TARGET_SAMPLERATE_HZ, ) -> torch.nn.Module: + """Build an audio transform module from a configuration object. + + Parameters + ---------- + config : AudioTransform + A configuration object for one of the supported audio transforms + (``CenterAudioConfig``, ``ScaleAudioConfig``, or + ``FixDurationConfig``). + samplerate : int, default=256000 + Sample rate of the audio in Hz. Passed to the transform builder; + some transforms (e.g. ``FixDuration``) use it to convert seconds + to samples. + + Returns + ------- + torch.nn.Module + The constructed audio transform module. + """ return audio_transforms.build(config, samplerate) diff --git a/src/batdetect2/preprocess/common.py b/src/batdetect2/preprocess/common.py index c498063..be27b18 100644 --- a/src/batdetect2/preprocess/common.py +++ b/src/batdetect2/preprocess/common.py @@ -1,3 +1,10 @@ +"""Shared tensor primitives used across the preprocessing pipeline. + +This module provides small, stateless helper functions that operate on +PyTorch tensors. They are used by both audio-level and spectrogram-level +transforms, and are kept here to avoid duplication. +""" + import torch __all__ = [ @@ -7,11 +14,42 @@ __all__ = [ def center_tensor(tensor: torch.Tensor) -> torch.Tensor: + """Subtract the mean of a tensor from all of its values. + + This centres the signal around zero, removing any constant DC offset. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor of any shape. + + Returns + ------- + torch.Tensor + A new tensor of the same shape and dtype with the global mean + subtracted from every element. + """ return tensor - tensor.mean() def peak_normalize(tensor: torch.Tensor) -> torch.Tensor: - max_value = tensor.abs().min() + """Scale a tensor so that its largest absolute value equals one. + + Divides the tensor by its peak absolute value. If the tensor is + identically zero, it is returned unchanged (no division by zero). + + Parameters + ---------- + tensor : torch.Tensor + Input tensor of any shape. + + Returns + ------- + torch.Tensor + A new tensor of the same shape and dtype with values in the range + ``[-1, 1]`` (or exactly ``[0, 0]`` for a zero tensor). + """ + max_value = tensor.abs().max() denominator = torch.where( max_value == 0, diff --git a/src/batdetect2/preprocess/config.py b/src/batdetect2/preprocess/config.py index 10676c3..086a640 100644 --- a/src/batdetect2/preprocess/config.py +++ b/src/batdetect2/preprocess/config.py @@ -1,3 +1,10 @@ +"""Configuration for the full batdetect2 preprocessing pipeline. + +This module defines :class:`PreprocessingConfig`, which aggregates all +configuration needed to convert a raw audio waveform into a normalised +spectrogram ready for the detection model. +""" + from typing import List from pydantic import Field @@ -9,7 +16,7 @@ from batdetect2.preprocess.spectrogram import ( FrequencyConfig, PcenConfig, ResizeConfig, - SpectralMeanSubstractionConfig, + SpectralMeanSubtractionConfig, SpectrogramTransform, STFTConfig, ) @@ -24,19 +31,30 @@ __all__ = [ class PreprocessingConfig(BaseConfig): """Unified configuration for the audio preprocessing pipeline. - Aggregates the configuration for both the initial audio processing stage - and the subsequent spectrogram generation stage. + Aggregates the parameters for every stage of the pipeline: + audio-level transforms, STFT computation, frequency cropping, + spectrogram-level transforms, and the final resize step. Attributes ---------- - audio : AudioConfig - Configuration settings for the audio loading and initial waveform - processing steps (e.g., resampling, duration adjustment, scaling). - Defaults to default `AudioConfig` settings if omitted. - spectrogram : SpectrogramConfig - Configuration settings for the spectrogram generation process - (e.g., STFT parameters, frequency cropping, scaling, denoising, - resizing). Defaults to default `SpectrogramConfig` settings if omitted. + audio_transforms : list of AudioTransform, default=[] + Ordered list of transforms applied to the raw audio waveform + before the STFT is computed. Each entry is a configuration + object for one of the supported audio transforms + (``"center_audio"``, ``"scale_audio"``, or ``"fix_duration"``). + spectrogram_transforms : list of SpectrogramTransform + Ordered list of transforms applied to the cropped spectrogram + after the STFT and frequency crop steps. Defaults to + ``[PcenConfig(), SpectralMeanSubtractionConfig()]``, which + applies PCEN followed by spectral mean subtraction. + stft : STFTConfig + Parameters for the Short-Time Fourier Transform (window + duration, overlap, and window function). + frequencies : FrequencyConfig + Frequency range (in Hz) to retain after the STFT. + size : ResizeConfig + Target height (number of frequency bins) and time-axis scaling + factor for the final resize step. """ audio_transforms: List[AudioTransform] = Field(default_factory=list) @@ -44,7 +62,7 @@ class PreprocessingConfig(BaseConfig): spectrogram_transforms: List[SpectrogramTransform] = Field( default_factory=lambda: [ PcenConfig(), - SpectralMeanSubstractionConfig(), + SpectralMeanSubtractionConfig(), ] ) @@ -59,4 +77,20 @@ def load_preprocessing_config( path: PathLike, field: str | None = None, ) -> PreprocessingConfig: + """Load a ``PreprocessingConfig`` from a YAML file. + + Parameters + ---------- + path : PathLike + Path to the YAML configuration file. + field : str, optional + If provided, read the config from a nested field within the + YAML document (e.g. ``"preprocessing"`` to read from a top-level + ``preprocessing:`` key). + + Returns + ------- + PreprocessingConfig + The deserialised preprocessing configuration. + """ return load_config(path, schema=PreprocessingConfig, field=field) diff --git a/src/batdetect2/preprocess/preprocessor.py b/src/batdetect2/preprocess/preprocessor.py index f5f351b..e32f463 100644 --- a/src/batdetect2/preprocess/preprocessor.py +++ b/src/batdetect2/preprocess/preprocessor.py @@ -1,3 +1,25 @@ +"""Assembles the full batdetect2 preprocessing pipeline. + +This module defines :class:`Preprocessor`, the concrete implementation of +:class:`~batdetect2.typing.PreprocessorProtocol`, and the +:func:`build_preprocessor` factory function that constructs it from a +:class:`~batdetect2.preprocess.config.PreprocessingConfig`. + +The preprocessing pipeline converts a raw audio waveform (as a +``torch.Tensor``) into a normalised, cropped, and resized spectrogram ready +for the detection model. The stages are applied in this order: + +1. **Audio transforms** — optional waveform-level operations such as DC + removal, peak normalisation, or duration fixing. +2. **STFT** — Short-Time Fourier Transform to produce an amplitude + spectrogram. +3. **Frequency crop** — retain only the frequency band of interest. +4. **Spectrogram transforms** — normalisation operations such as PCEN and + spectral mean subtraction. +5. **Resize** — scale the spectrogram to the model's expected height and + reduce the time resolution. +""" + import torch from loguru import logger @@ -20,7 +42,32 @@ __all__ = [ class Preprocessor(torch.nn.Module, PreprocessorProtocol): - """Standard implementation of the `Preprocessor` protocol.""" + """Standard implementation of the :class:`~batdetect2.typing.PreprocessorProtocol`. + + Wraps all preprocessing stages as ``torch.nn.Module`` submodules so + that parameters (e.g. PCEN filter coefficients) can be tracked and + moved between devices. + + Parameters + ---------- + config : PreprocessingConfig + Full pipeline configuration. + input_samplerate : int + Sample rate of the audio that will be passed to this preprocessor, + in Hz. + + Attributes + ---------- + input_samplerate : int + Sample rate of the input audio in Hz. + output_samplerate : float + Effective frame rate of the output spectrogram in frames per second. + Computed from the STFT hop length and the time-axis resize factor. + min_freq : float + Lower bound of the retained frequency band in Hz. + max_freq : float + Upper bound of the retained frequency band in Hz. + """ input_samplerate: int output_samplerate: float @@ -72,17 +119,75 @@ class Preprocessor(torch.nn.Module, PreprocessorProtocol): ) def forward(self, wav: torch.Tensor) -> torch.Tensor: + """Run the full preprocessing pipeline on a waveform. + + Applies audio transforms, then the STFT, then + :meth:`process_spectrogram`. + + Parameters + ---------- + wav : torch.Tensor + Input waveform of shape ``(samples,)``. + + Returns + ------- + torch.Tensor + Preprocessed spectrogram of shape + ``(freq_bins, time_frames)``. + """ wav = self.audio_transforms(wav) spec = self.spectrogram_builder(wav) return self.process_spectrogram(spec) def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: + """Compute the raw STFT spectrogram without any further processing. + + Parameters + ---------- + wav : torch.Tensor + Input waveform of shape ``(samples,)``. + + Returns + ------- + torch.Tensor + Amplitude spectrogram of shape ``(n_fft//2 + 1, time_frames)`` + with no frequency cropping, normalisation, or resizing applied. + """ return self.spectrogram_builder(wav) def process_audio(self, wav: torch.Tensor) -> torch.Tensor: + """Alias for :meth:`forward`. + + Parameters + ---------- + wav : torch.Tensor + Input waveform of shape ``(samples,)``. + + Returns + ------- + torch.Tensor + Preprocessed spectrogram (same as calling the object directly). + """ return self(wav) def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: + """Apply the post-STFT processing stages to an existing spectrogram. + + Applies frequency cropping, spectrogram-level transforms (e.g. + PCEN, spectral mean subtraction), and the final resize step. + + Parameters + ---------- + spec : torch.Tensor + Raw amplitude spectrogram of shape + ``(..., n_fft//2 + 1, time_frames)``. + + Returns + ------- + torch.Tensor + Normalised and resized spectrogram of shape + ``(..., height, scaled_time_frames)``. + """ spec = self.spectrogram_crop(spec) spec = self.spectrogram_transforms(spec) return self.spectrogram_resizer(spec) @@ -92,6 +197,25 @@ def compute_output_samplerate( config: PreprocessingConfig, input_samplerate: int = TARGET_SAMPLERATE_HZ, ) -> float: + """Compute the effective frame rate of the preprocessor's output. + + The output frame rate (in frames per second) depends on the STFT hop + length and the time-axis resize factor applied by the final resize step. + + Parameters + ---------- + config : PreprocessingConfig + Pipeline configuration. + input_samplerate : int, default=256000 + Sample rate of the input audio in Hz. + + Returns + ------- + float + Output frame rate in frames per second. + For example, at the default settings (256 kHz, hop=128, + resize_factor=0.5) this equals ``1000.0``. + """ _, hop_size = _spec_params_from_config( config.stft, samplerate=input_samplerate ) @@ -103,7 +227,24 @@ def build_preprocessor( config: PreprocessingConfig | None = None, input_samplerate: int = TARGET_SAMPLERATE_HZ, ) -> PreprocessorProtocol: - """Factory function to build the standard preprocessor from configuration.""" + """Build the standard preprocessor from a configuration object. + + Parameters + ---------- + config : PreprocessingConfig, optional + Pipeline configuration. If ``None``, the default + ``PreprocessingConfig()`` is used (PCEN + spectral mean + subtraction, 256 kHz, standard STFT parameters). + input_samplerate : int, default=256000 + Sample rate of the audio that will be fed to the preprocessor, + in Hz. + + Returns + ------- + PreprocessorProtocol + A :class:`Preprocessor` instance ready to convert waveforms to + spectrograms. + """ config = config or PreprocessingConfig() logger.opt(lazy=True).debug( "Building preprocessor with config: \n{}", diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index d248fa1..c579c1f 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -1,4 +1,14 @@ -"""Computes spectrograms from audio waveforms with configurable parameters.""" +"""Computes spectrograms from audio waveforms with configurable parameters. + +This module defines the STFT-based spectrogram builder and a collection of +spectrogram-level transforms (PCEN, spectral mean subtraction, amplitude +scaling, peak normalisation, frequency cropping, and resizing) that form the +signal-processing stage of the batdetect2 preprocessing pipeline. + +Each transform is paired with a Pydantic configuration class and registered +in the ``spectrogram_transforms`` registry so that the pipeline can be fully +specified via a YAML or Python configuration object. +""" from typing import Annotated, Callable, Literal @@ -32,17 +42,22 @@ class STFTConfig(BaseConfig): Attributes ---------- window_duration : float, default=0.002 - Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be - > 0. Determines frequency resolution (longer window = finer frequency - resolution). + Duration of the STFT analysis window in seconds (e.g. 0.002 for + 2 ms). Must be > 0. A longer window gives finer frequency resolution + but coarser time resolution. window_overlap : float, default=0.75 - Fraction of overlap between consecutive STFT windows (e.g., 0.75 - for 75%). Must be >= 0 and < 1. Determines time resolution - (higher overlap = finer time resolution). + Fraction of overlap between consecutive windows (e.g. 0.75 for + 75 %). Must be >= 0 and < 1. Higher overlap gives finer time + resolution at the cost of more computation. window_fn : str, default="hann" - Name of the window function to apply before FFT calculation. Common - options include "hann", "hamming", "blackman". See - `scipy.signal.get_window`. + Name of the tapering window applied to each frame before the FFT. + Supported values: ``"hann"``, ``"hamming"``, ``"kaiser"``, + ``"blackman"``, ``"bartlett"``. + + Notes + ----- + At the default sample rate of 256 kHz, ``window_duration=0.002`` and + ``window_overlap=0.75`` give ``n_fft=512`` and ``hop_length=128``. """ window_duration: float = Field(default=0.002, gt=0) @@ -54,6 +69,23 @@ def build_spectrogram_builder( config: STFTConfig, samplerate: int = TARGET_SAMPLERATE_HZ, ) -> torch.nn.Module: + """Build a torchaudio STFT spectrogram module from an ``STFTConfig``. + + Parameters + ---------- + config : STFTConfig + STFT parameters (window duration, overlap, and window function). + samplerate : int, default=256000 + Sample rate of the input audio in Hz. Used to convert the + window duration into a number of samples. + + Returns + ------- + torch.nn.Module + A ``torchaudio.transforms.Spectrogram`` module configured to + produce an amplitude (``power=1``) spectrogram with centred + frames. + """ n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate) return torchaudio.transforms.Spectrogram( n_fft=n_fft, @@ -65,6 +97,25 @@ def build_spectrogram_builder( def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]: + """Return the PyTorch window function matching the given name. + + Parameters + ---------- + name : str + Name of the window function. One of ``"hann"``, ``"hamming"``, + ``"kaiser"``, ``"blackman"``, or ``"bartlett"``. + + Returns + ------- + Callable[..., torch.Tensor] + A PyTorch window function that accepts a window length and returns + a 1-D tensor of weights. + + Raises + ------ + NotImplementedError + If ``name`` does not match any supported window function. + """ if name == "hann": return torch.hann_window @@ -88,7 +139,22 @@ def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]: def _spec_params_from_config( config: STFTConfig, samplerate: int = TARGET_SAMPLERATE_HZ, -): +) -> tuple[int, int]: + """Compute ``n_fft`` and ``hop_length`` from an ``STFTConfig``. + + Parameters + ---------- + config : STFTConfig + STFT parameters. + samplerate : int, default=256000 + Sample rate of the input audio in Hz. + + Returns + ------- + tuple[int, int] + A pair ``(n_fft, hop_length)`` giving the FFT size and the step + between consecutive frames in samples. + """ n_fft = int(samplerate * config.window_duration) hop_length = int(n_fft * (1 - config.window_overlap)) return n_fft, hop_length @@ -99,6 +165,24 @@ def _frequency_to_index( n_fft: int, samplerate: int = TARGET_SAMPLERATE_HZ, ) -> int | None: + """Convert a frequency in Hz to the nearest STFT frequency bin index. + + Parameters + ---------- + freq : float + Frequency in Hz to convert. + n_fft : int + FFT size used by the STFT. + samplerate : int, default=256000 + Sample rate of the audio in Hz. + + Returns + ------- + int or None + The bin index corresponding to ``freq``, or ``None`` if the + frequency is outside the valid range (i.e. <= 0 Hz or >= the + Nyquist frequency). + """ alpha = freq * 2 / samplerate height = np.floor(n_fft / 2) + 1 index = int(np.floor(alpha * height)) @@ -118,11 +202,11 @@ class FrequencyConfig(BaseConfig): Attributes ---------- max_freq : int, default=120000 - Maximum frequency in Hz to retain in the spectrogram after STFT. - Frequencies above this value will be cropped. Must be > 0. + Maximum frequency in Hz to retain after STFT. Frequency bins + above this value are discarded. Must be >= 0. min_freq : int, default=10000 - Minimum frequency in Hz to retain in the spectrogram after STFT. - Frequencies below this value will be cropped. Must be >= 0. + Minimum frequency in Hz to retain after STFT. Frequency bins + below this value are discarded. Must be >= 0. """ max_freq: int = Field(default=MAX_FREQ, ge=0) @@ -130,6 +214,27 @@ class FrequencyConfig(BaseConfig): class FrequencyCrop(torch.nn.Module): + """Crop a spectrogram to a specified frequency band. + + On construction the Hz boundaries are converted to STFT bin indices. + During the forward pass the spectrogram is sliced along its + frequency axis (second-to-last dimension) to retain only the bins + that fall within ``[min_freq, max_freq)``. + + Parameters + ---------- + samplerate : int + Sample rate of the audio in Hz. + n_fft : int + FFT size used by the STFT. + min_freq : int, optional + Lower frequency bound in Hz. If ``None``, no lower crop is + applied and the DC bin is retained. + max_freq : int, optional + Upper frequency bound in Hz. If ``None``, no upper crop is + applied and all bins up to Nyquist are retained. + """ + def __init__( self, samplerate: int, @@ -162,6 +267,19 @@ class FrequencyCrop(torch.nn.Module): self.high_index = high_index def forward(self, spec: torch.Tensor) -> torch.Tensor: + """Crop the spectrogram to the configured frequency band. + + Parameters + ---------- + spec : torch.Tensor + Spectrogram tensor of shape ``(..., freq_bins, time_frames)``. + + Returns + ------- + torch.Tensor + Cropped spectrogram with shape + ``(..., n_retained_bins, time_frames)``. + """ low_index = self.low_index if low_index is None: low_index = 0 @@ -184,6 +302,24 @@ def build_spectrogram_crop( stft: STFTConfig | None = None, samplerate: int = TARGET_SAMPLERATE_HZ, ) -> torch.nn.Module: + """Build a ``FrequencyCrop`` module from configuration objects. + + Parameters + ---------- + config : FrequencyConfig + Frequency boundary configuration specifying ``min_freq`` and + ``max_freq`` in Hz. + stft : STFTConfig, optional + STFT configuration used to derive ``n_fft``. Defaults to + ``STFTConfig()`` if not provided. + samplerate : int, default=256000 + Sample rate of the audio in Hz. + + Returns + ------- + torch.nn.Module + A ``FrequencyCrop`` module ready to crop spectrograms. + """ stft = stft or STFTConfig() n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate) return FrequencyCrop( @@ -195,18 +331,61 @@ def build_spectrogram_crop( class ResizeConfig(BaseConfig): + """Configuration for the final spectrogram resize step. + + Attributes + ---------- + name : str + Fixed identifier; always ``"resize_spec"``. + height : int, default=128 + Target number of frequency bins in the output spectrogram. + The spectrogram is resized to this height using bilinear + interpolation. + resize_factor : float, default=0.5 + Fraction by which the time axis is scaled. For example, ``0.5`` + halves the number of time frames, reducing computational cost + downstream. + """ + name: Literal["resize_spec"] = "resize_spec" height: int = 128 resize_factor: float = 0.5 class ResizeSpec(torch.nn.Module): + """Resize a spectrogram to a fixed height and scaled width. + + Uses bilinear interpolation so it handles arbitrary input shapes + gracefully. Input tensors with fewer than four dimensions are + temporarily unsqueezed to satisfy ``torch.nn.functional.interpolate``. + + Parameters + ---------- + height : int + Target number of frequency bins (output height). + time_factor : float + Multiplicative scaling applied to the time axis length. + """ + def __init__(self, height: int, time_factor: float): super().__init__() self.height = height self.time_factor = time_factor def forward(self, spec: torch.Tensor) -> torch.Tensor: + """Resize the spectrogram to the configured output dimensions. + + Parameters + ---------- + spec : torch.Tensor + Input spectrogram of shape ``(..., freq_bins, time_frames)``. + + Returns + ------- + torch.Tensor + Resized spectrogram with shape + ``(..., height, int(time_factor * time_frames))``. + """ current_length = spec.shape[-1] target_length = int(self.time_factor * current_length) @@ -227,6 +406,18 @@ class ResizeSpec(torch.nn.Module): def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module: + """Build a ``ResizeSpec`` module from a ``ResizeConfig``. + + Parameters + ---------- + config : ResizeConfig + Resize configuration specifying ``height`` and ``resize_factor``. + + Returns + ------- + torch.nn.Module + A ``ResizeSpec`` module configured with the given parameters. + """ return ResizeSpec(height=config.height, time_factor=config.resize_factor) @@ -236,7 +427,32 @@ spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry( class PcenConfig(BaseConfig): - """Configuration for Per-Channel Energy Normalization (PCEN).""" + """Configuration for Per-Channel Energy Normalisation (PCEN). + + PCEN is a frontend processing technique that replaces simple log + compression. It applies a learnable automatic gain control followed + by a stabilised root compression, making the representation more + robust to variations in recording level. + + Attributes + ---------- + name : str + Fixed identifier; always ``"pcen"``. + time_constant : float, default=0.4 + Time constant (in seconds) of the IIR smoothing filter used + for the background estimate. Larger values produce a slower- + adapting background. + gain : float, default=0.98 + Exponent controlling how strongly the background estimate + suppresses the signal. + bias : float, default=2 + Stabilisation bias added inside the root-compression step to + avoid division by zero. + power : float, default=0.5 + Root-compression exponent. A value of 0.5 gives square-root + compression, similar to log compression but differentiable at + zero. + """ name: Literal["pcen"] = "pcen" time_constant: float = 0.4 @@ -246,6 +462,35 @@ class PcenConfig(BaseConfig): class PCEN(torch.nn.Module): + """Per-Channel Energy Normalisation (PCEN) transform. + + Applies automatic gain control and root compression to a spectrogram. + The background estimate is computed with a first-order IIR filter + applied along the time axis. + + Parameters + ---------- + smoothing_constant : float + IIR filter coefficient ``alpha``. Derived from the time constant + and sample rate via ``_compute_smoothing_constant``. + gain : float, default=0.98 + AGC gain exponent. + bias : float, default=2.0 + Root-compression stabilisation bias. + power : float, default=0.5 + Root-compression exponent. + eps : float, default=1e-6 + Small constant for numerical stability. + dtype : torch.dtype, default=torch.float32 + Floating-point precision used for internal computation. + + Notes + ----- + The smoothing constant is computed to match the original batdetect2 + implementation for numerical compatibility. See + ``_compute_smoothing_constant`` for details. + """ + def __init__( self, smoothing_constant: float, @@ -269,6 +514,20 @@ class PCEN(torch.nn.Module): ) def forward(self, spec: torch.Tensor) -> torch.Tensor: + """Apply PCEN to a spectrogram. + + Parameters + ---------- + spec : torch.Tensor + Input amplitude spectrogram of shape + ``(..., freq_bins, time_frames)``. + + Returns + ------- + torch.Tensor + PCEN-normalised spectrogram with the same shape and dtype as + the input. + """ S = spec.to(self.dtype) * 2**31 M = ( @@ -305,21 +564,73 @@ def _compute_smoothing_constant( samplerate: int, time_constant: float, ) -> float: - # NOTE: These were taken to match the original implementation + """Compute the IIR smoothing coefficient for PCEN. + + Parameters + ---------- + samplerate : int + Sample rate of the audio in Hz. + time_constant : float + Desired smoothing time constant in seconds. + + Returns + ------- + float + IIR filter coefficient ``alpha`` used by ``PCEN``. + + Notes + ----- + The hop length (512) and the sample-rate divisor (10) are fixed to + reproduce the numerical behaviour of the original batdetect2 + implementation, which used ``librosa.pcen`` with ``sr=samplerate/10`` + and the default ``hop_length=512``. These values do not reflect the + actual STFT hop length used in the pipeline; they are retained + solely for backward compatibility. + """ + # NOTE: These parameters are fixed to match the original implementation. hop_length = 512 sr = samplerate / 10 - time_constant = time_constant t_frames = time_constant * sr / float(hop_length) return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) class ScaleAmplitudeConfig(BaseConfig): + """Configuration for amplitude scaling of a spectrogram. + + Attributes + ---------- + name : str + Fixed identifier; always ``"scale_amplitude"``. + scale : str, default="db" + Scaling mode. Either ``"db"`` (convert amplitude to decibels + using ``torchaudio.transforms.AmplitudeToDB``) or ``"power"`` + (square the amplitude values). + """ + name: Literal["scale_amplitude"] = "scale_amplitude" scale: Literal["power", "db"] = "db" class ToPower(torch.nn.Module): + """Square the values of a spectrogram (amplitude → power). + + Raises each element to the power of two, converting an amplitude + spectrogram into a power spectrogram. + """ + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """Square all elements of the spectrogram. + + Parameters + ---------- + spec : torch.Tensor + Input amplitude spectrogram. + + Returns + ------- + torch.Tensor + Power spectrogram (same shape as input). + """ return spec**2 @@ -330,12 +641,34 @@ _scalers = { class ScaleAmplitude(torch.nn.Module): + """Convert spectrogram amplitude values to a different scale. + + Supports conversion to decibels (dB) or to power (squared amplitude). + + Parameters + ---------- + scale : str + Either ``"db"`` or ``"power"``. + """ + def __init__(self, scale: Literal["power", "db"]): super().__init__() self.scale = scale self.scaler = _scalers[scale]() def forward(self, spec: torch.Tensor) -> torch.Tensor: + """Apply the configured amplitude scaling. + + Parameters + ---------- + spec : torch.Tensor + Input spectrogram tensor. + + Returns + ------- + torch.Tensor + Scaled spectrogram with the same shape as the input. + """ return self.scaler(spec) @spectrogram_transforms.register(ScaleAmplitudeConfig) @@ -344,30 +677,86 @@ class ScaleAmplitude(torch.nn.Module): return ScaleAmplitude(scale=config.scale) -class SpectralMeanSubstractionConfig(BaseConfig): - name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction" +class SpectralMeanSubtractionConfig(BaseConfig): + """Configuration for spectral mean subtraction. + + Attributes + ---------- + name : str + Fixed identifier; always ``"spectral_mean_subtraction"``. + """ + + name: Literal["spectral_mean_subtraction"] = "spectral_mean_subtraction" -class SpectralMeanSubstraction(torch.nn.Module): +class SpectralMeanSubtraction(torch.nn.Module): + """Remove the time-averaged background noise from a spectrogram. + + For each frequency bin, the mean value across all time frames is + computed and subtracted. The result is then clamped to zero so that + no values fall below the baseline. This is a simple form of spectral + denoising that suppresses stationary background noise. + """ + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """Subtract the time-axis mean from each frequency bin. + + Parameters + ---------- + spec : torch.Tensor + Input spectrogram of shape ``(..., freq_bins, time_frames)``. + + Returns + ------- + torch.Tensor + Denoised spectrogram with the same shape as the input. All + values are non-negative (clamped to 0). + """ mean = spec.mean(-1, keepdim=True) return (spec - mean).clamp(min=0) - @spectrogram_transforms.register(SpectralMeanSubstractionConfig) + @spectrogram_transforms.register(SpectralMeanSubtractionConfig) @staticmethod def from_config( - config: SpectralMeanSubstractionConfig, + config: SpectralMeanSubtractionConfig, samplerate: int, ): - return SpectralMeanSubstraction() + return SpectralMeanSubtraction() class PeakNormalizeConfig(BaseConfig): + """Configuration for peak normalisation of a spectrogram. + + Attributes + ---------- + name : str + Fixed identifier; always ``"peak_normalize"``. + """ + name: Literal["peak_normalize"] = "peak_normalize" class PeakNormalize(torch.nn.Module): + """Scale a spectrogram so that its largest absolute value equals one. + + Wraps :func:`batdetect2.preprocess.common.peak_normalize` as a + ``torch.nn.Module`` for use inside a sequential transform pipeline. + """ + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """Peak-normalise the spectrogram. + + Parameters + ---------- + spec : torch.Tensor + Input spectrogram tensor of any shape. + + Returns + ------- + torch.Tensor + Normalised spectrogram where the maximum absolute value is 1. + If the input is identically zero, it is returned unchanged. + """ return peak_normalize(spec) @spectrogram_transforms.register(PeakNormalizeConfig) @@ -379,14 +768,37 @@ class PeakNormalize(torch.nn.Module): SpectrogramTransform = Annotated[ PcenConfig | ScaleAmplitudeConfig - | SpectralMeanSubstractionConfig + | SpectralMeanSubtractionConfig | PeakNormalizeConfig, Field(discriminator="name"), ] +"""Discriminated union of all spectrogram transform configuration types. + +Use this type when a field should accept any of the supported spectrogram +transforms. Pydantic will select the correct config class based on the +``name`` field. +""" def build_spectrogram_transform( config: SpectrogramTransform, samplerate: int, ) -> torch.nn.Module: + """Build a spectrogram transform module from a configuration object. + + Parameters + ---------- + config : SpectrogramTransform + A configuration object for one of the supported spectrogram + transforms (PCEN, amplitude scaling, spectral mean subtraction, + or peak normalisation). + samplerate : int + Sample rate of the audio in Hz. Some transforms (e.g. PCEN) use + this to set internal parameters. + + Returns + ------- + torch.nn.Module + The constructed transform module. + """ return spectrogram_transforms.build(config, samplerate) diff --git a/tests/test_preprocessing/test_audio.py b/tests/test_preprocessing/test_audio.py index ab22f6c..273d0a4 100644 --- a/tests/test_preprocessing/test_audio.py +++ b/tests/test_preprocessing/test_audio.py @@ -1,12 +1,31 @@ +"""Tests for audio-level preprocessing transforms. + +Covers :mod:`batdetect2.preprocess.audio` and the shared helper functions +in :mod:`batdetect2.preprocess.common`. +""" + import pathlib import uuid import numpy as np import pytest import soundfile as sf +import torch from soundevent import data from batdetect2.audio import AudioConfig +from batdetect2.preprocess.audio import ( + CenterAudio, + CenterAudioConfig, + FixDuration, + FixDurationConfig, + ScaleAudio, + ScaleAudioConfig, + build_audio_transform, +) +from batdetect2.preprocess.common import center_tensor, peak_normalize + +SAMPLERATE = 256_000 def create_dummy_wave( @@ -15,9 +34,9 @@ def create_dummy_wave( num_channels: int = 1, freq: float = 440.0, amplitude: float = 0.5, - dtype: np.dtype = np.float32, + dtype: type = np.float32, ) -> np.ndarray: - """Generates a simple numpy waveform.""" + """Generate a simple sine-wave waveform as a NumPy array.""" t = np.linspace( 0.0, duration, int(samplerate * duration), endpoint=False, dtype=dtype ) @@ -29,7 +48,7 @@ def create_dummy_wave( @pytest.fixture def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path: - """Creates a dummy WAV file and returns its path.""" + """Create a dummy 2-channel WAV file and return its path.""" samplerate = 48000 duration = 2.0 num_channels = 2 @@ -41,13 +60,13 @@ def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path: @pytest.fixture def dummy_recording(dummy_wav_path: pathlib.Path) -> data.Recording: - """Creates a Recording object pointing to the dummy WAV file.""" + """Create a Recording object pointing to the dummy WAV file.""" return data.Recording.from_file(dummy_wav_path) @pytest.fixture def dummy_clip(dummy_recording: data.Recording) -> data.Clip: - """Creates a Clip object from the dummy recording.""" + """Create a Clip object from the dummy recording.""" return data.Clip( recording=dummy_recording, start_time=0.5, @@ -58,3 +77,165 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip: @pytest.fixture def default_audio_config() -> AudioConfig: return AudioConfig() + + +# --------------------------------------------------------------------------- +# center_tensor +# --------------------------------------------------------------------------- + + +def test_center_tensor_zero_mean(): + """Output tensor should have a mean very close to zero.""" + wav = torch.tensor([1.0, 2.0, 3.0, 4.0]) + result = center_tensor(wav) + assert result.mean().abs().item() < 1e-5 + + +def test_center_tensor_preserves_shape(): + wav = torch.randn(3, 1000) + result = center_tensor(wav) + assert result.shape == wav.shape + + +# --------------------------------------------------------------------------- +# peak_normalize +# --------------------------------------------------------------------------- + + +def test_peak_normalize_max_is_one(): + """After peak normalisation, the maximum absolute value should be 1.""" + wav = torch.tensor([0.1, -0.4, 0.2, 0.8, -0.3]) + result = peak_normalize(wav) + assert abs(result.abs().max().item() - 1.0) < 1e-6 + + +def test_peak_normalize_zero_tensor_unchanged(): + """A zero tensor should be returned unchanged (no division by zero).""" + wav = torch.zeros(100) + result = peak_normalize(wav) + assert (result == 0).all() + + +def test_peak_normalize_preserves_sign(): + """Signs of all elements should be preserved after normalisation.""" + wav = torch.tensor([-2.0, 1.0, -0.5]) + result = peak_normalize(wav) + assert (result < 0).sum() == 2 + assert result[0].item() < 0 + + +def test_peak_normalize_preserves_shape(): + wav = torch.randn(2, 512) + result = peak_normalize(wav) + assert result.shape == wav.shape + + +# --------------------------------------------------------------------------- +# CenterAudio +# --------------------------------------------------------------------------- + + +def test_center_audio_forward_zero_mean(): + module = CenterAudio() + wav = torch.tensor([1.0, 3.0, 5.0]) + result = module(wav) + assert result.mean().abs().item() < 1e-5 + + +def test_center_audio_from_config(): + config = CenterAudioConfig() + module = CenterAudio.from_config(config, samplerate=SAMPLERATE) + assert isinstance(module, CenterAudio) + + +# --------------------------------------------------------------------------- +# ScaleAudio +# --------------------------------------------------------------------------- + + +def test_scale_audio_peak_normalises_to_one(): + """ScaleAudio.forward should scale the peak absolute value to 1.""" + module = ScaleAudio() + wav = torch.tensor([0.0, 0.25, -0.5, 0.1]) + result = module(wav) + assert abs(result.abs().max().item() - 1.0) < 1e-6 + + +def test_scale_audio_handles_zero_tensor(): + """ScaleAudio should not raise on a zero tensor.""" + module = ScaleAudio() + wav = torch.zeros(100) + result = module(wav) + assert (result == 0).all() + + +def test_scale_audio_from_config(): + config = ScaleAudioConfig() + module = ScaleAudio.from_config(config, samplerate=SAMPLERATE) + assert isinstance(module, ScaleAudio) + + +# --------------------------------------------------------------------------- +# FixDuration +# --------------------------------------------------------------------------- + + +def test_fix_duration_truncates_long_input(): + """Waveform longer than target should be truncated to the target length.""" + target_samples = int(SAMPLERATE * 0.5) + module = FixDuration(samplerate=SAMPLERATE, duration=0.5) + wav = torch.randn(target_samples + 1000) + result = module(wav) + assert result.shape[-1] == target_samples + + +def test_fix_duration_pads_short_input(): + """Waveform shorter than target should be zero-padded to the target length.""" + target_samples = int(SAMPLERATE * 0.5) + module = FixDuration(samplerate=SAMPLERATE, duration=0.5) + short_wav = torch.randn(target_samples - 100) + result = module(short_wav) + assert result.shape[-1] == target_samples + # Padded region should be zero + assert (result[target_samples - 100 :] == 0).all() + + +def test_fix_duration_passthrough_exact_length(): + """Waveform with exactly the right length should be returned unchanged.""" + target_samples = int(SAMPLERATE * 0.5) + module = FixDuration(samplerate=SAMPLERATE, duration=0.5) + wav = torch.randn(target_samples) + result = module(wav) + assert result.shape[-1] == target_samples + assert torch.equal(result, wav) + + +def test_fix_duration_from_config(): + """FixDurationConfig should produce a FixDuration with the correct length.""" + config = FixDurationConfig(duration=0.256) + module = FixDuration.from_config(config, samplerate=SAMPLERATE) + assert isinstance(module, FixDuration) + assert module.length == int(SAMPLERATE * 0.256) + + +# --------------------------------------------------------------------------- +# build_audio_transform dispatch +# --------------------------------------------------------------------------- + + +def test_build_audio_transform_center_audio(): + config = CenterAudioConfig() + module = build_audio_transform(config, samplerate=SAMPLERATE) + assert isinstance(module, CenterAudio) + + +def test_build_audio_transform_scale_audio(): + config = ScaleAudioConfig() + module = build_audio_transform(config, samplerate=SAMPLERATE) + assert isinstance(module, ScaleAudio) + + +def test_build_audio_transform_fix_duration(): + config = FixDurationConfig(duration=0.5) + module = build_audio_transform(config, samplerate=SAMPLERATE) + assert isinstance(module, FixDuration) diff --git a/tests/test_preprocessing/test_preprocessor.py b/tests/test_preprocessing/test_preprocessor.py new file mode 100644 index 0000000..c3afd09 --- /dev/null +++ b/tests/test_preprocessing/test_preprocessor.py @@ -0,0 +1,243 @@ +"""Integration and unit tests for the Preprocessor pipeline. + +Covers :mod:`batdetect2.preprocess.preprocessor` — construction, +pipeline output shape/dtype, the ``process_numpy`` helper, attribute +values, output frame rate, and a round-trip YAML → config → build test. +""" + +import pathlib + +import numpy as np +import torch + +from batdetect2.preprocess.audio import FixDurationConfig +from batdetect2.preprocess.config import PreprocessingConfig +from batdetect2.preprocess.preprocessor import ( + Preprocessor, + build_preprocessor, + compute_output_samplerate, +) +from batdetect2.preprocess.spectrogram import ( + FrequencyConfig, + PcenConfig, + ResizeConfig, + SpectralMeanSubtractionConfig, + STFTConfig, +) + +SAMPLERATE = 256_000 +# 0.256 s at 256 kHz = 65536 samples — a convenient power-of-two-sized clip +CLIP_SAMPLES = int(SAMPLERATE * 0.256) + + +def make_sine_wav( + samplerate: int = SAMPLERATE, + duration: float = 0.256, + freq: float = 40_000.0, +) -> torch.Tensor: + """Return a single-channel sine-wave tensor.""" + t = torch.linspace(0.0, duration, int(samplerate * duration)) + return torch.sin(2 * torch.pi * freq * t) + + +# --------------------------------------------------------------------------- +# build_preprocessor — construction +# --------------------------------------------------------------------------- + + +def test_build_preprocessor_returns_protocol(): + """build_preprocessor should return a Preprocessor instance.""" + preprocessor = build_preprocessor() + assert isinstance(preprocessor, Preprocessor) + + +def test_build_preprocessor_with_default_config(): + """build_preprocessor() with no arguments should not raise.""" + preprocessor = build_preprocessor(input_samplerate=SAMPLERATE) + assert preprocessor is not None + + +def test_build_preprocessor_with_explicit_config(): + config = PreprocessingConfig( + stft=STFTConfig(window_duration=0.002, window_overlap=0.75), + frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000), + size=ResizeConfig(height=128, resize_factor=0.5), + spectrogram_transforms=[PcenConfig(), SpectralMeanSubtractionConfig()], + ) + preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE) + assert isinstance(preprocessor, Preprocessor) + + +# --------------------------------------------------------------------------- +# Output shape and dtype +# --------------------------------------------------------------------------- + + +def test_preprocessor_output_is_2d(): + """The preprocessor output should be a 2-D tensor (freq_bins × time_frames).""" + preprocessor = build_preprocessor(input_samplerate=SAMPLERATE) + wav = make_sine_wav() + result = preprocessor(wav) + assert result.ndim == 2 + + +def test_preprocessor_output_height_matches_config(): + """Output height should match the ResizeConfig.height setting.""" + config = PreprocessingConfig( + size=ResizeConfig(height=64, resize_factor=0.5) + ) + preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE) + wav = make_sine_wav() + result = preprocessor(wav) + assert result.shape[0] == 64 + + +def test_preprocessor_output_dtype_float32(): + """Output tensor should be float32.""" + preprocessor = build_preprocessor(input_samplerate=SAMPLERATE) + wav = make_sine_wav() + result = preprocessor(wav) + assert result.dtype == torch.float32 + + +def test_preprocessor_output_is_finite(): + """Output spectrogram should contain no NaN or Inf values.""" + preprocessor = build_preprocessor(input_samplerate=SAMPLERATE) + wav = make_sine_wav() + result = preprocessor(wav) + assert torch.isfinite(result).all() + + +# --------------------------------------------------------------------------- +# process_numpy +# --------------------------------------------------------------------------- + + +def test_preprocessor_process_numpy_accepts_ndarray(): + """process_numpy should accept a NumPy array and return a NumPy array.""" + preprocessor = build_preprocessor(input_samplerate=SAMPLERATE) + wav_np = make_sine_wav().numpy() + result = preprocessor.process_numpy(wav_np) + assert isinstance(result, np.ndarray) + + +def test_preprocessor_process_numpy_matches_forward(): + """process_numpy and forward should give numerically identical results.""" + preprocessor = build_preprocessor(input_samplerate=SAMPLERATE) + wav = make_sine_wav() + result_pt = preprocessor(wav).numpy() + result_np = preprocessor.process_numpy(wav.numpy()) + np.testing.assert_array_almost_equal(result_pt, result_np) + + +# --------------------------------------------------------------------------- +# Attributes +# --------------------------------------------------------------------------- + + +def test_preprocessor_min_max_freq_attributes(): + """min_freq and max_freq should match the FrequencyConfig values.""" + config = PreprocessingConfig( + frequencies=FrequencyConfig(min_freq=15_000, max_freq=100_000) + ) + preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE) + assert preprocessor.min_freq == 15_000 + assert preprocessor.max_freq == 100_000 + + +def test_preprocessor_input_samplerate_attribute(): + preprocessor = build_preprocessor(input_samplerate=SAMPLERATE) + assert preprocessor.input_samplerate == SAMPLERATE + + +# --------------------------------------------------------------------------- +# compute_output_samplerate +# --------------------------------------------------------------------------- + + +def test_compute_output_samplerate_defaults(): + """At default settings, output_samplerate should equal 1000 fps.""" + config = PreprocessingConfig() + rate = compute_output_samplerate(config, input_samplerate=SAMPLERATE) + assert abs(rate - 1000.0) < 1e-6 + + +def test_preprocessor_output_samplerate_attribute_matches_compute(): + config = PreprocessingConfig() + preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE) + expected = compute_output_samplerate(config, input_samplerate=SAMPLERATE) + assert abs(preprocessor.output_samplerate - expected) < 1e-6 + + +# --------------------------------------------------------------------------- +# generate_spectrogram (raw STFT) +# --------------------------------------------------------------------------- + + +def test_generate_spectrogram_shape(): + """generate_spectrogram should return the full STFT without crop or resize.""" + config = PreprocessingConfig() + preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE) + wav = make_sine_wav() + spec = preprocessor.generate_spectrogram(wav) + # Full STFT: n_fft//2 + 1 = 257 bins at defaults + assert spec.shape[0] == 257 + + +def test_generate_spectrogram_larger_than_forward(): + """Raw spectrogram should have more frequency bins than the processed output.""" + preprocessor = build_preprocessor(input_samplerate=SAMPLERATE) + wav = make_sine_wav() + raw = preprocessor.generate_spectrogram(wav) + processed = preprocessor(wav) + assert raw.shape[0] > processed.shape[0] + + +# --------------------------------------------------------------------------- +# Audio transforms pipeline (FixDuration) +# --------------------------------------------------------------------------- + + +def test_preprocessor_with_fix_duration_audio_transform(): + """A FixDuration audio transform should produce consistent output shapes.""" + config = PreprocessingConfig( + audio_transforms=[FixDurationConfig(duration=0.256)], + ) + preprocessor = build_preprocessor(config, input_samplerate=SAMPLERATE) + # Feed different lengths; output shape should be the same after fix + for n_samples in [CLIP_SAMPLES - 1000, CLIP_SAMPLES, CLIP_SAMPLES + 1000]: + wav = torch.randn(n_samples) + result = preprocessor(wav) + assert result.ndim == 2 + + +# --------------------------------------------------------------------------- +# YAML round-trip +# --------------------------------------------------------------------------- + + +def test_preprocessor_yaml_roundtrip(tmp_path: pathlib.Path): + """PreprocessingConfig serialised to YAML and reloaded should produce + a functionally identical preprocessor.""" + from batdetect2.preprocess.config import load_preprocessing_config + + config = PreprocessingConfig( + stft=STFTConfig(window_duration=0.002, window_overlap=0.75), + frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000), + size=ResizeConfig(height=128, resize_factor=0.5), + ) + + yaml_path = tmp_path / "preprocess_config.yaml" + yaml_path.write_text(config.to_yaml_string()) + + loaded_config = load_preprocessing_config(yaml_path) + + preprocessor = build_preprocessor( + loaded_config, input_samplerate=SAMPLERATE + ) + wav = make_sine_wav() + result = preprocessor(wav) + + assert result.ndim == 2 + assert result.shape[0] == 128 + assert torch.isfinite(result).all() diff --git a/tests/test_preprocessing/test_spectrogram.py b/tests/test_preprocessing/test_spectrogram.py index f61bcff..f040c84 100644 --- a/tests/test_preprocessing/test_spectrogram.py +++ b/tests/test_preprocessing/test_spectrogram.py @@ -1,37 +1,316 @@ -import numpy as np -import pytest -import xarray as xr +"""Tests for spectrogram-level preprocessing transforms. -SAMPLERATE = 250_000 -DURATION = 0.1 -TEST_FREQ = 30_000 -N_SAMPLES = int(SAMPLERATE * DURATION) -TIME_COORD = np.linspace( - 0, DURATION, N_SAMPLES, endpoint=False, dtype=np.float32 +Covers :mod:`batdetect2.preprocess.spectrogram` — STFT configuration, +frequency cropping, PCEN, spectral mean subtraction, amplitude scaling, +peak normalisation, and resizing. +""" + +import torch + +from batdetect2.preprocess.spectrogram import ( + PCEN, + FrequencyConfig, + FrequencyCrop, + PcenConfig, + PeakNormalize, + PeakNormalizeConfig, + ResizeConfig, + ResizeSpec, + ScaleAmplitude, + ScaleAmplitudeConfig, + SpectralMeanSubtraction, + SpectralMeanSubtractionConfig, + STFTConfig, + _spec_params_from_config, + build_spectrogram_builder, + build_spectrogram_crop, + build_spectrogram_resizer, + build_spectrogram_transform, ) +SAMPLERATE = 256_000 -@pytest.fixture -def sine_wave_xr() -> xr.DataArray: - """Generate a single sine wave as an xr.DataArray.""" - t = TIME_COORD - wav_data = np.sin(2 * np.pi * TEST_FREQ * t, dtype=np.float32) - return xr.DataArray( - wav_data, - coords={"time": t}, - dims=["time"], - attrs={"samplerate": SAMPLERATE}, + +# --------------------------------------------------------------------------- +# STFTConfig / _spec_params_from_config +# --------------------------------------------------------------------------- + + +def test_stft_config_defaults_give_correct_params(): + """Default STFTConfig at 256 kHz should give n_fft=512, hop_length=128.""" + config = STFTConfig() + n_fft, hop_length = _spec_params_from_config(config, samplerate=SAMPLERATE) + assert n_fft == 512 + assert hop_length == 128 + + +def test_stft_config_custom_params(): + """Custom window duration and overlap should produce the expected sizes.""" + config = STFTConfig(window_duration=0.004, window_overlap=0.5) + n_fft, hop_length = _spec_params_from_config(config, samplerate=SAMPLERATE) + assert n_fft == 1024 + assert hop_length == 512 + + +# --------------------------------------------------------------------------- +# build_spectrogram_builder +# --------------------------------------------------------------------------- + + +def test_spectrogram_builder_output_shape(): + """Builder should produce a spectrogram with the expected number of bins.""" + config = STFTConfig() + n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE) + expected_freq_bins = n_fft // 2 + 1 # 257 at defaults + + builder = build_spectrogram_builder(config, samplerate=SAMPLERATE) + n_samples = SAMPLERATE # 1 second of audio + wav = torch.randn(n_samples) + spec = builder(wav) + + assert spec.ndim == 2 + assert spec.shape[0] == expected_freq_bins + + +def test_spectrogram_builder_output_is_nonnegative(): + """Amplitude spectrogram values should all be >= 0.""" + config = STFTConfig() + builder = build_spectrogram_builder(config, samplerate=SAMPLERATE) + wav = torch.randn(SAMPLERATE) + spec = builder(wav) + assert (spec >= 0).all() + + +# --------------------------------------------------------------------------- +# FrequencyCrop +# --------------------------------------------------------------------------- + + +def test_frequency_crop_output_shape(): + """FrequencyCrop should reduce the number of frequency bins.""" + config = STFTConfig() + n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE) + + crop = FrequencyCrop( + samplerate=SAMPLERATE, + n_fft=n_fft, + min_freq=10_000, + max_freq=120_000, ) + spec = torch.ones(n_fft // 2 + 1, 100) + cropped = crop(spec) + + assert cropped.ndim == 2 + # Must be smaller than the full spectrogram + assert cropped.shape[0] < spec.shape[0] + assert cropped.shape[1] == 100 # time axis unchanged -@pytest.fixture -def constant_wave_xr() -> xr.DataArray: - """Generate a constant signal as an xr.DataArray.""" - t = TIME_COORD - wav_data = np.ones(N_SAMPLES, dtype=np.float32) * 0.5 - return xr.DataArray( - wav_data, - coords={"time": t}, - dims=["time"], - attrs={"samplerate": SAMPLERATE}, +def test_frequency_crop_build_from_config(): + """build_spectrogram_crop should return a working FrequencyCrop.""" + freq_config = FrequencyConfig(min_freq=10_000, max_freq=120_000) + stft_config = STFTConfig() + crop = build_spectrogram_crop( + freq_config, stft=stft_config, samplerate=SAMPLERATE ) + assert isinstance(crop, FrequencyCrop) + + +def test_frequency_crop_no_crop_when_bounds_are_none(): + """FrequencyCrop with no bounds should return the full spectrogram.""" + config = STFTConfig() + n_fft, _ = _spec_params_from_config(config, samplerate=SAMPLERATE) + + crop = FrequencyCrop(samplerate=SAMPLERATE, n_fft=n_fft) + spec = torch.ones(n_fft // 2 + 1, 100) + cropped = crop(spec) + + assert cropped.shape == spec.shape + + +# --------------------------------------------------------------------------- +# PCEN +# --------------------------------------------------------------------------- + + +def test_pcen_output_shape_preserved(): + """PCEN should not change the shape of the spectrogram.""" + config = PcenConfig() + pcen = PCEN.from_config(config, samplerate=SAMPLERATE) + spec = torch.rand(64, 200) + 1e-6 # positive values + result = pcen(spec) + assert result.shape == spec.shape + + +def test_pcen_output_is_finite(): + """PCEN applied to a well-formed spectrogram should produce no NaN or Inf.""" + config = PcenConfig() + pcen = PCEN.from_config(config, samplerate=SAMPLERATE) + spec = torch.rand(64, 200) + 1e-6 + result = pcen(spec) + assert torch.isfinite(result).all() + + +def test_pcen_output_dtype_matches_input(): + """PCEN should return a tensor with the same dtype as the input.""" + config = PcenConfig() + pcen = PCEN.from_config(config, samplerate=SAMPLERATE) + spec = torch.rand(64, 200, dtype=torch.float32) + result = pcen(spec) + assert result.dtype == spec.dtype + + +# --------------------------------------------------------------------------- +# SpectralMeanSubtraction +# --------------------------------------------------------------------------- + + +def test_spectral_mean_subtraction_output_nonnegative(): + """SpectralMeanSubtraction clamps output to >= 0.""" + module = SpectralMeanSubtraction() + spec = torch.rand(64, 200) + result = module(spec) + assert (result >= 0).all() + + +def test_spectral_mean_subtraction_shape_preserved(): + module = SpectralMeanSubtraction() + spec = torch.rand(64, 200) + result = module(spec) + assert result.shape == spec.shape + + +def test_spectral_mean_subtraction_reduces_time_mean(): + """After subtraction the time-axis mean per bin should be <= 0 (pre-clamp).""" + module = SpectralMeanSubtraction() + # Constant spectrogram: mean subtraction should produce all zeros before clamp + spec = torch.ones(32, 100) * 3.0 + result = module(spec) + assert (result == 0).all() + + +def test_spectral_mean_subtraction_from_config(): + config = SpectralMeanSubtractionConfig() + module = SpectralMeanSubtraction.from_config(config, samplerate=SAMPLERATE) + assert isinstance(module, SpectralMeanSubtraction) + + +# --------------------------------------------------------------------------- +# PeakNormalize (spectrogram-level) +# --------------------------------------------------------------------------- + + +def test_peak_normalize_spec_max_is_one(): + """PeakNormalize should scale the spectrogram peak to 1.""" + module = PeakNormalize() + spec = torch.rand(64, 200) * 5.0 + result = module(spec) + assert abs(result.abs().max().item() - 1.0) < 1e-6 + + +def test_peak_normalize_spec_handles_zero(): + """PeakNormalize on a zero spectrogram should not raise.""" + module = PeakNormalize() + spec = torch.zeros(64, 200) + result = module(spec) + assert (result == 0).all() + + +def test_peak_normalize_from_config(): + config = PeakNormalizeConfig() + module = PeakNormalize.from_config(config, samplerate=SAMPLERATE) + assert isinstance(module, PeakNormalize) + + +# --------------------------------------------------------------------------- +# ScaleAmplitude +# --------------------------------------------------------------------------- + + +def test_scale_amplitude_db_output_is_finite(): + """AmplitudeToDB scaling should produce finite values for positive input.""" + module = ScaleAmplitude(scale="db") + spec = torch.rand(64, 200) + 1e-4 + result = module(spec) + assert torch.isfinite(result).all() + + +def test_scale_amplitude_power_output_equals_square(): + """ScaleAmplitude('power') should square every element.""" + module = ScaleAmplitude(scale="power") + spec = torch.tensor([[2.0, 3.0], [4.0, 5.0]]) + result = module(spec) + expected = spec**2 + assert torch.allclose(result, expected) + + +def test_scale_amplitude_from_config(): + config = ScaleAmplitudeConfig(scale="db") + module = ScaleAmplitude.from_config(config, samplerate=SAMPLERATE) + assert isinstance(module, ScaleAmplitude) + assert module.scale == "db" + + +# --------------------------------------------------------------------------- +# ResizeSpec +# --------------------------------------------------------------------------- + + +def test_resize_spec_output_shape(): + """ResizeSpec should produce the target height and scaled width.""" + module = ResizeSpec(height=64, time_factor=0.5) + spec = torch.rand(1, 128, 200) # (batch, freq, time) + result = module(spec) + assert result.shape == (1, 64, 100) + + +def test_resize_spec_2d_input(): + """ResizeSpec should handle 2-D input (no batch or channel dimensions).""" + module = ResizeSpec(height=64, time_factor=0.5) + spec = torch.rand(128, 200) + result = module(spec) + assert result.shape == (64, 100) + + +def test_resize_spec_output_is_finite(): + module = ResizeSpec(height=128, time_factor=0.5) + spec = torch.rand(128, 200) + result = module(spec) + assert torch.isfinite(result).all() + + +def test_resize_spec_from_config(): + config = ResizeConfig(height=64, resize_factor=0.25) + module = build_spectrogram_resizer(config) + assert isinstance(module, ResizeSpec) + assert module.height == 64 + assert module.time_factor == 0.25 + + +# --------------------------------------------------------------------------- +# build_spectrogram_transform dispatch +# --------------------------------------------------------------------------- + + +def test_build_spectrogram_transform_pcen(): + config = PcenConfig() + module = build_spectrogram_transform(config, samplerate=SAMPLERATE) + assert isinstance(module, PCEN) + + +def test_build_spectrogram_transform_spectral_mean_subtraction(): + config = SpectralMeanSubtractionConfig() + module = build_spectrogram_transform(config, samplerate=SAMPLERATE) + assert isinstance(module, SpectralMeanSubtraction) + + +def test_build_spectrogram_transform_scale_amplitude(): + config = ScaleAmplitudeConfig(scale="db") + module = build_spectrogram_transform(config, samplerate=SAMPLERATE) + assert isinstance(module, ScaleAmplitude) + + +def test_build_spectrogram_transform_peak_normalize(): + config = PeakNormalizeConfig() + module = build_spectrogram_transform(config, samplerate=SAMPLERATE) + assert isinstance(module, PeakNormalize) diff --git a/tests/test_targets/test_rois.py b/tests/test_targets/test_rois.py index 9c517c4..2d91d5e 100644 --- a/tests/test_targets/test_rois.py +++ b/tests/test_targets/test_rois.py @@ -10,7 +10,7 @@ from batdetect2.preprocess import ( ) from batdetect2.preprocess.spectrogram import ( ScaleAmplitudeConfig, - SpectralMeanSubstractionConfig, + SpectralMeanSubtractionConfig, ) from batdetect2.targets.rois import ( DEFAULT_ANCHOR, @@ -597,7 +597,7 @@ def test_build_roi_mapper_for_peak_energy_bbox(): preproc_config = PreprocessingConfig( spectrogram_transforms=[ ScaleAmplitudeConfig(scale="db"), - SpectralMeanSubstractionConfig(), + SpectralMeanSubtractionConfig(), ] ) config = PeakEnergyBBoxMapperConfig(