From 667b18a54de259d8ff26ff380b8a9447173aeeb5 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 25 Aug 2025 11:41:55 +0100 Subject: [PATCH] Preprocessing in pytorch --- src/batdetect2/preprocess/__init__.py | 357 +------- src/batdetect2/preprocess/audio.py | 658 ++++---------- src/batdetect2/preprocess/common.py | 24 + src/batdetect2/preprocess/spectrogram.py | 893 ++++++------------- src/batdetect2/targets/rois.py | 20 +- src/batdetect2/train/augmentations.py | 8 +- src/batdetect2/train/clips.py | 2 + src/batdetect2/train/preprocess.py | 55 +- src/batdetect2/typing/preprocess.py | 289 +----- src/batdetect2/utils/arrays.py | 56 ++ tests/conftest.py | 7 + tests/test_preprocessing/test_audio.py | 446 --------- tests/test_preprocessing/test_spectrogram.py | 411 --------- tests/test_targets/test_rois.py | 28 +- tests/test_train/test_augmentations.py | 14 +- 15 files changed, 690 insertions(+), 2578 deletions(-) create mode 100644 src/batdetect2/preprocess/common.py diff --git a/src/batdetect2/preprocess/__init__.py b/src/batdetect2/preprocess/__init__.py index 37f7ad0..f8df745 100644 --- a/src/batdetect2/preprocess/__init__.py +++ b/src/batdetect2/preprocess/__init__.py @@ -28,13 +28,12 @@ This module provides the primary interface: """ -from typing import Optional, Union +from typing import Optional -import numpy as np -import xarray as xr +import torch from loguru import logger from pydantic import Field -from soundevent import data +from soundevent.data import PathLike from batdetect2.configs import BaseConfig, load_config from batdetect2.preprocess.audio import ( @@ -44,28 +43,23 @@ from batdetect2.preprocess.audio import ( AudioConfig, ResampleConfig, build_audio_loader, + build_audio_pipeline, ) from batdetect2.preprocess.spectrogram import ( MAX_FREQ, MIN_FREQ, - ConfigurableSpectrogramBuilder, FrequencyConfig, PcenConfig, - SpecSizeConfig, SpectrogramConfig, + SpectrogramPipeline, STFTConfig, build_spectrogram_builder, - get_spectrogram_resolution, -) -from batdetect2.typing.preprocess import ( - AudioLoader, - PreprocessorProtocol, - SpectrogramBuilder, + build_spectrogram_pipeline, ) +from batdetect2.typing import PreprocessorProtocol __all__ = [ "AudioConfig", - "ConfigurableSpectrogramBuilder", "DEFAULT_DURATION", "FrequencyConfig", "MAX_FREQ", @@ -75,16 +69,11 @@ __all__ = [ "ResampleConfig", "SCALE_RAW_AUDIO", "STFTConfig", - "SpecSizeConfig", "SpectrogramConfig", - "StandardPreprocessor", "TARGET_SAMPLERATE_HZ", "build_audio_loader", - "build_preprocessor", "build_spectrogram_builder", - "get_spectrogram_resolution", "load_preprocessing_config", - "get_default_preprocessor", ] @@ -110,343 +99,61 @@ class PreprocessingConfig(BaseConfig): spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) -class StandardPreprocessor(PreprocessorProtocol): - """Standard implementation of the `Preprocessor` protocol. +def load_preprocessing_config( + path: PathLike, + field: Optional[str] = None, +) -> PreprocessingConfig: + return load_config(path, schema=PreprocessingConfig, field=field) - Orchestrates the audio loading and spectrogram generation pipeline using - an `AudioLoader` and a `SpectrogramBuilder` internally, which are - configured according to a `PreprocessingConfig`. - This class is typically instantiated using the `build_preprocessor` - factory function. +class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol): + """Standard implementation of the `Preprocessor` protocol.""" - Attributes - ---------- - audio_loader : AudioLoader - The configured audio loader instance used for waveform loading and - initial processing. - spectrogram_builder : SpectrogramBuilder - The configured spectrogram builder instance used for generating - spectrograms from waveforms. - default_samplerate : int - The sample rate (in Hz) assumed for input waveforms when they are - provided as raw NumPy arrays without coordinate information (e.g., - when calling `compute_spectrogram` directly with `np.ndarray`). - This value is derived from the `AudioConfig` (target resample rate - or default if resampling is off) and also serves as documentation - for the pipeline's intended operating sample rate. Note that when - processing `xr.DataArray` inputs that have coordinate information - (the standard internal workflow), the sample rate embedded in the - coordinates takes precedence over this default value during - spectrogram calculation. - """ - - audio_loader: AudioLoader - spectrogram_builder: SpectrogramBuilder - default_samplerate: int + samplerate: int max_freq: float min_freq: float def __init__( self, - audio_loader: AudioLoader, - spectrogram_builder: SpectrogramBuilder, - default_samplerate: int, + audio_pipeline: torch.nn.Module, + spectrogram_pipeline: SpectrogramPipeline, + samplerate: int, max_freq: float, min_freq: float, ) -> None: - """Initialize the StandardPreprocessor. - - Parameters - ---------- - audio_loader : AudioLoader - An initialized audio loader conforming to the AudioLoader protocol. - spectrogram_builder : SpectrogramBuilder - An initialized spectrogram builder conforming to the - SpectrogramBuilder protocol. - default_samplerate : int - The sample rate to assume for NumPy array inputs and potentially - reflecting the target rate of the audio config. - """ - self.audio_loader = audio_loader - self.spectrogram_builder = spectrogram_builder - self.default_samplerate = default_samplerate + super().__init__() + self.audio_pipeline = audio_pipeline + self.spectrogram_pipeline = spectrogram_pipeline + self.samplerate = samplerate self.max_freq = max_freq self.min_freq = min_freq - def load_file_audio( - self, - path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess *only* the audio waveform from a file path. - - Delegates to the internal `audio_loader`. - - Parameters - ---------- - path : PathLike - Path to the audio file. - audio_dir : PathLike, optional - A directory prefix if `path` is relative. - - Returns - ------- - xr.DataArray - The loaded and preprocessed audio waveform (typically first - channel). - """ - return self.audio_loader.load_file( - path, - audio_dir=audio_dir, - ) - - def load_recording_audio( - self, - recording: data.Recording, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess *only* the audio waveform for a Recording. - - Delegates to the internal `audio_loader`. - - Parameters - ---------- - recording : data.Recording - The Recording object. - audio_dir : PathLike, optional - Directory containing the audio file. - - Returns - ------- - xr.DataArray - The loaded and preprocessed audio waveform (typically first - channel). - """ - return self.audio_loader.load_recording( - recording, - audio_dir=audio_dir, - ) - - def load_clip_audio( - self, - clip: data.Clip, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess *only* the audio waveform for a Clip. - - Delegates to the internal `audio_loader`. - - Parameters - ---------- - clip : data.Clip - The Clip object defining the segment. - audio_dir : PathLike, optional - Directory containing the audio file. - - Returns - ------- - xr.DataArray - The loaded and preprocessed audio waveform segment (typically first - channel). - """ - return self.audio_loader.load_clip( - clip, - audio_dir=audio_dir, - ) - - def preprocess_file( - self, - path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load audio from a file and compute the final processed spectrogram. - - Performs the full pipeline: - - Load -> Preprocess Audio -> Compute Spectrogram. - - Parameters - ---------- - path : PathLike - Path to the audio file. - audio_dir : PathLike, optional - A directory prefix if `path` is relative. - - Returns - ------- - xr.DataArray - The final processed spectrogram. - """ - wav = self.load_file_audio(path, audio_dir=audio_dir) - return self.spectrogram_builder( - wav, - samplerate=self.default_samplerate, - ) - - def preprocess_recording( - self, - recording: data.Recording, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load audio for a Recording and compute the processed spectrogram. - - Performs the full pipeline for the entire duration of the recording. - - Parameters - ---------- - recording : data.Recording - The Recording object. - audio_dir : PathLike, optional - Directory containing the audio file. - - Returns - ------- - xr.DataArray - The final processed spectrogram. - """ - wav = self.load_recording_audio(recording, audio_dir=audio_dir) - return self.spectrogram_builder( - wav, - samplerate=self.default_samplerate, - ) - - def preprocess_clip( - self, - clip: data.Clip, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load audio for a Clip and compute the final processed spectrogram. - - Performs the full pipeline for the specified clip segment. - - Parameters - ---------- - clip : data.Clip - The Clip object defining the audio segment. - audio_dir : PathLike, optional - Directory containing the audio file. - - Returns - ------- - xr.DataArray - The final processed spectrogram. - """ - wav = self.load_clip_audio(clip, audio_dir=audio_dir) - return self.spectrogram_builder( - wav, - samplerate=self.default_samplerate, - ) - - def compute_spectrogram( - self, wav: Union[xr.DataArray, np.ndarray] - ) -> xr.DataArray: - """Compute the spectrogram from a pre-loaded audio waveform. - - Applies the configured spectrogram generation steps - (STFT, scaling, etc.) using the internal `spectrogram_builder`. - - If `wav` is a NumPy array, the `default_samplerate` stored in this - preprocessor instance will be used. If `wav` is an xarray DataArray - with time coordinates, the sample rate derived from those coordinates - will take precedence over `default_samplerate`. - - Parameters - ---------- - wav : Union[xr.DataArray, np.ndarray] - The input audio waveform. If numpy array, `default_samplerate` - stored in this object will be assumed. - - Returns - ------- - xr.DataArray - The computed spectrogram. - """ - return self.spectrogram_builder( - wav, - samplerate=self.default_samplerate, - ) - - -def load_preprocessing_config( - path: data.PathLike, - field: Optional[str] = None, -) -> PreprocessingConfig: - """Load the unified preprocessing configuration from a file. - - Reads a configuration file (YAML) and validates it against the - `PreprocessingConfig` schema, potentially extracting data from a nested - field. - - Parameters - ---------- - path : PathLike - Path to the configuration file. - field : str, optional - Dot-separated path to a nested section within the file containing the - preprocessing configuration (e.g., "train.preprocessing"). If None, the - entire file content is validated as the PreprocessingConfig. - - Returns - ------- - PreprocessingConfig - Loaded and validated preprocessing configuration object. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - yaml.YAMLError - If the file content is not valid YAML. - pydantic.ValidationError - If the loaded config data does not conform to PreprocessingConfig. - KeyError, TypeError - If `field` specifies an invalid path. - """ - return load_config(path, schema=PreprocessingConfig, field=field) + def forward(self, wav: torch.Tensor) -> torch.Tensor: + wav = self.audio_pipeline(wav) + return self.spectrogram_pipeline(wav) def build_preprocessor( config: Optional[PreprocessingConfig] = None, ) -> PreprocessorProtocol: - """Factory function to build the standard preprocessor from configuration. - - Creates instances of the required `AudioLoader` and `SpectrogramBuilder` - based on the provided `PreprocessingConfig` (or defaults if config is None), - determines the effective default sample rate, and initializes the - `StandardPreprocessor`. - - Parameters - ---------- - config : PreprocessingConfig, optional - The unified preprocessing configuration object. If None, default - configurations for audio and spectrogram processing will be used. - - Returns - ------- - Preprocessor - An initialized `StandardPreprocessor` instance ready to process audio - according to the configuration. - """ + """Factory function to build the standard preprocessor from configuration.""" config = config or PreprocessingConfig() logger.opt(lazy=True).debug( "Building preprocessor with config: \n{}", lambda: config.to_yaml_string(), ) - default_samplerate = ( - config.audio.resample.samplerate - if config.audio.resample - else TARGET_SAMPLERATE_HZ - ) + samplerate = config.audio.samplerate min_freq = config.spectrogram.frequencies.min_freq max_freq = config.spectrogram.frequencies.max_freq return StandardPreprocessor( - audio_loader=build_audio_loader(config.audio), - spectrogram_builder=build_spectrogram_builder(config.spectrogram), - default_samplerate=default_samplerate, + audio_pipeline=build_audio_pipeline(config.audio), + spectrogram_pipeline=build_spectrogram_pipeline( + samplerate, config.spectrogram + ), + samplerate=samplerate, min_freq=min_freq, max_freq=max_freq, ) diff --git a/src/batdetect2/preprocess/audio.py b/src/batdetect2/preprocess/audio.py index a67c065..c5c72e2 100644 --- a/src/batdetect2/preprocess/audio.py +++ b/src/batdetect2/preprocess/audio.py @@ -1,53 +1,31 @@ -"""Handles loading and initial preprocessing of audio waveforms. +"""Handles loading and initial preprocessing of audio waveforms.""" -This module provides components for loading audio data associated with -`soundevent` objects (Clips, Recordings, or raw files) and applying -fundamental waveform processing steps. These steps typically include: - -1. Loading the raw audio data. -2. Adjusting the audio clip to a fixed duration (optional). -3. Resampling the audio to a target sample rate (optional). -4. Centering the waveform (DC offset removal) (optional). -5. Scaling the waveform amplitude (optional). - -The processing pipeline is configurable via the `AudioConfig` data structure, -allowing for reproducible preprocessing consistent between model training and -inference. It uses the `soundevent` library for audio loading and basic array -operations, and `scipy` for resampling implementations. - -The primary interface is the `AudioLoader` protocol, with -`ConfigurableAudioLoader` providing a concrete implementation driven by the -`AudioConfig`. -""" - -from typing import Optional +from typing import Annotated, List, Literal, Optional, Union import numpy as np -import xarray as xr +import torch from numpy.typing import DTypeLike from pydantic import Field from scipy.signal import resample, resample_poly -from soundevent import arrays, audio, data -from soundevent.arrays import operations as ops +from soundevent import audio, data from soundfile import LibsndfileError from batdetect2.configs import BaseConfig -from batdetect2.typing.preprocess import AudioLoader +from batdetect2.preprocess.common import CenterTensor, PeakNormalize +from batdetect2.typing import AudioLoader __all__ = [ "ResampleConfig", "AudioConfig", - "ConfigurableAudioLoader", + "SoundEventAudioLoader", "build_audio_loader", "load_file_audio", "load_recording_audio", "load_clip_audio", - "adjust_audio_duration", "resample_audio", "TARGET_SAMPLERATE_HZ", "SCALE_RAW_AUDIO", "DEFAULT_DURATION", - "convert_to_xr", ] TARGET_SAMPLERATE_HZ = 256_000 @@ -76,192 +54,69 @@ class ResampleConfig(BaseConfig): resampling factors differently. """ - samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) + enabled: bool = True method: str = "poly" -class AudioConfig(BaseConfig): - """Configuration for loading and initial audio preprocessing. - - Defines the sequence of operations applied to raw audio waveforms after - loading, controlling steps like resampling, scaling, centering, and - duration adjustment. - - Attributes - ---------- - resample : ResampleConfig, optional - Configuration for resampling. If provided (or defaulted), audio will - be resampled to the specified `samplerate` using the specified - `method`. If set to `None` in the config file, resampling is skipped. - Defaults to a ResampleConfig instance with standard settings. - scale : bool, default=False - If True, scales the audio waveform using peak normalization so that - its maximum absolute amplitude is approximately 1.0. If False - (default), no amplitude scaling is applied. - center : bool, default=True - If True (default), centers the waveform by subtracting its mean - (DC offset removal). If False, the waveform is not centered. - duration : float, optional - If set to a float value (seconds), the loaded audio clip will be - adjusted (cropped or padded with zeros) to exactly this duration. - If None (default), the original duration is kept. - """ - - resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig) - scale: bool = SCALE_RAW_AUDIO - center: bool = False - duration: Optional[float] = DEFAULT_DURATION - - -class ConfigurableAudioLoader: - """Concrete implementation of the `AudioLoader` driven by `AudioConfig`. - - This class loads audio and applies preprocessing steps (resampling, - scaling, centering, duration adjustment) based on the settings provided - in an `AudioConfig` object during initialization. It delegates the actual - work to module-level functions. - """ +class SoundEventAudioLoader: + """Concrete implementation of the `AudioLoader`.""" def __init__( self, - config: AudioConfig, + samplerate: int = TARGET_SAMPLERATE_HZ, + config: Optional[ResampleConfig] = None, ): - """Initialize the ConfigurableAudioLoader. - - Parameters - ---------- - config : AudioConfig - The configuration object specifying the desired preprocessing steps - and parameters. - """ - self.config = config + self.samplerate = samplerate + self.config = config or ResampleConfig() def load_file( self, path: data.PathLike, audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess audio directly from a file path. - - Implements the `AudioLoader.load_file` method by delegating to the - `load_file_audio` function, passing the stored configuration. - - Parameters - ---------- - path : PathLike - Path to the audio file. - audio_dir : PathLike, optional - A directory prefix if `path` is relative. - - Returns - ------- - xr.DataArray - Loaded and preprocessed waveform (first channel). - """ - return load_file_audio(path, config=self.config, audio_dir=audio_dir) + ) -> np.ndarray: + """Load and preprocess audio directly from a file path.""" + return load_file_audio( + path, + samplerate=self.samplerate, + config=self.config, + audio_dir=audio_dir, + ) def load_recording( self, recording: data.Recording, audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess the entire audio for a Recording object. - - Implements the `AudioLoader.load_recording` method by delegating to the - `load_recording_audio` function, passing the stored configuration. - - Parameters - ---------- - recording : data.Recording - The Recording object. - audio_dir : PathLike, optional - Directory containing the audio file. - - Returns - ------- - xr.DataArray - Loaded and preprocessed waveform (first channel). - """ + ) -> np.ndarray: + """Load and preprocess the entire audio for a Recording object.""" return load_recording_audio( - recording, config=self.config, audio_dir=audio_dir + recording, + samplerate=self.samplerate, + config=self.config, + audio_dir=audio_dir, ) def load_clip( self, clip: data.Clip, audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess the audio segment defined by a Clip object. - - Implements the `AudioLoader.load_clip` method by delegating to the - `load_clip_audio` function, passing the stored configuration. - - Parameters - ---------- - clip : data.Clip - The Clip object specifying the segment. - audio_dir : PathLike, optional - Directory containing the audio file. - - Returns - ------- - xr.DataArray - Loaded and preprocessed waveform segment (first channel). - """ - return load_clip_audio(clip, config=self.config, audio_dir=audio_dir) - - -def build_audio_loader( - config: AudioConfig, -) -> AudioLoader: - """Factory function to create an AudioLoader based on configuration. - - Instantiates and returns a `ConfigurableAudioLoader` initialized with - the provided `AudioConfig`. The return type is `AudioLoader`, adhering - to the protocol. - - Parameters - ---------- - config : AudioConfig - The configuration object specifying preprocessing steps. - - Returns - ------- - AudioLoader - An instance of `ConfigurableAudioLoader` ready to load and process audio - according to the configuration. - """ - return ConfigurableAudioLoader(config=config) + ) -> np.ndarray: + """Load and preprocess the audio segment defined by a Clip object.""" + return load_clip_audio( + clip, + samplerate=self.samplerate, + config=self.config, + audio_dir=audio_dir, + ) def load_file_audio( path: data.PathLike, - config: Optional[AudioConfig] = None, + samplerate: Optional[int] = None, + config: Optional[ResampleConfig] = None, audio_dir: Optional[data.PathLike] = None, dtype: DTypeLike = np.float32, # type: ignore -) -> xr.DataArray: - """Load and preprocess audio from a file path using specified config. - - Creates a `soundevent.data.Recording` object from the file path and then - delegates the loading and processing to `load_recording_audio`. - - Parameters - ---------- - path : PathLike - Path to the audio file. - config : AudioConfig, optional - Audio processing configuration. If None, default settings defined - in `AudioConfig` are used. - audio_dir : PathLike, optional - Directory prefix if `path` is relative. - dtype : DTypeLike, default=np.float32 - Target NumPy data type for the loaded audio array. - - Returns - ------- - xr.DataArray - Loaded and preprocessed waveform (first channel only). - """ +) -> np.ndarray: + """Load and preprocess audio from a file path using specified config.""" try: recording = data.Recording.from_file(path) except LibsndfileError as e: @@ -271,6 +126,7 @@ def load_file_audio( return load_recording_audio( recording, + samplerate=samplerate, config=config, dtype=dtype, audio_dir=audio_dir, @@ -279,33 +135,12 @@ def load_file_audio( def load_recording_audio( recording: data.Recording, - config: Optional[AudioConfig] = None, + samplerate: Optional[int] = None, + config: Optional[ResampleConfig] = None, audio_dir: Optional[data.PathLike] = None, dtype: DTypeLike = np.float32, # type: ignore -) -> xr.DataArray: - """Load and preprocess the entire audio content of a recording using config. - - Creates a `soundevent.data.Clip` spanning the full duration of the - recording and then delegates the loading and processing to - `load_clip_audio`. - - Parameters - ---------- - recording : data.Recording - The Recording object containing metadata and file path. - config : AudioConfig, optional - Audio processing configuration. If None, default settings are used. - audio_dir : PathLike, optional - Directory containing the audio file, used if the path in `recording` - is relative. - dtype : DTypeLike, default=np.float32 - Target NumPy data type for the loaded audio array. - - Returns - ------- - xr.DataArray - Loaded and preprocessed waveform (first channel only). - """ +) -> np.ndarray: + """Load and preprocess the entire audio content of a recording using config.""" clip = data.Clip( recording=recording, start_time=0, @@ -313,6 +148,7 @@ def load_recording_audio( ) return load_clip_audio( clip, + samplerate=samplerate, config=config, dtype=dtype, audio_dir=audio_dir, @@ -321,257 +157,66 @@ def load_recording_audio( def load_clip_audio( clip: data.Clip, - config: Optional[AudioConfig] = None, + samplerate: Optional[int] = None, + config: Optional[ResampleConfig] = None, audio_dir: Optional[data.PathLike] = None, dtype: DTypeLike = np.float32, # type: ignore -) -> xr.DataArray: - """Load and preprocess a specific audio clip segment based on config. - - This is the core function performing the configured processing pipeline: - 1. Loads the specified clip segment using `soundevent.audio.load_clip`. - 2. Selects the first audio channel. - 3. Resamples if `config.resample` is configured. - 4. Centers (DC offset removal) if `config.center` is True. - 5. Scales (peak normalization) if `config.scale` is True. - 6. Adjusts duration (crop/pad) if `config.duration` is set. - - Parameters - ---------- - clip : data.Clip - The Clip object defining the audio segment and source recording. - config : AudioConfig, optional - Audio processing configuration. If None, a default `AudioConfig` is - used. - audio_dir : PathLike, optional - Directory containing the source audio file specified in the clip's - recording. - dtype : DTypeLike, default=np.float32 - Target NumPy data type for the processed audio array. - - Returns - ------- - xr.DataArray - The loaded and preprocessed waveform segment as an xarray DataArray - with time coordinates. - - Raises - ------ - FileNotFoundError - If the underlying audio file cannot be found. - Exception - If audio loading or processing fails for other reasons (e.g., invalid - format, resampling error). - - Notes - ----- - - **Mono Processing:** This function currently loads and processes only the - **first channel** (channel 0) of the audio file. Any other channels - are ignored. - """ - config = config or AudioConfig() - - with xr.set_options(keep_attrs=True): - try: - wav = ( - audio.load_clip(clip, audio_dir=audio_dir) - .sel(channel=0) - .astype(dtype) - ) - except LibsndfileError as e: - raise FileNotFoundError( - f"Could not load the recording at path: {clip.recording.path}. " - f"Error: {e}" - ) from e - - if config.resample: - wav = resample_audio( - wav, - samplerate=config.resample.samplerate, - dtype=dtype, - ) - - if config.center: - wav = ops.center(wav) - - if config.scale: - wav = scale_audio(wav) - - if config.duration is not None: - wav = adjust_audio_duration(wav, duration=config.duration) - - return wav.astype(dtype) - - -def scale_audio( - wave: xr.DataArray, -) -> xr.DataArray: - """ - Scale the audio waveform to have a maximum absolute value of 1.0. - - This function normalizes the waveform by dividing it by its maximum - absolute value. If the maximum value is zero, the waveform is returned - unchanged. Also known as peak normalization, this process ensures that the - waveform's amplitude is within a standard range, which can be useful for - audio processing and analysis. - - """ - max_val = np.max(np.abs(wave)) - - if max_val == 0: - return wave - - return ops.scale(wave, 1 / max_val) - - -def adjust_audio_duration( - wave: xr.DataArray, - duration: float, -) -> xr.DataArray: - """Adjust the duration of an audio waveform array via cropping or padding. - - If the current duration is longer than the target, it crops the array - from the beginning. If shorter, it pads the array with zeros at the end - using `soundevent.arrays.extend_dim`. - - Parameters - ---------- - wave : xr.DataArray - The input audio waveform with a 'time' dimension and coordinates. - duration : float - The target duration in seconds. - - Returns - ------- - xr.DataArray - The waveform adjusted to the target duration. Returns the input - unmodified if duration already matches or if the wave is empty. - - Raises - ------ - ValueError - If `duration` is negative. - """ - start_time, end_time = arrays.get_dim_range(wave, dim="time") - step = arrays.get_dim_step(wave, dim="time") - current_duration = end_time - start_time + step - - if current_duration == duration: - return wave - - with xr.set_options(keep_attrs=True): - if current_duration > duration: - return arrays.crop_dim( - wave, - dim="time", - start=start_time, - stop=start_time + duration - step / 2, - right_closed=True, - ) - - return arrays.extend_dim( - wave, - dim="time", - start=start_time, - stop=start_time + duration - step / 2, - eps=0, - right_closed=True, +) -> np.ndarray: + """Load and preprocess a specific audio clip segment based on config.""" + try: + wav = ( + audio.load_clip(clip, audio_dir=audio_dir) + .sel(channel=0) + .astype(dtype) ) + except LibsndfileError as e: + raise FileNotFoundError( + f"Could not load the recording at path: {clip.recording.path}. " + f"Error: {e}" + ) from e + + if not config or not config.enabled or samplerate is None: + return wav.data.astype(dtype) + + sr = int(1 / wav.time.attrs["step"]) + return resample_audio( + wav.data, + sr=sr, + samplerate=samplerate, + method=config.method, + ) def resample_audio( - wav: xr.DataArray, + wav: np.ndarray, + sr: int, samplerate: int = TARGET_SAMPLERATE_HZ, method: str = "poly", - dtype: DTypeLike = np.float32, # type: ignore -) -> xr.DataArray: - """Resample an audio waveform DataArray to a target sample rate. - - Updates the 'time' coordinate axis according to the new sample rate and - number of samples. Uses either polyphase (`scipy.signal.resample_poly`) - or Fourier method (`scipy.signal.resample`) based on the `method`. - - Parameters - ---------- - wav : xr.DataArray - Input audio waveform with 'time' dimension and coordinates. - samplerate : int, default=TARGET_SAMPLERATE_HZ - Target sample rate in Hz. - method : str, default="poly" - Resampling algorithm: "poly" or "fourier". - dtype : DTypeLike, default=np.float32 - Target data type for the resampled array. - - Returns - ------- - xr.DataArray - Resampled waveform with updated time coordinates. Returns the input - unmodified (but dtype cast) if the sample rate is already correct or - if the input array is empty. - - Raises - ------ - ValueError - If `wav` lacks a 'time' dimension, the original sample rate cannot - be determined, `samplerate` is non-positive, or `method` is invalid. - """ - if "time" not in wav.dims: - raise ValueError("Audio must have a time dimension") - - time_axis: int = wav.get_axis_num("time") # type: ignore - step = arrays.get_dim_step(wav, dim="time") - original_samplerate = int(1 / step) - - if original_samplerate == samplerate: - return wav.astype(dtype).assign_attrs(original_samplerate=samplerate) +) -> np.ndarray: + """Resample an audio waveform DataArray to a target sample rate.""" + if sr == samplerate: + return wav if method == "poly": - resampled = resample_audio_poly( + return resample_audio_poly( wav, - sr_orig=original_samplerate, + sr_orig=sr, sr_new=samplerate, - axis=time_axis, ) elif method == "fourier": - resampled = resample_audio_fourier( + return resample_audio_fourier( wav, - sr_orig=original_samplerate, + sr_orig=sr, sr_new=samplerate, - axis=time_axis, ) else: raise NotImplementedError( f"Resampling method '{method}' not implemented" ) - start, stop = arrays.get_dim_range(wav, dim="time") - times = np.linspace( - start, - stop + step, - len(resampled), - endpoint=False, - dtype=dtype, - ) - - return xr.DataArray( - data=resampled.astype(dtype), - dims=wav.dims, - coords={ - **wav.coords, - "time": arrays.create_time_dim_from_array( - times, - samplerate=samplerate, - ), - }, - attrs={ - **wav.attrs, - "samplerate": samplerate, - "original_samplerate": original_samplerate, - }, - ) - def resample_audio_poly( - array: xr.DataArray, + array: np.ndarray, sr_orig: int, sr_new: int, axis: int = -1, @@ -605,7 +250,7 @@ def resample_audio_poly( """ gcd = np.gcd(sr_orig, sr_new) return resample_poly( - array.values, + array, sr_new // gcd, sr_orig // gcd, axis=axis, @@ -613,7 +258,7 @@ def resample_audio_poly( def resample_audio_fourier( - array: xr.DataArray, + array: np.ndarray, sr_orig: int, sr_new: int, axis: int = -1, @@ -649,66 +294,89 @@ def resample_audio_fourier( ) -def convert_to_xr( - wav: np.ndarray, +class CenterAudioConfig(BaseConfig): + name: Literal["center_audio"] = "center_audio" + + +class ScaleAudioConfig(BaseConfig): + name: Literal["scale_audio"] = "scale_audio" + + +class FixDurationConfig(BaseConfig): + name: Literal["fix_duration"] = "fix_duration" + duration: float = 0.5 + + +class FixDuration(torch.nn.Module): + def __init__(self, samplerate: int, duration: float): + super().__init__() + self.samplerate = samplerate + self.duration = duration + self.length = int(samplerate * duration) + + def forward(self, wav: torch.Tensor) -> torch.Tensor: + length = wav.shape[-1] + + if length == self.length: + return wav + + if length > self.length: + return wav[: self.length] + + return torch.nn.functional.pad(wav, (0, self.length - length)) + + +AudioTransform = Annotated[ + Union[ + FixDurationConfig, + ScaleAudioConfig, + CenterAudioConfig, + ], + Field(discriminator="name"), +] + + +class AudioConfig(BaseConfig): + """Configuration for loading and initial audio preprocessing.""" + + samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) + resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig) + transforms: List[AudioTransform] = Field(default_factory=list) + + +def build_audio_loader( + config: Optional[AudioConfig] = None, +) -> AudioLoader: + """Factory function to create an AudioLoader based on configuration.""" + config = config or AudioConfig() + return SoundEventAudioLoader( + samplerate=config.samplerate, + config=config.resample, + ) + + +def build_audio_transform_step( + config: AudioTransform, samplerate: int, - dtype: DTypeLike = np.float32, # type: ignore -) -> xr.DataArray: - """Convert a NumPy array to an xarray DataArray with time coordinates. +) -> torch.nn.Module: + if config.name == "fix_duration": + return FixDuration(samplerate=samplerate, duration=config.duration) - Parameters - ---------- - wav : np.ndarray - The input waveform array. Expected to be 1D or 2D (with the first - axis as the channel dimension). - samplerate : int - The sample rate in Hz. - dtype : DTypeLike, default=np.float32 - Target data type for the xarray DataArray. + if config.name == "scale_audio": + return PeakNormalize() - Returns - ------- - xr.DataArray - The waveform as an xarray DataArray with time coordinates. + if config.name == "center_audio": + return CenterTensor() - Raises - ------ - ValueError - If the input array is not 1D or 2D, or if the sample rate is - non-positive. If the input array is empty. - """ - - if wav.ndim == 2: - wav = wav[0, :] - - if wav.ndim != 1: - raise ValueError( - "Audio must be 1D array or 2D channel where the first " - "axis is the channel dimension" - ) - - if wav.size == 0: - raise ValueError("Audio array is empty") - - if samplerate <= 0: - raise ValueError("Sample rate must be positive") - - times = np.linspace( - 0, - wav.shape[0] / samplerate, - wav.shape[0], - endpoint=False, - dtype=dtype, + raise NotImplementedError( + f"Audio preprocessing step {config.name} not implemented" ) - return xr.DataArray( - data=wav.astype(dtype), - dims=["time"], - coords={ - "time": arrays.create_time_dim_from_array( - times, - samplerate=samplerate, - ), - }, - attrs={"samplerate": samplerate}, + +def build_audio_pipeline(config: AudioConfig) -> torch.nn.Module: + return torch.nn.Sequential( + *[ + build_audio_transform_step(step, samplerate=config.samplerate) + for step in config.transforms + ] ) diff --git a/src/batdetect2/preprocess/common.py b/src/batdetect2/preprocess/common.py new file mode 100644 index 0000000..37bf2a2 --- /dev/null +++ b/src/batdetect2/preprocess/common.py @@ -0,0 +1,24 @@ +import torch + +__all__ = [ + "CenterTensor", + "PeakNormalize", +] + + +class CenterTensor(torch.nn.Module): + def forward(self, wav: torch.Tensor): + return wav - wav.mean() + + +class PeakNormalize(torch.nn.Module): + def forward(self, wav: torch.Tensor): + max_value = wav.abs().min() + + denominator = torch.where( + max_value == 0, + torch.tensor(1.0, device=wav.device, dtype=wav.dtype), + max_value, + ) + + return wav / denominator diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index 77a311d..e416501 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -1,48 +1,22 @@ -"""Computes spectrograms from audio waveforms with configurable parameters. +"""Computes spectrograms from audio waveforms with configurable parameters.""" -This module provides the functionality to convert preprocessed audio waveforms -(typically output from the `batdetect2.preprocessing.audio` module) into -spectrogram representations suitable for input into deep learning models like -BatDetect2. - -It offers a configurable pipeline including: -1. Short-Time Fourier Transform (STFT) calculation to get magnitude. -2. Frequency axis cropping to a relevant range. -3. Per-Channel Energy Normalization (PCEN) (optional). -4. Amplitude scaling/representation (dB, power, or linear amplitude). -5. Simple spectral mean subtraction denoising (optional). -6. Resizing to target dimensions (optional). -7. Final peak normalization (optional). - -Configuration is managed via the `SpectrogramConfig` class, allowing for -reproducible spectrogram generation consistent between training and inference. -The core computation is performed by `compute_spectrogram`. -""" - -from typing import Callable, Literal, Optional, Union +from typing import Annotated, Callable, List, Literal, Optional, Union import numpy as np -import xarray as xr -from numpy.typing import DTypeLike +import torch +import torchaudio from pydantic import Field -from scipy import signal -from soundevent import arrays, audio -from soundevent.arrays import operations as ops from batdetect2.configs import BaseConfig -from batdetect2.preprocess.audio import convert_to_xr +from batdetect2.preprocess.common import PeakNormalize from batdetect2.typing.preprocess import SpectrogramBuilder __all__ = [ "STFTConfig", "FrequencyConfig", - "SpecSizeConfig", "PcenConfig", "SpectrogramConfig", - "ConfigurableSpectrogramBuilder", "build_spectrogram_builder", - "compute_spectrogram", - "get_spectrogram_resolution", "MIN_FREQ", "MAX_FREQ", ] @@ -79,6 +53,47 @@ class STFTConfig(BaseConfig): window_fn: str = "hann" +def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]: + if name == "hann": + return torch.hann_window + + if name == "hamming": + return torch.hamming_window + + if name == "kaiser": + return torch.kaiser_window + + if name == "blackman": + return torch.blackman_window + + if name == "bartlett": + return torch.bartlett_window + + raise NotImplementedError( + f"Spectrogram window function {name} not implemented" + ) + + +def _spec_params_from_config(samplerate: int, conf: STFTConfig): + n_fft = int(samplerate * conf.window_duration) + hop_length = int(n_fft * (1 - conf.window_overlap)) + return n_fft, hop_length + + +def build_spectrogram_builder( + samplerate: int, + conf: STFTConfig, +) -> SpectrogramBuilder: + n_fft, hop_length = _spec_params_from_config(samplerate, conf) + return torchaudio.transforms.Spectrogram( + n_fft=n_fft, + hop_length=hop_length, + window_fn=get_spectrogram_window(conf.window_fn), + center=False, + power=1, + ) + + class FrequencyConfig(BaseConfig): """Configuration for frequency axis parameters. @@ -96,644 +111,282 @@ class FrequencyConfig(BaseConfig): min_freq: int = Field(default=10_000, ge=0) -class SpecSizeConfig(BaseConfig): - """Configuration for the final size and shape of the spectrogram. +def _frequency_to_index( + freq: float, + samplerate: int, + n_fft: int, +) -> Optional[int]: + alpha = freq * 2 / samplerate + height = np.floor(n_fft / 2) + 1 + index = int(np.floor(alpha * height)) - Attributes - ---------- - height : int, default=128 - Target height of the spectrogram in pixels (frequency bins). The - frequency axis will be resized (e.g., via interpolation) to match this - height after frequency cropping and amplitude scaling. Must be > 0. - resize_factor : float, optional - Factor by which to resize the spectrogram along the time axis *after* - STFT calculation. A value of 0.5 halves the number of time bins, - 2.0 doubles it. If None (default), no resizing along the time axis is - performed relative to the STFT output width. Must be > 0 if provided. - """ + if index <= 0: + return None - height: int = 128 - resize_factor: Optional[float] = 0.5 + if index >= height: + return None + + return index + + +class FrequencyClip(torch.nn.Module): + def __init__( + self, + low_index: Optional[int] = None, + high_index: Optional[int] = None, + ): + super().__init__() + self.low_index = low_index + self.high_index = high_index + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + return spec[self.low_index : self.high_index] class PcenConfig(BaseConfig): - """Configuration for Per-Channel Energy Normalization (PCEN). + """Configuration for Per-Channel Energy Normalization (PCEN).""" - PCEN is an adaptive gain control method that can help emphasize transients - and suppress stationary noise. Applied after STFT and frequency cropping, - but before final amplitude scaling (dB, power, amplitude). - - Attributes - ---------- - time_constant : float, default=0.4 - Time constant (in seconds) for the PCEN smoothing filter. Controls - how quickly the normalization adapts to energy changes. - gain : float, default=0.98 - Gain factor (alpha). Controls the adaptive gain component. - bias : float, default=2.0 - Bias factor (delta). Added before the exponentiation. - power : float, default=0.5 - Exponent (r). Controls the compression characteristic. - """ - - time_constant: float = 0.01 + name: Literal["pcen"] = "pcen" + time_constant: float = 0.4 gain: float = 0.98 bias: float = 2 power: float = 0.5 -class SpectrogramConfig(BaseConfig): - """Unified configuration for spectrogram generation pipeline. - - Aggregates settings for all steps involved in converting a preprocessed - audio waveform into a final spectrogram representation suitable for model - input. - - Attributes - ---------- - stft : STFTConfig - Configuration for the initial Short-Time Fourier Transform. - Defaults to standard settings via `STFTConfig`. - frequencies : FrequencyConfig - Configuration for cropping the frequency range after STFT. - Defaults to standard settings via `FrequencyConfig`. - pcen : PcenConfig, optional - Configuration for applying Per-Channel Energy Normalization (PCEN). If - provided, PCEN is applied after frequency cropping. If None or omitted - (default), PCEN is skipped. - scale : Literal["dB", "amplitude", "power"], default="amplitude" - Determines the final amplitude representation *after* optional PCEN. - - "amplitude": Use linear magnitude values (output of STFT or PCEN). - - "power": Use power values (magnitude squared). - - "dB": Use logarithmic (decibel-like) scaling applied to the magnitude - (or PCEN output if enabled). Calculated as `log1p(C * S)`. - size : SpecSizeConfig, optional - Configuration for resizing the spectrogram dimensions - (frequency height, optional time width factor). Applied after PCEN and - scaling. If None (default), no resizing is performed. - spectral_mean_substraction : bool, default=True - If True (default), applies simple spectral mean subtraction denoising - *after* PCEN and amplitude scaling, but *before* resizing. - peak_normalize : bool, default=False - If True, applies a final peak normalization to the spectrogram *after* - all other steps (including resizing), scaling the overall maximum value - to 1.0. If False (default), this final normalization is skipped. - """ - - stft: STFTConfig = Field(default_factory=STFTConfig) - frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) - pcen: Optional[PcenConfig] = Field(default_factory=PcenConfig) - scale: Literal["dB", "amplitude", "power"] = "amplitude" - size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) - spectral_mean_substraction: bool = True - peak_normalize: bool = False - - -class ConfigurableSpectrogramBuilder(SpectrogramBuilder): - """Implementation of `SpectrogramBuilder` driven by `SpectrogramConfig`. - - This class computes spectrograms according to the parameters specified in a - `SpectrogramConfig` object provided during initialization. It handles both - numpy array and xarray DataArray inputs for the waveform. - """ - +class PCEN(torch.nn.Module): def __init__( self, - config: SpectrogramConfig, - dtype: DTypeLike = np.float32, # type: ignore - ) -> None: - """Initialize the ConfigurableSpectrogramBuilder. - - Parameters - ---------- - config : SpectrogramConfig - The configuration object specifying all spectrogram parameters. - dtype : DTypeLike, default=np.float32 - The target NumPy data type for the computed spectrogram array. - """ - self.config = config + smoothing_constant: float, + gain: float = 0.98, + bias: float = 2.0, + power: float = 0.5, + eps: float = 1e-6, + dtype=torch.float64, + ): + super().__init__() + self.smoothing_constant = smoothing_constant + self.gain = torch.tensor(gain, dtype=dtype) + self.bias = torch.tensor(bias, dtype=dtype) + self.power = torch.tensor(power, dtype=dtype) + self.eps = torch.tensor(eps, dtype=dtype) self.dtype = dtype - def __call__( - self, - wav: Union[np.ndarray, xr.DataArray], - samplerate: Optional[int] = None, - ) -> xr.DataArray: - """Generate a spectrogram from an audio waveform using the config. - - Implements the `SpectrogramBuilder` protocol. If the input `wav` is - a numpy array, `samplerate` must be provided; the array will be - converted to an xarray DataArray internally. If `wav` is already an - xarray DataArray with time coordinates, `samplerate` is ignored. - Delegates the main computation to `compute_spectrogram`. - - Parameters - ---------- - wav : Union[np.ndarray, xr.DataArray] - The input audio waveform. - samplerate : int, optional - The sample rate in Hz (required only if `wav` is np.ndarray). - - Returns - ------- - xr.DataArray - The computed spectrogram. - - Raises - ------ - ValueError - If `wav` is np.ndarray and `samplerate` is None. - """ - if isinstance(wav, np.ndarray): - if samplerate is None: - raise ValueError( - "Samplerate must be provided when passing a numpy array." - ) - wav = convert_to_xr( - wav, - samplerate=samplerate, - dtype=self.dtype, - ) - - return compute_spectrogram( - wav, - config=self.config, - dtype=self.dtype, + self._b = torch.tensor([self.smoothing_constant, 0.0], dtype=dtype) + self._a = torch.tensor( + [1.0, self.smoothing_constant - 1.0], dtype=dtype ) + def forward(self, spec: torch.Tensor) -> torch.Tensor: + S = spec.to(self.dtype) * 2**31 -def build_spectrogram_builder( - config: SpectrogramConfig, - dtype: DTypeLike = np.float32, # type: ignore -) -> SpectrogramBuilder: - """Factory function to create a SpectrogramBuilder based on configuration. + M = ( + torchaudio.functional.lfilter( + S, + self._a, + self._b, + clamp=False, + ) + ).clamp(min=0) - Instantiates and returns a `ConfigurableSpectrogramBuilder` initialized - with the provided `SpectrogramConfig`. - - Parameters - ---------- - config : SpectrogramConfig - The configuration object specifying spectrogram parameters. - dtype : DTypeLike, default=np.float32 - The target NumPy data type for the computed spectrogram array. - - Returns - ------- - SpectrogramBuilder - An instance of `ConfigurableSpectrogramBuilder` ready to compute - spectrograms according to the configuration. - """ - return ConfigurableSpectrogramBuilder(config=config, dtype=dtype) - - -def compute_spectrogram( - wav: xr.DataArray, - config: Optional[SpectrogramConfig] = None, - dtype: DTypeLike = np.float32, # type: ignore -) -> xr.DataArray: - """Compute a spectrogram from a waveform using specified configurations. - - Applies a sequence of operations based on the `config`: - 1. Compute STFT magnitude (`stft`). - 2. Crop frequency axis (`crop_spectrogram_frequencies`). - 3. Apply PCEN if configured (`apply_pcen`). - 4. Apply final amplitude scaling (dB, power, amplitude) - (`scale_spectrogram`). - 5. Apply spectral mean subtraction denoising if enabled. - 6. Resize dimensions if specified (`resize_spectrogram`). - 7. Apply final peak normalization if enabled. - - Parameters - ---------- - wav : xr.DataArray - Input audio waveform with a 'time' dimension and coordinates from - which the sample rate can be inferred. - config : SpectrogramConfig, optional - Configuration object specifying spectrogram parameters. If None, - default settings from `SpectrogramConfig` are used. - dtype : DTypeLike, default=np.float32 - Target NumPy data type for the final spectrogram array. - - Returns - ------- - xr.DataArray - The computed and processed spectrogram with 'time' and 'frequency' - coordinates. - - Raises - ------ - ValueError - If `wav` lacks necessary 'time' coordinates or dimensions. - """ - config = config or SpectrogramConfig() - - with xr.set_options(keep_attrs=True): - spec = stft( - wav, - window_duration=config.stft.window_duration, - window_overlap=config.stft.window_overlap, - window_fn=config.stft.window_fn, + smooth = torch.exp( + -self.gain * (torch.log(self.eps) + torch.log1p(M / self.eps)) ) - spec = crop_spectrogram_frequencies( - spec, - min_freq=config.frequencies.min_freq, - max_freq=config.frequencies.max_freq, - ) - - if config.pcen: - spec = apply_pcen( - spec, - time_constant=config.pcen.time_constant, - gain=config.pcen.gain, - power=config.pcen.power, - bias=config.pcen.bias, - ) - - spec = scale_spectrogram(spec, scale=config.scale) - - if config.spectral_mean_substraction: - spec = remove_spectral_mean(spec) - - if config.size: - spec = resize_spectrogram( - spec, - height=config.size.height, - resize_factor=config.size.resize_factor, - ) - - if config.peak_normalize: - spec = ops.normalize(spec) - - return spec.astype(dtype) + return ( + (self.bias**self.power) + * torch.expm1(self.power * torch.log1p(S * smooth / self.bias)) + ).to(spec.dtype) -def crop_spectrogram_frequencies( - spec: xr.DataArray, - min_freq: int = 10_000, - max_freq: int = 120_000, -) -> xr.DataArray: - """Crop the frequency axis of a spectrogram to a specified range. - - Uses `soundevent.arrays.crop_dim` to select the frequency bins - corresponding to the range [`min_freq`, `max_freq`]. - - Parameters - ---------- - spec : xr.DataArray - Input spectrogram with 'frequency' dimension and coordinates. - min_freq : int, default=MIN_FREQ - Minimum frequency (Hz) to keep. - max_freq : int, default=MAX_FREQ - Maximum frequency (Hz) to keep. - - Returns - ------- - xr.DataArray - Spectrogram cropped along the frequency axis. Preserves dtype. - """ - start_freq, end_freq = arrays.get_dim_range(spec, dim="frequency") - - return arrays.crop_dim( - spec, - dim="frequency", - start=min_freq if start_freq < min_freq else None, - stop=max_freq if end_freq > max_freq else None, - ).astype(spec.dtype) +def _compute_smoothing_constant( + samplerate: int, + time_constant: float, +) -> float: + # NOTE: These were taken to match the original implementation + hop_length = 512 + sr = samplerate / 10 + time_constant = time_constant + t_frames = time_constant * sr / float(hop_length) + return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) -def stft( - wave: xr.DataArray, - window_duration: float, - window_overlap: float, - window_fn: str = "hann", -) -> xr.DataArray: - """Compute the Short-Time Fourier Transform (STFT) magnitude spectrogram. - - Calculates STFT parameters (N-FFT, hop length) based on the window - duration, overlap, and waveform sample rate. Returns an xarray DataArray - with correctly calculated 'time' and 'frequency' coordinates. - - Parameters - ---------- - wave : xr.DataArray - Input audio waveform with 'time' coordinates. - window_duration : float - Duration of the STFT window in seconds. - window_overlap : float - Fractional overlap between consecutive windows. - window_fn : str, default="hann" - Name of the window function (e.g., "hann", "hamming"). - - Returns - ------- - xr.DataArray - Magnitude spectrogram with 'time' and 'frequency' dimensions and - coordinates. STFT parameters are stored in the `attrs`. - - Raises - ------ - ValueError - If sample rate cannot be determined from `wave` coordinates. - """ - if "channel" not in wave.coords: - wave = wave.assign_coords(channel=0) - - return audio.compute_spectrogram( - wave, - window_size=window_duration, - hop_size=(1 - window_overlap) * window_duration, - window_type=window_fn, - scale="amplitude", - sort_dims=False, - ) +class ScaleAmplitudeConfig(BaseConfig): + name: Literal["scale_amplitude"] = "scale_amplitude" + scale: Literal["power", "db"] = "db" -def remove_spectral_mean(spec: xr.DataArray) -> xr.DataArray: - """Apply simple spectral mean subtraction for denoising. - - Subtracts the mean value of each frequency bin (calculated across time) - from that bin, then clips negative values to zero. - - Parameters - ---------- - spec : xr.DataArray - Input spectrogram with 'time' and 'frequency' dimensions. - - Returns - ------- - xr.DataArray - Denoised spectrogram with the same dimensions, coordinates, and dtype. - """ - return xr.DataArray( - data=(spec - spec.mean("time")).clip(0), - dims=spec.dims, - coords=spec.coords, - attrs=spec.attrs, - ) - - -def scale_spectrogram( - spec: xr.DataArray, - scale: Literal["dB", "power", "amplitude"], - dtype: DTypeLike = np.float32, # type: ignore -) -> xr.DataArray: - """Apply final amplitude scaling/representation to the spectrogram. - - Converts the input magnitude spectrogram based on the `scale` type: - - "dB": Applies logarithmic scaling `log10(S)`. - - "power": Squares the magnitude values `S^2`. - - "amplitude": Returns the input magnitude values `S` unchanged. - - Parameters - ---------- - spec : xr.DataArray - Input magnitude spectrogram (potentially after PCEN). - scale : Literal["dB", "power", "amplitude"] - The target amplitude representation. - dtype : DTypeLike, default=np.float32 - Target data type for the output scaled spectrogram. - - Returns - ------- - xr.DataArray - Spectrogram with the specified amplitude scaling applied. - """ - if scale == "dB": - return arrays.to_db(spec).astype(dtype) - - if scale == "power": +class ToPower(torch.nn.Module): + def forward(self, spec: torch.Tensor) -> torch.Tensor: return spec**2 - return spec +def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module: + if conf.scale == "db": + return torchaudio.transforms.AmplitudeToDB() -def apply_pcen( - spec: xr.DataArray, - time_constant: float = 0.4, - gain: float = 0.98, - bias: float = 2, - eps: float = 1e-6, - power: float = 0.5, -) -> xr.DataArray: - """Apply Per-Channel Energy Normalization (PCEN) to a spectrogram. + if conf.scale == "power": + return ToPower() - Parameters - ---------- - spec : xr.DataArray - Input magnitude spectrogram with required attributes like - 'processing_original_samplerate'. - time_constant : float, default=0.4 - PCEN time constant in seconds. - gain : float, default=0.98 - Gain factor (alpha). - bias : float, default=2.0 - Bias factor (delta). - power : float, default=0.5 - Exponent (r). - dtype : DTypeLike, default=np.float32 - Target data type for the output spectrogram. - - Returns - ------- - xr.DataArray - PCEN-scaled spectrogram. - """ - samplerate = 1 / spec.time.attrs["step"] - hop_size = spec.attrs["hop_size"] - - hop_length = int(hop_size * samplerate) - - t_frames = time_constant * samplerate / hop_length - - smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) - - axis = spec.get_axis_num("time") - - shape = tuple([1] * spec.ndim) - zi = np.empty(shape) - zi[:] = signal.lfilter_zi( - [smoothing_constant], - [1, smoothing_constant - 1], - )[:] - - spec_data = spec.data * (2**31) - - # Smooth the input array along the given axis - smoothed, _ = signal.lfilter( - [smoothing_constant], - [1, smoothing_constant - 1], - spec_data, - zi=zi, - axis=axis, # type: ignore - ) - - smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps))) - data = (bias**power) * np.expm1( - power * np.log1p(spec_data * smooth / bias) - ) - - return xr.DataArray( - data.astype(spec.dtype), - dims=spec.dims, - coords=spec.coords, - attrs=spec.attrs, + raise NotImplementedError( + f"Amplitude scaling {conf.scale} not implemented" ) -def scale_log( - spec: xr.DataArray, - dtype: DTypeLike = np.float32, # type: ignore - ref: Union[float, Callable] = np.max, - amin: float = 1e-10, - top_db: Optional[float] = 80.0, -) -> xr.DataArray: - """Apply logarithmic scaling to a magnitude spectrogram. - - Calculates `log10(S)`, where S is the input magnitude spectrogram. - - Parameters - ---------- - spec : xr.DataArray - Input magnitude spectrogram with required attributes like - 'processing_original_samplerate', 'processing_nfft'. - dtype : DTypeLike, default=np.float32 - Target data type for the output spectrogram. - - Returns - ------- - xr.DataArray - Log-scaled spectrogram. - - Raises - ------ - KeyError - If required attributes are missing from `spec.attrs`. - ValueError - If attributes are non-numeric or window function is invalid. +class SpectralMeanSubstractionConfig(BaseConfig): + name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction" - Notes - ----- - Implementation mainly taken from librosa `power_to_db` function - """ - - if callable(ref): - ref_value = ref(spec) - else: - ref_value = np.abs(ref) - - log_spec = 10.0 * np.log10(np.maximum(amin, spec)) - np.log10( - np.maximum(amin, ref_value) - ) - - if top_db is not None: - if top_db < 0: - raise ValueError("top_db must be non-negative") - log_spec = np.maximum(log_spec, log_spec.max() - top_db) - - return xr.DataArray( - data=log_spec.astype(dtype), - dims=spec.dims, - coords=spec.coords, - attrs=spec.attrs, - ) +class SpectralMeanSubstraction(torch.nn.Module): + def forward(self, spec: torch.Tensor) -> torch.Tensor: + mean = spec.mean(-1, keepdim=True) + return (spec - mean).clamp(min=0) -def resize_spectrogram( - spec: xr.DataArray, - height: int = 128, - resize_factor: Optional[float] = 0.5, - dtype: DTypeLike = np.float32, # type: ignore -) -> xr.DataArray: - """Resize a spectrogram to target dimensions using interpolation. +class ResizeConfig(BaseConfig): + name: Literal["resize_spec"] = "resize_spec" + height: int = 128 + resize_factor: float = 0.5 - Resizes the frequency axis to `height` bins and optionally resizes the - time axis by `resize_factor`. - Parameters - ---------- - spec : xr.DataArray - Input spectrogram with 'time' and 'frequency' dimensions. - height : int, default=128 - Target number of frequency bins (vertical dimension). - resize_factor : float, optional - Factor to resize the time dimension. If 1.0 or None, time dimension - is unchanged. If 0.5, time dimension is halved, etc. +class ResizeSpec(torch.nn.Module): + def __init__(self, height: int, time_factor: float): + super().__init__() + self.height = height + self.time_factor = time_factor - Returns - ------- - xr.DataArray - Resized spectrogram. Coordinates are typically adjusted by the - underlying resize operation if implemented in `ops.resize`. - The dtype is currently hardcoded to float32 by ops.resize call. - """ - resize_factor = resize_factor or 1 - current_width = spec.sizes["time"] - - target_sizes = { - "time": int(current_width * resize_factor), - "frequency": height, - } - - new_coords = {} - for dim in ["time", "frequency"]: - step = arrays.get_dim_step(spec, dim) - start, stop = arrays.get_dim_range(spec, dim) - new_coords[dim] = arrays.create_range_dim( - name=dim, - start=start, - stop=stop + step, - size=target_sizes[dim], - dtype=dtype, + def forward(self, spec: torch.Tensor) -> torch.Tensor: + current_length = spec.shape[-1] + target_length = int(self.time_factor * current_length) + return torch.nn.functional.interpolate( + spec.unsqueeze(0).unsqueeze(0), + size=(self.height, target_length), + mode="bilinear", ) - return spec.interp( - coords=new_coords, method="linear", kwargs=dict(fill_value=0) + +class PeakNormalizeConfig(BaseConfig): + name: Literal["peak_normalize"] = "peak_normalize" + + +SpectrogramTransform = Annotated[ + Union[ + PcenConfig, + ScaleAmplitudeConfig, + SpectralMeanSubstractionConfig, + PeakNormalizeConfig, + ], + Field(discriminator="name"), +] + + +class SpectrogramConfig(BaseConfig): + stft: STFTConfig = Field(default_factory=STFTConfig) + frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) + size: ResizeConfig = Field(default_factory=ResizeConfig) + transforms: List[SpectrogramTransform] = Field( + default_factory=lambda: [ + PcenConfig(), + SpectralMeanSubstractionConfig(), + ] ) -def get_spectrogram_resolution( - config: SpectrogramConfig, -) -> tuple[float, float]: - """Calculate the approximate resolution of the final spectrogram. +def _build_spectrogram_transform_step( + step: SpectrogramTransform, + samplerate: int, +) -> torch.nn.Module: + if step.name == "pcen": + return PCEN( + smoothing_constant=_compute_smoothing_constant( + samplerate=samplerate, + time_constant=step.time_constant, + ), + gain=step.gain, + bias=step.bias, + power=step.power, + ) - Computes the width of each frequency bin (Hz/bin) and the duration - of each time bin (seconds/bin) based on the configuration parameters. + if step.name == "scale_amplitude": + return _build_amplitude_scaler(step) - Parameters - ---------- - config : SpectrogramConfig - The spectrogram configuration object. - samplerate : int, optional - The sample rate of the audio *before* STFT. Required if needed to - calculate hop duration accurately from STFT config, but the current - implementation calculates hop_duration directly from STFT config times. + if step.name == "spectral_mean_substraction": + return SpectralMeanSubstraction() - Returns - ------- - Tuple[float, float] - A tuple containing: - - frequency_resolution (float): Approximate Hz per frequency bin. - - time_resolution (float): Approximate seconds per time bin. + if step.name == "peak_normalize": + return PeakNormalize() - Raises - ------ - ValueError - If required configuration fields (like `config.size`) are missing - or invalid. - """ - max_freq = config.frequencies.max_freq - min_freq = config.frequencies.min_freq - - if config.size is None: - raise ValueError("Spectrogram size configuration is required.") - - spec_height = config.size.height - resize_factor = config.size.resize_factor or 1 - freq_bin_width = (max_freq - min_freq) / spec_height - hop_duration = config.stft.window_duration * ( - 1 - config.stft.window_overlap + raise NotImplementedError( + f"Spectrogram preprocessing step {step.name} not implemented" + ) + + +def build_spectrogram_transform( + samplerate: int, + conf: SpectrogramConfig, +) -> torch.nn.Module: + return torch.nn.Sequential( + *[ + _build_spectrogram_transform_step(step, samplerate=samplerate) + for step in conf.transforms + ] + ) + + +class SpectrogramPipeline(torch.nn.Module): + def __init__( + self, + spec_builder: SpectrogramBuilder, + freq_cutter: torch.nn.Module, + transforms: torch.nn.Module, + resizer: torch.nn.Module, + ): + super().__init__() + self.spec_builder = spec_builder + self.freq_cutter = freq_cutter + self.transforms = transforms + self.resizer = resizer + + def forward(self, wav: torch.Tensor) -> torch.Tensor: + spec = self.spec_builder(wav) + spec = self.freq_cutter(spec) + spec = self.transforms(spec) + return self.resizer(spec) + + def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: + return self.spec_builder(wav) + + def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: + return self.freq_cutter(spec) + + def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: + return self.transforms(spec) + + def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: + return self.resizer(spec) + + +def build_spectrogram_pipeline( + samplerate: int, + conf: SpectrogramConfig, +) -> SpectrogramPipeline: + spec_builder = build_spectrogram_builder(samplerate, conf.stft) + n_fft, _ = _spec_params_from_config(samplerate, conf.stft) + cutter = FrequencyClip( + low_index=_frequency_to_index( + conf.frequencies.min_freq, samplerate, n_fft + ), + high_index=_frequency_to_index( + conf.frequencies.max_freq, samplerate, n_fft + ), + ) + transforms = build_spectrogram_transform(samplerate, conf) + resizer = ResizeSpec( + height=conf.size.height, + time_factor=conf.size.resize_factor, + ) + return SpectrogramPipeline( + spec_builder=spec_builder, + freq_cutter=cutter, + transforms=transforms, + resizer=resizer, ) - return freq_bin_width, hop_duration / resize_factor diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index 120459f..8b7745c 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -28,8 +28,10 @@ from soundevent import data from batdetect2.configs import BaseConfig from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.typing.preprocess import PreprocessorProtocol +from batdetect2.preprocess.audio import build_audio_loader +from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol from batdetect2.typing.targets import Position, Size +from batdetect2.utils.arrays import spec_to_xarray __all__ = [ "Anchor", @@ -365,6 +367,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper): def __init__( self, preprocessor: PreprocessorProtocol, + audio_loader: AudioLoader, time_scale: float = DEFAULT_TIME_SCALE, frequency_scale: float = DEFAULT_FREQUENCY_SCALE, loading_buffer: float = 0.01, @@ -383,6 +386,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper): Buffer in seconds to add when loading audio clips. """ self.preprocessor = preprocessor + self.audio_loader = audio_loader self.time_scale = time_scale self.frequency_scale = frequency_scale self.loading_buffer = loading_buffer @@ -422,6 +426,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper): time, freq = get_peak_energy_coordinates( recording=sound_event.recording, + audio_loader=self.audio_loader, preprocessor=self.preprocessor, start_time=start_time, end_time=end_time, @@ -511,8 +516,10 @@ def build_roi_mapper( if config.name == "peak_energy_bbox": preprocessor = build_preprocessor(config.preprocessing) + audio_loader = build_audio_loader(config.preprocessing.audio) return PeakEnergyBBoxMapper( preprocessor=preprocessor, + audio_loader=audio_loader, time_scale=config.time_scale, frequency_scale=config.frequency_scale, loading_buffer=config.loading_buffer, @@ -617,6 +624,7 @@ def _build_bounding_box( def get_peak_energy_coordinates( recording: data.Recording, + audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, start_time: float = 0, end_time: Optional[float] = None, @@ -669,7 +677,15 @@ def get_peak_energy_coordinates( end_time=clip_end, ) - spec = preprocessor.preprocess_clip(clip) + wav = audio_loader.load_clip(clip) + spec = preprocessor.process_numpy(wav) + spec = spec_to_xarray( + spec, + clip.start_time, + clip.end_time, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) low_freq = max(low_freq, preprocessor.min_freq) high_freq = min(high_freq, preprocessor.max_freq) selection = spec.sel( diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index a991e32..c779d7e 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -129,9 +129,7 @@ def mix_examples( with xr.set_options(keep_attrs=True): combined = weight * audio1 + (1 - weight) * audio2 - spectrogram = preprocessor.compute_spectrogram( - combined.rename({"audio_time": "time"}) - ).data + spectrogram = preprocessor.process_numpy(combined.data) # NOTE: The subclip's spectrogram might be slightly longer than the # spectrogram computed from the subclip's audio. This is due to a @@ -241,9 +239,7 @@ def add_echo( with xr.set_options(keep_attrs=True): audio = audio + weight * audio_delay - spectrogram = preprocessor.compute_spectrogram( - audio.rename({"audio_time": "time"}), - ).data + spectrogram = preprocessor.process_numpy(audio.data) # NOTE: The subclip's spectrogram might be slightly longer than the # spectrogram computed from the subclip's audio. This is due to a diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index 2d8c62e..090a6a5 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -21,10 +21,12 @@ class ClipingConfig(BaseConfig): class Clipper(ClipperProtocol): def __init__( self, + samplerate: int, duration: float = 0.5, max_empty: float = 0.2, random: bool = True, ): + self.samplerate = samplerate self.duration = duration self.random = random self.max_empty = max_empty diff --git a/src/batdetect2/train/preprocess.py b/src/batdetect2/train/preprocess.py index 5c7c907..a8b5bff 100644 --- a/src/batdetect2/train/preprocess.py +++ b/src/batdetect2/train/preprocess.py @@ -25,6 +25,8 @@ from multiprocessing import Pool from pathlib import Path from typing import Callable, Optional, Sequence +import numpy as np +import torch import xarray as xr from loguru import logger from pydantic import Field @@ -34,9 +36,12 @@ from tqdm.auto import tqdm from batdetect2.configs import BaseConfig, load_config from batdetect2.data.datasets import Dataset from batdetect2.preprocess import PreprocessingConfig, build_preprocessor +from batdetect2.preprocess.audio import build_audio_loader from batdetect2.targets import TargetConfig, build_targets from batdetect2.train.labels import LabelConfig, build_clip_labeler from batdetect2.typing import ClipLabeller, PreprocessorProtocol +from batdetect2.typing.preprocess import AudioLoader +from batdetect2.utils.arrays import audio_to_xarray __all__ = [ "preprocess_annotations", @@ -76,6 +81,7 @@ def preprocess_dataset( targets = build_targets(config=config.targets) preprocessor = build_preprocessor(config=config.preprocess) labeller = build_clip_labeler(targets, config=config.labels) + audio_loader = build_audio_loader(config=config.preprocess.audio) if not output.exists(): logger.debug("Creating directory {directory}", directory=output) @@ -84,6 +90,7 @@ def preprocess_dataset( preprocess_annotations( dataset, output_dir=output, + audio_loader=audio_loader, preprocessor=preprocessor, labeller=labeller, replace=force, @@ -93,6 +100,7 @@ def preprocess_dataset( def generate_train_example( clip_annotation: data.ClipAnnotation, + audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, labeller: ClipLabeller, ) -> xr.Dataset: @@ -140,9 +148,15 @@ def generate_train_example( - The original `ClipAnnotation` metadata is stored as a JSON string in the Dataset's attributes for provenance. """ - wave = preprocessor.load_clip_audio(clip_annotation.clip) + wave = audio_loader.load_clip(clip_annotation.clip) - spectrogram = preprocessor.compute_spectrogram(wave) + spectrogram = _spec_to_xr( + preprocessor(torch.tensor(wave)), + start_time=clip_annotation.clip.start_time, + end_time=clip_annotation.clip.end_time, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) heatmaps = labeller(clip_annotation, spectrogram) @@ -152,7 +166,12 @@ def generate_train_example( # the spectrogram time dimension, otherwise xarray will interpolate # the spectrogram and the heatmaps to the same temporal resolution # as the waveform. - "audio": wave.rename({"time": "audio_time"}), + "audio": audio_to_xarray( + wave, + start_time=clip_annotation.clip.start_time, + end_time=clip_annotation.clip.end_time, + time_axis="audio_time", + ), "spectrogram": spectrogram, "detection": heatmaps.detection, "class": heatmaps.classes, @@ -170,6 +189,32 @@ def generate_train_example( ) +def _spec_to_xr( + spec: torch.Tensor, + start_time: float, + end_time: float, + min_freq: float, + max_freq: float, +) -> xr.DataArray: + data = spec.numpy()[0, 0] + + height, width = data.shape + + return xr.DataArray( + data=data, + dims=[ + "frequency", + "time", + ], + coords={ + "frequency": np.linspace( + min_freq, max_freq, height, endpoint=False + ), + "time": np.linspace(start_time, end_time, width, endpoint=False), + }, + ) + + def _save_xr_dataset_to_file( dataset: xr.Dataset, path: data.PathLike, @@ -206,6 +251,7 @@ def preprocess_annotations( clip_annotations: Sequence[data.ClipAnnotation], output_dir: data.PathLike, preprocessor: PreprocessorProtocol, + audio_loader: AudioLoader, labeller: ClipLabeller, filename_fn: FilenameFn = _get_filename, replace: bool = False, @@ -275,6 +321,7 @@ def preprocess_annotations( output_dir=output_dir, filename_fn=filename_fn, replace=replace, + audio_loader=audio_loader, preprocessor=preprocessor, labeller=labeller, ), @@ -290,6 +337,7 @@ def preprocess_annotations( def preprocess_single_annotation( clip_annotation: data.ClipAnnotation, output_dir: data.PathLike, + audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, labeller: ClipLabeller, filename_fn: FilenameFn = _get_filename, @@ -335,6 +383,7 @@ def preprocess_single_annotation( try: sample = generate_train_example( clip_annotation, + audio_loader=audio_loader, preprocessor=preprocessor, labeller=labeller, ) diff --git a/src/batdetect2/typing/preprocess.py b/src/batdetect2/typing/preprocess.py index 71ad9bd..9f02ab8 100644 --- a/src/batdetect2/typing/preprocess.py +++ b/src/batdetect2/typing/preprocess.py @@ -10,10 +10,10 @@ pipeline can interact consistently, regardless of the specific underlying implementation (e.g., different libraries or custom configurations). """ -from typing import Optional, Protocol, Union +from typing import Optional, Protocol import numpy as np -import xarray as xr +import torch from soundevent import data __all__ = [ @@ -36,7 +36,7 @@ class AudioLoader(Protocol): self, path: data.PathLike, audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: + ) -> np.ndarray: """Load and preprocess audio directly from a file path. Parameters @@ -46,12 +46,6 @@ class AudioLoader(Protocol): audio_dir : PathLike, optional A directory prefix to prepend to the path if `path` is relative. - Returns - ------- - xr.DataArray - The loaded and preprocessed audio waveform as an xarray DataArray - with time coordinates. Typically loads only the first channel. - Raises ------ FileNotFoundError @@ -65,7 +59,7 @@ class AudioLoader(Protocol): self, recording: data.Recording, audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: + ) -> np.ndarray: """Load and preprocess the entire audio for a Recording object. Parameters @@ -95,7 +89,7 @@ class AudioLoader(Protocol): self, clip: data.Clip, audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: + ) -> np.ndarray: """Load and preprocess the audio segment defined by a Clip object. Parameters @@ -124,264 +118,41 @@ class AudioLoader(Protocol): class SpectrogramBuilder(Protocol): - """Defines the interface for a spectrogram generation component. + """Defines the interface for a spectrogram generation component.""" - A SpectrogramBuilder takes a waveform (as numpy array or xarray DataArray) - and produces a spectrogram (as an xarray DataArray) based on its internal - configuration or implementation. - """ - - def __call__( - self, - wav: Union[np.ndarray, xr.DataArray], - samplerate: Optional[int] = None, - ) -> xr.DataArray: - """Generate a spectrogram from an audio waveform. - - Parameters - ---------- - wav : Union[np.ndarray, xr.DataArray] - The input audio waveform. If a numpy array, `samplerate` must - also be provided. If an xarray DataArray, it must have a 'time' - coordinate from which the sample rate can be inferred. - samplerate : int, optional - The sample rate of the audio in Hz. Required if `wav` is a - numpy array. If `wav` is an xarray DataArray, this parameter is - ignored as the sample rate is derived from the coordinates. - - Returns - ------- - xr.DataArray - The computed spectrogram as an xarray DataArray with 'time' and - 'frequency' coordinates. - - Raises - ------ - ValueError - If `wav` is a numpy array and `samplerate` is not provided, or - if `wav` is an xarray DataArray without a valid 'time' coordinate. - """ + def __call__(self, wav: torch.Tensor) -> torch.Tensor: + """Generate a spectrogram from an audio waveform.""" ... -class PreprocessorProtocol(Protocol): - """Defines a high-level interface for the complete preprocessing pipeline. +class AudioPipeline(Protocol): + def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... - A Preprocessor combines audio loading and spectrogram generation steps. - It provides methods to go directly from source descriptions (file paths, - Recording objects, Clip objects) to the final spectrogram representation - needed by the model. It may also expose intermediate steps like audio - loading or spectrogram computation from a waveform. - """ + +class SpectrogramPipeline(Protocol): + def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ... + + def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: ... + + def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ... + + def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ... + + def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... + + +class PreprocessorProtocol(Protocol): + """Defines a high-level interface for the complete preprocessing pipeline.""" max_freq: float min_freq: float - def preprocess_file( - self, - path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load audio from a file and compute the final processed spectrogram. + audio_pipeline: AudioPipeline - Performs the full pipeline: + spectrogram_pipeline: SpectrogramPipeline - Load -> Preprocess Audio -> Compute Spectrogram. + def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... - Parameters - ---------- - path : PathLike - Path to the audio file. - audio_dir : PathLike, optional - A directory prefix if `path` is relative. - - Returns - ------- - xr.DataArray - The final processed spectrogram. - - Raises - ------ - FileNotFoundError - If the audio file cannot be found. - Exception - If any step in the loading or preprocessing fails. - """ - ... - - def preprocess_recording( - self, - recording: data.Recording, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load audio for a Recording and compute the processed spectrogram. - - Performs the full pipeline for the entire duration of the recording. - - Parameters - ---------- - recording : data.Recording - The Recording object. - audio_dir : PathLike, optional - Directory containing the audio file. - - Returns - ------- - xr.DataArray - The final processed spectrogram. - - Raises - ------ - FileNotFoundError - If the audio file cannot be found. - Exception - If any step in the loading or preprocessing fails. - """ - ... - - def preprocess_clip( - self, - clip: data.Clip, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load audio for a Clip and compute the final processed spectrogram. - - Performs the full pipeline for the specified clip segment. - - Parameters - ---------- - clip : data.Clip - The Clip object defining the audio segment. - audio_dir : PathLike, optional - Directory containing the audio file. - - Returns - ------- - xr.DataArray - The final processed spectrogram. - - Raises - ------ - FileNotFoundError - If the audio file cannot be found. - Exception - If any step in the loading or preprocessing fails. - """ - ... - - def load_file_audio( - self, - path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess *only* the audio waveform from a file path. - - Performs the initial audio loading and waveform processing steps - (like resampling, scaling), but stops *before* spectrogram generation. - - Parameters - ---------- - path : PathLike - Path to the audio file. - audio_dir : PathLike, optional - A directory prefix if `path` is relative. - - Returns - ------- - xr.DataArray - The loaded and preprocessed audio waveform. - - Raises - ------ - FileNotFoundError, Exception - If audio loading/preprocessing fails. - """ - ... - - def load_recording_audio( - self, - recording: data.Recording, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess *only* the audio waveform for a Recording. - - Performs the initial audio loading and waveform processing steps - for the entire recording duration. - - Parameters - ---------- - recording : data.Recording - The Recording object. - audio_dir : PathLike, optional - Directory containing the audio file. - - Returns - ------- - xr.DataArray - The loaded and preprocessed audio waveform. - - Raises - ------ - FileNotFoundError, Exception - If audio loading/preprocessing fails. - """ - ... - - def load_clip_audio( - self, - clip: data.Clip, - audio_dir: Optional[data.PathLike] = None, - ) -> xr.DataArray: - """Load and preprocess *only* the audio waveform for a Clip. - - Performs the initial audio loading and waveform processing steps - for the specified clip segment. - - Parameters - ---------- - clip : data.Clip - The Clip object defining the segment. - audio_dir : PathLike, optional - Directory containing the audio file. - - Returns - ------- - xr.DataArray - The loaded and preprocessed audio waveform segment. - - Raises - ------ - FileNotFoundError, Exception - If audio loading/preprocessing fails. - """ - ... - - def compute_spectrogram( - self, - wav: Union[xr.DataArray, np.ndarray], - ) -> xr.DataArray: - """Compute the spectrogram from a pre-loaded audio waveform. - - Applies the spectrogram generation steps (STFT, scaling, etc.) defined - by the `SpectrogramBuilder` component of the preprocessor to an - already loaded (and potentially preprocessed) waveform. - - Parameters - ---------- - wav : Union[xr.DataArray, np.ndarray] - The input audio waveform. If numpy array, `samplerate` is required. - samplerate : int, optional - Sample rate in Hz (required if `wav` is np.ndarray). - - Returns - ------- - xr.DataArray - The computed spectrogram. - - Raises - ------ - ValueError, Exception - If waveform input is invalid or spectrogram computation fails. - """ - ... + def process_numpy(self, wav: np.ndarray) -> np.ndarray: + return self(torch.tensor(wav)).numpy()[0, 0] diff --git a/src/batdetect2/utils/arrays.py b/src/batdetect2/utils/arrays.py index bf00ee7..60a8bd3 100644 --- a/src/batdetect2/utils/arrays.py +++ b/src/batdetect2/utils/arrays.py @@ -2,6 +2,62 @@ import numpy as np import xarray as xr +def spec_to_xarray( + spec: np.ndarray, + start_time: float, + end_time: float, + min_freq: float, + max_freq: float, +) -> xr.DataArray: + if spec.ndim != 2: + raise ValueError( + "Input numpy spectrogram array should be 2-dimensional" + ) + + height, width = spec.shape + return xr.DataArray( + data=spec, + dims=["frequency", "time"], + coords={ + "frequency": np.linspace( + min_freq, + max_freq, + height, + endpoint=False, + ), + "time": np.linspace( + start_time, + end_time, + width, + endpoint=False, + ), + }, + ) + + +def audio_to_xarray( + wav: np.ndarray, + start_time: float, + end_time: float, + time_axis: str = "time", +) -> xr.DataArray: + if wav.ndim != 1: + raise ValueError("Input numpy audio array should be 1-dimensional") + + return xr.DataArray( + data=wav, + dims=[time_axis], + coords={ + time_axis: np.linspace( + start_time, + end_time, + len(wav), + endpoint=False, + ), + }, + ) + + def extend_width( array: np.ndarray, extra: int, diff --git a/tests/conftest.py b/tests/conftest.py index b56bce8..0f1f806 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from soundevent import data, terms from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.preprocess import build_preprocessor +from batdetect2.preprocess.audio import build_audio_loader from batdetect2.targets import ( TargetConfig, TermRegistry, @@ -27,6 +28,7 @@ from batdetect2.typing import ( PreprocessorProtocol, TargetProtocol, ) +from batdetect2.typing.preprocess import AudioLoader @pytest.fixture @@ -368,6 +370,11 @@ def sample_preprocessor() -> PreprocessorProtocol: return build_preprocessor() +@pytest.fixture +def sample_audio_loader() -> AudioLoader: + return build_audio_loader() + + @pytest.fixture def bat_tag() -> TagInfo: return TagInfo(key="class", value="bat") diff --git a/tests/test_preprocessing/test_audio.py b/tests/test_preprocessing/test_audio.py index 70a63f6..e820027 100644 --- a/tests/test_preprocessing/test_audio.py +++ b/tests/test_preprocessing/test_audio.py @@ -1,13 +1,10 @@ import pathlib import uuid -from pathlib import Path import numpy as np import pytest import soundfile as sf -import xarray as xr from soundevent import data -from soundevent.arrays import Dimensions, create_time_dim_from_array from batdetect2.preprocess import audio @@ -30,44 +27,6 @@ def create_dummy_wave( return wave.astype(dtype) -def create_xr_wave( - samplerate: int, - duration: float, - num_channels: int = 1, - freq: float = 440.0, - amplitude: float = 0.5, - start_time: float = 0.0, -) -> xr.DataArray: - """Generates a simple xarray waveform.""" - num_samples = int(samplerate * duration) - times = np.linspace( - start_time, - start_time + duration, - num_samples, - endpoint=False, - ) - coords = { - Dimensions.time.value: create_time_dim_from_array( - times, samplerate=samplerate, start_time=start_time - ) - } - dims = [Dimensions.time.value] - - wave_data = amplitude * np.sin(2 * np.pi * freq * times) - - if num_channels > 1: - coords[Dimensions.channel.value] = np.arange(num_channels) - dims = [Dimensions.channel.value] + dims - wave_data = np.stack([wave_data] * num_channels, axis=0) - - return xr.DataArray( - wave_data.astype(np.float32), - coords=coords, - dims=dims, - attrs={"samplerate": samplerate}, - ) - - @pytest.fixture def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path: """Creates a dummy WAV file and returns its path.""" @@ -99,408 +58,3 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip: @pytest.fixture def default_audio_config() -> audio.AudioConfig: return audio.AudioConfig() - - -@pytest.fixture -def no_resample_config() -> audio.AudioConfig: - return audio.AudioConfig(resample=None) - - -@pytest.fixture -def fixed_duration_config() -> audio.AudioConfig: - return audio.AudioConfig(duration=0.5) - - -@pytest.fixture -def scale_config() -> audio.AudioConfig: - return audio.AudioConfig(scale=True, center=False) - - -@pytest.fixture -def no_center_config() -> audio.AudioConfig: - return audio.AudioConfig(center=False) - - -@pytest.fixture -def resample_fourier_config() -> audio.AudioConfig: - return audio.AudioConfig( - resample=audio.ResampleConfig( - samplerate=audio.TARGET_SAMPLERATE_HZ // 2, method="fourier" - ) - ) - - -def test_resample_config_defaults(): - config = audio.ResampleConfig() - assert config.samplerate == audio.TARGET_SAMPLERATE_HZ - assert config.method == "poly" - - -def test_audio_config_defaults(): - config = audio.AudioConfig() - assert config.resample is not None - assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ - assert config.resample.method == "poly" - assert config.scale == audio.SCALE_RAW_AUDIO - assert config.center is False - assert config.duration == audio.DEFAULT_DURATION - - -def test_audio_config_override(): - resample_cfg = audio.ResampleConfig(samplerate=44100, method="fourier") - config = audio.AudioConfig( - resample=resample_cfg, - scale=True, - center=False, - duration=1.0, - ) - assert config.resample == resample_cfg - assert config.scale is True - assert config.center is False - assert config.duration == 1.0 - - -def test_audio_config_no_resample(): - config = audio.AudioConfig(resample=None) - assert config.resample is None - - -@pytest.mark.parametrize( - "orig_sr, orig_dur, target_dur", - [ - (256_000, 1.0, 0.5), - (256_000, 0.5, 1.0), - (256_000, 1.0, 1.0), - ], -) -def test_adjust_audio_duration(orig_sr, orig_dur, target_dur): - wave = create_xr_wave(samplerate=orig_sr, duration=orig_dur) - adjusted_wave = audio.adjust_audio_duration(wave, duration=target_dur) - expected_samples = int(target_dur * orig_sr) - assert adjusted_wave.sizes["time"] == expected_samples - assert adjusted_wave.coords["time"].attrs["step"] == 1 / orig_sr - assert adjusted_wave.dtype == wave.dtype - if orig_dur > 0 and target_dur > orig_dur: - padding_start_index = int(orig_dur * orig_sr) + 1 - assert np.all(adjusted_wave.values[padding_start_index:] == 0) - - -def test_adjust_audio_duration_negative_target_raises(): - wave = create_xr_wave(1000, 1.0) - with pytest.raises(ValueError): - audio.adjust_audio_duration(wave, duration=-0.5) - - -@pytest.mark.parametrize( - "orig_sr, target_sr, mode", - [ - (48000, 96000, "poly"), - (96000, 48000, "poly"), - (48000, 96000, "fourier"), - (96000, 48000, "fourier"), - (48000, 44100, "poly"), - (48000, 44100, "fourier"), - ], -) -def test_resample_audio(orig_sr, target_sr, mode): - duration = 0.1 - wave = create_xr_wave(orig_sr, duration) - resampled_wave = audio.resample_audio( - wave, samplerate=target_sr, method=mode, dtype=np.float32 - ) - expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr)) - assert resampled_wave.sizes["time"] == expected_samples - assert resampled_wave.coords["time"].attrs["step"] == 1 / target_sr - assert np.isclose( - resampled_wave.coords["time"].values[-1] - - resampled_wave.coords["time"].values[0], - duration, - atol=2 / target_sr, - ) - assert resampled_wave.dtype == np.float32 - - -def test_resample_audio_same_samplerate(): - sr = 48000 - duration = 0.1 - wave = create_xr_wave(sr, duration) - resampled_wave = audio.resample_audio( - wave, samplerate=sr, dtype=np.float64 - ) - xr.testing.assert_equal(wave.astype(np.float64), resampled_wave) - - -def test_resample_audio_invalid_mode_raises(): - wave = create_xr_wave(48000, 0.1) - with pytest.raises(NotImplementedError): - audio.resample_audio(wave, samplerate=96000, method="invalid_mode") - - -def test_resample_audio_no_time_dim_raises(): - wave = xr.DataArray(np.random.rand(100), dims=["samples"]) - with pytest.raises(ValueError, match="Audio must have a time dimension"): - audio.resample_audio(wave, samplerate=96000) - - -def test_load_clip_audio_default_config( - dummy_clip: data.Clip, - default_audio_config: audio.AudioConfig, - tmp_path: Path, -): - assert default_audio_config.resample is not None - target_sr = default_audio_config.resample.samplerate - orig_duration = dummy_clip.duration - expected_samples = int(orig_duration * target_sr) - - wav = audio.load_clip_audio( - dummy_clip, config=default_audio_config, audio_dir=tmp_path - ) - - assert isinstance(wav, xr.DataArray) - assert wav.dims == ("time",) - assert wav.sizes["time"] == expected_samples - assert wav.coords["time"].attrs["step"] == 1 / target_sr - assert np.isclose(wav.mean(), 0.0, atol=1e-6) - assert wav.dtype == np.float32 - - -def test_load_clip_audio_no_resample( - dummy_clip: data.Clip, - no_resample_config: audio.AudioConfig, - tmp_path: Path, -): - orig_sr = dummy_clip.recording.samplerate - orig_duration = dummy_clip.duration - expected_samples = int(orig_duration * orig_sr) - - wav = audio.load_clip_audio( - dummy_clip, config=no_resample_config, audio_dir=tmp_path - ) - - assert wav.coords["time"].attrs["step"] == 1 / orig_sr - assert wav.sizes["time"] == expected_samples - assert np.isclose(wav.mean(), 0.0, atol=1e-6) - - -def test_load_clip_audio_fixed_duration_crop( - dummy_clip: data.Clip, - fixed_duration_config: audio.AudioConfig, - tmp_path: Path, -): - target_sr = audio.TARGET_SAMPLERATE_HZ - target_duration = fixed_duration_config.duration - assert target_duration is not None - expected_samples = int(target_duration * target_sr) - - assert dummy_clip.duration > target_duration - - wav = audio.load_clip_audio( - dummy_clip, config=fixed_duration_config, audio_dir=tmp_path - ) - - assert wav.coords["time"].attrs["step"] == 1 / target_sr - assert wav.sizes["time"] == expected_samples - - -def test_load_clip_audio_fixed_duration_pad( - dummy_clip: data.Clip, - tmp_path: Path, -): - target_duration = dummy_clip.duration * 2 - config = audio.AudioConfig(duration=target_duration) - - assert config.resample is not None - target_sr = config.resample.samplerate - expected_samples = int(target_duration * target_sr) - - wav = audio.load_clip_audio(dummy_clip, config=config, audio_dir=tmp_path) - - assert wav.coords["time"].attrs["step"] == 1 / target_sr - assert wav.sizes["time"] == expected_samples - - original_samples_after_resample = int(dummy_clip.duration * target_sr) - assert np.allclose( - wav.values[original_samples_after_resample:], 0.0, atol=1e-6 - ) - - -def test_load_clip_audio_scale( - dummy_clip: data.Clip, scale_config: audio.AudioConfig, tmp_path -): - wav = audio.load_clip_audio( - dummy_clip, - config=scale_config, - audio_dir=tmp_path, - ) - - assert np.isclose(np.max(np.abs(wav.values)), 1.0, atol=1e-5) - - -def test_load_clip_audio_no_center( - dummy_clip: data.Clip, no_center_config: audio.AudioConfig, tmp_path -): - wav = audio.load_clip_audio( - dummy_clip, config=no_center_config, audio_dir=tmp_path - ) - - raw_wav, _ = sf.read( - dummy_clip.recording.path, - start=int(dummy_clip.start_time * dummy_clip.recording.samplerate), - stop=int(dummy_clip.end_time * dummy_clip.recording.samplerate), - dtype=np.float32, # type: ignore - ) - raw_wav_mono = raw_wav[:, 0] - - if not np.isclose(raw_wav_mono.mean(), 0.0, atol=1e-7): - assert not np.isclose(wav.mean(), 0.0, atol=1e-6) - - -def test_load_clip_audio_resample_fourier( - dummy_clip: data.Clip, resample_fourier_config: audio.AudioConfig, tmp_path -): - assert resample_fourier_config.resample is not None - target_sr = resample_fourier_config.resample.samplerate - orig_duration = dummy_clip.duration - expected_samples = int(orig_duration * target_sr) - - wav = audio.load_clip_audio( - dummy_clip, config=resample_fourier_config, audio_dir=tmp_path - ) - - assert wav.coords["time"].attrs["step"] == 1 / target_sr - assert wav.sizes["time"] == expected_samples - - -def test_load_clip_audio_dtype( - dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path -): - wav = audio.load_clip_audio( - dummy_clip, - config=default_audio_config, - audio_dir=tmp_path, - dtype=np.float64, - ) - assert wav.dtype == np.float64 - - -def test_load_clip_audio_file_not_found( - dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path -): - non_existent_path = tmp_path / "not_a_real_file.wav" - dummy_clip.recording = data.Recording( - path=non_existent_path, - duration=1, - channels=1, - samplerate=256000, - ) - with pytest.raises(FileNotFoundError): - audio.load_clip_audio( - dummy_clip, config=default_audio_config, audio_dir=tmp_path - ) - - -def test_load_recording_audio( - dummy_recording: data.Recording, - default_audio_config: audio.AudioConfig, - tmp_path, -): - assert default_audio_config.resample is not None - target_sr = default_audio_config.resample.samplerate - orig_duration = dummy_recording.duration - expected_samples = int(orig_duration * target_sr) - - wav = audio.load_recording_audio( - dummy_recording, config=default_audio_config, audio_dir=tmp_path - ) - - assert isinstance(wav, xr.DataArray) - assert wav.dims == ("time",) - assert wav.coords["time"].attrs["step"] == 1 / target_sr - assert wav.sizes["time"] == expected_samples - assert np.isclose(wav.mean(), 0.0, atol=1e-6) - assert wav.dtype == np.float32 - - -def test_load_recording_audio_file_not_found( - dummy_recording: data.Recording, - default_audio_config: audio.AudioConfig, - tmp_path, -): - non_existent_path = tmp_path / "not_a_real_file.wav" - dummy_recording = data.Recording( - path=non_existent_path, - duration=1, - channels=1, - samplerate=256000, - ) - with pytest.raises(FileNotFoundError): - audio.load_recording_audio( - dummy_recording, config=default_audio_config, audio_dir=tmp_path - ) - - -def test_load_file_audio( - dummy_wav_path: pathlib.Path, - default_audio_config: audio.AudioConfig, - tmp_path, -): - info = sf.info(dummy_wav_path) - orig_duration = info.duration - assert default_audio_config.resample is not None - target_sr = default_audio_config.resample.samplerate - expected_samples = int(orig_duration * target_sr) - - wav = audio.load_file_audio( - dummy_wav_path, config=default_audio_config, audio_dir=tmp_path - ) - - assert isinstance(wav, xr.DataArray) - assert wav.dims == ("time",) - assert wav.coords["time"].attrs["step"] == 1 / target_sr - assert wav.sizes["time"] == expected_samples - assert np.isclose(wav.mean(), 0.0, atol=1e-6) - assert wav.dtype == np.float32 - - -def test_load_file_audio_file_not_found( - default_audio_config: audio.AudioConfig, tmp_path -): - non_existent_path = tmp_path / "not_a_real_file.wav" - with pytest.raises(FileNotFoundError): - audio.load_file_audio( - non_existent_path, config=default_audio_config, audio_dir=tmp_path - ) - - -def test_build_audio_loader(default_audio_config: audio.AudioConfig): - loader = audio.build_audio_loader(config=default_audio_config) - assert isinstance(loader, audio.ConfigurableAudioLoader) - assert loader.config == default_audio_config - - -def test_configurable_audio_loader_methods( - default_audio_config: audio.AudioConfig, - dummy_wav_path: pathlib.Path, - dummy_recording: data.Recording, - dummy_clip: data.Clip, - tmp_path, -): - loader = audio.build_audio_loader(config=default_audio_config) - - expected_wav_file = audio.load_file_audio( - dummy_wav_path, config=default_audio_config, audio_dir=tmp_path - ) - loaded_wav_file = loader.load_file(dummy_wav_path, audio_dir=tmp_path) - xr.testing.assert_equal(expected_wav_file, loaded_wav_file) - - expected_wav_rec = audio.load_recording_audio( - dummy_recording, config=default_audio_config, audio_dir=tmp_path - ) - loaded_wav_rec = loader.load_recording(dummy_recording, audio_dir=tmp_path) - xr.testing.assert_equal(expected_wav_rec, loaded_wav_rec) - - expected_wav_clip = audio.load_clip_audio( - dummy_clip, config=default_audio_config, audio_dir=tmp_path - ) - loaded_wav_clip = loader.load_clip(dummy_clip, audio_dir=tmp_path) - xr.testing.assert_equal(expected_wav_clip, loaded_wav_clip) diff --git a/tests/test_preprocessing/test_spectrogram.py b/tests/test_preprocessing/test_spectrogram.py index 9f64494..b79fa48 100644 --- a/tests/test_preprocessing/test_spectrogram.py +++ b/tests/test_preprocessing/test_spectrogram.py @@ -1,32 +1,7 @@ -import math -from pathlib import Path -from typing import Callable, Union import numpy as np import pytest import xarray as xr -from soundevent import arrays - -from batdetect2.preprocess.audio import AudioConfig, load_file_audio -from batdetect2.preprocess.spectrogram import ( - MAX_FREQ, - MIN_FREQ, - ConfigurableSpectrogramBuilder, - FrequencyConfig, - PcenConfig, - SpecSizeConfig, - SpectrogramConfig, - STFTConfig, - apply_pcen, - build_spectrogram_builder, - compute_spectrogram, - crop_spectrogram_frequencies, - get_spectrogram_resolution, - remove_spectral_mean, - resize_spectrogram, - scale_spectrogram, - stft, -) SAMPLERATE = 250_000 DURATION = 0.1 @@ -61,389 +36,3 @@ def constant_wave_xr() -> xr.DataArray: dims=["time"], attrs={"samplerate": SAMPLERATE}, ) - - -@pytest.fixture -def sample_spec(sine_wave_xr: xr.DataArray) -> xr.DataArray: - """Generate a basic spectrogram for testing downstream functions.""" - config = SpectrogramConfig( - stft=STFTConfig(window_duration=0.002, window_overlap=0.5), - frequencies=FrequencyConfig( - min_freq=0, - max_freq=int(SAMPLERATE / 2), - ), - size=None, - pcen=None, - spectral_mean_substraction=False, - peak_normalize=False, - scale="amplitude", - ) - spec = stft( - sine_wave_xr, - window_duration=config.stft.window_duration, - window_overlap=config.stft.window_overlap, - window_fn=config.stft.window_fn, - ) - return spec - - -def test_stft_config_defaults(): - config = STFTConfig() - assert config.window_duration == 0.002 - assert config.window_overlap == 0.75 - assert config.window_fn == "hann" - - -def test_frequency_config_defaults(): - config = FrequencyConfig() - assert config.min_freq == MIN_FREQ - assert config.max_freq == MAX_FREQ - - -def test_spec_size_config_defaults(): - config = SpecSizeConfig() - assert config.height == 128 - assert config.resize_factor == 0.5 - - -def test_pcen_config_defaults(): - config = PcenConfig() - assert config.time_constant == 0.01 - assert config.gain == 0.98 - assert config.bias == 2 - assert config.power == 0.5 - - -def test_spectrogram_config_defaults(): - config = SpectrogramConfig() - assert isinstance(config.stft, STFTConfig) - assert isinstance(config.frequencies, FrequencyConfig) - assert isinstance(config.pcen, PcenConfig) - assert config.scale == "amplitude" - assert isinstance(config.size, SpecSizeConfig) - assert config.spectral_mean_substraction is True - assert config.peak_normalize is False - - -def test_stft_output_properties(sine_wave_xr: xr.DataArray): - window_duration = 0.002 - window_overlap = 0.5 - samplerate = sine_wave_xr.attrs["samplerate"] - nfft = int(window_duration * samplerate) - hop_len = nfft - int(window_overlap * nfft) - - spec = stft( - sine_wave_xr, - window_duration=window_duration, - window_overlap=window_overlap, - window_fn="hann", - ) - - assert isinstance(spec, xr.DataArray) - assert spec.dims == ("frequency", "time") - assert spec.dtype == np.float32 - assert "frequency" in spec.coords - assert "time" in spec.coords - - time_step = arrays.get_dim_step(spec, "time") - freq_step = arrays.get_dim_step(spec, "frequency") - freq_start, freq_end = arrays.get_dim_range(spec, "frequency") - assert np.isclose(freq_step, samplerate / nfft) - assert np.isclose(time_step, hop_len / samplerate) - assert spec.frequency.min() >= 0 - assert freq_start == 0 - assert np.isclose(freq_end, samplerate / 2, atol=freq_step / 2) - assert np.isclose(spec.time.min(), 0) - assert spec.time.max() < DURATION - - assert spec.attrs["samplerate"] == samplerate - assert spec.attrs["window_size"] == window_duration - assert spec.attrs["hop_size"] == window_duration * (1 - window_overlap) - - assert np.all(spec.data >= 0) - - -@pytest.mark.parametrize("window_fn", ["hann", "hamming"]) -def test_stft_window_fn(sine_wave_xr: xr.DataArray, window_fn: str): - spec = stft( - sine_wave_xr, - window_duration=0.002, - window_overlap=0.5, - window_fn=window_fn, - ) - assert isinstance(spec, xr.DataArray) - assert np.all(spec.data >= 0) - - -def test_crop_spectrogram_frequencies(sample_spec: xr.DataArray): - min_f, max_f = 20_000, 80_000 - cropped_spec = crop_spectrogram_frequencies( - sample_spec, min_freq=min_f, max_freq=max_f - ) - - assert cropped_spec.dims == sample_spec.dims - assert cropped_spec.dtype == sample_spec.dtype - assert cropped_spec.sizes["time"] == sample_spec.sizes["time"] - assert cropped_spec.sizes["frequency"] < sample_spec.sizes["frequency"] - assert cropped_spec.coords["frequency"].min() >= min_f - - assert np.isclose(cropped_spec.coords["frequency"].max(), max_f, rtol=0.1) - - -def test_crop_spectrogram_full_range(sample_spec: xr.DataArray): - samplerate = sample_spec.attrs["samplerate"] - min_f, max_f = 0, samplerate / 2 - cropped_spec = crop_spectrogram_frequencies( - sample_spec, min_freq=min_f, max_freq=max_f - ) - - assert cropped_spec.sizes == sample_spec.sizes - assert np.allclose(cropped_spec.data, sample_spec.data) - - -def test_apply_pcen(sample_spec: xr.DataArray): - pcen_config = PcenConfig() - pcen_spec = apply_pcen( - sample_spec, - time_constant=pcen_config.time_constant, - gain=pcen_config.gain, - bias=pcen_config.bias, - power=pcen_config.power, - ) - - assert pcen_spec.dims == sample_spec.dims - assert pcen_spec.sizes == sample_spec.sizes - assert pcen_spec.dtype == sample_spec.dtype - assert np.all(pcen_spec.data >= 0) - - assert not np.allclose(pcen_spec.data, sample_spec.data) - - -def test_scale_spectrogram_amplitude(sample_spec: xr.DataArray): - scaled_spec = scale_spectrogram(sample_spec, scale="amplitude") - assert np.allclose(scaled_spec.data, sample_spec.data) - assert scaled_spec.dtype == sample_spec.dtype - - -def test_scale_spectrogram_power(sample_spec: xr.DataArray): - scaled_spec = scale_spectrogram(sample_spec, scale="power") - assert np.allclose(scaled_spec.data, sample_spec.data**2) - assert scaled_spec.dtype == sample_spec.dtype - - -def test_scale_spectrogram_db(sample_spec: xr.DataArray): - scaled_spec = scale_spectrogram(sample_spec, scale="dB") - log_spec_expected = arrays.to_db(sample_spec) - xr.testing.assert_allclose(scaled_spec, log_spec_expected) - - -def test_remove_spectral_mean(sample_spec: xr.DataArray): - spec_noisy = sample_spec.copy() + 0.1 - denoised_spec = remove_spectral_mean(spec_noisy) - - assert denoised_spec.dims == spec_noisy.dims - assert denoised_spec.sizes == spec_noisy.sizes - assert denoised_spec.dtype == spec_noisy.dtype - assert np.all(denoised_spec.data >= 0) - - -def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray): - const_spec = stft(constant_wave_xr, 0.002, 0.5) - denoised_spec = remove_spectral_mean(const_spec) - assert np.all(denoised_spec.data >= 0) - - -@pytest.mark.parametrize( - "height, resize_factor, expected_freq_size, expected_time_factor", - [ - (128, 1.0, 128, 1.0), - (64, 0.5, 64, 0.5), - (256, None, 256, 1.0), - (100, 2.0, 100, 2.0), - ], -) -def test_resize_spectrogram( - sample_spec: xr.DataArray, - height: int, - resize_factor: Union[float, None], - expected_freq_size: int, - expected_time_factor: float, -): - original_time_size = sample_spec.sizes["time"] - resized_spec = resize_spectrogram( - sample_spec, - height=height, - resize_factor=resize_factor, - ) - - assert resized_spec.dims == ("frequency", "time") - assert resized_spec.sizes["frequency"] == expected_freq_size - expected_time_size = int(original_time_size * expected_time_factor) - - assert abs(resized_spec.sizes["time"] - expected_time_size) <= 1 - - -def test_compute_spectrogram_defaults(sine_wave_xr: xr.DataArray): - config = SpectrogramConfig() - spec = compute_spectrogram(sine_wave_xr, config=config) - - assert isinstance(spec, xr.DataArray) - assert spec.dims == ("frequency", "time") - assert spec.dtype == np.float32 - assert config.size is not None - assert spec.sizes["frequency"] == config.size.height - - temp_stft = stft( - sine_wave_xr, config.stft.window_duration, config.stft.window_overlap - ) - assert config.size.resize_factor is not None - expected_time_size = int( - temp_stft.sizes["time"] * config.size.resize_factor - ) - assert abs(spec.sizes["time"] - expected_time_size) <= 1 - - assert spec.coords["frequency"].min() >= config.frequencies.min_freq - assert np.isclose( - spec.coords["frequency"].max(), - config.frequencies.max_freq, - rtol=0.1, - ) - - -def test_compute_spectrogram_no_pcen_no_mean_sub_no_resize( - sine_wave_xr: xr.DataArray, -): - config = SpectrogramConfig( - pcen=None, - spectral_mean_substraction=False, - size=None, - scale="power", - frequencies=FrequencyConfig(min_freq=0, max_freq=int(SAMPLERATE / 2)), - ) - spec = compute_spectrogram(sine_wave_xr, config=config) - - stft_direct = stft( - sine_wave_xr, config.stft.window_duration, config.stft.window_overlap - ) - expected_spec = scale_spectrogram(stft_direct, scale="power") - - assert spec.sizes == expected_spec.sizes - assert np.allclose(spec.data, expected_spec.data) - assert spec.dtype == expected_spec.dtype - - -def test_compute_spectrogram_peak_normalize(sine_wave_xr: xr.DataArray): - config = SpectrogramConfig(peak_normalize=True, pcen=None) - spec = compute_spectrogram(sine_wave_xr, config=config) - assert np.isclose(spec.data.max(), 1.0, atol=1e-6) - - config = SpectrogramConfig(peak_normalize=False) - spec_no_norm = compute_spectrogram(sine_wave_xr, config=config) - assert not np.isclose(spec_no_norm.data.max(), 1.0, atol=1e-6) - - -def test_get_spectrogram_resolution_calculation(): - config = SpectrogramConfig( - stft=STFTConfig(window_duration=0.002, window_overlap=0.75), - size=SpecSizeConfig(height=100, resize_factor=0.5), - frequencies=FrequencyConfig(min_freq=10_000, max_freq=110_000), - ) - - freq_res, time_res = get_spectrogram_resolution(config) - - expected_freq_res = (110_000 - 10_000) / 100 - expected_hop_duration = 0.002 * (1 - 0.75) - expected_time_res = expected_hop_duration / 0.5 - - assert np.isclose(freq_res, expected_freq_res) - assert np.isclose(time_res, expected_time_res) - - -def test_get_spectrogram_resolution_no_resize_factor(): - config = SpectrogramConfig( - stft=STFTConfig(window_duration=0.004, window_overlap=0.5), - size=SpecSizeConfig(height=200, resize_factor=None), - frequencies=FrequencyConfig(min_freq=20_000, max_freq=120_000), - ) - freq_res, time_res = get_spectrogram_resolution(config) - expected_freq_res = (120_000 - 20_000) / 200 - expected_hop_duration = 0.004 * (1 - 0.5) - expected_time_res = expected_hop_duration / 1.0 - - assert np.isclose(freq_res, expected_freq_res) - assert np.isclose(time_res, expected_time_res) - - -def test_get_spectrogram_resolution_no_size_config(): - config = SpectrogramConfig(size=None) - with pytest.raises( - ValueError, match="Spectrogram size configuration is required" - ): - get_spectrogram_resolution(config) - - -def test_configurable_spectrogram_builder_init(): - config = SpectrogramConfig() - builder = ConfigurableSpectrogramBuilder(config=config, dtype=np.float16) - assert builder.config is config - assert builder.dtype == np.float16 - - -def test_configurable_spectrogram_builder_call_xr(sine_wave_xr: xr.DataArray): - config = SpectrogramConfig() - builder = ConfigurableSpectrogramBuilder(config=config) - spec_builder = builder(sine_wave_xr) - spec_direct = compute_spectrogram(sine_wave_xr, config=config) - assert isinstance(spec_builder, xr.DataArray) - assert np.allclose(spec_builder.data, spec_direct.data) - assert spec_builder.dtype == spec_direct.dtype - - -def test_configurable_spectrogram_builder_call_np_no_samplerate( - sine_wave_xr: xr.DataArray, -): - config = SpectrogramConfig() - builder = ConfigurableSpectrogramBuilder(config=config) - wav_np = sine_wave_xr.data - with pytest.raises(ValueError, match="Samplerate must be provided"): - builder(wav_np, samplerate=None) - - -def test_build_spectrogram_builder(): - config = SpectrogramConfig(peak_normalize=True) - builder = build_spectrogram_builder(config=config, dtype=np.float64) - assert isinstance(builder, ConfigurableSpectrogramBuilder) - assert builder.config is config - assert builder.dtype == np.float64 - - -def test_can_estimate_spectrogram_resolution( - wav_factory: Callable[..., Path], -): - path = wav_factory(duration=0.2, samplerate=256_000) - - audio_data = load_file_audio( - path, - config=AudioConfig(resample=None, duration=None), - ) - - config = SpectrogramConfig( - stft=STFTConfig(), - size=SpecSizeConfig(height=256, resize_factor=0.5), - frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000), - ) - - spec = compute_spectrogram(audio_data, config=config) - - freq_res, time_res = get_spectrogram_resolution(config) - - assert math.isclose( - arrays.get_dim_step(spec, dim="frequency"), - freq_res, - rel_tol=0.1, - ) - assert math.isclose( - arrays.get_dim_step(spec, dim="time"), - time_res, - rel_tol=0.1, - ) diff --git a/tests/test_targets/test_rois.py b/tests/test_targets/test_rois.py index 8b92045..8abf18f 100644 --- a/tests/test_targets/test_rois.py +++ b/tests/test_targets/test_rois.py @@ -3,7 +3,11 @@ import pytest import soundfile as sf from soundevent import data -from batdetect2.preprocess import PreprocessingConfig, build_preprocessor +from batdetect2.preprocess import ( + PreprocessingConfig, + build_preprocessor, +) +from batdetect2.preprocess.audio import build_audio_loader from batdetect2.targets.rois import ( DEFAULT_ANCHOR, DEFAULT_FREQUENCY_SCALE, @@ -275,6 +279,8 @@ def test_get_peak_energy_coordinates(generate_whistle): # Build a preprocessor (default config should be fine for this test) preprocessor = build_preprocessor() + audio_loader = build_audio_loader() + # Define a region of interest that contains the whistle start_time = 0.2 end_time = 0.7 @@ -285,6 +291,7 @@ def test_get_peak_energy_coordinates(generate_whistle): peak_time, peak_freq = get_peak_energy_coordinates( recording=recording, preprocessor=preprocessor, + audio_loader=audio_loader, start_time=start_time, end_time=end_time, low_freq=low_freq, @@ -356,6 +363,7 @@ def test_get_peak_energy_coordinates_with_two_whistles(generate_whistle): peak_time, peak_freq = get_peak_energy_coordinates( recording=recording, preprocessor=preprocessor, + audio_loader=build_audio_loader(), start_time=start_time, end_time=end_time, low_freq=low_freq, @@ -389,6 +397,7 @@ def test_get_peak_energy_coordinates_silent_region(create_recording): peak_time, peak_freq = get_peak_energy_coordinates( recording=recording, preprocessor=preprocessor, + audio_loader=build_audio_loader(), start_time=start_time, end_time=end_time, low_freq=low_freq, @@ -443,17 +452,11 @@ def test_peak_energy_bbox_mapper_encode(generate_whistle): # Instantiate the mapper with a preprocessor preprocessor = build_preprocessor( - PreprocessingConfig.model_validate( - { - "spectrogram": { - "pcen": None, - "spectral_mean_substraction": False, - } - } - ) + PreprocessingConfig.model_validate({"spectrogram": {"transforms": []}}) ) mapper = PeakEnergyBBoxMapper( preprocessor=preprocessor, + audio_loader=build_audio_loader(), time_scale=time_scale, frequency_scale=freq_scale, ) @@ -493,6 +496,7 @@ def test_peak_energy_bbox_mapper_decode(): mapper = PeakEnergyBBoxMapper( preprocessor=build_preprocessor(), + audio_loader=build_audio_loader(), time_scale=time_scale, frequency_scale=freq_scale, ) @@ -553,7 +557,11 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle): } ) ) - mapper = PeakEnergyBBoxMapper(preprocessor=preprocessor) + audio_loader = build_audio_loader() + mapper = PeakEnergyBBoxMapper( + preprocessor=preprocessor, + audio_loader=audio_loader, + ) # When # Encode the sound event, then immediately decode the result. diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index 4dbd2fa..beffc32 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -11,11 +11,12 @@ from batdetect2.train.augmentations import ( ) from batdetect2.train.clips import select_subclip from batdetect2.train.preprocess import generate_train_example -from batdetect2.typing import ClipLabeller, PreprocessorProtocol +from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol def test_mix_examples( sample_preprocessor: PreprocessorProtocol, + sample_audio_loader: AudioLoader, sample_labeller: ClipLabeller, create_recording: Callable[..., data.Recording], ): @@ -30,11 +31,13 @@ def test_mix_examples( example1 = generate_train_example( clip_annotation_1, + audio_loader=sample_audio_loader, preprocessor=sample_preprocessor, labeller=sample_labeller, ) example2 = generate_train_example( clip_annotation_2, + audio_loader=sample_audio_loader, preprocessor=sample_preprocessor, labeller=sample_labeller, ) @@ -51,6 +54,7 @@ def test_mix_examples( @pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7]) def test_mix_examples_of_different_durations( sample_preprocessor: PreprocessorProtocol, + sample_audio_loader: AudioLoader, sample_labeller: ClipLabeller, create_recording: Callable[..., data.Recording], duration1: float, @@ -67,11 +71,13 @@ def test_mix_examples_of_different_durations( example1 = generate_train_example( clip_annotation_1, + audio_loader=sample_audio_loader, preprocessor=sample_preprocessor, labeller=sample_labeller, ) example2 = generate_train_example( clip_annotation_2, + audio_loader=sample_audio_loader, preprocessor=sample_preprocessor, labeller=sample_labeller, ) @@ -87,6 +93,7 @@ def test_mix_examples_of_different_durations( def test_add_echo( sample_preprocessor: PreprocessorProtocol, + sample_audio_loader: AudioLoader, sample_labeller: ClipLabeller, create_recording: Callable[..., data.Recording], ): @@ -96,6 +103,7 @@ def test_add_echo( original = generate_train_example( clip_annotation_1, + audio_loader=sample_audio_loader, preprocessor=sample_preprocessor, labeller=sample_labeller, ) @@ -109,6 +117,7 @@ def test_add_echo( def test_selected_random_subclip_has_the_correct_width( sample_preprocessor: PreprocessorProtocol, + sample_audio_loader: AudioLoader, sample_labeller: ClipLabeller, create_recording: Callable[..., data.Recording], ): @@ -118,6 +127,7 @@ def test_selected_random_subclip_has_the_correct_width( original = generate_train_example( clip_annotation_1, + audio_loader=sample_audio_loader, preprocessor=sample_preprocessor, labeller=sample_labeller, ) @@ -128,6 +138,7 @@ def test_selected_random_subclip_has_the_correct_width( def test_add_echo_after_subclip( sample_preprocessor: PreprocessorProtocol, + sample_audio_loader: AudioLoader, sample_labeller: ClipLabeller, create_recording: Callable[..., data.Recording], ): @@ -136,6 +147,7 @@ def test_add_echo_after_subclip( clip_annotation_1 = data.ClipAnnotation(clip=clip1) original = generate_train_example( clip_annotation_1, + audio_loader=sample_audio_loader, preprocessor=sample_preprocessor, labeller=sample_labeller, )