From 19febf22167cec02cf26c0eee07ce6bc3ff8563f Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 17 Apr 2025 15:36:21 +0100 Subject: [PATCH] Separated the protocols to separate types module --- batdetect2/preprocess/__init__.py | 196 ++++++++++---- batdetect2/preprocess/audio.py | 128 ++------- batdetect2/preprocess/config.py | 31 --- batdetect2/preprocess/spectrogram.py | 233 ++++++----------- batdetect2/preprocess/types.py | 377 +++++++++++++++++++++++++++ 5 files changed, 627 insertions(+), 338 deletions(-) delete mode 100644 batdetect2/preprocess/config.py create mode 100644 batdetect2/preprocess/types.py diff --git a/batdetect2/preprocess/__init__.py b/batdetect2/preprocess/__init__.py index 62f9c82..b1a8022 100644 --- a/batdetect2/preprocess/__init__.py +++ b/batdetect2/preprocess/__init__.py @@ -1,87 +1,193 @@ """Module containing functions for preprocessing audio clips.""" -from functools import partial -from typing import Callable, Optional, Protocol +from typing import Optional, Union +import numpy as np import xarray as xr +from pydantic import Field from soundevent import data +from batdetect2.configs import BaseConfig, load_config from batdetect2.preprocess.audio import ( + DEFAULT_DURATION, + SCALE_RAW_AUDIO, + TARGET_SAMPLERATE_HZ, AudioConfig, ResampleConfig, + adjust_audio_duration, + build_audio_loader, + convert_to_xr, load_clip_audio, -) -from batdetect2.preprocess.config import ( - PreprocessingConfig, - load_preprocessing_config, + load_file_audio, + load_recording_audio, + resample_audio, ) from batdetect2.preprocess.spectrogram import ( MAX_FREQ, MIN_FREQ, - AmplitudeScaleConfig, + ConfigurableSpectrogramBuilder, FrequencyConfig, - LogScaleConfig, - PcenScaleConfig, - Scales, + PcenConfig, SpecSizeConfig, SpectrogramConfig, STFTConfig, + build_spectrogram_builder, compute_spectrogram, + get_spectrogram_resolution, +) +from batdetect2.preprocess.types import ( + AudioLoader, + Preprocessor, + SpectrogramBuilder, ) __all__ = [ - "AmplitudeScaleConfig", "AudioConfig", + "AudioLoader", + "ConfigurableSpectrogramBuilder", + "DEFAULT_DURATION", + "FrequencyConfig", "FrequencyConfig", - "LogScaleConfig", "MAX_FREQ", "MIN_FREQ", - "PcenScaleConfig", + "PcenConfig", + "PcenConfig", "PreprocessingConfig", "ResampleConfig", + "SCALE_RAW_AUDIO", + "STFTConfig", "STFTConfig", - "Scales", "SpecSizeConfig", + "SpecSizeConfig", + "SpectrogramBuilder", "SpectrogramConfig", + "SpectrogramConfig", + "TARGET_SAMPLERATE_HZ", + "adjust_audio_duration", + "build_audio_loader", + "build_spectrogram_builder", + "compute_spectrogram", + "convert_to_xr", + "get_spectrogram_resolution", + "load_clip_audio", + "load_file_audio", "load_preprocessing_config", - "preprocess_audio_clip", + "load_recording_audio", + "resample_audio", ] -class AudioPreprocessor(Protocol): - def __call__( +class PreprocessingConfig(BaseConfig): + """Configuration for preprocessing data.""" + + audio: AudioConfig = Field(default_factory=AudioConfig) + spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) + + +class StandardPreprocessor(Preprocessor): + audio_loader: AudioLoader + spectrogram_builder: SpectrogramBuilder + default_samplerate: int + + def __init__( + self, + audio_loader: AudioLoader, + spectrogram_builder: SpectrogramBuilder, + default_samplerate: int, + ) -> None: + self.audio_loader = audio_loader + self.spectrogram_builder = spectrogram_builder + self.default_samplerate = default_samplerate + + def load_file_audio( + self, + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + return self.audio_loader.load_file( + path, + audio_dir=audio_dir, + ) + + def load_recording_audio( + self, + recording: data.Recording, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + return self.audio_loader.load_recording( + recording, + audio_dir=audio_dir, + ) + + def load_clip_audio( self, clip: data.Clip, audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: ... + ) -> xr.DataArray: + return self.audio_loader.load_clip( + clip, + audio_dir=audio_dir, + ) + + def preprocess_file( + self, + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + wav = self.load_file_audio(path, audio_dir=audio_dir) + return self.spectrogram_builder( + wav, + samplerate=self.default_samplerate, + ) + + def preprocess_recording( + self, + recording: data.Recording, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + wav = self.load_recording_audio(recording, audio_dir=audio_dir) + return self.spectrogram_builder( + wav, + samplerate=self.default_samplerate, + ) + + def preprocess_clip( + self, + clip: data.Clip, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + wav = self.load_clip_audio(clip, audio_dir=audio_dir) + return self.spectrogram_builder( + wav, + samplerate=self.default_samplerate, + ) + + def compute_spectrogram( + self, wav: Union[xr.DataArray, np.ndarray] + ) -> xr.DataArray: + return self.spectrogram_builder( + wav, + samplerate=self.default_samplerate, + ) + + +def load_preprocessing_config( + path: data.PathLike, + field: Optional[str] = None, +) -> PreprocessingConfig: + return load_config(path, schema=PreprocessingConfig, field=field) def build_preprocessor_from_config( config: PreprocessingConfig, -) -> AudioPreprocessor: - return partial(preprocess_audio_clip, config=config) - - -def preprocess_audio_clip( - clip: data.Clip, - config: Optional[PreprocessingConfig] = None, - audio_dir: Optional[data.PathLike] = None, -) -> xr.DataArray: - """Preprocesses audio clip to generate spectrogram. - - Parameters - ---------- - clip - The audio clip to preprocess. - config - Configuration for preprocessing. - - Returns - ------- - xr.DataArray - Preprocessed spectrogram. - - """ - config = config or PreprocessingConfig() - wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir) - return compute_spectrogram(wav, config=config.spectrogram) +) -> Preprocessor: + default_samplerate = ( + config.audio.resample.samplerate + if config.audio.resample + else TARGET_SAMPLERATE_HZ + ) + return StandardPreprocessor( + audio_loader=build_audio_loader(config.audio), + spectrogram_builder=build_spectrogram_builder(config.spectrogram), + default_samplerate=default_samplerate, + ) diff --git a/batdetect2/preprocess/audio.py b/batdetect2/preprocess/audio.py index 6c4d68e..481e1de 100644 --- a/batdetect2/preprocess/audio.py +++ b/batdetect2/preprocess/audio.py @@ -20,7 +20,7 @@ The primary interface is the `AudioLoader` protocol, with `AudioConfig`. """ -from typing import Optional, Protocol +from typing import Optional import numpy as np import xarray as xr @@ -32,9 +32,9 @@ from soundevent.arrays import operations as ops from soundfile import LibsndfileError from batdetect2.configs import BaseConfig +from batdetect2.preprocess.types import AudioLoader __all__ = [ - "AudioLoader", "ResampleConfig", "AudioConfig", "ConfigurableAudioLoader", @@ -60,106 +60,6 @@ DEFAULT_DURATION = None """Default setting for target audio duration in seconds.""" -class AudioLoader(Protocol): - """Defines the interface for an audio loading and processing component. - - An AudioLoader is responsible for retrieving audio data corresponding to - different soundevent objects (files, Recordings, Clips) and applying a - configured set of initial preprocessing steps. Adhering to this protocol - allows for different loading strategies or implementations. - """ - - def load_file( - self, - path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess audio directly from a file path. - - Parameters - ---------- - path : PathLike - Path to the audio file. - audio_dir : PathLike, optional - A directory prefix to prepend to the path if `path` is relative. - - Returns - ------- - xr.DataArray - The loaded and preprocessed audio waveform as an xarray DataArray - with time coordinates. Typically loads only the first channel. - - Raises - ------ - FileNotFoundError - If the audio file cannot be found. - Exception - If the audio file cannot be loaded or processed. - """ - ... - - def load_recording( - self, - recording: data.Recording, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess the entire audio for a Recording object. - - Parameters - ---------- - recording : data.Recording - The Recording object containing metadata about the audio file. - audio_dir : PathLike, optional - A directory where the audio file associated with the recording - can be found, especially if the path in the recording is relative. - - Returns - ------- - xr.DataArray - The loaded and preprocessed audio waveform. Typically loads only - the first channel. - - Raises - ------ - FileNotFoundError - If the audio file associated with the recording cannot be found. - Exception - If the audio file cannot be loaded or processed. - """ - ... - - def load_clip( - self, - clip: data.Clip, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess the audio segment defined by a Clip object. - - Parameters - ---------- - clip : data.Clip - The Clip object specifying the recording and the start/end times - of the segment to load. - audio_dir : PathLike, optional - A directory where the audio file associated with the clip's - recording can be found. - - Returns - ------- - xr.DataArray - The loaded and preprocessed audio waveform for the specified clip - duration. Typically loads only the first channel. - - Raises - ------ - FileNotFoundError - If the audio file associated with the clip cannot be found. - Exception - If the audio file cannot be loaded or processed. - """ - ... - - class ResampleConfig(BaseConfig): """Configuration for audio resampling. @@ -167,7 +67,7 @@ class ResampleConfig(BaseConfig): ---------- samplerate : int, default=256000 The target sample rate in Hz to resample the audio to. Must be > 0. - mode : str, default="poly" + method : str, default="poly" The resampling algorithm to use. Options: - "poly": Polyphase resampling using `scipy.signal.resample_poly`. Generally fast. @@ -177,7 +77,7 @@ class ResampleConfig(BaseConfig): """ samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) - mode: str = "poly" + method: str = "poly" class AudioConfig(BaseConfig): @@ -191,8 +91,8 @@ class AudioConfig(BaseConfig): ---------- resample : ResampleConfig, optional Configuration for resampling. If provided (or defaulted), audio will - be resampled to the specified `samplerate` using the specified `mode`. - If set to `None` in the config file, resampling is skipped. + be resampled to the specified `samplerate` using the specified + `method`. If set to `None` in the config file, resampling is skipped. Defaults to a ResampleConfig instance with standard settings. scale : bool, default=False If True, scales the audio waveform using peak normalization so that @@ -579,14 +479,14 @@ def adjust_audio_duration( def resample_audio( wav: xr.DataArray, samplerate: int = TARGET_SAMPLERATE_HZ, - mode: str = "poly", + method: str = "poly", dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: """Resample an audio waveform DataArray to a target sample rate. Updates the 'time' coordinate axis according to the new sample rate and number of samples. Uses either polyphase (`scipy.signal.resample_poly`) - or Fourier method (`scipy.signal.resample`) based on the `mode`. + or Fourier method (`scipy.signal.resample`) based on the `method`. Parameters ---------- @@ -594,7 +494,7 @@ def resample_audio( Input audio waveform with 'time' dimension and coordinates. samplerate : int, default=TARGET_SAMPLERATE_HZ Target sample rate in Hz. - mode : str, default="poly" + method : str, default="poly" Resampling algorithm: "poly" or "fourier". dtype : DTypeLike, default=np.float32 Target data type for the resampled array. @@ -610,7 +510,7 @@ def resample_audio( ------ ValueError If `wav` lacks a 'time' dimension, the original sample rate cannot - be determined, `samplerate` is non-positive, or `mode` is invalid. + be determined, `samplerate` is non-positive, or `method` is invalid. """ if "time" not in wav.dims: raise ValueError("Audio must have a time dimension") @@ -622,14 +522,14 @@ def resample_audio( if original_samplerate == samplerate: return wav.astype(dtype) - if mode == "poly": + if method == "poly": resampled = resample_audio_poly( wav, sr_orig=original_samplerate, sr_new=samplerate, axis=time_axis, ) - elif mode == "fourier": + elif method == "fourier": resampled = resample_audio_fourier( wav, sr_orig=original_samplerate, @@ -637,7 +537,9 @@ def resample_audio( axis=time_axis, ) else: - raise NotImplementedError(f"Resampling mode '{mode}' not implemented") + raise NotImplementedError( + f"Resampling method '{method}' not implemented" + ) start, stop = arrays.get_dim_range(wav, dim="time") times = np.linspace( diff --git a/batdetect2/preprocess/config.py b/batdetect2/preprocess/config.py deleted file mode 100644 index 4baac44..0000000 --- a/batdetect2/preprocess/config.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Optional - -from pydantic import Field -from soundevent.data import PathLike - -from batdetect2.configs import BaseConfig, load_config -from batdetect2.preprocess.audio import ( - AudioConfig, -) -from batdetect2.preprocess.spectrogram import ( - SpectrogramConfig, -) - -__all__ = [ - "PreprocessingConfig", - "load_preprocessing_config", -] - - -class PreprocessingConfig(BaseConfig): - """Configuration for preprocessing data.""" - - audio: AudioConfig = Field(default_factory=AudioConfig) - spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) - - -def load_preprocessing_config( - path: PathLike, - field: Optional[str] = None, -) -> PreprocessingConfig: - return load_config(path, schema=PreprocessingConfig, field=field) diff --git a/batdetect2/preprocess/spectrogram.py b/batdetect2/preprocess/spectrogram.py index 8e7e963..b667f67 100644 --- a/batdetect2/preprocess/spectrogram.py +++ b/batdetect2/preprocess/spectrogram.py @@ -6,20 +6,20 @@ spectrogram representations suitable for input into deep learning models like BatDetect2. It offers a configurable pipeline including: -1. Short-Time Fourier Transform (STFT) calculation. +1. Short-Time Fourier Transform (STFT) calculation to get magnitude. 2. Frequency axis cropping to a relevant range. -3. Amplitude scaling (e.g., Logarithmic, Per-Channel Energy Normalization - - PCEN). -4. Simple denoising (optional). -5. Resizing to target dimensions (optional). -6. Final peak normalization (optional). +3. Per-Channel Energy Normalization (PCEN) (optional). +4. Amplitude scaling/representation (dB, power, or linear amplitude). +5. Simple spectral mean subtraction denoising (optional). +6. Resizing to target dimensions (optional). +7. Final peak normalization (optional). Configuration is managed via the `SpectrogramConfig` class, allowing for reproducible spectrogram generation consistent between training and inference. The core computation is performed by `compute_spectrogram`. """ -from typing import Literal, Optional, Protocol, Union +from typing import Literal, Optional, Union import librosa import librosa.core.spectrum @@ -32,65 +32,23 @@ from soundevent.arrays import operations as ops from batdetect2.configs import BaseConfig from batdetect2.preprocess.audio import convert_to_xr +from batdetect2.preprocess.types import SpectrogramBuilder __all__ = [ - "SpectrogramBuilder", "STFTConfig", "FrequencyConfig", "SpecSizeConfig", - "LogScaleConfig", - "PcenScaleConfig", - "AmplitudeScaleConfig", - "Scales", + "PcenConfig", "SpectrogramConfig", "ConfigurableSpectrogramBuilder", "build_spectrogram_builder", "compute_spectrogram", "get_spectrogram_resolution", + "MIN_FREQ", + "MAX_FREQ", ] -class SpectrogramBuilder(Protocol): - """Defines the interface for a spectrogram generation component. - - A SpectrogramBuilder takes a waveform (as numpy array or xarray DataArray) - and produces a spectrogram (as an xarray DataArray) based on its internal - configuration or implementation. - """ - - def __call__( - self, - wav: Union[np.ndarray, xr.DataArray], - samplerate: Optional[int] = None, - ) -> xr.DataArray: - """Generate a spectrogram from an audio waveform. - - Parameters - ---------- - wav : Union[np.ndarray, xr.DataArray] - The input audio waveform. If a numpy array, `samplerate` must - also be provided. If an xarray DataArray, it must have a 'time' - coordinate from which the sample rate can be inferred. - samplerate : int, optional - The sample rate of the audio in Hz. Required if `wav` is a - numpy array. If `wav` is an xarray DataArray, this parameter is - ignored as the sample rate is derived from the coordinates. - - Returns - ------- - xr.DataArray - The computed spectrogram as an xarray DataArray with 'time' and - 'frequency' coordinates. - - Raises - ------ - ValueError - If `wav` is a numpy array and `samplerate` is not provided, or - if `wav` is an xarray DataArray without a valid 'time' coordinate. - """ - ... - - MIN_FREQ = 10_000 """Default minimum frequency (Hz) for spectrogram frequency cropping.""" @@ -151,109 +109,84 @@ class SpecSizeConfig(BaseConfig): resize_factor : float, optional Factor by which to resize the spectrogram along the time axis *after* STFT calculation. A value of 0.5 halves the number of time bins, - 2.0 doubles it. If None (default), no resizing along the time axis - is performed relative to the STFT output width. Must be > 0 if provided. + 2.0 doubles it. If None (default), no resizing along the time axis is + performed relative to the STFT output width. Must be > 0 if provided. """ height: int = 128 resize_factor: Optional[float] = 0.5 -class LogScaleConfig(BaseConfig): - """Configuration marker for using Logarithmic Amplitude Scaling.""" +class PcenConfig(BaseConfig): + """Configuration for Per-Channel Energy Normalization (PCEN). - name: Literal["log"] = "log" - - -class PcenScaleConfig(BaseConfig): - """Configuration for Per-Channel Energy Normalization (PCEN) scaling. - - PCEN is an adaptive gain control method often used for audio event - detection. + PCEN is an adaptive gain control method that can help emphasize transients + and suppress stationary noise. Applied after STFT and frequency cropping, + but before final amplitude scaling (dB, power, amplitude). Attributes ---------- - name : Literal["pcen"] - Discriminator field identifying this scaling type. time_constant : float, default=0.4 - Time constant (in seconds) for the PCEN smoothing filter. Controls how - quickly the normalization adapts to energy changes. + Time constant (in seconds) for the PCEN smoothing filter. Controls + how quickly the normalization adapts to energy changes. gain : float, default=0.98 - Gain factor (alpha in some formulations). Controls the AGC behavior. + Gain factor (alpha). Controls the adaptive gain component. bias : float, default=2.0 - Bias factor (delta in some formulations). Added before the - exponentiation. + Bias factor (delta). Added before the exponentiation. power : float, default=0.5 - Exponent (r in some formulations). Controls the compression - characteristic. + Exponent (r). Controls the compression characteristic. """ - name: Literal["pcen"] = "pcen" time_constant: float = 0.4 gain: float = 0.98 bias: float = 2 power: float = 0.5 -class AmplitudeScaleConfig(BaseConfig): - """Configuration marker for using Linear Amplitude (no scaling applied). - - Note: The actual output is typically magnitude from STFT, not raw amplitude. - This option essentially skips log or PCEN scaling. - """ - - name: Literal["amplitude"] = "amplitude" - - -Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig] -"""Type alias for the different amplitude scaling configuration options.""" - - class SpectrogramConfig(BaseConfig): - """Unified configuration for spectrogram generation. + """Unified configuration for spectrogram generation pipeline. - Aggregates settings for STFT, frequency selection, amplitude scaling, - resizing, and optional post-processing steps like denoising and final - normalization. + Aggregates settings for all steps involved in converting a preprocessed + audio waveform into a final spectrogram representation suitable for model input. Attributes ---------- stft : STFTConfig - Configuration for the Short-Time Fourier Transform. Defaults to standard - settings via `STFTConfig`. + Configuration for the initial Short-Time Fourier Transform. + Defaults to standard settings via `STFTConfig`. frequencies : FrequencyConfig - Configuration for cropping the frequency range. Defaults to standard - settings via `FrequencyConfig`. - scale : Scales - Configuration for amplitude scaling. Determines whether to apply - log scaling, PCEN, or leave as linear magnitude. Defaults to PCEN - via `PcenScaleConfig`. Use the `name` field ("log", "pcen", "amplitude") - in config files to select the type and provide relevant parameters. + Configuration for cropping the frequency range after STFT. + Defaults to standard settings via `FrequencyConfig`. + pcen : PcenConfig, optional + Configuration for applying Per-Channel Energy Normalization (PCEN). If + provided, PCEN is applied after frequency cropping. If None or omitted + (default), PCEN is skipped. + scale : Literal["dB", "amplitude", "power"], default="amplitude" + Determines the final amplitude representation *after* optional PCEN. + - "amplitude": Use linear magnitude values (output of STFT or PCEN). + - "power": Use power values (magnitude squared). + - "dB": Use logarithmic (decibel-like) scaling applied to the magnitude + (or PCEN output if enabled). Calculated as `log1p(C * S)`. size : SpecSizeConfig, optional - Configuration for resizing the final spectrogram dimensions (height in - frequency bins, optional time resizing factor). If None or omitted, - no resizing is performed after STFT and frequency cropping. Defaults - to standard settings via `SpecSizeConfig`. - denoise : bool, default=True - If True (default), applies a simple spectral mean subtraction denoising - step after amplitude scaling. - max_scale : bool, default=False + Configuration for resizing the spectrogram dimensions + (frequency height, optional time width factor). Applied after PCEN and + scaling. If None (default), no resizing is performed. + spectral_mean_substraction : bool, default=True + If True (default), applies simple spectral mean subtraction denoising + *after* PCEN and amplitude scaling, but *before* resizing. + peak_normalize : bool, default=False If True, applies a final peak normalization to the spectrogram *after* - all other steps (including log/PCEN scaling and resizing), scaling the - maximum value across the entire spectrogram to 1.0. If False (default), - this final scaling is skipped. **Note:** Applying this after log or PCEN - scaling will alter the characteristics of those scales. + all other steps (including resizing), scaling the overall maximum value + to 1.0. If False (default), this final normalization is skipped. """ stft: STFTConfig = Field(default_factory=STFTConfig) frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) - scale: Scales = Field( - default_factory=PcenScaleConfig, - discriminator="name", - ) + pcen: Optional[PcenConfig] = Field(default_factory=PcenConfig) + scale: Literal["dB", "amplitude", "power"] = "amplitude" size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) - denoise: bool = True - max_scale: bool = False + spectral_mean_substraction: bool = True + peak_normalize: bool = False class ConfigurableSpectrogramBuilder(SpectrogramBuilder): @@ -362,13 +295,13 @@ def compute_spectrogram( """Compute a spectrogram from a waveform using specified configurations. Applies a sequence of operations based on the `config`: - 1. Compute STFT magnitude (`stft`). 2. Crop frequency axis (`crop_spectrogram_frequencies`). - 3. Apply amplitude scaling (log, PCEN, or none) (`scale_spectrogram`). - 4. Apply denoising if enabled (`denoise_spectrogram`). - 5. Resize dimensions if specified (`resize_spectrogram`). - 6. Apply final peak normalization if enabled (`max_scale`). + 3. Apply PCEN if configured (`apply_pcen`). + 4. Apply final amplitude scaling (dB, power, amplitude) (`scale_spectrogram`). + 5. Apply spectral mean subtraction denoising if enabled. + 6. Resize dimensions if specified (`resize_spectrogram`). + 7. Apply final peak normalization if enabled. Parameters ---------- @@ -411,10 +344,19 @@ def compute_spectrogram( max_freq=config.frequencies.max_freq, ) + if config.pcen: + spec = apply_pcen( + spec, + time_constant=config.pcen.time_constant, + gain=config.pcen.gain, + power=config.pcen.power, + bias=config.pcen.bias, + ) + spec = scale_spectrogram(spec, scale=config.scale) - if config.denoise: - spec = denoise_spectrogram(spec) + if config.spectral_mean_substraction: + spec = remove_spectral_mean(spec) if config.size: spec = resize_spectrogram( @@ -423,7 +365,7 @@ def compute_spectrogram( resize_factor=config.size.resize_factor, ) - if config.max_scale: + if config.peak_normalize: spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) return spec.astype(dtype) @@ -550,7 +492,7 @@ def stft( ) -def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray: +def remove_spectral_mean(spec: xr.DataArray) -> xr.DataArray: """Apply simple spectral mean subtraction for denoising. Subtracts the mean value of each frequency bin (calculated across time) @@ -576,23 +518,22 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray: def scale_spectrogram( spec: xr.DataArray, - scale: Scales, + scale: Literal["dB", "power", "amplitude"], dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: - """Apply configured amplitude scaling to the spectrogram. + """Apply final amplitude scaling/representation to the spectrogram. - Dispatches to the appropriate scaling function (log, PCEN) based on the - `scale` configuration object's `name` field. If `scale.name` is - "amplitude", the spectrogram is returned unchanged (as it's already - magnitude/amplitude). + Converts the input magnitude spectrogram based on the `scale` type: + - "dB": Applies logarithmic scaling `log1p(C * S)`. + - "power": Squares the magnitude values `S^2`. + - "amplitude": Returns the input magnitude values `S` unchanged. Parameters ---------- spec : xr.DataArray - Input magnitude spectrogram. - scale : Scales - The configuration object specifying the scaling method and parameters - (instance of LogScaleConfig, PcenScaleConfig, or AmplitudeScaleConfig). + Input magnitude spectrogram (potentially after PCEN). + scale : Literal["dB", "power", "amplitude"] + The target amplitude representation. dtype : DTypeLike, default=np.float32 Target data type for the output scaled spectrogram. @@ -601,22 +542,16 @@ def scale_spectrogram( xr.DataArray Spectrogram with the specified amplitude scaling applied. """ - if scale.name == "log": + if scale == "dB": return scale_log(spec, dtype=dtype) - if scale.name == "pcen": - return scale_pcen( - spec, - time_constant=scale.time_constant, - gain=scale.gain, - power=scale.power, - bias=scale.bias, - ) + if scale == "power": + return spec**2 return spec -def scale_pcen( +def apply_pcen( spec: xr.DataArray, time_constant: float = 0.4, gain: float = 0.98, diff --git a/batdetect2/preprocess/types.py b/batdetect2/preprocess/types.py new file mode 100644 index 0000000..f7346f5 --- /dev/null +++ b/batdetect2/preprocess/types.py @@ -0,0 +1,377 @@ +"""Defines common interfaces (Protocols) for preprocessing components. + +This module centralizes the Protocol definitions used throughout the +`batdetect2.preprocess` package. Protocols define expected methods and +signatures, allowing for flexible and interchangeable implementations of +components like audio loaders and spectrogram builders. + +Using these protocols ensures that different parts of the preprocessing +pipeline can interact consistently, regardless of the specific underlying +implementation (e.g., different libraries or custom configurations). +""" + +from typing import Optional, Protocol, Union + +import numpy as np +import xarray as xr +from soundevent import data + + +class AudioLoader(Protocol): + """Defines the interface for an audio loading and processing component. + + An AudioLoader is responsible for retrieving audio data corresponding to + different soundevent objects (files, Recordings, Clips) and applying a + configured set of initial preprocessing steps. Adhering to this protocol + allows for different loading strategies or implementations. + """ + + def load_file( + self, + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + """Load and preprocess audio directly from a file path. + + Parameters + ---------- + path : PathLike + Path to the audio file. + audio_dir : PathLike, optional + A directory prefix to prepend to the path if `path` is relative. + + Returns + ------- + xr.DataArray + The loaded and preprocessed audio waveform as an xarray DataArray + with time coordinates. Typically loads only the first channel. + + Raises + ------ + FileNotFoundError + If the audio file cannot be found. + Exception + If the audio file cannot be loaded or processed. + """ + ... + + def load_recording( + self, + recording: data.Recording, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + """Load and preprocess the entire audio for a Recording object. + + Parameters + ---------- + recording : data.Recording + The Recording object containing metadata about the audio file. + audio_dir : PathLike, optional + A directory where the audio file associated with the recording + can be found, especially if the path in the recording is relative. + + Returns + ------- + xr.DataArray + The loaded and preprocessed audio waveform. Typically loads only + the first channel. + + Raises + ------ + FileNotFoundError + If the audio file associated with the recording cannot be found. + Exception + If the audio file cannot be loaded or processed. + """ + ... + + def load_clip( + self, + clip: data.Clip, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + """Load and preprocess the audio segment defined by a Clip object. + + Parameters + ---------- + clip : data.Clip + The Clip object specifying the recording and the start/end times + of the segment to load. + audio_dir : PathLike, optional + A directory where the audio file associated with the clip's + recording can be found. + + Returns + ------- + xr.DataArray + The loaded and preprocessed audio waveform for the specified clip + duration. Typically loads only the first channel. + + Raises + ------ + FileNotFoundError + If the audio file associated with the clip cannot be found. + Exception + If the audio file cannot be loaded or processed. + """ + ... + + +class SpectrogramBuilder(Protocol): + """Defines the interface for a spectrogram generation component. + + A SpectrogramBuilder takes a waveform (as numpy array or xarray DataArray) + and produces a spectrogram (as an xarray DataArray) based on its internal + configuration or implementation. + """ + + def __call__( + self, + wav: Union[np.ndarray, xr.DataArray], + samplerate: Optional[int] = None, + ) -> xr.DataArray: + """Generate a spectrogram from an audio waveform. + + Parameters + ---------- + wav : Union[np.ndarray, xr.DataArray] + The input audio waveform. If a numpy array, `samplerate` must + also be provided. If an xarray DataArray, it must have a 'time' + coordinate from which the sample rate can be inferred. + samplerate : int, optional + The sample rate of the audio in Hz. Required if `wav` is a + numpy array. If `wav` is an xarray DataArray, this parameter is + ignored as the sample rate is derived from the coordinates. + + Returns + ------- + xr.DataArray + The computed spectrogram as an xarray DataArray with 'time' and + 'frequency' coordinates. + + Raises + ------ + ValueError + If `wav` is a numpy array and `samplerate` is not provided, or + if `wav` is an xarray DataArray without a valid 'time' coordinate. + """ + ... + + +class Preprocessor(Protocol): + """Defines a high-level interface for the complete preprocessing pipeline. + + A Preprocessor combines audio loading and spectrogram generation steps. + It provides methods to go directly from source descriptions (file paths, + Recording objects, Clip objects) to the final spectrogram representation + needed by the model. It may also expose intermediate steps like audio + loading or spectrogram computation from a waveform. + """ + + def preprocess_file( + self, + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + """Load audio from a file and compute the final processed spectrogram. + + Performs the full pipeline: + + Load -> Preprocess Audio -> Compute Spectrogram. + + Parameters + ---------- + path : PathLike + Path to the audio file. + audio_dir : PathLike, optional + A directory prefix if `path` is relative. + + Returns + ------- + xr.DataArray + The final processed spectrogram. + + Raises + ------ + FileNotFoundError + If the audio file cannot be found. + Exception + If any step in the loading or preprocessing fails. + """ + ... + + def preprocess_recording( + self, + recording: data.Recording, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + """Load audio for a Recording and compute the processed spectrogram. + + Performs the full pipeline for the entire duration of the recording. + + Parameters + ---------- + recording : data.Recording + The Recording object. + audio_dir : PathLike, optional + Directory containing the audio file. + + Returns + ------- + xr.DataArray + The final processed spectrogram. + + Raises + ------ + FileNotFoundError + If the audio file cannot be found. + Exception + If any step in the loading or preprocessing fails. + """ + ... + + def preprocess_clip( + self, + clip: data.Clip, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + """Load audio for a Clip and compute the final processed spectrogram. + + Performs the full pipeline for the specified clip segment. + + Parameters + ---------- + clip : data.Clip + The Clip object defining the audio segment. + audio_dir : PathLike, optional + Directory containing the audio file. + + Returns + ------- + xr.DataArray + The final processed spectrogram. + + Raises + ------ + FileNotFoundError + If the audio file cannot be found. + Exception + If any step in the loading or preprocessing fails. + """ + ... + + def load_file_audio( + self, + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + """Load and preprocess *only* the audio waveform from a file path. + + Performs the initial audio loading and waveform processing steps + (like resampling, scaling), but stops *before* spectrogram generation. + + Parameters + ---------- + path : PathLike + Path to the audio file. + audio_dir : PathLike, optional + A directory prefix if `path` is relative. + + Returns + ------- + xr.DataArray + The loaded and preprocessed audio waveform. + + Raises + ------ + FileNotFoundError, Exception + If audio loading/preprocessing fails. + """ + ... + + def load_recording_audio( + self, + recording: data.Recording, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + """Load and preprocess *only* the audio waveform for a Recording. + + Performs the initial audio loading and waveform processing steps + for the entire recording duration. + + Parameters + ---------- + recording : data.Recording + The Recording object. + audio_dir : PathLike, optional + Directory containing the audio file. + + Returns + ------- + xr.DataArray + The loaded and preprocessed audio waveform. + + Raises + ------ + FileNotFoundError, Exception + If audio loading/preprocessing fails. + """ + ... + + def load_clip_audio( + self, + clip: data.Clip, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.DataArray: + """Load and preprocess *only* the audio waveform for a Clip. + + Performs the initial audio loading and waveform processing steps + for the specified clip segment. + + Parameters + ---------- + clip : data.Clip + The Clip object defining the segment. + audio_dir : PathLike, optional + Directory containing the audio file. + + Returns + ------- + xr.DataArray + The loaded and preprocessed audio waveform segment. + + Raises + ------ + FileNotFoundError, Exception + If audio loading/preprocessing fails. + """ + ... + + def compute_spectrogram( + self, + wav: Union[xr.DataArray, np.ndarray], + ) -> xr.DataArray: + """Compute the spectrogram from a pre-loaded audio waveform. + + Applies the spectrogram generation steps (STFT, scaling, etc.) defined + by the `SpectrogramBuilder` component of the preprocessor to an + already loaded (and potentially preprocessed) waveform. + + Parameters + ---------- + wav : Union[xr.DataArray, np.ndarray] + The input audio waveform. If numpy array, `samplerate` is required. + samplerate : int, optional + Sample rate in Hz (required if `wav` is np.ndarray). + + Returns + ------- + xr.DataArray + The computed spectrogram. + + Raises + ------ + ValueError, Exception + If waveform input is invalid or spectrogram computation fails. + """ + ...