Separated the protocols to separate types module

This commit is contained in:
mbsantiago 2025-04-17 15:36:21 +01:00
parent 3417c496db
commit 19febf2216
5 changed files with 627 additions and 338 deletions

View File

@ -1,87 +1,193 @@
"""Module containing functions for preprocessing audio clips."""
from functools import partial
from typing import Callable, Optional, Protocol
from typing import Optional, Union
import numpy as np
import xarray as xr
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import (
DEFAULT_DURATION,
SCALE_RAW_AUDIO,
TARGET_SAMPLERATE_HZ,
AudioConfig,
ResampleConfig,
adjust_audio_duration,
build_audio_loader,
convert_to_xr,
load_clip_audio,
)
from batdetect2.preprocess.config import (
PreprocessingConfig,
load_preprocessing_config,
load_file_audio,
load_recording_audio,
resample_audio,
)
from batdetect2.preprocess.spectrogram import (
MAX_FREQ,
MIN_FREQ,
AmplitudeScaleConfig,
ConfigurableSpectrogramBuilder,
FrequencyConfig,
LogScaleConfig,
PcenScaleConfig,
Scales,
PcenConfig,
SpecSizeConfig,
SpectrogramConfig,
STFTConfig,
build_spectrogram_builder,
compute_spectrogram,
get_spectrogram_resolution,
)
from batdetect2.preprocess.types import (
AudioLoader,
Preprocessor,
SpectrogramBuilder,
)
__all__ = [
"AmplitudeScaleConfig",
"AudioConfig",
"AudioLoader",
"ConfigurableSpectrogramBuilder",
"DEFAULT_DURATION",
"FrequencyConfig",
"FrequencyConfig",
"LogScaleConfig",
"MAX_FREQ",
"MIN_FREQ",
"PcenScaleConfig",
"PcenConfig",
"PcenConfig",
"PreprocessingConfig",
"ResampleConfig",
"SCALE_RAW_AUDIO",
"STFTConfig",
"STFTConfig",
"Scales",
"SpecSizeConfig",
"SpecSizeConfig",
"SpectrogramBuilder",
"SpectrogramConfig",
"SpectrogramConfig",
"TARGET_SAMPLERATE_HZ",
"adjust_audio_duration",
"build_audio_loader",
"build_spectrogram_builder",
"compute_spectrogram",
"convert_to_xr",
"get_spectrogram_resolution",
"load_clip_audio",
"load_file_audio",
"load_preprocessing_config",
"preprocess_audio_clip",
"load_recording_audio",
"resample_audio",
]
class AudioPreprocessor(Protocol):
def __call__(
class PreprocessingConfig(BaseConfig):
"""Configuration for preprocessing data."""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
class StandardPreprocessor(Preprocessor):
audio_loader: AudioLoader
spectrogram_builder: SpectrogramBuilder
default_samplerate: int
def __init__(
self,
audio_loader: AudioLoader,
spectrogram_builder: SpectrogramBuilder,
default_samplerate: int,
) -> None:
self.audio_loader = audio_loader
self.spectrogram_builder = spectrogram_builder
self.default_samplerate = default_samplerate
def load_file_audio(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
return self.audio_loader.load_file(
path,
audio_dir=audio_dir,
)
def load_recording_audio(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
return self.audio_loader.load_recording(
recording,
audio_dir=audio_dir,
)
def load_clip_audio(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray: ...
) -> xr.DataArray:
return self.audio_loader.load_clip(
clip,
audio_dir=audio_dir,
)
def preprocess_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
wav = self.load_file_audio(path, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def preprocess_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
wav = self.load_recording_audio(recording, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def preprocess_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
wav = self.load_clip_audio(clip, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def compute_spectrogram(
self, wav: Union[xr.DataArray, np.ndarray]
) -> xr.DataArray:
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def load_preprocessing_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)
def build_preprocessor_from_config(
config: PreprocessingConfig,
) -> AudioPreprocessor:
return partial(preprocess_audio_clip, config=config)
def preprocess_audio_clip(
clip: data.Clip,
config: Optional[PreprocessingConfig] = None,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Preprocesses audio clip to generate spectrogram.
Parameters
----------
clip
The audio clip to preprocess.
config
Configuration for preprocessing.
Returns
-------
xr.DataArray
Preprocessed spectrogram.
"""
config = config or PreprocessingConfig()
wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir)
return compute_spectrogram(wav, config=config.spectrogram)
) -> Preprocessor:
default_samplerate = (
config.audio.resample.samplerate
if config.audio.resample
else TARGET_SAMPLERATE_HZ
)
return StandardPreprocessor(
audio_loader=build_audio_loader(config.audio),
spectrogram_builder=build_spectrogram_builder(config.spectrogram),
default_samplerate=default_samplerate,
)

View File

@ -20,7 +20,7 @@ The primary interface is the `AudioLoader` protocol, with
`AudioConfig`.
"""
from typing import Optional, Protocol
from typing import Optional
import numpy as np
import xarray as xr
@ -32,9 +32,9 @@ from soundevent.arrays import operations as ops
from soundfile import LibsndfileError
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.types import AudioLoader
__all__ = [
"AudioLoader",
"ResampleConfig",
"AudioConfig",
"ConfigurableAudioLoader",
@ -60,106 +60,6 @@ DEFAULT_DURATION = None
"""Default setting for target audio duration in seconds."""
class AudioLoader(Protocol):
"""Defines the interface for an audio loading and processing component.
An AudioLoader is responsible for retrieving audio data corresponding to
different soundevent objects (files, Recordings, Clips) and applying a
configured set of initial preprocessing steps. Adhering to this protocol
allows for different loading strategies or implementations.
"""
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess audio directly from a file path.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix to prepend to the path if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform as an xarray DataArray
with time coordinates. Typically loads only the first channel.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess the entire audio for a Recording object.
Parameters
----------
recording : data.Recording
The Recording object containing metadata about the audio file.
audio_dir : PathLike, optional
A directory where the audio file associated with the recording
can be found, especially if the path in the recording is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform. Typically loads only
the first channel.
Raises
------
FileNotFoundError
If the audio file associated with the recording cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess the audio segment defined by a Clip object.
Parameters
----------
clip : data.Clip
The Clip object specifying the recording and the start/end times
of the segment to load.
audio_dir : PathLike, optional
A directory where the audio file associated with the clip's
recording can be found.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform for the specified clip
duration. Typically loads only the first channel.
Raises
------
FileNotFoundError
If the audio file associated with the clip cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
@ -167,7 +67,7 @@ class ResampleConfig(BaseConfig):
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
mode : str, default="poly"
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
@ -177,7 +77,7 @@ class ResampleConfig(BaseConfig):
"""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
mode: str = "poly"
method: str = "poly"
class AudioConfig(BaseConfig):
@ -191,8 +91,8 @@ class AudioConfig(BaseConfig):
----------
resample : ResampleConfig, optional
Configuration for resampling. If provided (or defaulted), audio will
be resampled to the specified `samplerate` using the specified `mode`.
If set to `None` in the config file, resampling is skipped.
be resampled to the specified `samplerate` using the specified
`method`. If set to `None` in the config file, resampling is skipped.
Defaults to a ResampleConfig instance with standard settings.
scale : bool, default=False
If True, scales the audio waveform using peak normalization so that
@ -579,14 +479,14 @@ def adjust_audio_duration(
def resample_audio(
wav: xr.DataArray,
samplerate: int = TARGET_SAMPLERATE_HZ,
mode: str = "poly",
method: str = "poly",
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Resample an audio waveform DataArray to a target sample rate.
Updates the 'time' coordinate axis according to the new sample rate and
number of samples. Uses either polyphase (`scipy.signal.resample_poly`)
or Fourier method (`scipy.signal.resample`) based on the `mode`.
or Fourier method (`scipy.signal.resample`) based on the `method`.
Parameters
----------
@ -594,7 +494,7 @@ def resample_audio(
Input audio waveform with 'time' dimension and coordinates.
samplerate : int, default=TARGET_SAMPLERATE_HZ
Target sample rate in Hz.
mode : str, default="poly"
method : str, default="poly"
Resampling algorithm: "poly" or "fourier".
dtype : DTypeLike, default=np.float32
Target data type for the resampled array.
@ -610,7 +510,7 @@ def resample_audio(
------
ValueError
If `wav` lacks a 'time' dimension, the original sample rate cannot
be determined, `samplerate` is non-positive, or `mode` is invalid.
be determined, `samplerate` is non-positive, or `method` is invalid.
"""
if "time" not in wav.dims:
raise ValueError("Audio must have a time dimension")
@ -622,14 +522,14 @@ def resample_audio(
if original_samplerate == samplerate:
return wav.astype(dtype)
if mode == "poly":
if method == "poly":
resampled = resample_audio_poly(
wav,
sr_orig=original_samplerate,
sr_new=samplerate,
axis=time_axis,
)
elif mode == "fourier":
elif method == "fourier":
resampled = resample_audio_fourier(
wav,
sr_orig=original_samplerate,
@ -637,7 +537,9 @@ def resample_audio(
axis=time_axis,
)
else:
raise NotImplementedError(f"Resampling mode '{mode}' not implemented")
raise NotImplementedError(
f"Resampling method '{method}' not implemented"
)
start, stop = arrays.get_dim_range(wav, dim="time")
times = np.linspace(

View File

@ -1,31 +0,0 @@
from typing import Optional
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import (
AudioConfig,
)
from batdetect2.preprocess.spectrogram import (
SpectrogramConfig,
)
__all__ = [
"PreprocessingConfig",
"load_preprocessing_config",
]
class PreprocessingConfig(BaseConfig):
"""Configuration for preprocessing data."""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)

View File

@ -6,20 +6,20 @@ spectrogram representations suitable for input into deep learning models like
BatDetect2.
It offers a configurable pipeline including:
1. Short-Time Fourier Transform (STFT) calculation.
1. Short-Time Fourier Transform (STFT) calculation to get magnitude.
2. Frequency axis cropping to a relevant range.
3. Amplitude scaling (e.g., Logarithmic, Per-Channel Energy Normalization -
PCEN).
4. Simple denoising (optional).
5. Resizing to target dimensions (optional).
6. Final peak normalization (optional).
3. Per-Channel Energy Normalization (PCEN) (optional).
4. Amplitude scaling/representation (dB, power, or linear amplitude).
5. Simple spectral mean subtraction denoising (optional).
6. Resizing to target dimensions (optional).
7. Final peak normalization (optional).
Configuration is managed via the `SpectrogramConfig` class, allowing for
reproducible spectrogram generation consistent between training and inference.
The core computation is performed by `compute_spectrogram`.
"""
from typing import Literal, Optional, Protocol, Union
from typing import Literal, Optional, Union
import librosa
import librosa.core.spectrum
@ -32,65 +32,23 @@ from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.audio import convert_to_xr
from batdetect2.preprocess.types import SpectrogramBuilder
__all__ = [
"SpectrogramBuilder",
"STFTConfig",
"FrequencyConfig",
"SpecSizeConfig",
"LogScaleConfig",
"PcenScaleConfig",
"AmplitudeScaleConfig",
"Scales",
"PcenConfig",
"SpectrogramConfig",
"ConfigurableSpectrogramBuilder",
"build_spectrogram_builder",
"compute_spectrogram",
"get_spectrogram_resolution",
"MIN_FREQ",
"MAX_FREQ",
]
class SpectrogramBuilder(Protocol):
"""Defines the interface for a spectrogram generation component.
A SpectrogramBuilder takes a waveform (as numpy array or xarray DataArray)
and produces a spectrogram (as an xarray DataArray) based on its internal
configuration or implementation.
"""
def __call__(
self,
wav: Union[np.ndarray, xr.DataArray],
samplerate: Optional[int] = None,
) -> xr.DataArray:
"""Generate a spectrogram from an audio waveform.
Parameters
----------
wav : Union[np.ndarray, xr.DataArray]
The input audio waveform. If a numpy array, `samplerate` must
also be provided. If an xarray DataArray, it must have a 'time'
coordinate from which the sample rate can be inferred.
samplerate : int, optional
The sample rate of the audio in Hz. Required if `wav` is a
numpy array. If `wav` is an xarray DataArray, this parameter is
ignored as the sample rate is derived from the coordinates.
Returns
-------
xr.DataArray
The computed spectrogram as an xarray DataArray with 'time' and
'frequency' coordinates.
Raises
------
ValueError
If `wav` is a numpy array and `samplerate` is not provided, or
if `wav` is an xarray DataArray without a valid 'time' coordinate.
"""
...
MIN_FREQ = 10_000
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
@ -151,109 +109,84 @@ class SpecSizeConfig(BaseConfig):
resize_factor : float, optional
Factor by which to resize the spectrogram along the time axis *after*
STFT calculation. A value of 0.5 halves the number of time bins,
2.0 doubles it. If None (default), no resizing along the time axis
is performed relative to the STFT output width. Must be > 0 if provided.
2.0 doubles it. If None (default), no resizing along the time axis is
performed relative to the STFT output width. Must be > 0 if provided.
"""
height: int = 128
resize_factor: Optional[float] = 0.5
class LogScaleConfig(BaseConfig):
"""Configuration marker for using Logarithmic Amplitude Scaling."""
class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN).
name: Literal["log"] = "log"
class PcenScaleConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN) scaling.
PCEN is an adaptive gain control method often used for audio event
detection.
PCEN is an adaptive gain control method that can help emphasize transients
and suppress stationary noise. Applied after STFT and frequency cropping,
but before final amplitude scaling (dB, power, amplitude).
Attributes
----------
name : Literal["pcen"]
Discriminator field identifying this scaling type.
time_constant : float, default=0.4
Time constant (in seconds) for the PCEN smoothing filter. Controls how
quickly the normalization adapts to energy changes.
Time constant (in seconds) for the PCEN smoothing filter. Controls
how quickly the normalization adapts to energy changes.
gain : float, default=0.98
Gain factor (alpha in some formulations). Controls the AGC behavior.
Gain factor (alpha). Controls the adaptive gain component.
bias : float, default=2.0
Bias factor (delta in some formulations). Added before the
exponentiation.
Bias factor (delta). Added before the exponentiation.
power : float, default=0.5
Exponent (r in some formulations). Controls the compression
characteristic.
Exponent (r). Controls the compression characteristic.
"""
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4
gain: float = 0.98
bias: float = 2
power: float = 0.5
class AmplitudeScaleConfig(BaseConfig):
"""Configuration marker for using Linear Amplitude (no scaling applied).
Note: The actual output is typically magnitude from STFT, not raw amplitude.
This option essentially skips log or PCEN scaling.
"""
name: Literal["amplitude"] = "amplitude"
Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
"""Type alias for the different amplitude scaling configuration options."""
class SpectrogramConfig(BaseConfig):
"""Unified configuration for spectrogram generation.
"""Unified configuration for spectrogram generation pipeline.
Aggregates settings for STFT, frequency selection, amplitude scaling,
resizing, and optional post-processing steps like denoising and final
normalization.
Aggregates settings for all steps involved in converting a preprocessed
audio waveform into a final spectrogram representation suitable for model input.
Attributes
----------
stft : STFTConfig
Configuration for the Short-Time Fourier Transform. Defaults to standard
settings via `STFTConfig`.
Configuration for the initial Short-Time Fourier Transform.
Defaults to standard settings via `STFTConfig`.
frequencies : FrequencyConfig
Configuration for cropping the frequency range. Defaults to standard
settings via `FrequencyConfig`.
scale : Scales
Configuration for amplitude scaling. Determines whether to apply
log scaling, PCEN, or leave as linear magnitude. Defaults to PCEN
via `PcenScaleConfig`. Use the `name` field ("log", "pcen", "amplitude")
in config files to select the type and provide relevant parameters.
Configuration for cropping the frequency range after STFT.
Defaults to standard settings via `FrequencyConfig`.
pcen : PcenConfig, optional
Configuration for applying Per-Channel Energy Normalization (PCEN). If
provided, PCEN is applied after frequency cropping. If None or omitted
(default), PCEN is skipped.
scale : Literal["dB", "amplitude", "power"], default="amplitude"
Determines the final amplitude representation *after* optional PCEN.
- "amplitude": Use linear magnitude values (output of STFT or PCEN).
- "power": Use power values (magnitude squared).
- "dB": Use logarithmic (decibel-like) scaling applied to the magnitude
(or PCEN output if enabled). Calculated as `log1p(C * S)`.
size : SpecSizeConfig, optional
Configuration for resizing the final spectrogram dimensions (height in
frequency bins, optional time resizing factor). If None or omitted,
no resizing is performed after STFT and frequency cropping. Defaults
to standard settings via `SpecSizeConfig`.
denoise : bool, default=True
If True (default), applies a simple spectral mean subtraction denoising
step after amplitude scaling.
max_scale : bool, default=False
Configuration for resizing the spectrogram dimensions
(frequency height, optional time width factor). Applied after PCEN and
scaling. If None (default), no resizing is performed.
spectral_mean_substraction : bool, default=True
If True (default), applies simple spectral mean subtraction denoising
*after* PCEN and amplitude scaling, but *before* resizing.
peak_normalize : bool, default=False
If True, applies a final peak normalization to the spectrogram *after*
all other steps (including log/PCEN scaling and resizing), scaling the
maximum value across the entire spectrogram to 1.0. If False (default),
this final scaling is skipped. **Note:** Applying this after log or PCEN
scaling will alter the characteristics of those scales.
all other steps (including resizing), scaling the overall maximum value
to 1.0. If False (default), this final normalization is skipped.
"""
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
scale: Scales = Field(
default_factory=PcenScaleConfig,
discriminator="name",
)
pcen: Optional[PcenConfig] = Field(default_factory=PcenConfig)
scale: Literal["dB", "amplitude", "power"] = "amplitude"
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
denoise: bool = True
max_scale: bool = False
spectral_mean_substraction: bool = True
peak_normalize: bool = False
class ConfigurableSpectrogramBuilder(SpectrogramBuilder):
@ -362,13 +295,13 @@ def compute_spectrogram(
"""Compute a spectrogram from a waveform using specified configurations.
Applies a sequence of operations based on the `config`:
1. Compute STFT magnitude (`stft`).
2. Crop frequency axis (`crop_spectrogram_frequencies`).
3. Apply amplitude scaling (log, PCEN, or none) (`scale_spectrogram`).
4. Apply denoising if enabled (`denoise_spectrogram`).
5. Resize dimensions if specified (`resize_spectrogram`).
6. Apply final peak normalization if enabled (`max_scale`).
3. Apply PCEN if configured (`apply_pcen`).
4. Apply final amplitude scaling (dB, power, amplitude) (`scale_spectrogram`).
5. Apply spectral mean subtraction denoising if enabled.
6. Resize dimensions if specified (`resize_spectrogram`).
7. Apply final peak normalization if enabled.
Parameters
----------
@ -411,10 +344,19 @@ def compute_spectrogram(
max_freq=config.frequencies.max_freq,
)
if config.pcen:
spec = apply_pcen(
spec,
time_constant=config.pcen.time_constant,
gain=config.pcen.gain,
power=config.pcen.power,
bias=config.pcen.bias,
)
spec = scale_spectrogram(spec, scale=config.scale)
if config.denoise:
spec = denoise_spectrogram(spec)
if config.spectral_mean_substraction:
spec = remove_spectral_mean(spec)
if config.size:
spec = resize_spectrogram(
@ -423,7 +365,7 @@ def compute_spectrogram(
resize_factor=config.size.resize_factor,
)
if config.max_scale:
if config.peak_normalize:
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
return spec.astype(dtype)
@ -550,7 +492,7 @@ def stft(
)
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
def remove_spectral_mean(spec: xr.DataArray) -> xr.DataArray:
"""Apply simple spectral mean subtraction for denoising.
Subtracts the mean value of each frequency bin (calculated across time)
@ -576,23 +518,22 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
def scale_spectrogram(
spec: xr.DataArray,
scale: Scales,
scale: Literal["dB", "power", "amplitude"],
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Apply configured amplitude scaling to the spectrogram.
"""Apply final amplitude scaling/representation to the spectrogram.
Dispatches to the appropriate scaling function (log, PCEN) based on the
`scale` configuration object's `name` field. If `scale.name` is
"amplitude", the spectrogram is returned unchanged (as it's already
magnitude/amplitude).
Converts the input magnitude spectrogram based on the `scale` type:
- "dB": Applies logarithmic scaling `log1p(C * S)`.
- "power": Squares the magnitude values `S^2`.
- "amplitude": Returns the input magnitude values `S` unchanged.
Parameters
----------
spec : xr.DataArray
Input magnitude spectrogram.
scale : Scales
The configuration object specifying the scaling method and parameters
(instance of LogScaleConfig, PcenScaleConfig, or AmplitudeScaleConfig).
Input magnitude spectrogram (potentially after PCEN).
scale : Literal["dB", "power", "amplitude"]
The target amplitude representation.
dtype : DTypeLike, default=np.float32
Target data type for the output scaled spectrogram.
@ -601,22 +542,16 @@ def scale_spectrogram(
xr.DataArray
Spectrogram with the specified amplitude scaling applied.
"""
if scale.name == "log":
if scale == "dB":
return scale_log(spec, dtype=dtype)
if scale.name == "pcen":
return scale_pcen(
spec,
time_constant=scale.time_constant,
gain=scale.gain,
power=scale.power,
bias=scale.bias,
)
if scale == "power":
return spec**2
return spec
def scale_pcen(
def apply_pcen(
spec: xr.DataArray,
time_constant: float = 0.4,
gain: float = 0.98,

View File

@ -0,0 +1,377 @@
"""Defines common interfaces (Protocols) for preprocessing components.
This module centralizes the Protocol definitions used throughout the
`batdetect2.preprocess` package. Protocols define expected methods and
signatures, allowing for flexible and interchangeable implementations of
components like audio loaders and spectrogram builders.
Using these protocols ensures that different parts of the preprocessing
pipeline can interact consistently, regardless of the specific underlying
implementation (e.g., different libraries or custom configurations).
"""
from typing import Optional, Protocol, Union
import numpy as np
import xarray as xr
from soundevent import data
class AudioLoader(Protocol):
"""Defines the interface for an audio loading and processing component.
An AudioLoader is responsible for retrieving audio data corresponding to
different soundevent objects (files, Recordings, Clips) and applying a
configured set of initial preprocessing steps. Adhering to this protocol
allows for different loading strategies or implementations.
"""
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess audio directly from a file path.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix to prepend to the path if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform as an xarray DataArray
with time coordinates. Typically loads only the first channel.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess the entire audio for a Recording object.
Parameters
----------
recording : data.Recording
The Recording object containing metadata about the audio file.
audio_dir : PathLike, optional
A directory where the audio file associated with the recording
can be found, especially if the path in the recording is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform. Typically loads only
the first channel.
Raises
------
FileNotFoundError
If the audio file associated with the recording cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess the audio segment defined by a Clip object.
Parameters
----------
clip : data.Clip
The Clip object specifying the recording and the start/end times
of the segment to load.
audio_dir : PathLike, optional
A directory where the audio file associated with the clip's
recording can be found.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform for the specified clip
duration. Typically loads only the first channel.
Raises
------
FileNotFoundError
If the audio file associated with the clip cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
class SpectrogramBuilder(Protocol):
"""Defines the interface for a spectrogram generation component.
A SpectrogramBuilder takes a waveform (as numpy array or xarray DataArray)
and produces a spectrogram (as an xarray DataArray) based on its internal
configuration or implementation.
"""
def __call__(
self,
wav: Union[np.ndarray, xr.DataArray],
samplerate: Optional[int] = None,
) -> xr.DataArray:
"""Generate a spectrogram from an audio waveform.
Parameters
----------
wav : Union[np.ndarray, xr.DataArray]
The input audio waveform. If a numpy array, `samplerate` must
also be provided. If an xarray DataArray, it must have a 'time'
coordinate from which the sample rate can be inferred.
samplerate : int, optional
The sample rate of the audio in Hz. Required if `wav` is a
numpy array. If `wav` is an xarray DataArray, this parameter is
ignored as the sample rate is derived from the coordinates.
Returns
-------
xr.DataArray
The computed spectrogram as an xarray DataArray with 'time' and
'frequency' coordinates.
Raises
------
ValueError
If `wav` is a numpy array and `samplerate` is not provided, or
if `wav` is an xarray DataArray without a valid 'time' coordinate.
"""
...
class Preprocessor(Protocol):
"""Defines a high-level interface for the complete preprocessing pipeline.
A Preprocessor combines audio loading and spectrogram generation steps.
It provides methods to go directly from source descriptions (file paths,
Recording objects, Clip objects) to the final spectrogram representation
needed by the model. It may also expose intermediate steps like audio
loading or spectrogram computation from a waveform.
"""
def preprocess_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio from a file and compute the final processed spectrogram.
Performs the full pipeline:
Load -> Preprocess Audio -> Compute Spectrogram.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The final processed spectrogram.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If any step in the loading or preprocessing fails.
"""
...
def preprocess_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Recording and compute the processed spectrogram.
Performs the full pipeline for the entire duration of the recording.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If any step in the loading or preprocessing fails.
"""
...
def preprocess_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Clip and compute the final processed spectrogram.
Performs the full pipeline for the specified clip segment.
Parameters
----------
clip : data.Clip
The Clip object defining the audio segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If any step in the loading or preprocessing fails.
"""
...
def load_file_audio(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform from a file path.
Performs the initial audio loading and waveform processing steps
(like resampling, scaling), but stops *before* spectrogram generation.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform.
Raises
------
FileNotFoundError, Exception
If audio loading/preprocessing fails.
"""
...
def load_recording_audio(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Recording.
Performs the initial audio loading and waveform processing steps
for the entire recording duration.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform.
Raises
------
FileNotFoundError, Exception
If audio loading/preprocessing fails.
"""
...
def load_clip_audio(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Clip.
Performs the initial audio loading and waveform processing steps
for the specified clip segment.
Parameters
----------
clip : data.Clip
The Clip object defining the segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform segment.
Raises
------
FileNotFoundError, Exception
If audio loading/preprocessing fails.
"""
...
def compute_spectrogram(
self,
wav: Union[xr.DataArray, np.ndarray],
) -> xr.DataArray:
"""Compute the spectrogram from a pre-loaded audio waveform.
Applies the spectrogram generation steps (STFT, scaling, etc.) defined
by the `SpectrogramBuilder` component of the preprocessor to an
already loaded (and potentially preprocessed) waveform.
Parameters
----------
wav : Union[xr.DataArray, np.ndarray]
The input audio waveform. If numpy array, `samplerate` is required.
samplerate : int, optional
Sample rate in Hz (required if `wav` is np.ndarray).
Returns
-------
xr.DataArray
The computed spectrogram.
Raises
------
ValueError, Exception
If waveform input is invalid or spectrogram computation fails.
"""
...