Preprocessing in pytorch

This commit is contained in:
mbsantiago 2025-08-25 11:41:55 +01:00
parent 61115d562c
commit 667b18a54d
15 changed files with 690 additions and 2578 deletions

View File

@ -28,13 +28,12 @@ This module provides the primary interface:
"""
from typing import Optional, Union
from typing import Optional
import numpy as np
import xarray as xr
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import (
@ -44,28 +43,23 @@ from batdetect2.preprocess.audio import (
AudioConfig,
ResampleConfig,
build_audio_loader,
build_audio_pipeline,
)
from batdetect2.preprocess.spectrogram import (
MAX_FREQ,
MIN_FREQ,
ConfigurableSpectrogramBuilder,
FrequencyConfig,
PcenConfig,
SpecSizeConfig,
SpectrogramConfig,
SpectrogramPipeline,
STFTConfig,
build_spectrogram_builder,
get_spectrogram_resolution,
)
from batdetect2.typing.preprocess import (
AudioLoader,
PreprocessorProtocol,
SpectrogramBuilder,
build_spectrogram_pipeline,
)
from batdetect2.typing import PreprocessorProtocol
__all__ = [
"AudioConfig",
"ConfigurableSpectrogramBuilder",
"DEFAULT_DURATION",
"FrequencyConfig",
"MAX_FREQ",
@ -75,16 +69,11 @@ __all__ = [
"ResampleConfig",
"SCALE_RAW_AUDIO",
"STFTConfig",
"SpecSizeConfig",
"SpectrogramConfig",
"StandardPreprocessor",
"TARGET_SAMPLERATE_HZ",
"build_audio_loader",
"build_preprocessor",
"build_spectrogram_builder",
"get_spectrogram_resolution",
"load_preprocessing_config",
"get_default_preprocessor",
]
@ -110,343 +99,61 @@ class PreprocessingConfig(BaseConfig):
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
class StandardPreprocessor(PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol.
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)
Orchestrates the audio loading and spectrogram generation pipeline using
an `AudioLoader` and a `SpectrogramBuilder` internally, which are
configured according to a `PreprocessingConfig`.
This class is typically instantiated using the `build_preprocessor`
factory function.
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol."""
Attributes
----------
audio_loader : AudioLoader
The configured audio loader instance used for waveform loading and
initial processing.
spectrogram_builder : SpectrogramBuilder
The configured spectrogram builder instance used for generating
spectrograms from waveforms.
default_samplerate : int
The sample rate (in Hz) assumed for input waveforms when they are
provided as raw NumPy arrays without coordinate information (e.g.,
when calling `compute_spectrogram` directly with `np.ndarray`).
This value is derived from the `AudioConfig` (target resample rate
or default if resampling is off) and also serves as documentation
for the pipeline's intended operating sample rate. Note that when
processing `xr.DataArray` inputs that have coordinate information
(the standard internal workflow), the sample rate embedded in the
coordinates takes precedence over this default value during
spectrogram calculation.
"""
audio_loader: AudioLoader
spectrogram_builder: SpectrogramBuilder
default_samplerate: int
samplerate: int
max_freq: float
min_freq: float
def __init__(
self,
audio_loader: AudioLoader,
spectrogram_builder: SpectrogramBuilder,
default_samplerate: int,
audio_pipeline: torch.nn.Module,
spectrogram_pipeline: SpectrogramPipeline,
samplerate: int,
max_freq: float,
min_freq: float,
) -> None:
"""Initialize the StandardPreprocessor.
Parameters
----------
audio_loader : AudioLoader
An initialized audio loader conforming to the AudioLoader protocol.
spectrogram_builder : SpectrogramBuilder
An initialized spectrogram builder conforming to the
SpectrogramBuilder protocol.
default_samplerate : int
The sample rate to assume for NumPy array inputs and potentially
reflecting the target rate of the audio config.
"""
self.audio_loader = audio_loader
self.spectrogram_builder = spectrogram_builder
self.default_samplerate = default_samplerate
super().__init__()
self.audio_pipeline = audio_pipeline
self.spectrogram_pipeline = spectrogram_pipeline
self.samplerate = samplerate
self.max_freq = max_freq
self.min_freq = min_freq
def load_file_audio(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform from a file path.
Delegates to the internal `audio_loader`.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform (typically first
channel).
"""
return self.audio_loader.load_file(
path,
audio_dir=audio_dir,
)
def load_recording_audio(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Recording.
Delegates to the internal `audio_loader`.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform (typically first
channel).
"""
return self.audio_loader.load_recording(
recording,
audio_dir=audio_dir,
)
def load_clip_audio(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Clip.
Delegates to the internal `audio_loader`.
Parameters
----------
clip : data.Clip
The Clip object defining the segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform segment (typically first
channel).
"""
return self.audio_loader.load_clip(
clip,
audio_dir=audio_dir,
)
def preprocess_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio from a file and compute the final processed spectrogram.
Performs the full pipeline:
Load -> Preprocess Audio -> Compute Spectrogram.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_file_audio(path, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def preprocess_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Recording and compute the processed spectrogram.
Performs the full pipeline for the entire duration of the recording.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_recording_audio(recording, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def preprocess_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Clip and compute the final processed spectrogram.
Performs the full pipeline for the specified clip segment.
Parameters
----------
clip : data.Clip
The Clip object defining the audio segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_clip_audio(clip, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def compute_spectrogram(
self, wav: Union[xr.DataArray, np.ndarray]
) -> xr.DataArray:
"""Compute the spectrogram from a pre-loaded audio waveform.
Applies the configured spectrogram generation steps
(STFT, scaling, etc.) using the internal `spectrogram_builder`.
If `wav` is a NumPy array, the `default_samplerate` stored in this
preprocessor instance will be used. If `wav` is an xarray DataArray
with time coordinates, the sample rate derived from those coordinates
will take precedence over `default_samplerate`.
Parameters
----------
wav : Union[xr.DataArray, np.ndarray]
The input audio waveform. If numpy array, `default_samplerate`
stored in this object will be assumed.
Returns
-------
xr.DataArray
The computed spectrogram.
"""
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def load_preprocessing_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
"""Load the unified preprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PreprocessingConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
preprocessing configuration (e.g., "train.preprocessing"). If None, the
entire file content is validated as the PreprocessingConfig.
Returns
-------
PreprocessingConfig
Loaded and validated preprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to PreprocessingConfig.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=PreprocessingConfig, field=field)
def forward(self, wav: torch.Tensor) -> torch.Tensor:
wav = self.audio_pipeline(wav)
return self.spectrogram_pipeline(wav)
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration.
Creates instances of the required `AudioLoader` and `SpectrogramBuilder`
based on the provided `PreprocessingConfig` (or defaults if config is None),
determines the effective default sample rate, and initializes the
`StandardPreprocessor`.
Parameters
----------
config : PreprocessingConfig, optional
The unified preprocessing configuration object. If None, default
configurations for audio and spectrogram processing will be used.
Returns
-------
Preprocessor
An initialized `StandardPreprocessor` instance ready to process audio
according to the configuration.
"""
"""Factory function to build the standard preprocessor from configuration."""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
default_samplerate = (
config.audio.resample.samplerate
if config.audio.resample
else TARGET_SAMPLERATE_HZ
)
samplerate = config.audio.samplerate
min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq
return StandardPreprocessor(
audio_loader=build_audio_loader(config.audio),
spectrogram_builder=build_spectrogram_builder(config.spectrogram),
default_samplerate=default_samplerate,
audio_pipeline=build_audio_pipeline(config.audio),
spectrogram_pipeline=build_spectrogram_pipeline(
samplerate, config.spectrogram
),
samplerate=samplerate,
min_freq=min_freq,
max_freq=max_freq,
)

View File

@ -1,53 +1,31 @@
"""Handles loading and initial preprocessing of audio waveforms.
"""Handles loading and initial preprocessing of audio waveforms."""
This module provides components for loading audio data associated with
`soundevent` objects (Clips, Recordings, or raw files) and applying
fundamental waveform processing steps. These steps typically include:
1. Loading the raw audio data.
2. Adjusting the audio clip to a fixed duration (optional).
3. Resampling the audio to a target sample rate (optional).
4. Centering the waveform (DC offset removal) (optional).
5. Scaling the waveform amplitude (optional).
The processing pipeline is configurable via the `AudioConfig` data structure,
allowing for reproducible preprocessing consistent between model training and
inference. It uses the `soundevent` library for audio loading and basic array
operations, and `scipy` for resampling implementations.
The primary interface is the `AudioLoader` protocol, with
`ConfigurableAudioLoader` providing a concrete implementation driven by the
`AudioConfig`.
"""
from typing import Optional
from typing import Annotated, List, Literal, Optional, Union
import numpy as np
import xarray as xr
import torch
from numpy.typing import DTypeLike
from pydantic import Field
from scipy.signal import resample, resample_poly
from soundevent import arrays, audio, data
from soundevent.arrays import operations as ops
from soundevent import audio, data
from soundfile import LibsndfileError
from batdetect2.configs import BaseConfig
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
from batdetect2.typing import AudioLoader
__all__ = [
"ResampleConfig",
"AudioConfig",
"ConfigurableAudioLoader",
"SoundEventAudioLoader",
"build_audio_loader",
"load_file_audio",
"load_recording_audio",
"load_clip_audio",
"adjust_audio_duration",
"resample_audio",
"TARGET_SAMPLERATE_HZ",
"SCALE_RAW_AUDIO",
"DEFAULT_DURATION",
"convert_to_xr",
]
TARGET_SAMPLERATE_HZ = 256_000
@ -76,192 +54,69 @@ class ResampleConfig(BaseConfig):
resampling factors differently.
"""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
enabled: bool = True
method: str = "poly"
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing.
Defines the sequence of operations applied to raw audio waveforms after
loading, controlling steps like resampling, scaling, centering, and
duration adjustment.
Attributes
----------
resample : ResampleConfig, optional
Configuration for resampling. If provided (or defaulted), audio will
be resampled to the specified `samplerate` using the specified
`method`. If set to `None` in the config file, resampling is skipped.
Defaults to a ResampleConfig instance with standard settings.
scale : bool, default=False
If True, scales the audio waveform using peak normalization so that
its maximum absolute amplitude is approximately 1.0. If False
(default), no amplitude scaling is applied.
center : bool, default=True
If True (default), centers the waveform by subtracting its mean
(DC offset removal). If False, the waveform is not centered.
duration : float, optional
If set to a float value (seconds), the loaded audio clip will be
adjusted (cropped or padded with zeros) to exactly this duration.
If None (default), the original duration is kept.
"""
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
scale: bool = SCALE_RAW_AUDIO
center: bool = False
duration: Optional[float] = DEFAULT_DURATION
class ConfigurableAudioLoader:
"""Concrete implementation of the `AudioLoader` driven by `AudioConfig`.
This class loads audio and applies preprocessing steps (resampling,
scaling, centering, duration adjustment) based on the settings provided
in an `AudioConfig` object during initialization. It delegates the actual
work to module-level functions.
"""
class SoundEventAudioLoader:
"""Concrete implementation of the `AudioLoader`."""
def __init__(
self,
config: AudioConfig,
samplerate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ResampleConfig] = None,
):
"""Initialize the ConfigurableAudioLoader.
Parameters
----------
config : AudioConfig
The configuration object specifying the desired preprocessing steps
and parameters.
"""
self.config = config
self.samplerate = samplerate
self.config = config or ResampleConfig()
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess audio directly from a file path.
Implements the `AudioLoader.load_file` method by delegating to the
`load_file_audio` function, passing the stored configuration.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
Loaded and preprocessed waveform (first channel).
"""
return load_file_audio(path, config=self.config, audio_dir=audio_dir)
) -> np.ndarray:
"""Load and preprocess audio directly from a file path."""
return load_file_audio(
path,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess the entire audio for a Recording object.
Implements the `AudioLoader.load_recording` method by delegating to the
`load_recording_audio` function, passing the stored configuration.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
Loaded and preprocessed waveform (first channel).
"""
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object."""
return load_recording_audio(
recording, config=self.config, audio_dir=audio_dir
recording,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess the audio segment defined by a Clip object.
Implements the `AudioLoader.load_clip` method by delegating to the
`load_clip_audio` function, passing the stored configuration.
Parameters
----------
clip : data.Clip
The Clip object specifying the segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
Loaded and preprocessed waveform segment (first channel).
"""
return load_clip_audio(clip, config=self.config, audio_dir=audio_dir)
def build_audio_loader(
config: AudioConfig,
) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration.
Instantiates and returns a `ConfigurableAudioLoader` initialized with
the provided `AudioConfig`. The return type is `AudioLoader`, adhering
to the protocol.
Parameters
----------
config : AudioConfig
The configuration object specifying preprocessing steps.
Returns
-------
AudioLoader
An instance of `ConfigurableAudioLoader` ready to load and process audio
according to the configuration.
"""
return ConfigurableAudioLoader(config=config)
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object."""
return load_clip_audio(
clip,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_file_audio(
path: data.PathLike,
config: Optional[AudioConfig] = None,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Load and preprocess audio from a file path using specified config.
Creates a `soundevent.data.Recording` object from the file path and then
delegates the loading and processing to `load_recording_audio`.
Parameters
----------
path : PathLike
Path to the audio file.
config : AudioConfig, optional
Audio processing configuration. If None, default settings defined
in `AudioConfig` are used.
audio_dir : PathLike, optional
Directory prefix if `path` is relative.
dtype : DTypeLike, default=np.float32
Target NumPy data type for the loaded audio array.
Returns
-------
xr.DataArray
Loaded and preprocessed waveform (first channel only).
"""
) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config."""
try:
recording = data.Recording.from_file(path)
except LibsndfileError as e:
@ -271,6 +126,7 @@ def load_file_audio(
return load_recording_audio(
recording,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
@ -279,33 +135,12 @@ def load_file_audio(
def load_recording_audio(
recording: data.Recording,
config: Optional[AudioConfig] = None,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Load and preprocess the entire audio content of a recording using config.
Creates a `soundevent.data.Clip` spanning the full duration of the
recording and then delegates the loading and processing to
`load_clip_audio`.
Parameters
----------
recording : data.Recording
The Recording object containing metadata and file path.
config : AudioConfig, optional
Audio processing configuration. If None, default settings are used.
audio_dir : PathLike, optional
Directory containing the audio file, used if the path in `recording`
is relative.
dtype : DTypeLike, default=np.float32
Target NumPy data type for the loaded audio array.
Returns
-------
xr.DataArray
Loaded and preprocessed waveform (first channel only).
"""
) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config."""
clip = data.Clip(
recording=recording,
start_time=0,
@ -313,6 +148,7 @@ def load_recording_audio(
)
return load_clip_audio(
clip,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
@ -321,56 +157,12 @@ def load_recording_audio(
def load_clip_audio(
clip: data.Clip,
config: Optional[AudioConfig] = None,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Load and preprocess a specific audio clip segment based on config.
This is the core function performing the configured processing pipeline:
1. Loads the specified clip segment using `soundevent.audio.load_clip`.
2. Selects the first audio channel.
3. Resamples if `config.resample` is configured.
4. Centers (DC offset removal) if `config.center` is True.
5. Scales (peak normalization) if `config.scale` is True.
6. Adjusts duration (crop/pad) if `config.duration` is set.
Parameters
----------
clip : data.Clip
The Clip object defining the audio segment and source recording.
config : AudioConfig, optional
Audio processing configuration. If None, a default `AudioConfig` is
used.
audio_dir : PathLike, optional
Directory containing the source audio file specified in the clip's
recording.
dtype : DTypeLike, default=np.float32
Target NumPy data type for the processed audio array.
Returns
-------
xr.DataArray
The loaded and preprocessed waveform segment as an xarray DataArray
with time coordinates.
Raises
------
FileNotFoundError
If the underlying audio file cannot be found.
Exception
If audio loading or processing fails for other reasons (e.g., invalid
format, resampling error).
Notes
-----
- **Mono Processing:** This function currently loads and processes only the
**first channel** (channel 0) of the audio file. Any other channels
are ignored.
"""
config = config or AudioConfig()
with xr.set_options(keep_attrs=True):
) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config."""
try:
wav = (
audio.load_clip(clip, audio_dir=audio_dir)
@ -383,195 +175,48 @@ def load_clip_audio(
f"Error: {e}"
) from e
if config.resample:
wav = resample_audio(
wav,
samplerate=config.resample.samplerate,
dtype=dtype,
)
if not config or not config.enabled or samplerate is None:
return wav.data.astype(dtype)
if config.center:
wav = ops.center(wav)
if config.scale:
wav = scale_audio(wav)
if config.duration is not None:
wav = adjust_audio_duration(wav, duration=config.duration)
return wav.astype(dtype)
def scale_audio(
wave: xr.DataArray,
) -> xr.DataArray:
"""
Scale the audio waveform to have a maximum absolute value of 1.0.
This function normalizes the waveform by dividing it by its maximum
absolute value. If the maximum value is zero, the waveform is returned
unchanged. Also known as peak normalization, this process ensures that the
waveform's amplitude is within a standard range, which can be useful for
audio processing and analysis.
"""
max_val = np.max(np.abs(wave))
if max_val == 0:
return wave
return ops.scale(wave, 1 / max_val)
def adjust_audio_duration(
wave: xr.DataArray,
duration: float,
) -> xr.DataArray:
"""Adjust the duration of an audio waveform array via cropping or padding.
If the current duration is longer than the target, it crops the array
from the beginning. If shorter, it pads the array with zeros at the end
using `soundevent.arrays.extend_dim`.
Parameters
----------
wave : xr.DataArray
The input audio waveform with a 'time' dimension and coordinates.
duration : float
The target duration in seconds.
Returns
-------
xr.DataArray
The waveform adjusted to the target duration. Returns the input
unmodified if duration already matches or if the wave is empty.
Raises
------
ValueError
If `duration` is negative.
"""
start_time, end_time = arrays.get_dim_range(wave, dim="time")
step = arrays.get_dim_step(wave, dim="time")
current_duration = end_time - start_time + step
if current_duration == duration:
return wave
with xr.set_options(keep_attrs=True):
if current_duration > duration:
return arrays.crop_dim(
wave,
dim="time",
start=start_time,
stop=start_time + duration - step / 2,
right_closed=True,
)
return arrays.extend_dim(
wave,
dim="time",
start=start_time,
stop=start_time + duration - step / 2,
eps=0,
right_closed=True,
sr = int(1 / wav.time.attrs["step"])
return resample_audio(
wav.data,
sr=sr,
samplerate=samplerate,
method=config.method,
)
def resample_audio(
wav: xr.DataArray,
wav: np.ndarray,
sr: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
method: str = "poly",
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Resample an audio waveform DataArray to a target sample rate.
Updates the 'time' coordinate axis according to the new sample rate and
number of samples. Uses either polyphase (`scipy.signal.resample_poly`)
or Fourier method (`scipy.signal.resample`) based on the `method`.
Parameters
----------
wav : xr.DataArray
Input audio waveform with 'time' dimension and coordinates.
samplerate : int, default=TARGET_SAMPLERATE_HZ
Target sample rate in Hz.
method : str, default="poly"
Resampling algorithm: "poly" or "fourier".
dtype : DTypeLike, default=np.float32
Target data type for the resampled array.
Returns
-------
xr.DataArray
Resampled waveform with updated time coordinates. Returns the input
unmodified (but dtype cast) if the sample rate is already correct or
if the input array is empty.
Raises
------
ValueError
If `wav` lacks a 'time' dimension, the original sample rate cannot
be determined, `samplerate` is non-positive, or `method` is invalid.
"""
if "time" not in wav.dims:
raise ValueError("Audio must have a time dimension")
time_axis: int = wav.get_axis_num("time") # type: ignore
step = arrays.get_dim_step(wav, dim="time")
original_samplerate = int(1 / step)
if original_samplerate == samplerate:
return wav.astype(dtype).assign_attrs(original_samplerate=samplerate)
) -> np.ndarray:
"""Resample an audio waveform DataArray to a target sample rate."""
if sr == samplerate:
return wav
if method == "poly":
resampled = resample_audio_poly(
return resample_audio_poly(
wav,
sr_orig=original_samplerate,
sr_orig=sr,
sr_new=samplerate,
axis=time_axis,
)
elif method == "fourier":
resampled = resample_audio_fourier(
return resample_audio_fourier(
wav,
sr_orig=original_samplerate,
sr_orig=sr,
sr_new=samplerate,
axis=time_axis,
)
else:
raise NotImplementedError(
f"Resampling method '{method}' not implemented"
)
start, stop = arrays.get_dim_range(wav, dim="time")
times = np.linspace(
start,
stop + step,
len(resampled),
endpoint=False,
dtype=dtype,
)
return xr.DataArray(
data=resampled.astype(dtype),
dims=wav.dims,
coords={
**wav.coords,
"time": arrays.create_time_dim_from_array(
times,
samplerate=samplerate,
),
},
attrs={
**wav.attrs,
"samplerate": samplerate,
"original_samplerate": original_samplerate,
},
)
def resample_audio_poly(
array: xr.DataArray,
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
@ -605,7 +250,7 @@ def resample_audio_poly(
"""
gcd = np.gcd(sr_orig, sr_new)
return resample_poly(
array.values,
array,
sr_new // gcd,
sr_orig // gcd,
axis=axis,
@ -613,7 +258,7 @@ def resample_audio_poly(
def resample_audio_fourier(
array: xr.DataArray,
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
@ -649,66 +294,89 @@ def resample_audio_fourier(
)
def convert_to_xr(
wav: np.ndarray,
class CenterAudioConfig(BaseConfig):
name: Literal["center_audio"] = "center_audio"
class ScaleAudioConfig(BaseConfig):
name: Literal["scale_audio"] = "scale_audio"
class FixDurationConfig(BaseConfig):
name: Literal["fix_duration"] = "fix_duration"
duration: float = 0.5
class FixDuration(torch.nn.Module):
def __init__(self, samplerate: int, duration: float):
super().__init__()
self.samplerate = samplerate
self.duration = duration
self.length = int(samplerate * duration)
def forward(self, wav: torch.Tensor) -> torch.Tensor:
length = wav.shape[-1]
if length == self.length:
return wav
if length > self.length:
return wav[: self.length]
return torch.nn.functional.pad(wav, (0, self.length - length))
AudioTransform = Annotated[
Union[
FixDurationConfig,
ScaleAudioConfig,
CenterAudioConfig,
],
Field(discriminator="name"),
]
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing."""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
transforms: List[AudioTransform] = Field(default_factory=list)
def build_audio_loader(
config: Optional[AudioConfig] = None,
) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration."""
config = config or AudioConfig()
return SoundEventAudioLoader(
samplerate=config.samplerate,
config=config.resample,
)
def build_audio_transform_step(
config: AudioTransform,
samplerate: int,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Convert a NumPy array to an xarray DataArray with time coordinates.
) -> torch.nn.Module:
if config.name == "fix_duration":
return FixDuration(samplerate=samplerate, duration=config.duration)
Parameters
----------
wav : np.ndarray
The input waveform array. Expected to be 1D or 2D (with the first
axis as the channel dimension).
samplerate : int
The sample rate in Hz.
dtype : DTypeLike, default=np.float32
Target data type for the xarray DataArray.
if config.name == "scale_audio":
return PeakNormalize()
Returns
-------
xr.DataArray
The waveform as an xarray DataArray with time coordinates.
if config.name == "center_audio":
return CenterTensor()
Raises
------
ValueError
If the input array is not 1D or 2D, or if the sample rate is
non-positive. If the input array is empty.
"""
if wav.ndim == 2:
wav = wav[0, :]
if wav.ndim != 1:
raise ValueError(
"Audio must be 1D array or 2D channel where the first "
"axis is the channel dimension"
raise NotImplementedError(
f"Audio preprocessing step {config.name} not implemented"
)
if wav.size == 0:
raise ValueError("Audio array is empty")
if samplerate <= 0:
raise ValueError("Sample rate must be positive")
times = np.linspace(
0,
wav.shape[0] / samplerate,
wav.shape[0],
endpoint=False,
dtype=dtype,
)
return xr.DataArray(
data=wav.astype(dtype),
dims=["time"],
coords={
"time": arrays.create_time_dim_from_array(
times,
samplerate=samplerate,
),
},
attrs={"samplerate": samplerate},
def build_audio_pipeline(config: AudioConfig) -> torch.nn.Module:
return torch.nn.Sequential(
*[
build_audio_transform_step(step, samplerate=config.samplerate)
for step in config.transforms
]
)

View File

@ -0,0 +1,24 @@
import torch
__all__ = [
"CenterTensor",
"PeakNormalize",
]
class CenterTensor(torch.nn.Module):
def forward(self, wav: torch.Tensor):
return wav - wav.mean()
class PeakNormalize(torch.nn.Module):
def forward(self, wav: torch.Tensor):
max_value = wav.abs().min()
denominator = torch.where(
max_value == 0,
torch.tensor(1.0, device=wav.device, dtype=wav.dtype),
max_value,
)
return wav / denominator

View File

@ -1,48 +1,22 @@
"""Computes spectrograms from audio waveforms with configurable parameters.
"""Computes spectrograms from audio waveforms with configurable parameters."""
This module provides the functionality to convert preprocessed audio waveforms
(typically output from the `batdetect2.preprocessing.audio` module) into
spectrogram representations suitable for input into deep learning models like
BatDetect2.
It offers a configurable pipeline including:
1. Short-Time Fourier Transform (STFT) calculation to get magnitude.
2. Frequency axis cropping to a relevant range.
3. Per-Channel Energy Normalization (PCEN) (optional).
4. Amplitude scaling/representation (dB, power, or linear amplitude).
5. Simple spectral mean subtraction denoising (optional).
6. Resizing to target dimensions (optional).
7. Final peak normalization (optional).
Configuration is managed via the `SpectrogramConfig` class, allowing for
reproducible spectrogram generation consistent between training and inference.
The core computation is performed by `compute_spectrogram`.
"""
from typing import Callable, Literal, Optional, Union
from typing import Annotated, Callable, List, Literal, Optional, Union
import numpy as np
import xarray as xr
from numpy.typing import DTypeLike
import torch
import torchaudio
from pydantic import Field
from scipy import signal
from soundevent import arrays, audio
from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.audio import convert_to_xr
from batdetect2.preprocess.common import PeakNormalize
from batdetect2.typing.preprocess import SpectrogramBuilder
__all__ = [
"STFTConfig",
"FrequencyConfig",
"SpecSizeConfig",
"PcenConfig",
"SpectrogramConfig",
"ConfigurableSpectrogramBuilder",
"build_spectrogram_builder",
"compute_spectrogram",
"get_spectrogram_resolution",
"MIN_FREQ",
"MAX_FREQ",
]
@ -79,6 +53,47 @@ class STFTConfig(BaseConfig):
window_fn: str = "hann"
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
if name == "hann":
return torch.hann_window
if name == "hamming":
return torch.hamming_window
if name == "kaiser":
return torch.kaiser_window
if name == "blackman":
return torch.blackman_window
if name == "bartlett":
return torch.bartlett_window
raise NotImplementedError(
f"Spectrogram window function {name} not implemented"
)
def _spec_params_from_config(samplerate: int, conf: STFTConfig):
n_fft = int(samplerate * conf.window_duration)
hop_length = int(n_fft * (1 - conf.window_overlap))
return n_fft, hop_length
def build_spectrogram_builder(
samplerate: int,
conf: STFTConfig,
) -> SpectrogramBuilder:
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
return torchaudio.transforms.Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
window_fn=get_spectrogram_window(conf.window_fn),
center=False,
power=1,
)
class FrequencyConfig(BaseConfig):
"""Configuration for frequency axis parameters.
@ -96,644 +111,282 @@ class FrequencyConfig(BaseConfig):
min_freq: int = Field(default=10_000, ge=0)
class SpecSizeConfig(BaseConfig):
"""Configuration for the final size and shape of the spectrogram.
def _frequency_to_index(
freq: float,
samplerate: int,
n_fft: int,
) -> Optional[int]:
alpha = freq * 2 / samplerate
height = np.floor(n_fft / 2) + 1
index = int(np.floor(alpha * height))
Attributes
----------
height : int, default=128
Target height of the spectrogram in pixels (frequency bins). The
frequency axis will be resized (e.g., via interpolation) to match this
height after frequency cropping and amplitude scaling. Must be > 0.
resize_factor : float, optional
Factor by which to resize the spectrogram along the time axis *after*
STFT calculation. A value of 0.5 halves the number of time bins,
2.0 doubles it. If None (default), no resizing along the time axis is
performed relative to the STFT output width. Must be > 0 if provided.
"""
if index <= 0:
return None
height: int = 128
resize_factor: Optional[float] = 0.5
if index >= height:
return None
return index
class FrequencyClip(torch.nn.Module):
def __init__(
self,
low_index: Optional[int] = None,
high_index: Optional[int] = None,
):
super().__init__()
self.low_index = low_index
self.high_index = high_index
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return spec[self.low_index : self.high_index]
class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN).
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
PCEN is an adaptive gain control method that can help emphasize transients
and suppress stationary noise. Applied after STFT and frequency cropping,
but before final amplitude scaling (dB, power, amplitude).
Attributes
----------
time_constant : float, default=0.4
Time constant (in seconds) for the PCEN smoothing filter. Controls
how quickly the normalization adapts to energy changes.
gain : float, default=0.98
Gain factor (alpha). Controls the adaptive gain component.
bias : float, default=2.0
Bias factor (delta). Added before the exponentiation.
power : float, default=0.5
Exponent (r). Controls the compression characteristic.
"""
time_constant: float = 0.01
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4
gain: float = 0.98
bias: float = 2
power: float = 0.5
class SpectrogramConfig(BaseConfig):
"""Unified configuration for spectrogram generation pipeline.
Aggregates settings for all steps involved in converting a preprocessed
audio waveform into a final spectrogram representation suitable for model
input.
Attributes
----------
stft : STFTConfig
Configuration for the initial Short-Time Fourier Transform.
Defaults to standard settings via `STFTConfig`.
frequencies : FrequencyConfig
Configuration for cropping the frequency range after STFT.
Defaults to standard settings via `FrequencyConfig`.
pcen : PcenConfig, optional
Configuration for applying Per-Channel Energy Normalization (PCEN). If
provided, PCEN is applied after frequency cropping. If None or omitted
(default), PCEN is skipped.
scale : Literal["dB", "amplitude", "power"], default="amplitude"
Determines the final amplitude representation *after* optional PCEN.
- "amplitude": Use linear magnitude values (output of STFT or PCEN).
- "power": Use power values (magnitude squared).
- "dB": Use logarithmic (decibel-like) scaling applied to the magnitude
(or PCEN output if enabled). Calculated as `log1p(C * S)`.
size : SpecSizeConfig, optional
Configuration for resizing the spectrogram dimensions
(frequency height, optional time width factor). Applied after PCEN and
scaling. If None (default), no resizing is performed.
spectral_mean_substraction : bool, default=True
If True (default), applies simple spectral mean subtraction denoising
*after* PCEN and amplitude scaling, but *before* resizing.
peak_normalize : bool, default=False
If True, applies a final peak normalization to the spectrogram *after*
all other steps (including resizing), scaling the overall maximum value
to 1.0. If False (default), this final normalization is skipped.
"""
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
pcen: Optional[PcenConfig] = Field(default_factory=PcenConfig)
scale: Literal["dB", "amplitude", "power"] = "amplitude"
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
spectral_mean_substraction: bool = True
peak_normalize: bool = False
class ConfigurableSpectrogramBuilder(SpectrogramBuilder):
"""Implementation of `SpectrogramBuilder` driven by `SpectrogramConfig`.
This class computes spectrograms according to the parameters specified in a
`SpectrogramConfig` object provided during initialization. It handles both
numpy array and xarray DataArray inputs for the waveform.
"""
class PCEN(torch.nn.Module):
def __init__(
self,
config: SpectrogramConfig,
dtype: DTypeLike = np.float32, # type: ignore
) -> None:
"""Initialize the ConfigurableSpectrogramBuilder.
Parameters
----------
config : SpectrogramConfig
The configuration object specifying all spectrogram parameters.
dtype : DTypeLike, default=np.float32
The target NumPy data type for the computed spectrogram array.
"""
self.config = config
smoothing_constant: float,
gain: float = 0.98,
bias: float = 2.0,
power: float = 0.5,
eps: float = 1e-6,
dtype=torch.float64,
):
super().__init__()
self.smoothing_constant = smoothing_constant
self.gain = torch.tensor(gain, dtype=dtype)
self.bias = torch.tensor(bias, dtype=dtype)
self.power = torch.tensor(power, dtype=dtype)
self.eps = torch.tensor(eps, dtype=dtype)
self.dtype = dtype
def __call__(
self,
wav: Union[np.ndarray, xr.DataArray],
samplerate: Optional[int] = None,
) -> xr.DataArray:
"""Generate a spectrogram from an audio waveform using the config.
Implements the `SpectrogramBuilder` protocol. If the input `wav` is
a numpy array, `samplerate` must be provided; the array will be
converted to an xarray DataArray internally. If `wav` is already an
xarray DataArray with time coordinates, `samplerate` is ignored.
Delegates the main computation to `compute_spectrogram`.
Parameters
----------
wav : Union[np.ndarray, xr.DataArray]
The input audio waveform.
samplerate : int, optional
The sample rate in Hz (required only if `wav` is np.ndarray).
Returns
-------
xr.DataArray
The computed spectrogram.
Raises
------
ValueError
If `wav` is np.ndarray and `samplerate` is None.
"""
if isinstance(wav, np.ndarray):
if samplerate is None:
raise ValueError(
"Samplerate must be provided when passing a numpy array."
)
wav = convert_to_xr(
wav,
samplerate=samplerate,
dtype=self.dtype,
self._b = torch.tensor([self.smoothing_constant, 0.0], dtype=dtype)
self._a = torch.tensor(
[1.0, self.smoothing_constant - 1.0], dtype=dtype
)
return compute_spectrogram(
wav,
config=self.config,
dtype=self.dtype,
def forward(self, spec: torch.Tensor) -> torch.Tensor:
S = spec.to(self.dtype) * 2**31
M = (
torchaudio.functional.lfilter(
S,
self._a,
self._b,
clamp=False,
)
).clamp(min=0)
smooth = torch.exp(
-self.gain * (torch.log(self.eps) + torch.log1p(M / self.eps))
)
def build_spectrogram_builder(
config: SpectrogramConfig,
dtype: DTypeLike = np.float32, # type: ignore
) -> SpectrogramBuilder:
"""Factory function to create a SpectrogramBuilder based on configuration.
Instantiates and returns a `ConfigurableSpectrogramBuilder` initialized
with the provided `SpectrogramConfig`.
Parameters
----------
config : SpectrogramConfig
The configuration object specifying spectrogram parameters.
dtype : DTypeLike, default=np.float32
The target NumPy data type for the computed spectrogram array.
Returns
-------
SpectrogramBuilder
An instance of `ConfigurableSpectrogramBuilder` ready to compute
spectrograms according to the configuration.
"""
return ConfigurableSpectrogramBuilder(config=config, dtype=dtype)
return (
(self.bias**self.power)
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
).to(spec.dtype)
def compute_spectrogram(
wav: xr.DataArray,
config: Optional[SpectrogramConfig] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Compute a spectrogram from a waveform using specified configurations.
Applies a sequence of operations based on the `config`:
1. Compute STFT magnitude (`stft`).
2. Crop frequency axis (`crop_spectrogram_frequencies`).
3. Apply PCEN if configured (`apply_pcen`).
4. Apply final amplitude scaling (dB, power, amplitude)
(`scale_spectrogram`).
5. Apply spectral mean subtraction denoising if enabled.
6. Resize dimensions if specified (`resize_spectrogram`).
7. Apply final peak normalization if enabled.
Parameters
----------
wav : xr.DataArray
Input audio waveform with a 'time' dimension and coordinates from
which the sample rate can be inferred.
config : SpectrogramConfig, optional
Configuration object specifying spectrogram parameters. If None,
default settings from `SpectrogramConfig` are used.
dtype : DTypeLike, default=np.float32
Target NumPy data type for the final spectrogram array.
Returns
-------
xr.DataArray
The computed and processed spectrogram with 'time' and 'frequency'
coordinates.
Raises
------
ValueError
If `wav` lacks necessary 'time' coordinates or dimensions.
"""
config = config or SpectrogramConfig()
with xr.set_options(keep_attrs=True):
spec = stft(
wav,
window_duration=config.stft.window_duration,
window_overlap=config.stft.window_overlap,
window_fn=config.stft.window_fn,
)
spec = crop_spectrogram_frequencies(
spec,
min_freq=config.frequencies.min_freq,
max_freq=config.frequencies.max_freq,
)
if config.pcen:
spec = apply_pcen(
spec,
time_constant=config.pcen.time_constant,
gain=config.pcen.gain,
power=config.pcen.power,
bias=config.pcen.bias,
)
spec = scale_spectrogram(spec, scale=config.scale)
if config.spectral_mean_substraction:
spec = remove_spectral_mean(spec)
if config.size:
spec = resize_spectrogram(
spec,
height=config.size.height,
resize_factor=config.size.resize_factor,
)
if config.peak_normalize:
spec = ops.normalize(spec)
return spec.astype(dtype)
def _compute_smoothing_constant(
samplerate: int,
time_constant: float,
) -> float:
# NOTE: These were taken to match the original implementation
hop_length = 512
sr = samplerate / 10
time_constant = time_constant
t_frames = time_constant * sr / float(hop_length)
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
def crop_spectrogram_frequencies(
spec: xr.DataArray,
min_freq: int = 10_000,
max_freq: int = 120_000,
) -> xr.DataArray:
"""Crop the frequency axis of a spectrogram to a specified range.
Uses `soundevent.arrays.crop_dim` to select the frequency bins
corresponding to the range [`min_freq`, `max_freq`].
Parameters
----------
spec : xr.DataArray
Input spectrogram with 'frequency' dimension and coordinates.
min_freq : int, default=MIN_FREQ
Minimum frequency (Hz) to keep.
max_freq : int, default=MAX_FREQ
Maximum frequency (Hz) to keep.
Returns
-------
xr.DataArray
Spectrogram cropped along the frequency axis. Preserves dtype.
"""
start_freq, end_freq = arrays.get_dim_range(spec, dim="frequency")
return arrays.crop_dim(
spec,
dim="frequency",
start=min_freq if start_freq < min_freq else None,
stop=max_freq if end_freq > max_freq else None,
).astype(spec.dtype)
class ScaleAmplitudeConfig(BaseConfig):
name: Literal["scale_amplitude"] = "scale_amplitude"
scale: Literal["power", "db"] = "db"
def stft(
wave: xr.DataArray,
window_duration: float,
window_overlap: float,
window_fn: str = "hann",
) -> xr.DataArray:
"""Compute the Short-Time Fourier Transform (STFT) magnitude spectrogram.
Calculates STFT parameters (N-FFT, hop length) based on the window
duration, overlap, and waveform sample rate. Returns an xarray DataArray
with correctly calculated 'time' and 'frequency' coordinates.
Parameters
----------
wave : xr.DataArray
Input audio waveform with 'time' coordinates.
window_duration : float
Duration of the STFT window in seconds.
window_overlap : float
Fractional overlap between consecutive windows.
window_fn : str, default="hann"
Name of the window function (e.g., "hann", "hamming").
Returns
-------
xr.DataArray
Magnitude spectrogram with 'time' and 'frequency' dimensions and
coordinates. STFT parameters are stored in the `attrs`.
Raises
------
ValueError
If sample rate cannot be determined from `wave` coordinates.
"""
if "channel" not in wave.coords:
wave = wave.assign_coords(channel=0)
return audio.compute_spectrogram(
wave,
window_size=window_duration,
hop_size=(1 - window_overlap) * window_duration,
window_type=window_fn,
scale="amplitude",
sort_dims=False,
)
def remove_spectral_mean(spec: xr.DataArray) -> xr.DataArray:
"""Apply simple spectral mean subtraction for denoising.
Subtracts the mean value of each frequency bin (calculated across time)
from that bin, then clips negative values to zero.
Parameters
----------
spec : xr.DataArray
Input spectrogram with 'time' and 'frequency' dimensions.
Returns
-------
xr.DataArray
Denoised spectrogram with the same dimensions, coordinates, and dtype.
"""
return xr.DataArray(
data=(spec - spec.mean("time")).clip(0),
dims=spec.dims,
coords=spec.coords,
attrs=spec.attrs,
)
def scale_spectrogram(
spec: xr.DataArray,
scale: Literal["dB", "power", "amplitude"],
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Apply final amplitude scaling/representation to the spectrogram.
Converts the input magnitude spectrogram based on the `scale` type:
- "dB": Applies logarithmic scaling `log10(S)`.
- "power": Squares the magnitude values `S^2`.
- "amplitude": Returns the input magnitude values `S` unchanged.
Parameters
----------
spec : xr.DataArray
Input magnitude spectrogram (potentially after PCEN).
scale : Literal["dB", "power", "amplitude"]
The target amplitude representation.
dtype : DTypeLike, default=np.float32
Target data type for the output scaled spectrogram.
Returns
-------
xr.DataArray
Spectrogram with the specified amplitude scaling applied.
"""
if scale == "dB":
return arrays.to_db(spec).astype(dtype)
if scale == "power":
class ToPower(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return spec**2
return spec
def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
if conf.scale == "db":
return torchaudio.transforms.AmplitudeToDB()
def apply_pcen(
spec: xr.DataArray,
time_constant: float = 0.4,
gain: float = 0.98,
bias: float = 2,
eps: float = 1e-6,
power: float = 0.5,
) -> xr.DataArray:
"""Apply Per-Channel Energy Normalization (PCEN) to a spectrogram.
if conf.scale == "power":
return ToPower()
Parameters
----------
spec : xr.DataArray
Input magnitude spectrogram with required attributes like
'processing_original_samplerate'.
time_constant : float, default=0.4
PCEN time constant in seconds.
gain : float, default=0.98
Gain factor (alpha).
bias : float, default=2.0
Bias factor (delta).
power : float, default=0.5
Exponent (r).
dtype : DTypeLike, default=np.float32
Target data type for the output spectrogram.
Returns
-------
xr.DataArray
PCEN-scaled spectrogram.
"""
samplerate = 1 / spec.time.attrs["step"]
hop_size = spec.attrs["hop_size"]
hop_length = int(hop_size * samplerate)
t_frames = time_constant * samplerate / hop_length
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
axis = spec.get_axis_num("time")
shape = tuple([1] * spec.ndim)
zi = np.empty(shape)
zi[:] = signal.lfilter_zi(
[smoothing_constant],
[1, smoothing_constant - 1],
)[:]
spec_data = spec.data * (2**31)
# Smooth the input array along the given axis
smoothed, _ = signal.lfilter(
[smoothing_constant],
[1, smoothing_constant - 1],
spec_data,
zi=zi,
axis=axis, # type: ignore
)
smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps)))
data = (bias**power) * np.expm1(
power * np.log1p(spec_data * smooth / bias)
)
return xr.DataArray(
data.astype(spec.dtype),
dims=spec.dims,
coords=spec.coords,
attrs=spec.attrs,
raise NotImplementedError(
f"Amplitude scaling {conf.scale} not implemented"
)
def scale_log(
spec: xr.DataArray,
dtype: DTypeLike = np.float32, # type: ignore
ref: Union[float, Callable] = np.max,
amin: float = 1e-10,
top_db: Optional[float] = 80.0,
) -> xr.DataArray:
"""Apply logarithmic scaling to a magnitude spectrogram.
Calculates `log10(S)`, where S is the input magnitude spectrogram.
Parameters
----------
spec : xr.DataArray
Input magnitude spectrogram with required attributes like
'processing_original_samplerate', 'processing_nfft'.
dtype : DTypeLike, default=np.float32
Target data type for the output spectrogram.
Returns
-------
xr.DataArray
Log-scaled spectrogram.
Raises
------
KeyError
If required attributes are missing from `spec.attrs`.
ValueError
If attributes are non-numeric or window function is invalid.
class SpectralMeanSubstractionConfig(BaseConfig):
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
Notes
-----
Implementation mainly taken from librosa `power_to_db` function
"""
class SpectralMeanSubstraction(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
mean = spec.mean(-1, keepdim=True)
return (spec - mean).clamp(min=0)
if callable(ref):
ref_value = ref(spec)
else:
ref_value = np.abs(ref)
log_spec = 10.0 * np.log10(np.maximum(amin, spec)) - np.log10(
np.maximum(amin, ref_value)
)
class ResizeConfig(BaseConfig):
name: Literal["resize_spec"] = "resize_spec"
height: int = 128
resize_factor: float = 0.5
if top_db is not None:
if top_db < 0:
raise ValueError("top_db must be non-negative")
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
return xr.DataArray(
data=log_spec.astype(dtype),
dims=spec.dims,
coords=spec.coords,
attrs=spec.attrs,
class ResizeSpec(torch.nn.Module):
def __init__(self, height: int, time_factor: float):
super().__init__()
self.height = height
self.time_factor = time_factor
def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length)
return torch.nn.functional.interpolate(
spec.unsqueeze(0).unsqueeze(0),
size=(self.height, target_length),
mode="bilinear",
)
def resize_spectrogram(
spec: xr.DataArray,
height: int = 128,
resize_factor: Optional[float] = 0.5,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Resize a spectrogram to target dimensions using interpolation.
class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize"
Resizes the frequency axis to `height` bins and optionally resizes the
time axis by `resize_factor`.
Parameters
----------
spec : xr.DataArray
Input spectrogram with 'time' and 'frequency' dimensions.
height : int, default=128
Target number of frequency bins (vertical dimension).
resize_factor : float, optional
Factor to resize the time dimension. If 1.0 or None, time dimension
is unchanged. If 0.5, time dimension is halved, etc.
SpectrogramTransform = Annotated[
Union[
PcenConfig,
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
PeakNormalizeConfig,
],
Field(discriminator="name"),
]
Returns
-------
xr.DataArray
Resized spectrogram. Coordinates are typically adjusted by the
underlying resize operation if implemented in `ops.resize`.
The dtype is currently hardcoded to float32 by ops.resize call.
"""
resize_factor = resize_factor or 1
current_width = spec.sizes["time"]
target_sizes = {
"time": int(current_width * resize_factor),
"frequency": height,
}
new_coords = {}
for dim in ["time", "frequency"]:
step = arrays.get_dim_step(spec, dim)
start, stop = arrays.get_dim_range(spec, dim)
new_coords[dim] = arrays.create_range_dim(
name=dim,
start=start,
stop=stop + step,
size=target_sizes[dim],
dtype=dtype,
)
return spec.interp(
coords=new_coords, method="linear", kwargs=dict(fill_value=0)
class SpectrogramConfig(BaseConfig):
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
transforms: List[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),
]
)
def get_spectrogram_resolution(
config: SpectrogramConfig,
) -> tuple[float, float]:
"""Calculate the approximate resolution of the final spectrogram.
Computes the width of each frequency bin (Hz/bin) and the duration
of each time bin (seconds/bin) based on the configuration parameters.
Parameters
----------
config : SpectrogramConfig
The spectrogram configuration object.
samplerate : int, optional
The sample rate of the audio *before* STFT. Required if needed to
calculate hop duration accurately from STFT config, but the current
implementation calculates hop_duration directly from STFT config times.
Returns
-------
Tuple[float, float]
A tuple containing:
- frequency_resolution (float): Approximate Hz per frequency bin.
- time_resolution (float): Approximate seconds per time bin.
Raises
------
ValueError
If required configuration fields (like `config.size`) are missing
or invalid.
"""
max_freq = config.frequencies.max_freq
min_freq = config.frequencies.min_freq
if config.size is None:
raise ValueError("Spectrogram size configuration is required.")
spec_height = config.size.height
resize_factor = config.size.resize_factor or 1
freq_bin_width = (max_freq - min_freq) / spec_height
hop_duration = config.stft.window_duration * (
1 - config.stft.window_overlap
def _build_spectrogram_transform_step(
step: SpectrogramTransform,
samplerate: int,
) -> torch.nn.Module:
if step.name == "pcen":
return PCEN(
smoothing_constant=_compute_smoothing_constant(
samplerate=samplerate,
time_constant=step.time_constant,
),
gain=step.gain,
bias=step.bias,
power=step.power,
)
if step.name == "scale_amplitude":
return _build_amplitude_scaler(step)
if step.name == "spectral_mean_substraction":
return SpectralMeanSubstraction()
if step.name == "peak_normalize":
return PeakNormalize()
raise NotImplementedError(
f"Spectrogram preprocessing step {step.name} not implemented"
)
def build_spectrogram_transform(
samplerate: int,
conf: SpectrogramConfig,
) -> torch.nn.Module:
return torch.nn.Sequential(
*[
_build_spectrogram_transform_step(step, samplerate=samplerate)
for step in conf.transforms
]
)
class SpectrogramPipeline(torch.nn.Module):
def __init__(
self,
spec_builder: SpectrogramBuilder,
freq_cutter: torch.nn.Module,
transforms: torch.nn.Module,
resizer: torch.nn.Module,
):
super().__init__()
self.spec_builder = spec_builder
self.freq_cutter = freq_cutter
self.transforms = transforms
self.resizer = resizer
def forward(self, wav: torch.Tensor) -> torch.Tensor:
spec = self.spec_builder(wav)
spec = self.freq_cutter(spec)
spec = self.transforms(spec)
return self.resizer(spec)
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
return self.spec_builder(wav)
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor:
return self.freq_cutter(spec)
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
return self.transforms(spec)
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
return self.resizer(spec)
def build_spectrogram_pipeline(
samplerate: int,
conf: SpectrogramConfig,
) -> SpectrogramPipeline:
spec_builder = build_spectrogram_builder(samplerate, conf.stft)
n_fft, _ = _spec_params_from_config(samplerate, conf.stft)
cutter = FrequencyClip(
low_index=_frequency_to_index(
conf.frequencies.min_freq, samplerate, n_fft
),
high_index=_frequency_to_index(
conf.frequencies.max_freq, samplerate, n_fft
),
)
transforms = build_spectrogram_transform(samplerate, conf)
resizer = ResizeSpec(
height=conf.size.height,
time_factor=conf.size.resize_factor,
)
return SpectrogramPipeline(
spec_builder=spec_builder,
freq_cutter=cutter,
transforms=transforms,
resizer=resizer,
)
return freq_bin_width, hop_duration / resize_factor

View File

@ -28,8 +28,10 @@ from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.targets import Position, Size
from batdetect2.utils.arrays import spec_to_xarray
__all__ = [
"Anchor",
@ -365,6 +367,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
def __init__(
self,
preprocessor: PreprocessorProtocol,
audio_loader: AudioLoader,
time_scale: float = DEFAULT_TIME_SCALE,
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
loading_buffer: float = 0.01,
@ -383,6 +386,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
Buffer in seconds to add when loading audio clips.
"""
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.time_scale = time_scale
self.frequency_scale = frequency_scale
self.loading_buffer = loading_buffer
@ -422,6 +426,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
time, freq = get_peak_energy_coordinates(
recording=sound_event.recording,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
start_time=start_time,
end_time=end_time,
@ -511,8 +516,10 @@ def build_roi_mapper(
if config.name == "peak_energy_bbox":
preprocessor = build_preprocessor(config.preprocessing)
audio_loader = build_audio_loader(config.preprocessing.audio)
return PeakEnergyBBoxMapper(
preprocessor=preprocessor,
audio_loader=audio_loader,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
loading_buffer=config.loading_buffer,
@ -617,6 +624,7 @@ def _build_bounding_box(
def get_peak_energy_coordinates(
recording: data.Recording,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
start_time: float = 0,
end_time: Optional[float] = None,
@ -669,7 +677,15 @@ def get_peak_energy_coordinates(
end_time=clip_end,
)
spec = preprocessor.preprocess_clip(clip)
wav = audio_loader.load_clip(clip)
spec = preprocessor.process_numpy(wav)
spec = spec_to_xarray(
spec,
clip.start_time,
clip.end_time,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
low_freq = max(low_freq, preprocessor.min_freq)
high_freq = min(high_freq, preprocessor.max_freq)
selection = spec.sel(

View File

@ -129,9 +129,7 @@ def mix_examples(
with xr.set_options(keep_attrs=True):
combined = weight * audio1 + (1 - weight) * audio2
spectrogram = preprocessor.compute_spectrogram(
combined.rename({"audio_time": "time"})
).data
spectrogram = preprocessor.process_numpy(combined.data)
# NOTE: The subclip's spectrogram might be slightly longer than the
# spectrogram computed from the subclip's audio. This is due to a
@ -241,9 +239,7 @@ def add_echo(
with xr.set_options(keep_attrs=True):
audio = audio + weight * audio_delay
spectrogram = preprocessor.compute_spectrogram(
audio.rename({"audio_time": "time"}),
).data
spectrogram = preprocessor.process_numpy(audio.data)
# NOTE: The subclip's spectrogram might be slightly longer than the
# spectrogram computed from the subclip's audio. This is due to a

View File

@ -21,10 +21,12 @@ class ClipingConfig(BaseConfig):
class Clipper(ClipperProtocol):
def __init__(
self,
samplerate: int,
duration: float = 0.5,
max_empty: float = 0.2,
random: bool = True,
):
self.samplerate = samplerate
self.duration = duration
self.random = random
self.max_empty = max_empty

View File

@ -25,6 +25,8 @@ from multiprocessing import Pool
from pathlib import Path
from typing import Callable, Optional, Sequence
import numpy as np
import torch
import xarray as xr
from loguru import logger
from pydantic import Field
@ -34,9 +36,12 @@ from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.train.labels import LabelConfig, build_clip_labeler
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.utils.arrays import audio_to_xarray
__all__ = [
"preprocess_annotations",
@ -76,6 +81,7 @@ def preprocess_dataset(
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
labeller = build_clip_labeler(targets, config=config.labels)
audio_loader = build_audio_loader(config=config.preprocess.audio)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
@ -84,6 +90,7 @@ def preprocess_dataset(
preprocess_annotations(
dataset,
output_dir=output,
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
@ -93,6 +100,7 @@ def preprocess_dataset(
def generate_train_example(
clip_annotation: data.ClipAnnotation,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
) -> xr.Dataset:
@ -140,9 +148,15 @@ def generate_train_example(
- The original `ClipAnnotation` metadata is stored as a JSON string in the
Dataset's attributes for provenance.
"""
wave = preprocessor.load_clip_audio(clip_annotation.clip)
wave = audio_loader.load_clip(clip_annotation.clip)
spectrogram = preprocessor.compute_spectrogram(wave)
spectrogram = _spec_to_xr(
preprocessor(torch.tensor(wave)),
start_time=clip_annotation.clip.start_time,
end_time=clip_annotation.clip.end_time,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
heatmaps = labeller(clip_annotation, spectrogram)
@ -152,7 +166,12 @@ def generate_train_example(
# the spectrogram time dimension, otherwise xarray will interpolate
# the spectrogram and the heatmaps to the same temporal resolution
# as the waveform.
"audio": wave.rename({"time": "audio_time"}),
"audio": audio_to_xarray(
wave,
start_time=clip_annotation.clip.start_time,
end_time=clip_annotation.clip.end_time,
time_axis="audio_time",
),
"spectrogram": spectrogram,
"detection": heatmaps.detection,
"class": heatmaps.classes,
@ -170,6 +189,32 @@ def generate_train_example(
)
def _spec_to_xr(
spec: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
) -> xr.DataArray:
data = spec.numpy()[0, 0]
height, width = data.shape
return xr.DataArray(
data=data,
dims=[
"frequency",
"time",
],
coords={
"frequency": np.linspace(
min_freq, max_freq, height, endpoint=False
),
"time": np.linspace(start_time, end_time, width, endpoint=False),
},
)
def _save_xr_dataset_to_file(
dataset: xr.Dataset,
path: data.PathLike,
@ -206,6 +251,7 @@ def preprocess_annotations(
clip_annotations: Sequence[data.ClipAnnotation],
output_dir: data.PathLike,
preprocessor: PreprocessorProtocol,
audio_loader: AudioLoader,
labeller: ClipLabeller,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
@ -275,6 +321,7 @@ def preprocess_annotations(
output_dir=output_dir,
filename_fn=filename_fn,
replace=replace,
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
),
@ -290,6 +337,7 @@ def preprocess_annotations(
def preprocess_single_annotation(
clip_annotation: data.ClipAnnotation,
output_dir: data.PathLike,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
filename_fn: FilenameFn = _get_filename,
@ -335,6 +383,7 @@ def preprocess_single_annotation(
try:
sample = generate_train_example(
clip_annotation,
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
)

View File

@ -10,10 +10,10 @@ pipeline can interact consistently, regardless of the specific underlying
implementation (e.g., different libraries or custom configurations).
"""
from typing import Optional, Protocol, Union
from typing import Optional, Protocol
import numpy as np
import xarray as xr
import torch
from soundevent import data
__all__ = [
@ -36,7 +36,7 @@ class AudioLoader(Protocol):
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
) -> np.ndarray:
"""Load and preprocess audio directly from a file path.
Parameters
@ -46,12 +46,6 @@ class AudioLoader(Protocol):
audio_dir : PathLike, optional
A directory prefix to prepend to the path if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform as an xarray DataArray
with time coordinates. Typically loads only the first channel.
Raises
------
FileNotFoundError
@ -65,7 +59,7 @@ class AudioLoader(Protocol):
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object.
Parameters
@ -95,7 +89,7 @@ class AudioLoader(Protocol):
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object.
Parameters
@ -124,264 +118,41 @@ class AudioLoader(Protocol):
class SpectrogramBuilder(Protocol):
"""Defines the interface for a spectrogram generation component.
"""Defines the interface for a spectrogram generation component."""
A SpectrogramBuilder takes a waveform (as numpy array or xarray DataArray)
and produces a spectrogram (as an xarray DataArray) based on its internal
configuration or implementation.
"""
def __call__(
self,
wav: Union[np.ndarray, xr.DataArray],
samplerate: Optional[int] = None,
) -> xr.DataArray:
"""Generate a spectrogram from an audio waveform.
Parameters
----------
wav : Union[np.ndarray, xr.DataArray]
The input audio waveform. If a numpy array, `samplerate` must
also be provided. If an xarray DataArray, it must have a 'time'
coordinate from which the sample rate can be inferred.
samplerate : int, optional
The sample rate of the audio in Hz. Required if `wav` is a
numpy array. If `wav` is an xarray DataArray, this parameter is
ignored as the sample rate is derived from the coordinates.
Returns
-------
xr.DataArray
The computed spectrogram as an xarray DataArray with 'time' and
'frequency' coordinates.
Raises
------
ValueError
If `wav` is a numpy array and `samplerate` is not provided, or
if `wav` is an xarray DataArray without a valid 'time' coordinate.
"""
def __call__(self, wav: torch.Tensor) -> torch.Tensor:
"""Generate a spectrogram from an audio waveform."""
...
class PreprocessorProtocol(Protocol):
"""Defines a high-level interface for the complete preprocessing pipeline.
class AudioPipeline(Protocol):
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
A Preprocessor combines audio loading and spectrogram generation steps.
It provides methods to go directly from source descriptions (file paths,
Recording objects, Clip objects) to the final spectrogram representation
needed by the model. It may also expose intermediate steps like audio
loading or spectrogram computation from a waveform.
"""
class SpectrogramPipeline(Protocol):
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: ...
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
class PreprocessorProtocol(Protocol):
"""Defines a high-level interface for the complete preprocessing pipeline."""
max_freq: float
min_freq: float
def preprocess_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio from a file and compute the final processed spectrogram.
audio_pipeline: AudioPipeline
Performs the full pipeline:
spectrogram_pipeline: SpectrogramPipeline
Load -> Preprocess Audio -> Compute Spectrogram.
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The final processed spectrogram.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If any step in the loading or preprocessing fails.
"""
...
def preprocess_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Recording and compute the processed spectrogram.
Performs the full pipeline for the entire duration of the recording.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If any step in the loading or preprocessing fails.
"""
...
def preprocess_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Clip and compute the final processed spectrogram.
Performs the full pipeline for the specified clip segment.
Parameters
----------
clip : data.Clip
The Clip object defining the audio segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If any step in the loading or preprocessing fails.
"""
...
def load_file_audio(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform from a file path.
Performs the initial audio loading and waveform processing steps
(like resampling, scaling), but stops *before* spectrogram generation.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform.
Raises
------
FileNotFoundError, Exception
If audio loading/preprocessing fails.
"""
...
def load_recording_audio(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Recording.
Performs the initial audio loading and waveform processing steps
for the entire recording duration.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform.
Raises
------
FileNotFoundError, Exception
If audio loading/preprocessing fails.
"""
...
def load_clip_audio(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Clip.
Performs the initial audio loading and waveform processing steps
for the specified clip segment.
Parameters
----------
clip : data.Clip
The Clip object defining the segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform segment.
Raises
------
FileNotFoundError, Exception
If audio loading/preprocessing fails.
"""
...
def compute_spectrogram(
self,
wav: Union[xr.DataArray, np.ndarray],
) -> xr.DataArray:
"""Compute the spectrogram from a pre-loaded audio waveform.
Applies the spectrogram generation steps (STFT, scaling, etc.) defined
by the `SpectrogramBuilder` component of the preprocessor to an
already loaded (and potentially preprocessed) waveform.
Parameters
----------
wav : Union[xr.DataArray, np.ndarray]
The input audio waveform. If numpy array, `samplerate` is required.
samplerate : int, optional
Sample rate in Hz (required if `wav` is np.ndarray).
Returns
-------
xr.DataArray
The computed spectrogram.
Raises
------
ValueError, Exception
If waveform input is invalid or spectrogram computation fails.
"""
...
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
return self(torch.tensor(wav)).numpy()[0, 0]

View File

@ -2,6 +2,62 @@ import numpy as np
import xarray as xr
def spec_to_xarray(
spec: np.ndarray,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
) -> xr.DataArray:
if spec.ndim != 2:
raise ValueError(
"Input numpy spectrogram array should be 2-dimensional"
)
height, width = spec.shape
return xr.DataArray(
data=spec,
dims=["frequency", "time"],
coords={
"frequency": np.linspace(
min_freq,
max_freq,
height,
endpoint=False,
),
"time": np.linspace(
start_time,
end_time,
width,
endpoint=False,
),
},
)
def audio_to_xarray(
wav: np.ndarray,
start_time: float,
end_time: float,
time_axis: str = "time",
) -> xr.DataArray:
if wav.ndim != 1:
raise ValueError("Input numpy audio array should be 1-dimensional")
return xr.DataArray(
data=wav,
dims=[time_axis],
coords={
time_axis: np.linspace(
start_time,
end_time,
len(wav),
endpoint=False,
),
},
)
def extend_width(
array: np.ndarray,
extra: int,

View File

@ -12,6 +12,7 @@ from soundevent import data, terms
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.targets import (
TargetConfig,
TermRegistry,
@ -27,6 +28,7 @@ from batdetect2.typing import (
PreprocessorProtocol,
TargetProtocol,
)
from batdetect2.typing.preprocess import AudioLoader
@pytest.fixture
@ -368,6 +370,11 @@ def sample_preprocessor() -> PreprocessorProtocol:
return build_preprocessor()
@pytest.fixture
def sample_audio_loader() -> AudioLoader:
return build_audio_loader()
@pytest.fixture
def bat_tag() -> TagInfo:
return TagInfo(key="class", value="bat")

View File

@ -1,13 +1,10 @@
import pathlib
import uuid
from pathlib import Path
import numpy as np
import pytest
import soundfile as sf
import xarray as xr
from soundevent import data
from soundevent.arrays import Dimensions, create_time_dim_from_array
from batdetect2.preprocess import audio
@ -30,44 +27,6 @@ def create_dummy_wave(
return wave.astype(dtype)
def create_xr_wave(
samplerate: int,
duration: float,
num_channels: int = 1,
freq: float = 440.0,
amplitude: float = 0.5,
start_time: float = 0.0,
) -> xr.DataArray:
"""Generates a simple xarray waveform."""
num_samples = int(samplerate * duration)
times = np.linspace(
start_time,
start_time + duration,
num_samples,
endpoint=False,
)
coords = {
Dimensions.time.value: create_time_dim_from_array(
times, samplerate=samplerate, start_time=start_time
)
}
dims = [Dimensions.time.value]
wave_data = amplitude * np.sin(2 * np.pi * freq * times)
if num_channels > 1:
coords[Dimensions.channel.value] = np.arange(num_channels)
dims = [Dimensions.channel.value] + dims
wave_data = np.stack([wave_data] * num_channels, axis=0)
return xr.DataArray(
wave_data.astype(np.float32),
coords=coords,
dims=dims,
attrs={"samplerate": samplerate},
)
@pytest.fixture
def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
"""Creates a dummy WAV file and returns its path."""
@ -99,408 +58,3 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
@pytest.fixture
def default_audio_config() -> audio.AudioConfig:
return audio.AudioConfig()
@pytest.fixture
def no_resample_config() -> audio.AudioConfig:
return audio.AudioConfig(resample=None)
@pytest.fixture
def fixed_duration_config() -> audio.AudioConfig:
return audio.AudioConfig(duration=0.5)
@pytest.fixture
def scale_config() -> audio.AudioConfig:
return audio.AudioConfig(scale=True, center=False)
@pytest.fixture
def no_center_config() -> audio.AudioConfig:
return audio.AudioConfig(center=False)
@pytest.fixture
def resample_fourier_config() -> audio.AudioConfig:
return audio.AudioConfig(
resample=audio.ResampleConfig(
samplerate=audio.TARGET_SAMPLERATE_HZ // 2, method="fourier"
)
)
def test_resample_config_defaults():
config = audio.ResampleConfig()
assert config.samplerate == audio.TARGET_SAMPLERATE_HZ
assert config.method == "poly"
def test_audio_config_defaults():
config = audio.AudioConfig()
assert config.resample is not None
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
assert config.resample.method == "poly"
assert config.scale == audio.SCALE_RAW_AUDIO
assert config.center is False
assert config.duration == audio.DEFAULT_DURATION
def test_audio_config_override():
resample_cfg = audio.ResampleConfig(samplerate=44100, method="fourier")
config = audio.AudioConfig(
resample=resample_cfg,
scale=True,
center=False,
duration=1.0,
)
assert config.resample == resample_cfg
assert config.scale is True
assert config.center is False
assert config.duration == 1.0
def test_audio_config_no_resample():
config = audio.AudioConfig(resample=None)
assert config.resample is None
@pytest.mark.parametrize(
"orig_sr, orig_dur, target_dur",
[
(256_000, 1.0, 0.5),
(256_000, 0.5, 1.0),
(256_000, 1.0, 1.0),
],
)
def test_adjust_audio_duration(orig_sr, orig_dur, target_dur):
wave = create_xr_wave(samplerate=orig_sr, duration=orig_dur)
adjusted_wave = audio.adjust_audio_duration(wave, duration=target_dur)
expected_samples = int(target_dur * orig_sr)
assert adjusted_wave.sizes["time"] == expected_samples
assert adjusted_wave.coords["time"].attrs["step"] == 1 / orig_sr
assert adjusted_wave.dtype == wave.dtype
if orig_dur > 0 and target_dur > orig_dur:
padding_start_index = int(orig_dur * orig_sr) + 1
assert np.all(adjusted_wave.values[padding_start_index:] == 0)
def test_adjust_audio_duration_negative_target_raises():
wave = create_xr_wave(1000, 1.0)
with pytest.raises(ValueError):
audio.adjust_audio_duration(wave, duration=-0.5)
@pytest.mark.parametrize(
"orig_sr, target_sr, mode",
[
(48000, 96000, "poly"),
(96000, 48000, "poly"),
(48000, 96000, "fourier"),
(96000, 48000, "fourier"),
(48000, 44100, "poly"),
(48000, 44100, "fourier"),
],
)
def test_resample_audio(orig_sr, target_sr, mode):
duration = 0.1
wave = create_xr_wave(orig_sr, duration)
resampled_wave = audio.resample_audio(
wave, samplerate=target_sr, method=mode, dtype=np.float32
)
expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr))
assert resampled_wave.sizes["time"] == expected_samples
assert resampled_wave.coords["time"].attrs["step"] == 1 / target_sr
assert np.isclose(
resampled_wave.coords["time"].values[-1]
- resampled_wave.coords["time"].values[0],
duration,
atol=2 / target_sr,
)
assert resampled_wave.dtype == np.float32
def test_resample_audio_same_samplerate():
sr = 48000
duration = 0.1
wave = create_xr_wave(sr, duration)
resampled_wave = audio.resample_audio(
wave, samplerate=sr, dtype=np.float64
)
xr.testing.assert_equal(wave.astype(np.float64), resampled_wave)
def test_resample_audio_invalid_mode_raises():
wave = create_xr_wave(48000, 0.1)
with pytest.raises(NotImplementedError):
audio.resample_audio(wave, samplerate=96000, method="invalid_mode")
def test_resample_audio_no_time_dim_raises():
wave = xr.DataArray(np.random.rand(100), dims=["samples"])
with pytest.raises(ValueError, match="Audio must have a time dimension"):
audio.resample_audio(wave, samplerate=96000)
def test_load_clip_audio_default_config(
dummy_clip: data.Clip,
default_audio_config: audio.AudioConfig,
tmp_path: Path,
):
assert default_audio_config.resample is not None
target_sr = default_audio_config.resample.samplerate
orig_duration = dummy_clip.duration
expected_samples = int(orig_duration * target_sr)
wav = audio.load_clip_audio(
dummy_clip, config=default_audio_config, audio_dir=tmp_path
)
assert isinstance(wav, xr.DataArray)
assert wav.dims == ("time",)
assert wav.sizes["time"] == expected_samples
assert wav.coords["time"].attrs["step"] == 1 / target_sr
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
assert wav.dtype == np.float32
def test_load_clip_audio_no_resample(
dummy_clip: data.Clip,
no_resample_config: audio.AudioConfig,
tmp_path: Path,
):
orig_sr = dummy_clip.recording.samplerate
orig_duration = dummy_clip.duration
expected_samples = int(orig_duration * orig_sr)
wav = audio.load_clip_audio(
dummy_clip, config=no_resample_config, audio_dir=tmp_path
)
assert wav.coords["time"].attrs["step"] == 1 / orig_sr
assert wav.sizes["time"] == expected_samples
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
def test_load_clip_audio_fixed_duration_crop(
dummy_clip: data.Clip,
fixed_duration_config: audio.AudioConfig,
tmp_path: Path,
):
target_sr = audio.TARGET_SAMPLERATE_HZ
target_duration = fixed_duration_config.duration
assert target_duration is not None
expected_samples = int(target_duration * target_sr)
assert dummy_clip.duration > target_duration
wav = audio.load_clip_audio(
dummy_clip, config=fixed_duration_config, audio_dir=tmp_path
)
assert wav.coords["time"].attrs["step"] == 1 / target_sr
assert wav.sizes["time"] == expected_samples
def test_load_clip_audio_fixed_duration_pad(
dummy_clip: data.Clip,
tmp_path: Path,
):
target_duration = dummy_clip.duration * 2
config = audio.AudioConfig(duration=target_duration)
assert config.resample is not None
target_sr = config.resample.samplerate
expected_samples = int(target_duration * target_sr)
wav = audio.load_clip_audio(dummy_clip, config=config, audio_dir=tmp_path)
assert wav.coords["time"].attrs["step"] == 1 / target_sr
assert wav.sizes["time"] == expected_samples
original_samples_after_resample = int(dummy_clip.duration * target_sr)
assert np.allclose(
wav.values[original_samples_after_resample:], 0.0, atol=1e-6
)
def test_load_clip_audio_scale(
dummy_clip: data.Clip, scale_config: audio.AudioConfig, tmp_path
):
wav = audio.load_clip_audio(
dummy_clip,
config=scale_config,
audio_dir=tmp_path,
)
assert np.isclose(np.max(np.abs(wav.values)), 1.0, atol=1e-5)
def test_load_clip_audio_no_center(
dummy_clip: data.Clip, no_center_config: audio.AudioConfig, tmp_path
):
wav = audio.load_clip_audio(
dummy_clip, config=no_center_config, audio_dir=tmp_path
)
raw_wav, _ = sf.read(
dummy_clip.recording.path,
start=int(dummy_clip.start_time * dummy_clip.recording.samplerate),
stop=int(dummy_clip.end_time * dummy_clip.recording.samplerate),
dtype=np.float32, # type: ignore
)
raw_wav_mono = raw_wav[:, 0]
if not np.isclose(raw_wav_mono.mean(), 0.0, atol=1e-7):
assert not np.isclose(wav.mean(), 0.0, atol=1e-6)
def test_load_clip_audio_resample_fourier(
dummy_clip: data.Clip, resample_fourier_config: audio.AudioConfig, tmp_path
):
assert resample_fourier_config.resample is not None
target_sr = resample_fourier_config.resample.samplerate
orig_duration = dummy_clip.duration
expected_samples = int(orig_duration * target_sr)
wav = audio.load_clip_audio(
dummy_clip, config=resample_fourier_config, audio_dir=tmp_path
)
assert wav.coords["time"].attrs["step"] == 1 / target_sr
assert wav.sizes["time"] == expected_samples
def test_load_clip_audio_dtype(
dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path
):
wav = audio.load_clip_audio(
dummy_clip,
config=default_audio_config,
audio_dir=tmp_path,
dtype=np.float64,
)
assert wav.dtype == np.float64
def test_load_clip_audio_file_not_found(
dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path
):
non_existent_path = tmp_path / "not_a_real_file.wav"
dummy_clip.recording = data.Recording(
path=non_existent_path,
duration=1,
channels=1,
samplerate=256000,
)
with pytest.raises(FileNotFoundError):
audio.load_clip_audio(
dummy_clip, config=default_audio_config, audio_dir=tmp_path
)
def test_load_recording_audio(
dummy_recording: data.Recording,
default_audio_config: audio.AudioConfig,
tmp_path,
):
assert default_audio_config.resample is not None
target_sr = default_audio_config.resample.samplerate
orig_duration = dummy_recording.duration
expected_samples = int(orig_duration * target_sr)
wav = audio.load_recording_audio(
dummy_recording, config=default_audio_config, audio_dir=tmp_path
)
assert isinstance(wav, xr.DataArray)
assert wav.dims == ("time",)
assert wav.coords["time"].attrs["step"] == 1 / target_sr
assert wav.sizes["time"] == expected_samples
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
assert wav.dtype == np.float32
def test_load_recording_audio_file_not_found(
dummy_recording: data.Recording,
default_audio_config: audio.AudioConfig,
tmp_path,
):
non_existent_path = tmp_path / "not_a_real_file.wav"
dummy_recording = data.Recording(
path=non_existent_path,
duration=1,
channels=1,
samplerate=256000,
)
with pytest.raises(FileNotFoundError):
audio.load_recording_audio(
dummy_recording, config=default_audio_config, audio_dir=tmp_path
)
def test_load_file_audio(
dummy_wav_path: pathlib.Path,
default_audio_config: audio.AudioConfig,
tmp_path,
):
info = sf.info(dummy_wav_path)
orig_duration = info.duration
assert default_audio_config.resample is not None
target_sr = default_audio_config.resample.samplerate
expected_samples = int(orig_duration * target_sr)
wav = audio.load_file_audio(
dummy_wav_path, config=default_audio_config, audio_dir=tmp_path
)
assert isinstance(wav, xr.DataArray)
assert wav.dims == ("time",)
assert wav.coords["time"].attrs["step"] == 1 / target_sr
assert wav.sizes["time"] == expected_samples
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
assert wav.dtype == np.float32
def test_load_file_audio_file_not_found(
default_audio_config: audio.AudioConfig, tmp_path
):
non_existent_path = tmp_path / "not_a_real_file.wav"
with pytest.raises(FileNotFoundError):
audio.load_file_audio(
non_existent_path, config=default_audio_config, audio_dir=tmp_path
)
def test_build_audio_loader(default_audio_config: audio.AudioConfig):
loader = audio.build_audio_loader(config=default_audio_config)
assert isinstance(loader, audio.ConfigurableAudioLoader)
assert loader.config == default_audio_config
def test_configurable_audio_loader_methods(
default_audio_config: audio.AudioConfig,
dummy_wav_path: pathlib.Path,
dummy_recording: data.Recording,
dummy_clip: data.Clip,
tmp_path,
):
loader = audio.build_audio_loader(config=default_audio_config)
expected_wav_file = audio.load_file_audio(
dummy_wav_path, config=default_audio_config, audio_dir=tmp_path
)
loaded_wav_file = loader.load_file(dummy_wav_path, audio_dir=tmp_path)
xr.testing.assert_equal(expected_wav_file, loaded_wav_file)
expected_wav_rec = audio.load_recording_audio(
dummy_recording, config=default_audio_config, audio_dir=tmp_path
)
loaded_wav_rec = loader.load_recording(dummy_recording, audio_dir=tmp_path)
xr.testing.assert_equal(expected_wav_rec, loaded_wav_rec)
expected_wav_clip = audio.load_clip_audio(
dummy_clip, config=default_audio_config, audio_dir=tmp_path
)
loaded_wav_clip = loader.load_clip(dummy_clip, audio_dir=tmp_path)
xr.testing.assert_equal(expected_wav_clip, loaded_wav_clip)

View File

@ -1,32 +1,7 @@
import math
from pathlib import Path
from typing import Callable, Union
import numpy as np
import pytest
import xarray as xr
from soundevent import arrays
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
from batdetect2.preprocess.spectrogram import (
MAX_FREQ,
MIN_FREQ,
ConfigurableSpectrogramBuilder,
FrequencyConfig,
PcenConfig,
SpecSizeConfig,
SpectrogramConfig,
STFTConfig,
apply_pcen,
build_spectrogram_builder,
compute_spectrogram,
crop_spectrogram_frequencies,
get_spectrogram_resolution,
remove_spectral_mean,
resize_spectrogram,
scale_spectrogram,
stft,
)
SAMPLERATE = 250_000
DURATION = 0.1
@ -61,389 +36,3 @@ def constant_wave_xr() -> xr.DataArray:
dims=["time"],
attrs={"samplerate": SAMPLERATE},
)
@pytest.fixture
def sample_spec(sine_wave_xr: xr.DataArray) -> xr.DataArray:
"""Generate a basic spectrogram for testing downstream functions."""
config = SpectrogramConfig(
stft=STFTConfig(window_duration=0.002, window_overlap=0.5),
frequencies=FrequencyConfig(
min_freq=0,
max_freq=int(SAMPLERATE / 2),
),
size=None,
pcen=None,
spectral_mean_substraction=False,
peak_normalize=False,
scale="amplitude",
)
spec = stft(
sine_wave_xr,
window_duration=config.stft.window_duration,
window_overlap=config.stft.window_overlap,
window_fn=config.stft.window_fn,
)
return spec
def test_stft_config_defaults():
config = STFTConfig()
assert config.window_duration == 0.002
assert config.window_overlap == 0.75
assert config.window_fn == "hann"
def test_frequency_config_defaults():
config = FrequencyConfig()
assert config.min_freq == MIN_FREQ
assert config.max_freq == MAX_FREQ
def test_spec_size_config_defaults():
config = SpecSizeConfig()
assert config.height == 128
assert config.resize_factor == 0.5
def test_pcen_config_defaults():
config = PcenConfig()
assert config.time_constant == 0.01
assert config.gain == 0.98
assert config.bias == 2
assert config.power == 0.5
def test_spectrogram_config_defaults():
config = SpectrogramConfig()
assert isinstance(config.stft, STFTConfig)
assert isinstance(config.frequencies, FrequencyConfig)
assert isinstance(config.pcen, PcenConfig)
assert config.scale == "amplitude"
assert isinstance(config.size, SpecSizeConfig)
assert config.spectral_mean_substraction is True
assert config.peak_normalize is False
def test_stft_output_properties(sine_wave_xr: xr.DataArray):
window_duration = 0.002
window_overlap = 0.5
samplerate = sine_wave_xr.attrs["samplerate"]
nfft = int(window_duration * samplerate)
hop_len = nfft - int(window_overlap * nfft)
spec = stft(
sine_wave_xr,
window_duration=window_duration,
window_overlap=window_overlap,
window_fn="hann",
)
assert isinstance(spec, xr.DataArray)
assert spec.dims == ("frequency", "time")
assert spec.dtype == np.float32
assert "frequency" in spec.coords
assert "time" in spec.coords
time_step = arrays.get_dim_step(spec, "time")
freq_step = arrays.get_dim_step(spec, "frequency")
freq_start, freq_end = arrays.get_dim_range(spec, "frequency")
assert np.isclose(freq_step, samplerate / nfft)
assert np.isclose(time_step, hop_len / samplerate)
assert spec.frequency.min() >= 0
assert freq_start == 0
assert np.isclose(freq_end, samplerate / 2, atol=freq_step / 2)
assert np.isclose(spec.time.min(), 0)
assert spec.time.max() < DURATION
assert spec.attrs["samplerate"] == samplerate
assert spec.attrs["window_size"] == window_duration
assert spec.attrs["hop_size"] == window_duration * (1 - window_overlap)
assert np.all(spec.data >= 0)
@pytest.mark.parametrize("window_fn", ["hann", "hamming"])
def test_stft_window_fn(sine_wave_xr: xr.DataArray, window_fn: str):
spec = stft(
sine_wave_xr,
window_duration=0.002,
window_overlap=0.5,
window_fn=window_fn,
)
assert isinstance(spec, xr.DataArray)
assert np.all(spec.data >= 0)
def test_crop_spectrogram_frequencies(sample_spec: xr.DataArray):
min_f, max_f = 20_000, 80_000
cropped_spec = crop_spectrogram_frequencies(
sample_spec, min_freq=min_f, max_freq=max_f
)
assert cropped_spec.dims == sample_spec.dims
assert cropped_spec.dtype == sample_spec.dtype
assert cropped_spec.sizes["time"] == sample_spec.sizes["time"]
assert cropped_spec.sizes["frequency"] < sample_spec.sizes["frequency"]
assert cropped_spec.coords["frequency"].min() >= min_f
assert np.isclose(cropped_spec.coords["frequency"].max(), max_f, rtol=0.1)
def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
samplerate = sample_spec.attrs["samplerate"]
min_f, max_f = 0, samplerate / 2
cropped_spec = crop_spectrogram_frequencies(
sample_spec, min_freq=min_f, max_freq=max_f
)
assert cropped_spec.sizes == sample_spec.sizes
assert np.allclose(cropped_spec.data, sample_spec.data)
def test_apply_pcen(sample_spec: xr.DataArray):
pcen_config = PcenConfig()
pcen_spec = apply_pcen(
sample_spec,
time_constant=pcen_config.time_constant,
gain=pcen_config.gain,
bias=pcen_config.bias,
power=pcen_config.power,
)
assert pcen_spec.dims == sample_spec.dims
assert pcen_spec.sizes == sample_spec.sizes
assert pcen_spec.dtype == sample_spec.dtype
assert np.all(pcen_spec.data >= 0)
assert not np.allclose(pcen_spec.data, sample_spec.data)
def test_scale_spectrogram_amplitude(sample_spec: xr.DataArray):
scaled_spec = scale_spectrogram(sample_spec, scale="amplitude")
assert np.allclose(scaled_spec.data, sample_spec.data)
assert scaled_spec.dtype == sample_spec.dtype
def test_scale_spectrogram_power(sample_spec: xr.DataArray):
scaled_spec = scale_spectrogram(sample_spec, scale="power")
assert np.allclose(scaled_spec.data, sample_spec.data**2)
assert scaled_spec.dtype == sample_spec.dtype
def test_scale_spectrogram_db(sample_spec: xr.DataArray):
scaled_spec = scale_spectrogram(sample_spec, scale="dB")
log_spec_expected = arrays.to_db(sample_spec)
xr.testing.assert_allclose(scaled_spec, log_spec_expected)
def test_remove_spectral_mean(sample_spec: xr.DataArray):
spec_noisy = sample_spec.copy() + 0.1
denoised_spec = remove_spectral_mean(spec_noisy)
assert denoised_spec.dims == spec_noisy.dims
assert denoised_spec.sizes == spec_noisy.sizes
assert denoised_spec.dtype == spec_noisy.dtype
assert np.all(denoised_spec.data >= 0)
def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
const_spec = stft(constant_wave_xr, 0.002, 0.5)
denoised_spec = remove_spectral_mean(const_spec)
assert np.all(denoised_spec.data >= 0)
@pytest.mark.parametrize(
"height, resize_factor, expected_freq_size, expected_time_factor",
[
(128, 1.0, 128, 1.0),
(64, 0.5, 64, 0.5),
(256, None, 256, 1.0),
(100, 2.0, 100, 2.0),
],
)
def test_resize_spectrogram(
sample_spec: xr.DataArray,
height: int,
resize_factor: Union[float, None],
expected_freq_size: int,
expected_time_factor: float,
):
original_time_size = sample_spec.sizes["time"]
resized_spec = resize_spectrogram(
sample_spec,
height=height,
resize_factor=resize_factor,
)
assert resized_spec.dims == ("frequency", "time")
assert resized_spec.sizes["frequency"] == expected_freq_size
expected_time_size = int(original_time_size * expected_time_factor)
assert abs(resized_spec.sizes["time"] - expected_time_size) <= 1
def test_compute_spectrogram_defaults(sine_wave_xr: xr.DataArray):
config = SpectrogramConfig()
spec = compute_spectrogram(sine_wave_xr, config=config)
assert isinstance(spec, xr.DataArray)
assert spec.dims == ("frequency", "time")
assert spec.dtype == np.float32
assert config.size is not None
assert spec.sizes["frequency"] == config.size.height
temp_stft = stft(
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
)
assert config.size.resize_factor is not None
expected_time_size = int(
temp_stft.sizes["time"] * config.size.resize_factor
)
assert abs(spec.sizes["time"] - expected_time_size) <= 1
assert spec.coords["frequency"].min() >= config.frequencies.min_freq
assert np.isclose(
spec.coords["frequency"].max(),
config.frequencies.max_freq,
rtol=0.1,
)
def test_compute_spectrogram_no_pcen_no_mean_sub_no_resize(
sine_wave_xr: xr.DataArray,
):
config = SpectrogramConfig(
pcen=None,
spectral_mean_substraction=False,
size=None,
scale="power",
frequencies=FrequencyConfig(min_freq=0, max_freq=int(SAMPLERATE / 2)),
)
spec = compute_spectrogram(sine_wave_xr, config=config)
stft_direct = stft(
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
)
expected_spec = scale_spectrogram(stft_direct, scale="power")
assert spec.sizes == expected_spec.sizes
assert np.allclose(spec.data, expected_spec.data)
assert spec.dtype == expected_spec.dtype
def test_compute_spectrogram_peak_normalize(sine_wave_xr: xr.DataArray):
config = SpectrogramConfig(peak_normalize=True, pcen=None)
spec = compute_spectrogram(sine_wave_xr, config=config)
assert np.isclose(spec.data.max(), 1.0, atol=1e-6)
config = SpectrogramConfig(peak_normalize=False)
spec_no_norm = compute_spectrogram(sine_wave_xr, config=config)
assert not np.isclose(spec_no_norm.data.max(), 1.0, atol=1e-6)
def test_get_spectrogram_resolution_calculation():
config = SpectrogramConfig(
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
size=SpecSizeConfig(height=100, resize_factor=0.5),
frequencies=FrequencyConfig(min_freq=10_000, max_freq=110_000),
)
freq_res, time_res = get_spectrogram_resolution(config)
expected_freq_res = (110_000 - 10_000) / 100
expected_hop_duration = 0.002 * (1 - 0.75)
expected_time_res = expected_hop_duration / 0.5
assert np.isclose(freq_res, expected_freq_res)
assert np.isclose(time_res, expected_time_res)
def test_get_spectrogram_resolution_no_resize_factor():
config = SpectrogramConfig(
stft=STFTConfig(window_duration=0.004, window_overlap=0.5),
size=SpecSizeConfig(height=200, resize_factor=None),
frequencies=FrequencyConfig(min_freq=20_000, max_freq=120_000),
)
freq_res, time_res = get_spectrogram_resolution(config)
expected_freq_res = (120_000 - 20_000) / 200
expected_hop_duration = 0.004 * (1 - 0.5)
expected_time_res = expected_hop_duration / 1.0
assert np.isclose(freq_res, expected_freq_res)
assert np.isclose(time_res, expected_time_res)
def test_get_spectrogram_resolution_no_size_config():
config = SpectrogramConfig(size=None)
with pytest.raises(
ValueError, match="Spectrogram size configuration is required"
):
get_spectrogram_resolution(config)
def test_configurable_spectrogram_builder_init():
config = SpectrogramConfig()
builder = ConfigurableSpectrogramBuilder(config=config, dtype=np.float16)
assert builder.config is config
assert builder.dtype == np.float16
def test_configurable_spectrogram_builder_call_xr(sine_wave_xr: xr.DataArray):
config = SpectrogramConfig()
builder = ConfigurableSpectrogramBuilder(config=config)
spec_builder = builder(sine_wave_xr)
spec_direct = compute_spectrogram(sine_wave_xr, config=config)
assert isinstance(spec_builder, xr.DataArray)
assert np.allclose(spec_builder.data, spec_direct.data)
assert spec_builder.dtype == spec_direct.dtype
def test_configurable_spectrogram_builder_call_np_no_samplerate(
sine_wave_xr: xr.DataArray,
):
config = SpectrogramConfig()
builder = ConfigurableSpectrogramBuilder(config=config)
wav_np = sine_wave_xr.data
with pytest.raises(ValueError, match="Samplerate must be provided"):
builder(wav_np, samplerate=None)
def test_build_spectrogram_builder():
config = SpectrogramConfig(peak_normalize=True)
builder = build_spectrogram_builder(config=config, dtype=np.float64)
assert isinstance(builder, ConfigurableSpectrogramBuilder)
assert builder.config is config
assert builder.dtype == np.float64
def test_can_estimate_spectrogram_resolution(
wav_factory: Callable[..., Path],
):
path = wav_factory(duration=0.2, samplerate=256_000)
audio_data = load_file_audio(
path,
config=AudioConfig(resample=None, duration=None),
)
config = SpectrogramConfig(
stft=STFTConfig(),
size=SpecSizeConfig(height=256, resize_factor=0.5),
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
)
spec = compute_spectrogram(audio_data, config=config)
freq_res, time_res = get_spectrogram_resolution(config)
assert math.isclose(
arrays.get_dim_step(spec, dim="frequency"),
freq_res,
rel_tol=0.1,
)
assert math.isclose(
arrays.get_dim_step(spec, dim="time"),
time_res,
rel_tol=0.1,
)

View File

@ -3,7 +3,11 @@ import pytest
import soundfile as sf
from soundevent import data
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess import (
PreprocessingConfig,
build_preprocessor,
)
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.targets.rois import (
DEFAULT_ANCHOR,
DEFAULT_FREQUENCY_SCALE,
@ -275,6 +279,8 @@ def test_get_peak_energy_coordinates(generate_whistle):
# Build a preprocessor (default config should be fine for this test)
preprocessor = build_preprocessor()
audio_loader = build_audio_loader()
# Define a region of interest that contains the whistle
start_time = 0.2
end_time = 0.7
@ -285,6 +291,7 @@ def test_get_peak_energy_coordinates(generate_whistle):
peak_time, peak_freq = get_peak_energy_coordinates(
recording=recording,
preprocessor=preprocessor,
audio_loader=audio_loader,
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
@ -356,6 +363,7 @@ def test_get_peak_energy_coordinates_with_two_whistles(generate_whistle):
peak_time, peak_freq = get_peak_energy_coordinates(
recording=recording,
preprocessor=preprocessor,
audio_loader=build_audio_loader(),
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
@ -389,6 +397,7 @@ def test_get_peak_energy_coordinates_silent_region(create_recording):
peak_time, peak_freq = get_peak_energy_coordinates(
recording=recording,
preprocessor=preprocessor,
audio_loader=build_audio_loader(),
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
@ -443,17 +452,11 @@ def test_peak_energy_bbox_mapper_encode(generate_whistle):
# Instantiate the mapper with a preprocessor
preprocessor = build_preprocessor(
PreprocessingConfig.model_validate(
{
"spectrogram": {
"pcen": None,
"spectral_mean_substraction": False,
}
}
)
PreprocessingConfig.model_validate({"spectrogram": {"transforms": []}})
)
mapper = PeakEnergyBBoxMapper(
preprocessor=preprocessor,
audio_loader=build_audio_loader(),
time_scale=time_scale,
frequency_scale=freq_scale,
)
@ -493,6 +496,7 @@ def test_peak_energy_bbox_mapper_decode():
mapper = PeakEnergyBBoxMapper(
preprocessor=build_preprocessor(),
audio_loader=build_audio_loader(),
time_scale=time_scale,
frequency_scale=freq_scale,
)
@ -553,7 +557,11 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
}
)
)
mapper = PeakEnergyBBoxMapper(preprocessor=preprocessor)
audio_loader = build_audio_loader()
mapper = PeakEnergyBBoxMapper(
preprocessor=preprocessor,
audio_loader=audio_loader,
)
# When
# Encode the sound event, then immediately decode the result.

View File

@ -11,11 +11,12 @@ from batdetect2.train.augmentations import (
)
from batdetect2.train.clips import select_subclip
from batdetect2.train.preprocess import generate_train_example
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol
def test_mix_examples(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
@ -30,11 +31,13 @@ def test_mix_examples(
example1 = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
example2 = generate_train_example(
clip_annotation_2,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
@ -51,6 +54,7 @@ def test_mix_examples(
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
def test_mix_examples_of_different_durations(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
duration1: float,
@ -67,11 +71,13 @@ def test_mix_examples_of_different_durations(
example1 = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
example2 = generate_train_example(
clip_annotation_2,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
@ -87,6 +93,7 @@ def test_mix_examples_of_different_durations(
def test_add_echo(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
@ -96,6 +103,7 @@ def test_add_echo(
original = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
@ -109,6 +117,7 @@ def test_add_echo(
def test_selected_random_subclip_has_the_correct_width(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
@ -118,6 +127,7 @@ def test_selected_random_subclip_has_the_correct_width(
original = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
@ -128,6 +138,7 @@ def test_selected_random_subclip_has_the_correct_width(
def test_add_echo_after_subclip(
sample_preprocessor: PreprocessorProtocol,
sample_audio_loader: AudioLoader,
sample_labeller: ClipLabeller,
create_recording: Callable[..., data.Recording],
):
@ -136,6 +147,7 @@ def test_add_echo_after_subclip(
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
original = generate_train_example(
clip_annotation_1,
audio_loader=sample_audio_loader,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)