Documented the preprocessing module

This commit is contained in:
mbsantiago 2025-04-17 15:56:07 +01:00
parent 19febf2216
commit 638f93fe92
3 changed files with 273 additions and 30 deletions

View File

@ -35,7 +35,7 @@ def get_preprocessing_config(params: dict) -> PreprocessingConfig:
audio=AudioConfig( audio=AudioConfig(
resample=ResampleConfig( resample=ResampleConfig(
samplerate=params["target_samp_rate"], samplerate=params["target_samp_rate"],
mode="poly", method="poly",
), ),
scale=params["scale_raw_audio"], scale=params["scale_raw_audio"],
center=params["scale_raw_audio"], center=params["scale_raw_audio"],

View File

@ -1,4 +1,32 @@
"""Module containing functions for preprocessing audio clips.""" """Main entry point for the BatDetect2 Preprocessing subsystem.
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
for converting raw audio input (from files or data objects) into processed
spectrograms suitable for input to BatDetect2 models. This ensures consistent
data handling between model training and inference.
The preprocessing pipeline consists of two main stages, configured via nested
data structures:
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
processing like resampling, duration adjustment, centering, and scaling.
Configured via `AudioConfig`.
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
the processed waveform using STFT, followed by frequency cropping, optional
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
resizing, and optional peak normalization. Configured via
`SpectrogramConfig`.
This module provides the primary interface:
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
and `SpectrogramConfig`.
- `load_preprocessing_config`: Function to load the unified configuration.
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
instance from a `PreprocessingConfig`.
"""
from typing import Optional, Union from typing import Optional, Union
@ -14,13 +42,7 @@ from batdetect2.preprocess.audio import (
TARGET_SAMPLERATE_HZ, TARGET_SAMPLERATE_HZ,
AudioConfig, AudioConfig,
ResampleConfig, ResampleConfig,
adjust_audio_duration,
build_audio_loader, build_audio_loader,
convert_to_xr,
load_clip_audio,
load_file_audio,
load_recording_audio,
resample_audio,
) )
from batdetect2.preprocess.spectrogram import ( from batdetect2.preprocess.spectrogram import (
MAX_FREQ, MAX_FREQ,
@ -32,7 +54,6 @@ from batdetect2.preprocess.spectrogram import (
SpectrogramConfig, SpectrogramConfig,
STFTConfig, STFTConfig,
build_spectrogram_builder, build_spectrogram_builder,
compute_spectrogram,
get_spectrogram_resolution, get_spectrogram_resolution,
) )
from batdetect2.preprocess.types import ( from batdetect2.preprocess.types import (
@ -47,44 +68,79 @@ __all__ = [
"ConfigurableSpectrogramBuilder", "ConfigurableSpectrogramBuilder",
"DEFAULT_DURATION", "DEFAULT_DURATION",
"FrequencyConfig", "FrequencyConfig",
"FrequencyConfig",
"MAX_FREQ", "MAX_FREQ",
"MIN_FREQ", "MIN_FREQ",
"PcenConfig", "PcenConfig",
"PcenConfig",
"PreprocessingConfig", "PreprocessingConfig",
"ResampleConfig", "ResampleConfig",
"SCALE_RAW_AUDIO", "SCALE_RAW_AUDIO",
"STFTConfig", "STFTConfig",
"STFTConfig",
"SpecSizeConfig",
"SpecSizeConfig", "SpecSizeConfig",
"SpectrogramBuilder", "SpectrogramBuilder",
"SpectrogramConfig", "SpectrogramConfig",
"SpectrogramConfig", "StandardPreprocessor",
"TARGET_SAMPLERATE_HZ", "TARGET_SAMPLERATE_HZ",
"adjust_audio_duration",
"build_audio_loader", "build_audio_loader",
"build_preprocessor",
"build_spectrogram_builder", "build_spectrogram_builder",
"compute_spectrogram",
"convert_to_xr",
"get_spectrogram_resolution", "get_spectrogram_resolution",
"load_clip_audio",
"load_file_audio",
"load_preprocessing_config", "load_preprocessing_config",
"load_recording_audio",
"resample_audio",
] ]
class PreprocessingConfig(BaseConfig): class PreprocessingConfig(BaseConfig):
"""Configuration for preprocessing data.""" """Unified configuration for the audio preprocessing pipeline.
Aggregates the configuration for both the initial audio processing stage
and the subsequent spectrogram generation stage.
Attributes
----------
audio : AudioConfig
Configuration settings for the audio loading and initial waveform
processing steps (e.g., resampling, duration adjustment, scaling).
Defaults to default `AudioConfig` settings if omitted.
spectrogram : SpectrogramConfig
Configuration settings for the spectrogram generation process
(e.g., STFT parameters, frequency cropping, scaling, denoising,
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
"""
audio: AudioConfig = Field(default_factory=AudioConfig) audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
class StandardPreprocessor(Preprocessor): class StandardPreprocessor(Preprocessor):
"""Standard implementation of the `Preprocessor` protocol.
Orchestrates the audio loading and spectrogram generation pipeline using
an `AudioLoader` and a `SpectrogramBuilder` internally, which are
configured according to a `PreprocessingConfig`.
This class is typically instantiated using the `build_preprocessor`
factory function.
Attributes
----------
audio_loader : AudioLoader
The configured audio loader instance used for waveform loading and
initial processing.
spectrogram_builder : SpectrogramBuilder
The configured spectrogram builder instance used for generating
spectrograms from waveforms.
default_samplerate : int
The sample rate (in Hz) assumed for input waveforms when they are
provided as raw NumPy arrays without coordinate information (e.g.,
when calling `compute_spectrogram` directly with `np.ndarray`).
This value is derived from the `AudioConfig` (target resample rate
or default if resampling is off) and also serves as documentation
for the pipeline's intended operating sample rate. Note that when
processing `xr.DataArray` inputs that have coordinate information
(the standard internal workflow), the sample rate embedded in the
coordinates takes precedence over this default value during
spectrogram calculation.
"""
audio_loader: AudioLoader audio_loader: AudioLoader
spectrogram_builder: SpectrogramBuilder spectrogram_builder: SpectrogramBuilder
default_samplerate: int default_samplerate: int
@ -95,6 +151,19 @@ class StandardPreprocessor(Preprocessor):
spectrogram_builder: SpectrogramBuilder, spectrogram_builder: SpectrogramBuilder,
default_samplerate: int, default_samplerate: int,
) -> None: ) -> None:
"""Initialize the StandardPreprocessor.
Parameters
----------
audio_loader : AudioLoader
An initialized audio loader conforming to the AudioLoader protocol.
spectrogram_builder : SpectrogramBuilder
An initialized spectrogram builder conforming to the
SpectrogramBuilder protocol.
default_samplerate : int
The sample rate to assume for NumPy array inputs and potentially
reflecting the target rate of the audio config.
"""
self.audio_loader = audio_loader self.audio_loader = audio_loader
self.spectrogram_builder = spectrogram_builder self.spectrogram_builder = spectrogram_builder
self.default_samplerate = default_samplerate self.default_samplerate = default_samplerate
@ -104,6 +173,23 @@ class StandardPreprocessor(Preprocessor):
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray: ) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform from a file path.
Delegates to the internal `audio_loader`.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform (typically first
channel).
"""
return self.audio_loader.load_file( return self.audio_loader.load_file(
path, path,
audio_dir=audio_dir, audio_dir=audio_dir,
@ -114,6 +200,23 @@ class StandardPreprocessor(Preprocessor):
recording: data.Recording, recording: data.Recording,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray: ) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Recording.
Delegates to the internal `audio_loader`.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform (typically first
channel).
"""
return self.audio_loader.load_recording( return self.audio_loader.load_recording(
recording, recording,
audio_dir=audio_dir, audio_dir=audio_dir,
@ -124,6 +227,23 @@ class StandardPreprocessor(Preprocessor):
clip: data.Clip, clip: data.Clip,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray: ) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Clip.
Delegates to the internal `audio_loader`.
Parameters
----------
clip : data.Clip
The Clip object defining the segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform segment (typically first
channel).
"""
return self.audio_loader.load_clip( return self.audio_loader.load_clip(
clip, clip,
audio_dir=audio_dir, audio_dir=audio_dir,
@ -134,6 +254,24 @@ class StandardPreprocessor(Preprocessor):
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray: ) -> xr.DataArray:
"""Load audio from a file and compute the final processed spectrogram.
Performs the full pipeline:
Load -> Preprocess Audio -> Compute Spectrogram.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_file_audio(path, audio_dir=audio_dir) wav = self.load_file_audio(path, audio_dir=audio_dir)
return self.spectrogram_builder( return self.spectrogram_builder(
wav, wav,
@ -145,6 +283,22 @@ class StandardPreprocessor(Preprocessor):
recording: data.Recording, recording: data.Recording,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray: ) -> xr.DataArray:
"""Load audio for a Recording and compute the processed spectrogram.
Performs the full pipeline for the entire duration of the recording.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_recording_audio(recording, audio_dir=audio_dir) wav = self.load_recording_audio(recording, audio_dir=audio_dir)
return self.spectrogram_builder( return self.spectrogram_builder(
wav, wav,
@ -156,6 +310,22 @@ class StandardPreprocessor(Preprocessor):
clip: data.Clip, clip: data.Clip,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray: ) -> xr.DataArray:
"""Load audio for a Clip and compute the final processed spectrogram.
Performs the full pipeline for the specified clip segment.
Parameters
----------
clip : data.Clip
The Clip object defining the audio segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_clip_audio(clip, audio_dir=audio_dir) wav = self.load_clip_audio(clip, audio_dir=audio_dir)
return self.spectrogram_builder( return self.spectrogram_builder(
wav, wav,
@ -165,6 +335,27 @@ class StandardPreprocessor(Preprocessor):
def compute_spectrogram( def compute_spectrogram(
self, wav: Union[xr.DataArray, np.ndarray] self, wav: Union[xr.DataArray, np.ndarray]
) -> xr.DataArray: ) -> xr.DataArray:
"""Compute the spectrogram from a pre-loaded audio waveform.
Applies the configured spectrogram generation steps
(STFT, scaling, etc.) using the internal `spectrogram_builder`.
If `wav` is a NumPy array, the `default_samplerate` stored in this
preprocessor instance will be used. If `wav` is an xarray DataArray
with time coordinates, the sample rate derived from those coordinates
will take precedence over `default_samplerate`.
Parameters
----------
wav : Union[xr.DataArray, np.ndarray]
The input audio waveform. If numpy array, `default_samplerate`
stored in this object will be assumed.
Returns
-------
xr.DataArray
The computed spectrogram.
"""
return self.spectrogram_builder( return self.spectrogram_builder(
wav, wav,
samplerate=self.default_samplerate, samplerate=self.default_samplerate,
@ -175,12 +366,64 @@ def load_preprocessing_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
) -> PreprocessingConfig: ) -> PreprocessingConfig:
"""Load the unified preprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PreprocessingConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
preprocessing configuration (e.g., "train.preprocessing"). If None, the
entire file content is validated as the PreprocessingConfig.
Returns
-------
PreprocessingConfig
Loaded and validated preprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to PreprocessingConfig.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=PreprocessingConfig, field=field) return load_config(path, schema=PreprocessingConfig, field=field)
def build_preprocessor_from_config( def build_preprocessor(
config: PreprocessingConfig, config: Optional[PreprocessingConfig] = None,
) -> Preprocessor: ) -> Preprocessor:
"""Factory function to build the standard preprocessor from configuration.
Creates instances of the required `AudioLoader` and `SpectrogramBuilder`
based on the provided `PreprocessingConfig` (or defaults if config is None),
determines the effective default sample rate, and initializes the
`StandardPreprocessor`.
Parameters
----------
config : PreprocessingConfig, optional
The unified preprocessing configuration object. If None, default
configurations for audio and spectrogram processing will be used.
Returns
-------
Preprocessor
An initialized `StandardPreprocessor` instance ready to process audio
according to the configuration.
"""
config = config or PreprocessingConfig()
default_samplerate = ( default_samplerate = (
config.audio.resample.samplerate config.audio.resample.samplerate
if config.audio.resample if config.audio.resample

View File

@ -125,7 +125,7 @@ def no_center_config() -> audio.AudioConfig:
def resample_fourier_config() -> audio.AudioConfig: def resample_fourier_config() -> audio.AudioConfig:
return audio.AudioConfig( return audio.AudioConfig(
resample=audio.ResampleConfig( resample=audio.ResampleConfig(
samplerate=audio.TARGET_SAMPLERATE_HZ // 2, mode="fourier" samplerate=audio.TARGET_SAMPLERATE_HZ // 2, method="fourier"
) )
) )
@ -133,21 +133,21 @@ def resample_fourier_config() -> audio.AudioConfig:
def test_resample_config_defaults(): def test_resample_config_defaults():
config = audio.ResampleConfig() config = audio.ResampleConfig()
assert config.samplerate == audio.TARGET_SAMPLERATE_HZ assert config.samplerate == audio.TARGET_SAMPLERATE_HZ
assert config.mode == "poly" assert config.method == "poly"
def test_audio_config_defaults(): def test_audio_config_defaults():
config = audio.AudioConfig() config = audio.AudioConfig()
assert config.resample is not None assert config.resample is not None
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
assert config.resample.mode == "poly" assert config.resample.method == "poly"
assert config.scale == audio.SCALE_RAW_AUDIO assert config.scale == audio.SCALE_RAW_AUDIO
assert config.center is True assert config.center is True
assert config.duration == audio.DEFAULT_DURATION assert config.duration == audio.DEFAULT_DURATION
def test_audio_config_override(): def test_audio_config_override():
resample_cfg = audio.ResampleConfig(samplerate=44100, mode="fourier") resample_cfg = audio.ResampleConfig(samplerate=44100, method="fourier")
config = audio.AudioConfig( config = audio.AudioConfig(
resample=resample_cfg, resample=resample_cfg,
scale=True, scale=True,
@ -206,7 +206,7 @@ def test_resample_audio(orig_sr, target_sr, mode):
duration = 0.1 duration = 0.1
wave = create_xr_wave(orig_sr, duration) wave = create_xr_wave(orig_sr, duration)
resampled_wave = audio.resample_audio( resampled_wave = audio.resample_audio(
wave, samplerate=target_sr, mode=mode, dtype=np.float32 wave, samplerate=target_sr, method=mode, dtype=np.float32
) )
expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr)) expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr))
assert resampled_wave.sizes["time"] == expected_samples assert resampled_wave.sizes["time"] == expected_samples
@ -233,7 +233,7 @@ def test_resample_audio_same_samplerate():
def test_resample_audio_invalid_mode_raises(): def test_resample_audio_invalid_mode_raises():
wave = create_xr_wave(48000, 0.1) wave = create_xr_wave(48000, 0.1)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
audio.resample_audio(wave, samplerate=96000, mode="invalid_mode") audio.resample_audio(wave, samplerate=96000, method="invalid_mode")
def test_resample_audio_no_time_dim_raises(): def test_resample_audio_no_time_dim_raises():