diff --git a/batdetect2/compat/params.py b/batdetect2/compat/params.py index 0cf9b19..14cfdbc 100644 --- a/batdetect2/compat/params.py +++ b/batdetect2/compat/params.py @@ -35,7 +35,7 @@ def get_preprocessing_config(params: dict) -> PreprocessingConfig: audio=AudioConfig( resample=ResampleConfig( samplerate=params["target_samp_rate"], - mode="poly", + method="poly", ), scale=params["scale_raw_audio"], center=params["scale_raw_audio"], diff --git a/batdetect2/preprocess/__init__.py b/batdetect2/preprocess/__init__.py index b1a8022..fe791a9 100644 --- a/batdetect2/preprocess/__init__.py +++ b/batdetect2/preprocess/__init__.py @@ -1,4 +1,32 @@ -"""Module containing functions for preprocessing audio clips.""" +"""Main entry point for the BatDetect2 Preprocessing subsystem. + +This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline +for converting raw audio input (from files or data objects) into processed +spectrograms suitable for input to BatDetect2 models. This ensures consistent +data handling between model training and inference. + +The preprocessing pipeline consists of two main stages, configured via nested +data structures: +1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial + processing like resampling, duration adjustment, centering, and scaling. + Configured via `AudioConfig`. +2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from + the processed waveform using STFT, followed by frequency cropping, optional + PCEN, amplitude scaling (dB, power, linear), optional denoising, optional + resizing, and optional peak normalization. Configured via + `SpectrogramConfig`. + +This module provides the primary interface: + +- `PreprocessingConfig`: A unified configuration object holding `AudioConfig` + and `SpectrogramConfig`. +- `load_preprocessing_config`: Function to load the unified configuration. +- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline. +- `StandardPreprocessor`: The default implementation of the `Preprocessor`. +- `build_preprocessor`: A factory function to create a `StandardPreprocessor` + instance from a `PreprocessingConfig`. + +""" from typing import Optional, Union @@ -14,13 +42,7 @@ from batdetect2.preprocess.audio import ( TARGET_SAMPLERATE_HZ, AudioConfig, ResampleConfig, - adjust_audio_duration, build_audio_loader, - convert_to_xr, - load_clip_audio, - load_file_audio, - load_recording_audio, - resample_audio, ) from batdetect2.preprocess.spectrogram import ( MAX_FREQ, @@ -32,7 +54,6 @@ from batdetect2.preprocess.spectrogram import ( SpectrogramConfig, STFTConfig, build_spectrogram_builder, - compute_spectrogram, get_spectrogram_resolution, ) from batdetect2.preprocess.types import ( @@ -47,44 +68,79 @@ __all__ = [ "ConfigurableSpectrogramBuilder", "DEFAULT_DURATION", "FrequencyConfig", - "FrequencyConfig", "MAX_FREQ", "MIN_FREQ", "PcenConfig", - "PcenConfig", "PreprocessingConfig", "ResampleConfig", "SCALE_RAW_AUDIO", "STFTConfig", - "STFTConfig", - "SpecSizeConfig", "SpecSizeConfig", "SpectrogramBuilder", "SpectrogramConfig", - "SpectrogramConfig", + "StandardPreprocessor", "TARGET_SAMPLERATE_HZ", - "adjust_audio_duration", "build_audio_loader", + "build_preprocessor", "build_spectrogram_builder", - "compute_spectrogram", - "convert_to_xr", "get_spectrogram_resolution", - "load_clip_audio", - "load_file_audio", "load_preprocessing_config", - "load_recording_audio", - "resample_audio", ] class PreprocessingConfig(BaseConfig): - """Configuration for preprocessing data.""" + """Unified configuration for the audio preprocessing pipeline. + + Aggregates the configuration for both the initial audio processing stage + and the subsequent spectrogram generation stage. + + Attributes + ---------- + audio : AudioConfig + Configuration settings for the audio loading and initial waveform + processing steps (e.g., resampling, duration adjustment, scaling). + Defaults to default `AudioConfig` settings if omitted. + spectrogram : SpectrogramConfig + Configuration settings for the spectrogram generation process + (e.g., STFT parameters, frequency cropping, scaling, denoising, + resizing). Defaults to default `SpectrogramConfig` settings if omitted. + """ audio: AudioConfig = Field(default_factory=AudioConfig) spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) class StandardPreprocessor(Preprocessor): + """Standard implementation of the `Preprocessor` protocol. + + Orchestrates the audio loading and spectrogram generation pipeline using + an `AudioLoader` and a `SpectrogramBuilder` internally, which are + configured according to a `PreprocessingConfig`. + + This class is typically instantiated using the `build_preprocessor` + factory function. + + Attributes + ---------- + audio_loader : AudioLoader + The configured audio loader instance used for waveform loading and + initial processing. + spectrogram_builder : SpectrogramBuilder + The configured spectrogram builder instance used for generating + spectrograms from waveforms. + default_samplerate : int + The sample rate (in Hz) assumed for input waveforms when they are + provided as raw NumPy arrays without coordinate information (e.g., + when calling `compute_spectrogram` directly with `np.ndarray`). + This value is derived from the `AudioConfig` (target resample rate + or default if resampling is off) and also serves as documentation + for the pipeline's intended operating sample rate. Note that when + processing `xr.DataArray` inputs that have coordinate information + (the standard internal workflow), the sample rate embedded in the + coordinates takes precedence over this default value during + spectrogram calculation. + """ + audio_loader: AudioLoader spectrogram_builder: SpectrogramBuilder default_samplerate: int @@ -95,6 +151,19 @@ class StandardPreprocessor(Preprocessor): spectrogram_builder: SpectrogramBuilder, default_samplerate: int, ) -> None: + """Initialize the StandardPreprocessor. + + Parameters + ---------- + audio_loader : AudioLoader + An initialized audio loader conforming to the AudioLoader protocol. + spectrogram_builder : SpectrogramBuilder + An initialized spectrogram builder conforming to the + SpectrogramBuilder protocol. + default_samplerate : int + The sample rate to assume for NumPy array inputs and potentially + reflecting the target rate of the audio config. + """ self.audio_loader = audio_loader self.spectrogram_builder = spectrogram_builder self.default_samplerate = default_samplerate @@ -104,6 +173,23 @@ class StandardPreprocessor(Preprocessor): path: data.PathLike, audio_dir: Optional[data.PathLike] = None, ) -> xr.DataArray: + """Load and preprocess *only* the audio waveform from a file path. + + Delegates to the internal `audio_loader`. + + Parameters + ---------- + path : PathLike + Path to the audio file. + audio_dir : PathLike, optional + A directory prefix if `path` is relative. + + Returns + ------- + xr.DataArray + The loaded and preprocessed audio waveform (typically first + channel). + """ return self.audio_loader.load_file( path, audio_dir=audio_dir, @@ -114,6 +200,23 @@ class StandardPreprocessor(Preprocessor): recording: data.Recording, audio_dir: Optional[data.PathLike] = None, ) -> xr.DataArray: + """Load and preprocess *only* the audio waveform for a Recording. + + Delegates to the internal `audio_loader`. + + Parameters + ---------- + recording : data.Recording + The Recording object. + audio_dir : PathLike, optional + Directory containing the audio file. + + Returns + ------- + xr.DataArray + The loaded and preprocessed audio waveform (typically first + channel). + """ return self.audio_loader.load_recording( recording, audio_dir=audio_dir, @@ -124,6 +227,23 @@ class StandardPreprocessor(Preprocessor): clip: data.Clip, audio_dir: Optional[data.PathLike] = None, ) -> xr.DataArray: + """Load and preprocess *only* the audio waveform for a Clip. + + Delegates to the internal `audio_loader`. + + Parameters + ---------- + clip : data.Clip + The Clip object defining the segment. + audio_dir : PathLike, optional + Directory containing the audio file. + + Returns + ------- + xr.DataArray + The loaded and preprocessed audio waveform segment (typically first + channel). + """ return self.audio_loader.load_clip( clip, audio_dir=audio_dir, @@ -134,6 +254,24 @@ class StandardPreprocessor(Preprocessor): path: data.PathLike, audio_dir: Optional[data.PathLike] = None, ) -> xr.DataArray: + """Load audio from a file and compute the final processed spectrogram. + + Performs the full pipeline: + + Load -> Preprocess Audio -> Compute Spectrogram. + + Parameters + ---------- + path : PathLike + Path to the audio file. + audio_dir : PathLike, optional + A directory prefix if `path` is relative. + + Returns + ------- + xr.DataArray + The final processed spectrogram. + """ wav = self.load_file_audio(path, audio_dir=audio_dir) return self.spectrogram_builder( wav, @@ -145,6 +283,22 @@ class StandardPreprocessor(Preprocessor): recording: data.Recording, audio_dir: Optional[data.PathLike] = None, ) -> xr.DataArray: + """Load audio for a Recording and compute the processed spectrogram. + + Performs the full pipeline for the entire duration of the recording. + + Parameters + ---------- + recording : data.Recording + The Recording object. + audio_dir : PathLike, optional + Directory containing the audio file. + + Returns + ------- + xr.DataArray + The final processed spectrogram. + """ wav = self.load_recording_audio(recording, audio_dir=audio_dir) return self.spectrogram_builder( wav, @@ -156,6 +310,22 @@ class StandardPreprocessor(Preprocessor): clip: data.Clip, audio_dir: Optional[data.PathLike] = None, ) -> xr.DataArray: + """Load audio for a Clip and compute the final processed spectrogram. + + Performs the full pipeline for the specified clip segment. + + Parameters + ---------- + clip : data.Clip + The Clip object defining the audio segment. + audio_dir : PathLike, optional + Directory containing the audio file. + + Returns + ------- + xr.DataArray + The final processed spectrogram. + """ wav = self.load_clip_audio(clip, audio_dir=audio_dir) return self.spectrogram_builder( wav, @@ -165,6 +335,27 @@ class StandardPreprocessor(Preprocessor): def compute_spectrogram( self, wav: Union[xr.DataArray, np.ndarray] ) -> xr.DataArray: + """Compute the spectrogram from a pre-loaded audio waveform. + + Applies the configured spectrogram generation steps + (STFT, scaling, etc.) using the internal `spectrogram_builder`. + + If `wav` is a NumPy array, the `default_samplerate` stored in this + preprocessor instance will be used. If `wav` is an xarray DataArray + with time coordinates, the sample rate derived from those coordinates + will take precedence over `default_samplerate`. + + Parameters + ---------- + wav : Union[xr.DataArray, np.ndarray] + The input audio waveform. If numpy array, `default_samplerate` + stored in this object will be assumed. + + Returns + ------- + xr.DataArray + The computed spectrogram. + """ return self.spectrogram_builder( wav, samplerate=self.default_samplerate, @@ -175,12 +366,64 @@ def load_preprocessing_config( path: data.PathLike, field: Optional[str] = None, ) -> PreprocessingConfig: + """Load the unified preprocessing configuration from a file. + + Reads a configuration file (YAML) and validates it against the + `PreprocessingConfig` schema, potentially extracting data from a nested + field. + + Parameters + ---------- + path : PathLike + Path to the configuration file. + field : str, optional + Dot-separated path to a nested section within the file containing the + preprocessing configuration (e.g., "train.preprocessing"). If None, the + entire file content is validated as the PreprocessingConfig. + + Returns + ------- + PreprocessingConfig + Loaded and validated preprocessing configuration object. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + yaml.YAMLError + If the file content is not valid YAML. + pydantic.ValidationError + If the loaded config data does not conform to PreprocessingConfig. + KeyError, TypeError + If `field` specifies an invalid path. + """ return load_config(path, schema=PreprocessingConfig, field=field) -def build_preprocessor_from_config( - config: PreprocessingConfig, +def build_preprocessor( + config: Optional[PreprocessingConfig] = None, ) -> Preprocessor: + """Factory function to build the standard preprocessor from configuration. + + Creates instances of the required `AudioLoader` and `SpectrogramBuilder` + based on the provided `PreprocessingConfig` (or defaults if config is None), + determines the effective default sample rate, and initializes the + `StandardPreprocessor`. + + Parameters + ---------- + config : PreprocessingConfig, optional + The unified preprocessing configuration object. If None, default + configurations for audio and spectrogram processing will be used. + + Returns + ------- + Preprocessor + An initialized `StandardPreprocessor` instance ready to process audio + according to the configuration. + """ + config = config or PreprocessingConfig() + default_samplerate = ( config.audio.resample.samplerate if config.audio.resample diff --git a/tests/test_preprocessing/test_audio.py b/tests/test_preprocessing/test_audio.py index 4a96c32..15fc13d 100644 --- a/tests/test_preprocessing/test_audio.py +++ b/tests/test_preprocessing/test_audio.py @@ -125,7 +125,7 @@ def no_center_config() -> audio.AudioConfig: def resample_fourier_config() -> audio.AudioConfig: return audio.AudioConfig( resample=audio.ResampleConfig( - samplerate=audio.TARGET_SAMPLERATE_HZ // 2, mode="fourier" + samplerate=audio.TARGET_SAMPLERATE_HZ // 2, method="fourier" ) ) @@ -133,21 +133,21 @@ def resample_fourier_config() -> audio.AudioConfig: def test_resample_config_defaults(): config = audio.ResampleConfig() assert config.samplerate == audio.TARGET_SAMPLERATE_HZ - assert config.mode == "poly" + assert config.method == "poly" def test_audio_config_defaults(): config = audio.AudioConfig() assert config.resample is not None assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ - assert config.resample.mode == "poly" + assert config.resample.method == "poly" assert config.scale == audio.SCALE_RAW_AUDIO assert config.center is True assert config.duration == audio.DEFAULT_DURATION def test_audio_config_override(): - resample_cfg = audio.ResampleConfig(samplerate=44100, mode="fourier") + resample_cfg = audio.ResampleConfig(samplerate=44100, method="fourier") config = audio.AudioConfig( resample=resample_cfg, scale=True, @@ -206,7 +206,7 @@ def test_resample_audio(orig_sr, target_sr, mode): duration = 0.1 wave = create_xr_wave(orig_sr, duration) resampled_wave = audio.resample_audio( - wave, samplerate=target_sr, mode=mode, dtype=np.float32 + wave, samplerate=target_sr, method=mode, dtype=np.float32 ) expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr)) assert resampled_wave.sizes["time"] == expected_samples @@ -233,7 +233,7 @@ def test_resample_audio_same_samplerate(): def test_resample_audio_invalid_mode_raises(): wave = create_xr_wave(48000, 0.1) with pytest.raises(NotImplementedError): - audio.resample_audio(wave, samplerate=96000, mode="invalid_mode") + audio.resample_audio(wave, samplerate=96000, method="invalid_mode") def test_resample_audio_no_time_dim_raises():