mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Preprocessing in pytorch
This commit is contained in:
parent
61115d562c
commit
667b18a54d
@ -28,13 +28,12 @@ This module provides the primary interface:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess.audio import (
|
||||
@ -44,28 +43,23 @@ from batdetect2.preprocess.audio import (
|
||||
AudioConfig,
|
||||
ResampleConfig,
|
||||
build_audio_loader,
|
||||
build_audio_pipeline,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
MAX_FREQ,
|
||||
MIN_FREQ,
|
||||
ConfigurableSpectrogramBuilder,
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
SpectrogramPipeline,
|
||||
STFTConfig,
|
||||
build_spectrogram_builder,
|
||||
get_spectrogram_resolution,
|
||||
)
|
||||
from batdetect2.typing.preprocess import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
SpectrogramBuilder,
|
||||
build_spectrogram_pipeline,
|
||||
)
|
||||
from batdetect2.typing import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"AudioConfig",
|
||||
"ConfigurableSpectrogramBuilder",
|
||||
"DEFAULT_DURATION",
|
||||
"FrequencyConfig",
|
||||
"MAX_FREQ",
|
||||
@ -75,16 +69,11 @@ __all__ = [
|
||||
"ResampleConfig",
|
||||
"SCALE_RAW_AUDIO",
|
||||
"STFTConfig",
|
||||
"SpecSizeConfig",
|
||||
"SpectrogramConfig",
|
||||
"StandardPreprocessor",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"build_audio_loader",
|
||||
"build_preprocessor",
|
||||
"build_spectrogram_builder",
|
||||
"get_spectrogram_resolution",
|
||||
"load_preprocessing_config",
|
||||
"get_default_preprocessor",
|
||||
]
|
||||
|
||||
|
||||
@ -110,343 +99,61 @@ class PreprocessingConfig(BaseConfig):
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
|
||||
class StandardPreprocessor(PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol.
|
||||
def load_preprocessing_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||
|
||||
Orchestrates the audio loading and spectrogram generation pipeline using
|
||||
an `AudioLoader` and a `SpectrogramBuilder` internally, which are
|
||||
configured according to a `PreprocessingConfig`.
|
||||
|
||||
This class is typically instantiated using the `build_preprocessor`
|
||||
factory function.
|
||||
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol."""
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio_loader : AudioLoader
|
||||
The configured audio loader instance used for waveform loading and
|
||||
initial processing.
|
||||
spectrogram_builder : SpectrogramBuilder
|
||||
The configured spectrogram builder instance used for generating
|
||||
spectrograms from waveforms.
|
||||
default_samplerate : int
|
||||
The sample rate (in Hz) assumed for input waveforms when they are
|
||||
provided as raw NumPy arrays without coordinate information (e.g.,
|
||||
when calling `compute_spectrogram` directly with `np.ndarray`).
|
||||
This value is derived from the `AudioConfig` (target resample rate
|
||||
or default if resampling is off) and also serves as documentation
|
||||
for the pipeline's intended operating sample rate. Note that when
|
||||
processing `xr.DataArray` inputs that have coordinate information
|
||||
(the standard internal workflow), the sample rate embedded in the
|
||||
coordinates takes precedence over this default value during
|
||||
spectrogram calculation.
|
||||
"""
|
||||
|
||||
audio_loader: AudioLoader
|
||||
spectrogram_builder: SpectrogramBuilder
|
||||
default_samplerate: int
|
||||
samplerate: int
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_loader: AudioLoader,
|
||||
spectrogram_builder: SpectrogramBuilder,
|
||||
default_samplerate: int,
|
||||
audio_pipeline: torch.nn.Module,
|
||||
spectrogram_pipeline: SpectrogramPipeline,
|
||||
samplerate: int,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> None:
|
||||
"""Initialize the StandardPreprocessor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_loader : AudioLoader
|
||||
An initialized audio loader conforming to the AudioLoader protocol.
|
||||
spectrogram_builder : SpectrogramBuilder
|
||||
An initialized spectrogram builder conforming to the
|
||||
SpectrogramBuilder protocol.
|
||||
default_samplerate : int
|
||||
The sample rate to assume for NumPy array inputs and potentially
|
||||
reflecting the target rate of the audio config.
|
||||
"""
|
||||
self.audio_loader = audio_loader
|
||||
self.spectrogram_builder = spectrogram_builder
|
||||
self.default_samplerate = default_samplerate
|
||||
super().__init__()
|
||||
self.audio_pipeline = audio_pipeline
|
||||
self.spectrogram_pipeline = spectrogram_pipeline
|
||||
self.samplerate = samplerate
|
||||
self.max_freq = max_freq
|
||||
self.min_freq = min_freq
|
||||
|
||||
def load_file_audio(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform from a file path.
|
||||
|
||||
Delegates to the internal `audio_loader`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform (typically first
|
||||
channel).
|
||||
"""
|
||||
return self.audio_loader.load_file(
|
||||
path,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_recording_audio(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform for a Recording.
|
||||
|
||||
Delegates to the internal `audio_loader`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform (typically first
|
||||
channel).
|
||||
"""
|
||||
return self.audio_loader.load_recording(
|
||||
recording,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_clip_audio(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform for a Clip.
|
||||
|
||||
Delegates to the internal `audio_loader`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform segment (typically first
|
||||
channel).
|
||||
"""
|
||||
return self.audio_loader.load_clip(
|
||||
clip,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def preprocess_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio from a file and compute the final processed spectrogram.
|
||||
|
||||
Performs the full pipeline:
|
||||
|
||||
Load -> Preprocess Audio -> Compute Spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
"""
|
||||
wav = self.load_file_audio(path, audio_dir=audio_dir)
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
def preprocess_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio for a Recording and compute the processed spectrogram.
|
||||
|
||||
Performs the full pipeline for the entire duration of the recording.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
"""
|
||||
wav = self.load_recording_audio(recording, audio_dir=audio_dir)
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
def preprocess_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio for a Clip and compute the final processed spectrogram.
|
||||
|
||||
Performs the full pipeline for the specified clip segment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the audio segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
"""
|
||||
wav = self.load_clip_audio(clip, audio_dir=audio_dir)
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
def compute_spectrogram(
|
||||
self, wav: Union[xr.DataArray, np.ndarray]
|
||||
) -> xr.DataArray:
|
||||
"""Compute the spectrogram from a pre-loaded audio waveform.
|
||||
|
||||
Applies the configured spectrogram generation steps
|
||||
(STFT, scaling, etc.) using the internal `spectrogram_builder`.
|
||||
|
||||
If `wav` is a NumPy array, the `default_samplerate` stored in this
|
||||
preprocessor instance will be used. If `wav` is an xarray DataArray
|
||||
with time coordinates, the sample rate derived from those coordinates
|
||||
will take precedence over `default_samplerate`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : Union[xr.DataArray, np.ndarray]
|
||||
The input audio waveform. If numpy array, `default_samplerate`
|
||||
stored in this object will be assumed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed spectrogram.
|
||||
"""
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
"""Load the unified preprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PreprocessingConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
preprocessing configuration (e.g., "train.preprocessing"). If None, the
|
||||
entire file content is validated as the PreprocessingConfig.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PreprocessingConfig
|
||||
Loaded and validated preprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to PreprocessingConfig.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
wav = self.audio_pipeline(wav)
|
||||
return self.spectrogram_pipeline(wav)
|
||||
|
||||
|
||||
def build_preprocessor(
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> PreprocessorProtocol:
|
||||
"""Factory function to build the standard preprocessor from configuration.
|
||||
|
||||
Creates instances of the required `AudioLoader` and `SpectrogramBuilder`
|
||||
based on the provided `PreprocessingConfig` (or defaults if config is None),
|
||||
determines the effective default sample rate, and initializes the
|
||||
`StandardPreprocessor`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : PreprocessingConfig, optional
|
||||
The unified preprocessing configuration object. If None, default
|
||||
configurations for audio and spectrogram processing will be used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Preprocessor
|
||||
An initialized `StandardPreprocessor` instance ready to process audio
|
||||
according to the configuration.
|
||||
"""
|
||||
"""Factory function to build the standard preprocessor from configuration."""
|
||||
config = config or PreprocessingConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building preprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
default_samplerate = (
|
||||
config.audio.resample.samplerate
|
||||
if config.audio.resample
|
||||
else TARGET_SAMPLERATE_HZ
|
||||
)
|
||||
samplerate = config.audio.samplerate
|
||||
|
||||
min_freq = config.spectrogram.frequencies.min_freq
|
||||
max_freq = config.spectrogram.frequencies.max_freq
|
||||
|
||||
return StandardPreprocessor(
|
||||
audio_loader=build_audio_loader(config.audio),
|
||||
spectrogram_builder=build_spectrogram_builder(config.spectrogram),
|
||||
default_samplerate=default_samplerate,
|
||||
audio_pipeline=build_audio_pipeline(config.audio),
|
||||
spectrogram_pipeline=build_spectrogram_pipeline(
|
||||
samplerate, config.spectrogram
|
||||
),
|
||||
samplerate=samplerate,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
@ -1,53 +1,31 @@
|
||||
"""Handles loading and initial preprocessing of audio waveforms.
|
||||
"""Handles loading and initial preprocessing of audio waveforms."""
|
||||
|
||||
This module provides components for loading audio data associated with
|
||||
`soundevent` objects (Clips, Recordings, or raw files) and applying
|
||||
fundamental waveform processing steps. These steps typically include:
|
||||
|
||||
1. Loading the raw audio data.
|
||||
2. Adjusting the audio clip to a fixed duration (optional).
|
||||
3. Resampling the audio to a target sample rate (optional).
|
||||
4. Centering the waveform (DC offset removal) (optional).
|
||||
5. Scaling the waveform amplitude (optional).
|
||||
|
||||
The processing pipeline is configurable via the `AudioConfig` data structure,
|
||||
allowing for reproducible preprocessing consistent between model training and
|
||||
inference. It uses the `soundevent` library for audio loading and basic array
|
||||
operations, and `scipy` for resampling implementations.
|
||||
|
||||
The primary interface is the `AudioLoader` protocol, with
|
||||
`ConfigurableAudioLoader` providing a concrete implementation driven by the
|
||||
`AudioConfig`.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
import torch
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import arrays, audio, data
|
||||
from soundevent.arrays import operations as ops
|
||||
from soundevent import audio, data
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
|
||||
from batdetect2.typing import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"ResampleConfig",
|
||||
"AudioConfig",
|
||||
"ConfigurableAudioLoader",
|
||||
"SoundEventAudioLoader",
|
||||
"build_audio_loader",
|
||||
"load_file_audio",
|
||||
"load_recording_audio",
|
||||
"load_clip_audio",
|
||||
"adjust_audio_duration",
|
||||
"resample_audio",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"SCALE_RAW_AUDIO",
|
||||
"DEFAULT_DURATION",
|
||||
"convert_to_xr",
|
||||
]
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
@ -76,192 +54,69 @@ class ResampleConfig(BaseConfig):
|
||||
resampling factors differently.
|
||||
"""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
enabled: bool = True
|
||||
method: str = "poly"
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
"""Configuration for loading and initial audio preprocessing.
|
||||
|
||||
Defines the sequence of operations applied to raw audio waveforms after
|
||||
loading, controlling steps like resampling, scaling, centering, and
|
||||
duration adjustment.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
resample : ResampleConfig, optional
|
||||
Configuration for resampling. If provided (or defaulted), audio will
|
||||
be resampled to the specified `samplerate` using the specified
|
||||
`method`. If set to `None` in the config file, resampling is skipped.
|
||||
Defaults to a ResampleConfig instance with standard settings.
|
||||
scale : bool, default=False
|
||||
If True, scales the audio waveform using peak normalization so that
|
||||
its maximum absolute amplitude is approximately 1.0. If False
|
||||
(default), no amplitude scaling is applied.
|
||||
center : bool, default=True
|
||||
If True (default), centers the waveform by subtracting its mean
|
||||
(DC offset removal). If False, the waveform is not centered.
|
||||
duration : float, optional
|
||||
If set to a float value (seconds), the loaded audio clip will be
|
||||
adjusted (cropped or padded with zeros) to exactly this duration.
|
||||
If None (default), the original duration is kept.
|
||||
"""
|
||||
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
scale: bool = SCALE_RAW_AUDIO
|
||||
center: bool = False
|
||||
duration: Optional[float] = DEFAULT_DURATION
|
||||
|
||||
|
||||
class ConfigurableAudioLoader:
|
||||
"""Concrete implementation of the `AudioLoader` driven by `AudioConfig`.
|
||||
|
||||
This class loads audio and applies preprocessing steps (resampling,
|
||||
scaling, centering, duration adjustment) based on the settings provided
|
||||
in an `AudioConfig` object during initialization. It delegates the actual
|
||||
work to module-level functions.
|
||||
"""
|
||||
class SoundEventAudioLoader:
|
||||
"""Concrete implementation of the `AudioLoader`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AudioConfig,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
):
|
||||
"""Initialize the ConfigurableAudioLoader.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : AudioConfig
|
||||
The configuration object specifying the desired preprocessing steps
|
||||
and parameters.
|
||||
"""
|
||||
self.config = config
|
||||
self.samplerate = samplerate
|
||||
self.config = config or ResampleConfig()
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess audio directly from a file path.
|
||||
|
||||
Implements the `AudioLoader.load_file` method by delegating to the
|
||||
`load_file_audio` function, passing the stored configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform (first channel).
|
||||
"""
|
||||
return load_file_audio(path, config=self.config, audio_dir=audio_dir)
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path."""
|
||||
return load_file_audio(
|
||||
path,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess the entire audio for a Recording object.
|
||||
|
||||
Implements the `AudioLoader.load_recording` method by delegating to the
|
||||
`load_recording_audio` function, passing the stored configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform (first channel).
|
||||
"""
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object."""
|
||||
return load_recording_audio(
|
||||
recording, config=self.config, audio_dir=audio_dir
|
||||
recording,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object.
|
||||
|
||||
Implements the `AudioLoader.load_clip` method by delegating to the
|
||||
`load_clip_audio` function, passing the stored configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object specifying the segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform segment (first channel).
|
||||
"""
|
||||
return load_clip_audio(clip, config=self.config, audio_dir=audio_dir)
|
||||
|
||||
|
||||
def build_audio_loader(
|
||||
config: AudioConfig,
|
||||
) -> AudioLoader:
|
||||
"""Factory function to create an AudioLoader based on configuration.
|
||||
|
||||
Instantiates and returns a `ConfigurableAudioLoader` initialized with
|
||||
the provided `AudioConfig`. The return type is `AudioLoader`, adhering
|
||||
to the protocol.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : AudioConfig
|
||||
The configuration object specifying preprocessing steps.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AudioLoader
|
||||
An instance of `ConfigurableAudioLoader` ready to load and process audio
|
||||
according to the configuration.
|
||||
"""
|
||||
return ConfigurableAudioLoader(config=config)
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
config: Optional[AudioConfig] = None,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess audio from a file path using specified config.
|
||||
|
||||
Creates a `soundevent.data.Recording` object from the file path and then
|
||||
delegates the loading and processing to `load_recording_audio`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
config : AudioConfig, optional
|
||||
Audio processing configuration. If None, default settings defined
|
||||
in `AudioConfig` are used.
|
||||
audio_dir : PathLike, optional
|
||||
Directory prefix if `path` is relative.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target NumPy data type for the loaded audio array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform (first channel only).
|
||||
"""
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio from a file path using specified config."""
|
||||
try:
|
||||
recording = data.Recording.from_file(path)
|
||||
except LibsndfileError as e:
|
||||
@ -271,6 +126,7 @@ def load_file_audio(
|
||||
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
@ -279,33 +135,12 @@ def load_file_audio(
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
config: Optional[AudioConfig] = None,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess the entire audio content of a recording using config.
|
||||
|
||||
Creates a `soundevent.data.Clip` spanning the full duration of the
|
||||
recording and then delegates the loading and processing to
|
||||
`load_clip_audio`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object containing metadata and file path.
|
||||
config : AudioConfig, optional
|
||||
Audio processing configuration. If None, default settings are used.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file, used if the path in `recording`
|
||||
is relative.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target NumPy data type for the loaded audio array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform (first channel only).
|
||||
"""
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio content of a recording using config."""
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
@ -313,6 +148,7 @@ def load_recording_audio(
|
||||
)
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
@ -321,56 +157,12 @@ def load_recording_audio(
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
config: Optional[AudioConfig] = None,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess a specific audio clip segment based on config.
|
||||
|
||||
This is the core function performing the configured processing pipeline:
|
||||
1. Loads the specified clip segment using `soundevent.audio.load_clip`.
|
||||
2. Selects the first audio channel.
|
||||
3. Resamples if `config.resample` is configured.
|
||||
4. Centers (DC offset removal) if `config.center` is True.
|
||||
5. Scales (peak normalization) if `config.scale` is True.
|
||||
6. Adjusts duration (crop/pad) if `config.duration` is set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the audio segment and source recording.
|
||||
config : AudioConfig, optional
|
||||
Audio processing configuration. If None, a default `AudioConfig` is
|
||||
used.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the source audio file specified in the clip's
|
||||
recording.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target NumPy data type for the processed audio array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed waveform segment as an xarray DataArray
|
||||
with time coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the underlying audio file cannot be found.
|
||||
Exception
|
||||
If audio loading or processing fails for other reasons (e.g., invalid
|
||||
format, resampling error).
|
||||
|
||||
Notes
|
||||
-----
|
||||
- **Mono Processing:** This function currently loads and processes only the
|
||||
**first channel** (channel 0) of the audio file. Any other channels
|
||||
are ignored.
|
||||
"""
|
||||
config = config or AudioConfig()
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess a specific audio clip segment based on config."""
|
||||
try:
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir)
|
||||
@ -383,195 +175,48 @@ def load_clip_audio(
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
if config.resample:
|
||||
wav = resample_audio(
|
||||
wav,
|
||||
samplerate=config.resample.samplerate,
|
||||
dtype=dtype,
|
||||
)
|
||||
if not config or not config.enabled or samplerate is None:
|
||||
return wav.data.astype(dtype)
|
||||
|
||||
if config.center:
|
||||
wav = ops.center(wav)
|
||||
|
||||
if config.scale:
|
||||
wav = scale_audio(wav)
|
||||
|
||||
if config.duration is not None:
|
||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||
|
||||
return wav.astype(dtype)
|
||||
|
||||
|
||||
def scale_audio(
|
||||
wave: xr.DataArray,
|
||||
) -> xr.DataArray:
|
||||
"""
|
||||
Scale the audio waveform to have a maximum absolute value of 1.0.
|
||||
|
||||
This function normalizes the waveform by dividing it by its maximum
|
||||
absolute value. If the maximum value is zero, the waveform is returned
|
||||
unchanged. Also known as peak normalization, this process ensures that the
|
||||
waveform's amplitude is within a standard range, which can be useful for
|
||||
audio processing and analysis.
|
||||
|
||||
"""
|
||||
max_val = np.max(np.abs(wave))
|
||||
|
||||
if max_val == 0:
|
||||
return wave
|
||||
|
||||
return ops.scale(wave, 1 / max_val)
|
||||
|
||||
|
||||
def adjust_audio_duration(
|
||||
wave: xr.DataArray,
|
||||
duration: float,
|
||||
) -> xr.DataArray:
|
||||
"""Adjust the duration of an audio waveform array via cropping or padding.
|
||||
|
||||
If the current duration is longer than the target, it crops the array
|
||||
from the beginning. If shorter, it pads the array with zeros at the end
|
||||
using `soundevent.arrays.extend_dim`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wave : xr.DataArray
|
||||
The input audio waveform with a 'time' dimension and coordinates.
|
||||
duration : float
|
||||
The target duration in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The waveform adjusted to the target duration. Returns the input
|
||||
unmodified if duration already matches or if the wave is empty.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `duration` is negative.
|
||||
"""
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
current_duration = end_time - start_time + step
|
||||
|
||||
if current_duration == duration:
|
||||
return wave
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
if current_duration > duration:
|
||||
return arrays.crop_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration - step / 2,
|
||||
right_closed=True,
|
||||
)
|
||||
|
||||
return arrays.extend_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration - step / 2,
|
||||
eps=0,
|
||||
right_closed=True,
|
||||
sr = int(1 / wav.time.attrs["step"])
|
||||
return resample_audio(
|
||||
wav.data,
|
||||
sr=sr,
|
||||
samplerate=samplerate,
|
||||
method=config.method,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: xr.DataArray,
|
||||
wav: np.ndarray,
|
||||
sr: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
method: str = "poly",
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Resample an audio waveform DataArray to a target sample rate.
|
||||
|
||||
Updates the 'time' coordinate axis according to the new sample rate and
|
||||
number of samples. Uses either polyphase (`scipy.signal.resample_poly`)
|
||||
or Fourier method (`scipy.signal.resample`) based on the `method`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : xr.DataArray
|
||||
Input audio waveform with 'time' dimension and coordinates.
|
||||
samplerate : int, default=TARGET_SAMPLERATE_HZ
|
||||
Target sample rate in Hz.
|
||||
method : str, default="poly"
|
||||
Resampling algorithm: "poly" or "fourier".
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the resampled array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Resampled waveform with updated time coordinates. Returns the input
|
||||
unmodified (but dtype cast) if the sample rate is already correct or
|
||||
if the input array is empty.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `wav` lacks a 'time' dimension, the original sample rate cannot
|
||||
be determined, `samplerate` is non-positive, or `method` is invalid.
|
||||
"""
|
||||
if "time" not in wav.dims:
|
||||
raise ValueError("Audio must have a time dimension")
|
||||
|
||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||
step = arrays.get_dim_step(wav, dim="time")
|
||||
original_samplerate = int(1 / step)
|
||||
|
||||
if original_samplerate == samplerate:
|
||||
return wav.astype(dtype).assign_attrs(original_samplerate=samplerate)
|
||||
) -> np.ndarray:
|
||||
"""Resample an audio waveform DataArray to a target sample rate."""
|
||||
if sr == samplerate:
|
||||
return wav
|
||||
|
||||
if method == "poly":
|
||||
resampled = resample_audio_poly(
|
||||
return resample_audio_poly(
|
||||
wav,
|
||||
sr_orig=original_samplerate,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
axis=time_axis,
|
||||
)
|
||||
elif method == "fourier":
|
||||
resampled = resample_audio_fourier(
|
||||
return resample_audio_fourier(
|
||||
wav,
|
||||
sr_orig=original_samplerate,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
axis=time_axis,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Resampling method '{method}' not implemented"
|
||||
)
|
||||
|
||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||
times = np.linspace(
|
||||
start,
|
||||
stop + step,
|
||||
len(resampled),
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=resampled.astype(dtype),
|
||||
dims=wav.dims,
|
||||
coords={
|
||||
**wav.coords,
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
times,
|
||||
samplerate=samplerate,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
**wav.attrs,
|
||||
"samplerate": samplerate,
|
||||
"original_samplerate": original_samplerate,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_poly(
|
||||
array: xr.DataArray,
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
@ -605,7 +250,7 @@ def resample_audio_poly(
|
||||
"""
|
||||
gcd = np.gcd(sr_orig, sr_new)
|
||||
return resample_poly(
|
||||
array.values,
|
||||
array,
|
||||
sr_new // gcd,
|
||||
sr_orig // gcd,
|
||||
axis=axis,
|
||||
@ -613,7 +258,7 @@ def resample_audio_poly(
|
||||
|
||||
|
||||
def resample_audio_fourier(
|
||||
array: xr.DataArray,
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
@ -649,66 +294,89 @@ def resample_audio_fourier(
|
||||
)
|
||||
|
||||
|
||||
def convert_to_xr(
|
||||
wav: np.ndarray,
|
||||
class CenterAudioConfig(BaseConfig):
|
||||
name: Literal["center_audio"] = "center_audio"
|
||||
|
||||
|
||||
class ScaleAudioConfig(BaseConfig):
|
||||
name: Literal["scale_audio"] = "scale_audio"
|
||||
|
||||
|
||||
class FixDurationConfig(BaseConfig):
|
||||
name: Literal["fix_duration"] = "fix_duration"
|
||||
duration: float = 0.5
|
||||
|
||||
|
||||
class FixDuration(torch.nn.Module):
|
||||
def __init__(self, samplerate: int, duration: float):
|
||||
super().__init__()
|
||||
self.samplerate = samplerate
|
||||
self.duration = duration
|
||||
self.length = int(samplerate * duration)
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
length = wav.shape[-1]
|
||||
|
||||
if length == self.length:
|
||||
return wav
|
||||
|
||||
if length > self.length:
|
||||
return wav[: self.length]
|
||||
|
||||
return torch.nn.functional.pad(wav, (0, self.length - length))
|
||||
|
||||
|
||||
AudioTransform = Annotated[
|
||||
Union[
|
||||
FixDurationConfig,
|
||||
ScaleAudioConfig,
|
||||
CenterAudioConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
"""Configuration for loading and initial audio preprocessing."""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
transforms: List[AudioTransform] = Field(default_factory=list)
|
||||
|
||||
|
||||
def build_audio_loader(
|
||||
config: Optional[AudioConfig] = None,
|
||||
) -> AudioLoader:
|
||||
"""Factory function to create an AudioLoader based on configuration."""
|
||||
config = config or AudioConfig()
|
||||
return SoundEventAudioLoader(
|
||||
samplerate=config.samplerate,
|
||||
config=config.resample,
|
||||
)
|
||||
|
||||
|
||||
def build_audio_transform_step(
|
||||
config: AudioTransform,
|
||||
samplerate: int,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Convert a NumPy array to an xarray DataArray with time coordinates.
|
||||
) -> torch.nn.Module:
|
||||
if config.name == "fix_duration":
|
||||
return FixDuration(samplerate=samplerate, duration=config.duration)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : np.ndarray
|
||||
The input waveform array. Expected to be 1D or 2D (with the first
|
||||
axis as the channel dimension).
|
||||
samplerate : int
|
||||
The sample rate in Hz.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the xarray DataArray.
|
||||
if config.name == "scale_audio":
|
||||
return PeakNormalize()
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The waveform as an xarray DataArray with time coordinates.
|
||||
if config.name == "center_audio":
|
||||
return CenterTensor()
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input array is not 1D or 2D, or if the sample rate is
|
||||
non-positive. If the input array is empty.
|
||||
"""
|
||||
|
||||
if wav.ndim == 2:
|
||||
wav = wav[0, :]
|
||||
|
||||
if wav.ndim != 1:
|
||||
raise ValueError(
|
||||
"Audio must be 1D array or 2D channel where the first "
|
||||
"axis is the channel dimension"
|
||||
raise NotImplementedError(
|
||||
f"Audio preprocessing step {config.name} not implemented"
|
||||
)
|
||||
|
||||
if wav.size == 0:
|
||||
raise ValueError("Audio array is empty")
|
||||
|
||||
if samplerate <= 0:
|
||||
raise ValueError("Sample rate must be positive")
|
||||
|
||||
times = np.linspace(
|
||||
0,
|
||||
wav.shape[0] / samplerate,
|
||||
wav.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=wav.astype(dtype),
|
||||
dims=["time"],
|
||||
coords={
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
times,
|
||||
samplerate=samplerate,
|
||||
),
|
||||
},
|
||||
attrs={"samplerate": samplerate},
|
||||
def build_audio_pipeline(config: AudioConfig) -> torch.nn.Module:
|
||||
return torch.nn.Sequential(
|
||||
*[
|
||||
build_audio_transform_step(step, samplerate=config.samplerate)
|
||||
for step in config.transforms
|
||||
]
|
||||
)
|
||||
|
||||
24
src/batdetect2/preprocess/common.py
Normal file
24
src/batdetect2/preprocess/common.py
Normal file
@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"CenterTensor",
|
||||
"PeakNormalize",
|
||||
]
|
||||
|
||||
|
||||
class CenterTensor(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor):
|
||||
return wav - wav.mean()
|
||||
|
||||
|
||||
class PeakNormalize(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor):
|
||||
max_value = wav.abs().min()
|
||||
|
||||
denominator = torch.where(
|
||||
max_value == 0,
|
||||
torch.tensor(1.0, device=wav.device, dtype=wav.dtype),
|
||||
max_value,
|
||||
)
|
||||
|
||||
return wav / denominator
|
||||
@ -1,48 +1,22 @@
|
||||
"""Computes spectrograms from audio waveforms with configurable parameters.
|
||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||
|
||||
This module provides the functionality to convert preprocessed audio waveforms
|
||||
(typically output from the `batdetect2.preprocessing.audio` module) into
|
||||
spectrogram representations suitable for input into deep learning models like
|
||||
BatDetect2.
|
||||
|
||||
It offers a configurable pipeline including:
|
||||
1. Short-Time Fourier Transform (STFT) calculation to get magnitude.
|
||||
2. Frequency axis cropping to a relevant range.
|
||||
3. Per-Channel Energy Normalization (PCEN) (optional).
|
||||
4. Amplitude scaling/representation (dB, power, or linear amplitude).
|
||||
5. Simple spectral mean subtraction denoising (optional).
|
||||
6. Resizing to target dimensions (optional).
|
||||
7. Final peak normalization (optional).
|
||||
|
||||
Configuration is managed via the `SpectrogramConfig` class, allowing for
|
||||
reproducible spectrogram generation consistent between training and inference.
|
||||
The core computation is performed by `compute_spectrogram`.
|
||||
"""
|
||||
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
from typing import Annotated, Callable, List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
import torch
|
||||
import torchaudio
|
||||
from pydantic import Field
|
||||
from scipy import signal
|
||||
from soundevent import arrays, audio
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.audio import convert_to_xr
|
||||
from batdetect2.preprocess.common import PeakNormalize
|
||||
from batdetect2.typing.preprocess import SpectrogramBuilder
|
||||
|
||||
__all__ = [
|
||||
"STFTConfig",
|
||||
"FrequencyConfig",
|
||||
"SpecSizeConfig",
|
||||
"PcenConfig",
|
||||
"SpectrogramConfig",
|
||||
"ConfigurableSpectrogramBuilder",
|
||||
"build_spectrogram_builder",
|
||||
"compute_spectrogram",
|
||||
"get_spectrogram_resolution",
|
||||
"MIN_FREQ",
|
||||
"MAX_FREQ",
|
||||
]
|
||||
@ -79,6 +53,47 @@ class STFTConfig(BaseConfig):
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||
if name == "hann":
|
||||
return torch.hann_window
|
||||
|
||||
if name == "hamming":
|
||||
return torch.hamming_window
|
||||
|
||||
if name == "kaiser":
|
||||
return torch.kaiser_window
|
||||
|
||||
if name == "blackman":
|
||||
return torch.blackman_window
|
||||
|
||||
if name == "bartlett":
|
||||
return torch.bartlett_window
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Spectrogram window function {name} not implemented"
|
||||
)
|
||||
|
||||
|
||||
def _spec_params_from_config(samplerate: int, conf: STFTConfig):
|
||||
n_fft = int(samplerate * conf.window_duration)
|
||||
hop_length = int(n_fft * (1 - conf.window_overlap))
|
||||
return n_fft, hop_length
|
||||
|
||||
|
||||
def build_spectrogram_builder(
|
||||
samplerate: int,
|
||||
conf: STFTConfig,
|
||||
) -> SpectrogramBuilder:
|
||||
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
||||
return torchaudio.transforms.Spectrogram(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
window_fn=get_spectrogram_window(conf.window_fn),
|
||||
center=False,
|
||||
power=1,
|
||||
)
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
"""Configuration for frequency axis parameters.
|
||||
|
||||
@ -96,644 +111,282 @@ class FrequencyConfig(BaseConfig):
|
||||
min_freq: int = Field(default=10_000, ge=0)
|
||||
|
||||
|
||||
class SpecSizeConfig(BaseConfig):
|
||||
"""Configuration for the final size and shape of the spectrogram.
|
||||
def _frequency_to_index(
|
||||
freq: float,
|
||||
samplerate: int,
|
||||
n_fft: int,
|
||||
) -> Optional[int]:
|
||||
alpha = freq * 2 / samplerate
|
||||
height = np.floor(n_fft / 2) + 1
|
||||
index = int(np.floor(alpha * height))
|
||||
|
||||
Attributes
|
||||
----------
|
||||
height : int, default=128
|
||||
Target height of the spectrogram in pixels (frequency bins). The
|
||||
frequency axis will be resized (e.g., via interpolation) to match this
|
||||
height after frequency cropping and amplitude scaling. Must be > 0.
|
||||
resize_factor : float, optional
|
||||
Factor by which to resize the spectrogram along the time axis *after*
|
||||
STFT calculation. A value of 0.5 halves the number of time bins,
|
||||
2.0 doubles it. If None (default), no resizing along the time axis is
|
||||
performed relative to the STFT output width. Must be > 0 if provided.
|
||||
"""
|
||||
if index <= 0:
|
||||
return None
|
||||
|
||||
height: int = 128
|
||||
resize_factor: Optional[float] = 0.5
|
||||
if index >= height:
|
||||
return None
|
||||
|
||||
return index
|
||||
|
||||
|
||||
class FrequencyClip(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
low_index: Optional[int] = None,
|
||||
high_index: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.low_index = low_index
|
||||
self.high_index = high_index
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return spec[self.low_index : self.high_index]
|
||||
|
||||
|
||||
class PcenConfig(BaseConfig):
|
||||
"""Configuration for Per-Channel Energy Normalization (PCEN).
|
||||
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
||||
|
||||
PCEN is an adaptive gain control method that can help emphasize transients
|
||||
and suppress stationary noise. Applied after STFT and frequency cropping,
|
||||
but before final amplitude scaling (dB, power, amplitude).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
time_constant : float, default=0.4
|
||||
Time constant (in seconds) for the PCEN smoothing filter. Controls
|
||||
how quickly the normalization adapts to energy changes.
|
||||
gain : float, default=0.98
|
||||
Gain factor (alpha). Controls the adaptive gain component.
|
||||
bias : float, default=2.0
|
||||
Bias factor (delta). Added before the exponentiation.
|
||||
power : float, default=0.5
|
||||
Exponent (r). Controls the compression characteristic.
|
||||
"""
|
||||
|
||||
time_constant: float = 0.01
|
||||
name: Literal["pcen"] = "pcen"
|
||||
time_constant: float = 0.4
|
||||
gain: float = 0.98
|
||||
bias: float = 2
|
||||
power: float = 0.5
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
"""Unified configuration for spectrogram generation pipeline.
|
||||
|
||||
Aggregates settings for all steps involved in converting a preprocessed
|
||||
audio waveform into a final spectrogram representation suitable for model
|
||||
input.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
stft : STFTConfig
|
||||
Configuration for the initial Short-Time Fourier Transform.
|
||||
Defaults to standard settings via `STFTConfig`.
|
||||
frequencies : FrequencyConfig
|
||||
Configuration for cropping the frequency range after STFT.
|
||||
Defaults to standard settings via `FrequencyConfig`.
|
||||
pcen : PcenConfig, optional
|
||||
Configuration for applying Per-Channel Energy Normalization (PCEN). If
|
||||
provided, PCEN is applied after frequency cropping. If None or omitted
|
||||
(default), PCEN is skipped.
|
||||
scale : Literal["dB", "amplitude", "power"], default="amplitude"
|
||||
Determines the final amplitude representation *after* optional PCEN.
|
||||
- "amplitude": Use linear magnitude values (output of STFT or PCEN).
|
||||
- "power": Use power values (magnitude squared).
|
||||
- "dB": Use logarithmic (decibel-like) scaling applied to the magnitude
|
||||
(or PCEN output if enabled). Calculated as `log1p(C * S)`.
|
||||
size : SpecSizeConfig, optional
|
||||
Configuration for resizing the spectrogram dimensions
|
||||
(frequency height, optional time width factor). Applied after PCEN and
|
||||
scaling. If None (default), no resizing is performed.
|
||||
spectral_mean_substraction : bool, default=True
|
||||
If True (default), applies simple spectral mean subtraction denoising
|
||||
*after* PCEN and amplitude scaling, but *before* resizing.
|
||||
peak_normalize : bool, default=False
|
||||
If True, applies a final peak normalization to the spectrogram *after*
|
||||
all other steps (including resizing), scaling the overall maximum value
|
||||
to 1.0. If False (default), this final normalization is skipped.
|
||||
"""
|
||||
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
pcen: Optional[PcenConfig] = Field(default_factory=PcenConfig)
|
||||
scale: Literal["dB", "amplitude", "power"] = "amplitude"
|
||||
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||
spectral_mean_substraction: bool = True
|
||||
peak_normalize: bool = False
|
||||
|
||||
|
||||
class ConfigurableSpectrogramBuilder(SpectrogramBuilder):
|
||||
"""Implementation of `SpectrogramBuilder` driven by `SpectrogramConfig`.
|
||||
|
||||
This class computes spectrograms according to the parameters specified in a
|
||||
`SpectrogramConfig` object provided during initialization. It handles both
|
||||
numpy array and xarray DataArray inputs for the waveform.
|
||||
"""
|
||||
|
||||
class PCEN(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SpectrogramConfig,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> None:
|
||||
"""Initialize the ConfigurableSpectrogramBuilder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : SpectrogramConfig
|
||||
The configuration object specifying all spectrogram parameters.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
The target NumPy data type for the computed spectrogram array.
|
||||
"""
|
||||
self.config = config
|
||||
smoothing_constant: float,
|
||||
gain: float = 0.98,
|
||||
bias: float = 2.0,
|
||||
power: float = 0.5,
|
||||
eps: float = 1e-6,
|
||||
dtype=torch.float64,
|
||||
):
|
||||
super().__init__()
|
||||
self.smoothing_constant = smoothing_constant
|
||||
self.gain = torch.tensor(gain, dtype=dtype)
|
||||
self.bias = torch.tensor(bias, dtype=dtype)
|
||||
self.power = torch.tensor(power, dtype=dtype)
|
||||
self.eps = torch.tensor(eps, dtype=dtype)
|
||||
self.dtype = dtype
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
wav: Union[np.ndarray, xr.DataArray],
|
||||
samplerate: Optional[int] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Generate a spectrogram from an audio waveform using the config.
|
||||
|
||||
Implements the `SpectrogramBuilder` protocol. If the input `wav` is
|
||||
a numpy array, `samplerate` must be provided; the array will be
|
||||
converted to an xarray DataArray internally. If `wav` is already an
|
||||
xarray DataArray with time coordinates, `samplerate` is ignored.
|
||||
Delegates the main computation to `compute_spectrogram`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : Union[np.ndarray, xr.DataArray]
|
||||
The input audio waveform.
|
||||
samplerate : int, optional
|
||||
The sample rate in Hz (required only if `wav` is np.ndarray).
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `wav` is np.ndarray and `samplerate` is None.
|
||||
"""
|
||||
if isinstance(wav, np.ndarray):
|
||||
if samplerate is None:
|
||||
raise ValueError(
|
||||
"Samplerate must be provided when passing a numpy array."
|
||||
)
|
||||
wav = convert_to_xr(
|
||||
wav,
|
||||
samplerate=samplerate,
|
||||
dtype=self.dtype,
|
||||
self._b = torch.tensor([self.smoothing_constant, 0.0], dtype=dtype)
|
||||
self._a = torch.tensor(
|
||||
[1.0, self.smoothing_constant - 1.0], dtype=dtype
|
||||
)
|
||||
|
||||
return compute_spectrogram(
|
||||
wav,
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
S = spec.to(self.dtype) * 2**31
|
||||
|
||||
M = (
|
||||
torchaudio.functional.lfilter(
|
||||
S,
|
||||
self._a,
|
||||
self._b,
|
||||
clamp=False,
|
||||
)
|
||||
).clamp(min=0)
|
||||
|
||||
smooth = torch.exp(
|
||||
-self.gain * (torch.log(self.eps) + torch.log1p(M / self.eps))
|
||||
)
|
||||
|
||||
|
||||
def build_spectrogram_builder(
|
||||
config: SpectrogramConfig,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> SpectrogramBuilder:
|
||||
"""Factory function to create a SpectrogramBuilder based on configuration.
|
||||
|
||||
Instantiates and returns a `ConfigurableSpectrogramBuilder` initialized
|
||||
with the provided `SpectrogramConfig`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : SpectrogramConfig
|
||||
The configuration object specifying spectrogram parameters.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
The target NumPy data type for the computed spectrogram array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SpectrogramBuilder
|
||||
An instance of `ConfigurableSpectrogramBuilder` ready to compute
|
||||
spectrograms according to the configuration.
|
||||
"""
|
||||
return ConfigurableSpectrogramBuilder(config=config, dtype=dtype)
|
||||
return (
|
||||
(self.bias**self.power)
|
||||
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
|
||||
).to(spec.dtype)
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
wav: xr.DataArray,
|
||||
config: Optional[SpectrogramConfig] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Compute a spectrogram from a waveform using specified configurations.
|
||||
|
||||
Applies a sequence of operations based on the `config`:
|
||||
1. Compute STFT magnitude (`stft`).
|
||||
2. Crop frequency axis (`crop_spectrogram_frequencies`).
|
||||
3. Apply PCEN if configured (`apply_pcen`).
|
||||
4. Apply final amplitude scaling (dB, power, amplitude)
|
||||
(`scale_spectrogram`).
|
||||
5. Apply spectral mean subtraction denoising if enabled.
|
||||
6. Resize dimensions if specified (`resize_spectrogram`).
|
||||
7. Apply final peak normalization if enabled.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : xr.DataArray
|
||||
Input audio waveform with a 'time' dimension and coordinates from
|
||||
which the sample rate can be inferred.
|
||||
config : SpectrogramConfig, optional
|
||||
Configuration object specifying spectrogram parameters. If None,
|
||||
default settings from `SpectrogramConfig` are used.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target NumPy data type for the final spectrogram array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed and processed spectrogram with 'time' and 'frequency'
|
||||
coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `wav` lacks necessary 'time' coordinates or dimensions.
|
||||
"""
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
spec = stft(
|
||||
wav,
|
||||
window_duration=config.stft.window_duration,
|
||||
window_overlap=config.stft.window_overlap,
|
||||
window_fn=config.stft.window_fn,
|
||||
)
|
||||
|
||||
spec = crop_spectrogram_frequencies(
|
||||
spec,
|
||||
min_freq=config.frequencies.min_freq,
|
||||
max_freq=config.frequencies.max_freq,
|
||||
)
|
||||
|
||||
if config.pcen:
|
||||
spec = apply_pcen(
|
||||
spec,
|
||||
time_constant=config.pcen.time_constant,
|
||||
gain=config.pcen.gain,
|
||||
power=config.pcen.power,
|
||||
bias=config.pcen.bias,
|
||||
)
|
||||
|
||||
spec = scale_spectrogram(spec, scale=config.scale)
|
||||
|
||||
if config.spectral_mean_substraction:
|
||||
spec = remove_spectral_mean(spec)
|
||||
|
||||
if config.size:
|
||||
spec = resize_spectrogram(
|
||||
spec,
|
||||
height=config.size.height,
|
||||
resize_factor=config.size.resize_factor,
|
||||
)
|
||||
|
||||
if config.peak_normalize:
|
||||
spec = ops.normalize(spec)
|
||||
|
||||
return spec.astype(dtype)
|
||||
def _compute_smoothing_constant(
|
||||
samplerate: int,
|
||||
time_constant: float,
|
||||
) -> float:
|
||||
# NOTE: These were taken to match the original implementation
|
||||
hop_length = 512
|
||||
sr = samplerate / 10
|
||||
time_constant = time_constant
|
||||
t_frames = time_constant * sr / float(hop_length)
|
||||
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
|
||||
|
||||
def crop_spectrogram_frequencies(
|
||||
spec: xr.DataArray,
|
||||
min_freq: int = 10_000,
|
||||
max_freq: int = 120_000,
|
||||
) -> xr.DataArray:
|
||||
"""Crop the frequency axis of a spectrogram to a specified range.
|
||||
|
||||
Uses `soundevent.arrays.crop_dim` to select the frequency bins
|
||||
corresponding to the range [`min_freq`, `max_freq`].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input spectrogram with 'frequency' dimension and coordinates.
|
||||
min_freq : int, default=MIN_FREQ
|
||||
Minimum frequency (Hz) to keep.
|
||||
max_freq : int, default=MAX_FREQ
|
||||
Maximum frequency (Hz) to keep.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Spectrogram cropped along the frequency axis. Preserves dtype.
|
||||
"""
|
||||
start_freq, end_freq = arrays.get_dim_range(spec, dim="frequency")
|
||||
|
||||
return arrays.crop_dim(
|
||||
spec,
|
||||
dim="frequency",
|
||||
start=min_freq if start_freq < min_freq else None,
|
||||
stop=max_freq if end_freq > max_freq else None,
|
||||
).astype(spec.dtype)
|
||||
class ScaleAmplitudeConfig(BaseConfig):
|
||||
name: Literal["scale_amplitude"] = "scale_amplitude"
|
||||
scale: Literal["power", "db"] = "db"
|
||||
|
||||
|
||||
def stft(
|
||||
wave: xr.DataArray,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
window_fn: str = "hann",
|
||||
) -> xr.DataArray:
|
||||
"""Compute the Short-Time Fourier Transform (STFT) magnitude spectrogram.
|
||||
|
||||
Calculates STFT parameters (N-FFT, hop length) based on the window
|
||||
duration, overlap, and waveform sample rate. Returns an xarray DataArray
|
||||
with correctly calculated 'time' and 'frequency' coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wave : xr.DataArray
|
||||
Input audio waveform with 'time' coordinates.
|
||||
window_duration : float
|
||||
Duration of the STFT window in seconds.
|
||||
window_overlap : float
|
||||
Fractional overlap between consecutive windows.
|
||||
window_fn : str, default="hann"
|
||||
Name of the window function (e.g., "hann", "hamming").
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Magnitude spectrogram with 'time' and 'frequency' dimensions and
|
||||
coordinates. STFT parameters are stored in the `attrs`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sample rate cannot be determined from `wave` coordinates.
|
||||
"""
|
||||
if "channel" not in wave.coords:
|
||||
wave = wave.assign_coords(channel=0)
|
||||
|
||||
return audio.compute_spectrogram(
|
||||
wave,
|
||||
window_size=window_duration,
|
||||
hop_size=(1 - window_overlap) * window_duration,
|
||||
window_type=window_fn,
|
||||
scale="amplitude",
|
||||
sort_dims=False,
|
||||
)
|
||||
|
||||
|
||||
def remove_spectral_mean(spec: xr.DataArray) -> xr.DataArray:
|
||||
"""Apply simple spectral mean subtraction for denoising.
|
||||
|
||||
Subtracts the mean value of each frequency bin (calculated across time)
|
||||
from that bin, then clips negative values to zero.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input spectrogram with 'time' and 'frequency' dimensions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Denoised spectrogram with the same dimensions, coordinates, and dtype.
|
||||
"""
|
||||
return xr.DataArray(
|
||||
data=(spec - spec.mean("time")).clip(0),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
scale: Literal["dB", "power", "amplitude"],
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Apply final amplitude scaling/representation to the spectrogram.
|
||||
|
||||
Converts the input magnitude spectrogram based on the `scale` type:
|
||||
- "dB": Applies logarithmic scaling `log10(S)`.
|
||||
- "power": Squares the magnitude values `S^2`.
|
||||
- "amplitude": Returns the input magnitude values `S` unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input magnitude spectrogram (potentially after PCEN).
|
||||
scale : Literal["dB", "power", "amplitude"]
|
||||
The target amplitude representation.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the output scaled spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Spectrogram with the specified amplitude scaling applied.
|
||||
"""
|
||||
if scale == "dB":
|
||||
return arrays.to_db(spec).astype(dtype)
|
||||
|
||||
if scale == "power":
|
||||
class ToPower(torch.nn.Module):
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return spec**2
|
||||
|
||||
return spec
|
||||
|
||||
def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
|
||||
if conf.scale == "db":
|
||||
return torchaudio.transforms.AmplitudeToDB()
|
||||
|
||||
def apply_pcen(
|
||||
spec: xr.DataArray,
|
||||
time_constant: float = 0.4,
|
||||
gain: float = 0.98,
|
||||
bias: float = 2,
|
||||
eps: float = 1e-6,
|
||||
power: float = 0.5,
|
||||
) -> xr.DataArray:
|
||||
"""Apply Per-Channel Energy Normalization (PCEN) to a spectrogram.
|
||||
if conf.scale == "power":
|
||||
return ToPower()
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input magnitude spectrogram with required attributes like
|
||||
'processing_original_samplerate'.
|
||||
time_constant : float, default=0.4
|
||||
PCEN time constant in seconds.
|
||||
gain : float, default=0.98
|
||||
Gain factor (alpha).
|
||||
bias : float, default=2.0
|
||||
Bias factor (delta).
|
||||
power : float, default=0.5
|
||||
Exponent (r).
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the output spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
PCEN-scaled spectrogram.
|
||||
"""
|
||||
samplerate = 1 / spec.time.attrs["step"]
|
||||
hop_size = spec.attrs["hop_size"]
|
||||
|
||||
hop_length = int(hop_size * samplerate)
|
||||
|
||||
t_frames = time_constant * samplerate / hop_length
|
||||
|
||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
|
||||
axis = spec.get_axis_num("time")
|
||||
|
||||
shape = tuple([1] * spec.ndim)
|
||||
zi = np.empty(shape)
|
||||
zi[:] = signal.lfilter_zi(
|
||||
[smoothing_constant],
|
||||
[1, smoothing_constant - 1],
|
||||
)[:]
|
||||
|
||||
spec_data = spec.data * (2**31)
|
||||
|
||||
# Smooth the input array along the given axis
|
||||
smoothed, _ = signal.lfilter(
|
||||
[smoothing_constant],
|
||||
[1, smoothing_constant - 1],
|
||||
spec_data,
|
||||
zi=zi,
|
||||
axis=axis, # type: ignore
|
||||
)
|
||||
|
||||
smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps)))
|
||||
data = (bias**power) * np.expm1(
|
||||
power * np.log1p(spec_data * smooth / bias)
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data.astype(spec.dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
raise NotImplementedError(
|
||||
f"Amplitude scaling {conf.scale} not implemented"
|
||||
)
|
||||
|
||||
|
||||
def scale_log(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
ref: Union[float, Callable] = np.max,
|
||||
amin: float = 1e-10,
|
||||
top_db: Optional[float] = 80.0,
|
||||
) -> xr.DataArray:
|
||||
"""Apply logarithmic scaling to a magnitude spectrogram.
|
||||
|
||||
Calculates `log10(S)`, where S is the input magnitude spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input magnitude spectrogram with required attributes like
|
||||
'processing_original_samplerate', 'processing_nfft'.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the output spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Log-scaled spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If required attributes are missing from `spec.attrs`.
|
||||
ValueError
|
||||
If attributes are non-numeric or window function is invalid.
|
||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
Implementation mainly taken from librosa `power_to_db` function
|
||||
"""
|
||||
class SpectralMeanSubstraction(torch.nn.Module):
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
mean = spec.mean(-1, keepdim=True)
|
||||
return (spec - mean).clamp(min=0)
|
||||
|
||||
if callable(ref):
|
||||
ref_value = ref(spec)
|
||||
else:
|
||||
ref_value = np.abs(ref)
|
||||
|
||||
log_spec = 10.0 * np.log10(np.maximum(amin, spec)) - np.log10(
|
||||
np.maximum(amin, ref_value)
|
||||
)
|
||||
class ResizeConfig(BaseConfig):
|
||||
name: Literal["resize_spec"] = "resize_spec"
|
||||
height: int = 128
|
||||
resize_factor: float = 0.5
|
||||
|
||||
if top_db is not None:
|
||||
if top_db < 0:
|
||||
raise ValueError("top_db must be non-negative")
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
|
||||
|
||||
return xr.DataArray(
|
||||
data=log_spec.astype(dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
class ResizeSpec(torch.nn.Module):
|
||||
def __init__(self, height: int, time_factor: float):
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.time_factor = time_factor
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
current_length = spec.shape[-1]
|
||||
target_length = int(self.time_factor * current_length)
|
||||
return torch.nn.functional.interpolate(
|
||||
spec.unsqueeze(0).unsqueeze(0),
|
||||
size=(self.height, target_length),
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
|
||||
def resize_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
height: int = 128,
|
||||
resize_factor: Optional[float] = 0.5,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Resize a spectrogram to target dimensions using interpolation.
|
||||
class PeakNormalizeConfig(BaseConfig):
|
||||
name: Literal["peak_normalize"] = "peak_normalize"
|
||||
|
||||
Resizes the frequency axis to `height` bins and optionally resizes the
|
||||
time axis by `resize_factor`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input spectrogram with 'time' and 'frequency' dimensions.
|
||||
height : int, default=128
|
||||
Target number of frequency bins (vertical dimension).
|
||||
resize_factor : float, optional
|
||||
Factor to resize the time dimension. If 1.0 or None, time dimension
|
||||
is unchanged. If 0.5, time dimension is halved, etc.
|
||||
SpectrogramTransform = Annotated[
|
||||
Union[
|
||||
PcenConfig,
|
||||
ScaleAmplitudeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
PeakNormalizeConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Resized spectrogram. Coordinates are typically adjusted by the
|
||||
underlying resize operation if implemented in `ops.resize`.
|
||||
The dtype is currently hardcoded to float32 by ops.resize call.
|
||||
"""
|
||||
resize_factor = resize_factor or 1
|
||||
current_width = spec.sizes["time"]
|
||||
|
||||
target_sizes = {
|
||||
"time": int(current_width * resize_factor),
|
||||
"frequency": height,
|
||||
}
|
||||
|
||||
new_coords = {}
|
||||
for dim in ["time", "frequency"]:
|
||||
step = arrays.get_dim_step(spec, dim)
|
||||
start, stop = arrays.get_dim_range(spec, dim)
|
||||
new_coords[dim] = arrays.create_range_dim(
|
||||
name=dim,
|
||||
start=start,
|
||||
stop=stop + step,
|
||||
size=target_sizes[dim],
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return spec.interp(
|
||||
coords=new_coords, method="linear", kwargs=dict(fill_value=0)
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||
transforms: List[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_spectrogram_resolution(
|
||||
config: SpectrogramConfig,
|
||||
) -> tuple[float, float]:
|
||||
"""Calculate the approximate resolution of the final spectrogram.
|
||||
|
||||
Computes the width of each frequency bin (Hz/bin) and the duration
|
||||
of each time bin (seconds/bin) based on the configuration parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : SpectrogramConfig
|
||||
The spectrogram configuration object.
|
||||
samplerate : int, optional
|
||||
The sample rate of the audio *before* STFT. Required if needed to
|
||||
calculate hop duration accurately from STFT config, but the current
|
||||
implementation calculates hop_duration directly from STFT config times.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
A tuple containing:
|
||||
- frequency_resolution (float): Approximate Hz per frequency bin.
|
||||
- time_resolution (float): Approximate seconds per time bin.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If required configuration fields (like `config.size`) are missing
|
||||
or invalid.
|
||||
"""
|
||||
max_freq = config.frequencies.max_freq
|
||||
min_freq = config.frequencies.min_freq
|
||||
|
||||
if config.size is None:
|
||||
raise ValueError("Spectrogram size configuration is required.")
|
||||
|
||||
spec_height = config.size.height
|
||||
resize_factor = config.size.resize_factor or 1
|
||||
freq_bin_width = (max_freq - min_freq) / spec_height
|
||||
hop_duration = config.stft.window_duration * (
|
||||
1 - config.stft.window_overlap
|
||||
def _build_spectrogram_transform_step(
|
||||
step: SpectrogramTransform,
|
||||
samplerate: int,
|
||||
) -> torch.nn.Module:
|
||||
if step.name == "pcen":
|
||||
return PCEN(
|
||||
smoothing_constant=_compute_smoothing_constant(
|
||||
samplerate=samplerate,
|
||||
time_constant=step.time_constant,
|
||||
),
|
||||
gain=step.gain,
|
||||
bias=step.bias,
|
||||
power=step.power,
|
||||
)
|
||||
|
||||
if step.name == "scale_amplitude":
|
||||
return _build_amplitude_scaler(step)
|
||||
|
||||
if step.name == "spectral_mean_substraction":
|
||||
return SpectralMeanSubstraction()
|
||||
|
||||
if step.name == "peak_normalize":
|
||||
return PeakNormalize()
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Spectrogram preprocessing step {step.name} not implemented"
|
||||
)
|
||||
|
||||
|
||||
def build_spectrogram_transform(
|
||||
samplerate: int,
|
||||
conf: SpectrogramConfig,
|
||||
) -> torch.nn.Module:
|
||||
return torch.nn.Sequential(
|
||||
*[
|
||||
_build_spectrogram_transform_step(step, samplerate=samplerate)
|
||||
for step in conf.transforms
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SpectrogramPipeline(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec_builder: SpectrogramBuilder,
|
||||
freq_cutter: torch.nn.Module,
|
||||
transforms: torch.nn.Module,
|
||||
resizer: torch.nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_builder = spec_builder
|
||||
self.freq_cutter = freq_cutter
|
||||
self.transforms = transforms
|
||||
self.resizer = resizer
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
spec = self.spec_builder(wav)
|
||||
spec = self.freq_cutter(spec)
|
||||
spec = self.transforms(spec)
|
||||
return self.resizer(spec)
|
||||
|
||||
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self.spec_builder(wav)
|
||||
|
||||
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.freq_cutter(spec)
|
||||
|
||||
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.transforms(spec)
|
||||
|
||||
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.resizer(spec)
|
||||
|
||||
|
||||
def build_spectrogram_pipeline(
|
||||
samplerate: int,
|
||||
conf: SpectrogramConfig,
|
||||
) -> SpectrogramPipeline:
|
||||
spec_builder = build_spectrogram_builder(samplerate, conf.stft)
|
||||
n_fft, _ = _spec_params_from_config(samplerate, conf.stft)
|
||||
cutter = FrequencyClip(
|
||||
low_index=_frequency_to_index(
|
||||
conf.frequencies.min_freq, samplerate, n_fft
|
||||
),
|
||||
high_index=_frequency_to_index(
|
||||
conf.frequencies.max_freq, samplerate, n_fft
|
||||
),
|
||||
)
|
||||
transforms = build_spectrogram_transform(samplerate, conf)
|
||||
resizer = ResizeSpec(
|
||||
height=conf.size.height,
|
||||
time_factor=conf.size.resize_factor,
|
||||
)
|
||||
return SpectrogramPipeline(
|
||||
spec_builder=spec_builder,
|
||||
freq_cutter=cutter,
|
||||
transforms=transforms,
|
||||
resizer=resizer,
|
||||
)
|
||||
return freq_bin_width, hop_duration / resize_factor
|
||||
|
||||
@ -28,8 +28,10 @@ from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.typing.targets import Position, Size
|
||||
from batdetect2.utils.arrays import spec_to_xarray
|
||||
|
||||
__all__ = [
|
||||
"Anchor",
|
||||
@ -365,6 +367,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
def __init__(
|
||||
self,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
time_scale: float = DEFAULT_TIME_SCALE,
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||
loading_buffer: float = 0.01,
|
||||
@ -383,6 +386,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
Buffer in seconds to add when loading audio clips.
|
||||
"""
|
||||
self.preprocessor = preprocessor
|
||||
self.audio_loader = audio_loader
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
self.loading_buffer = loading_buffer
|
||||
@ -422,6 +426,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
|
||||
time, freq = get_peak_energy_coordinates(
|
||||
recording=sound_event.recording,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
@ -511,8 +516,10 @@ def build_roi_mapper(
|
||||
|
||||
if config.name == "peak_energy_bbox":
|
||||
preprocessor = build_preprocessor(config.preprocessing)
|
||||
audio_loader = build_audio_loader(config.preprocessing.audio)
|
||||
return PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
loading_buffer=config.loading_buffer,
|
||||
@ -617,6 +624,7 @@ def _build_bounding_box(
|
||||
|
||||
def get_peak_energy_coordinates(
|
||||
recording: data.Recording,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
start_time: float = 0,
|
||||
end_time: Optional[float] = None,
|
||||
@ -669,7 +677,15 @@ def get_peak_energy_coordinates(
|
||||
end_time=clip_end,
|
||||
)
|
||||
|
||||
spec = preprocessor.preprocess_clip(clip)
|
||||
wav = audio_loader.load_clip(clip)
|
||||
spec = preprocessor.process_numpy(wav)
|
||||
spec = spec_to_xarray(
|
||||
spec,
|
||||
clip.start_time,
|
||||
clip.end_time,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
low_freq = max(low_freq, preprocessor.min_freq)
|
||||
high_freq = min(high_freq, preprocessor.max_freq)
|
||||
selection = spec.sel(
|
||||
|
||||
@ -129,9 +129,7 @@ def mix_examples(
|
||||
with xr.set_options(keep_attrs=True):
|
||||
combined = weight * audio1 + (1 - weight) * audio2
|
||||
|
||||
spectrogram = preprocessor.compute_spectrogram(
|
||||
combined.rename({"audio_time": "time"})
|
||||
).data
|
||||
spectrogram = preprocessor.process_numpy(combined.data)
|
||||
|
||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||
# spectrogram computed from the subclip's audio. This is due to a
|
||||
@ -241,9 +239,7 @@ def add_echo(
|
||||
with xr.set_options(keep_attrs=True):
|
||||
audio = audio + weight * audio_delay
|
||||
|
||||
spectrogram = preprocessor.compute_spectrogram(
|
||||
audio.rename({"audio_time": "time"}),
|
||||
).data
|
||||
spectrogram = preprocessor.process_numpy(audio.data)
|
||||
|
||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||
# spectrogram computed from the subclip's audio. This is due to a
|
||||
|
||||
@ -21,10 +21,12 @@ class ClipingConfig(BaseConfig):
|
||||
class Clipper(ClipperProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int,
|
||||
duration: float = 0.5,
|
||||
max_empty: float = 0.2,
|
||||
random: bool = True,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.duration = duration
|
||||
self.random = random
|
||||
self.max_empty = max_empty
|
||||
|
||||
@ -25,6 +25,8 @@ from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
@ -34,9 +36,12 @@ from tqdm.auto import tqdm
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
from batdetect2.utils.arrays import audio_to_xarray
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
@ -76,6 +81,7 @@ def preprocess_dataset(
|
||||
targets = build_targets(config=config.targets)
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
labeller = build_clip_labeler(targets, config=config.labels)
|
||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||
|
||||
if not output.exists():
|
||||
logger.debug("Creating directory {directory}", directory=output)
|
||||
@ -84,6 +90,7 @@ def preprocess_dataset(
|
||||
preprocess_annotations(
|
||||
dataset,
|
||||
output_dir=output,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
replace=force,
|
||||
@ -93,6 +100,7 @@ def preprocess_dataset(
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
) -> xr.Dataset:
|
||||
@ -140,9 +148,15 @@ def generate_train_example(
|
||||
- The original `ClipAnnotation` metadata is stored as a JSON string in the
|
||||
Dataset's attributes for provenance.
|
||||
"""
|
||||
wave = preprocessor.load_clip_audio(clip_annotation.clip)
|
||||
wave = audio_loader.load_clip(clip_annotation.clip)
|
||||
|
||||
spectrogram = preprocessor.compute_spectrogram(wave)
|
||||
spectrogram = _spec_to_xr(
|
||||
preprocessor(torch.tensor(wave)),
|
||||
start_time=clip_annotation.clip.start_time,
|
||||
end_time=clip_annotation.clip.end_time,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
heatmaps = labeller(clip_annotation, spectrogram)
|
||||
|
||||
@ -152,7 +166,12 @@ def generate_train_example(
|
||||
# the spectrogram time dimension, otherwise xarray will interpolate
|
||||
# the spectrogram and the heatmaps to the same temporal resolution
|
||||
# as the waveform.
|
||||
"audio": wave.rename({"time": "audio_time"}),
|
||||
"audio": audio_to_xarray(
|
||||
wave,
|
||||
start_time=clip_annotation.clip.start_time,
|
||||
end_time=clip_annotation.clip.end_time,
|
||||
time_axis="audio_time",
|
||||
),
|
||||
"spectrogram": spectrogram,
|
||||
"detection": heatmaps.detection,
|
||||
"class": heatmaps.classes,
|
||||
@ -170,6 +189,32 @@ def generate_train_example(
|
||||
)
|
||||
|
||||
|
||||
def _spec_to_xr(
|
||||
spec: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> xr.DataArray:
|
||||
data = spec.numpy()[0, 0]
|
||||
|
||||
height, width = data.shape
|
||||
|
||||
return xr.DataArray(
|
||||
data=data,
|
||||
dims=[
|
||||
"frequency",
|
||||
"time",
|
||||
],
|
||||
coords={
|
||||
"frequency": np.linspace(
|
||||
min_freq, max_freq, height, endpoint=False
|
||||
),
|
||||
"time": np.linspace(start_time, end_time, width, endpoint=False),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _save_xr_dataset_to_file(
|
||||
dataset: xr.Dataset,
|
||||
path: data.PathLike,
|
||||
@ -206,6 +251,7 @@ def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: data.PathLike,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
@ -275,6 +321,7 @@ def preprocess_annotations(
|
||||
output_dir=output_dir,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
),
|
||||
@ -290,6 +337,7 @@ def preprocess_annotations(
|
||||
def preprocess_single_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
output_dir: data.PathLike,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
@ -335,6 +383,7 @@ def preprocess_single_annotation(
|
||||
try:
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
)
|
||||
|
||||
@ -10,10 +10,10 @@ pipeline can interact consistently, regardless of the specific underlying
|
||||
implementation (e.g., different libraries or custom configurations).
|
||||
"""
|
||||
|
||||
from typing import Optional, Protocol, Union
|
||||
from typing import Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
@ -36,7 +36,7 @@ class AudioLoader(Protocol):
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path.
|
||||
|
||||
Parameters
|
||||
@ -46,12 +46,6 @@ class AudioLoader(Protocol):
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix to prepend to the path if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform as an xarray DataArray
|
||||
with time coordinates. Typically loads only the first channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
@ -65,7 +59,7 @@ class AudioLoader(Protocol):
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object.
|
||||
|
||||
Parameters
|
||||
@ -95,7 +89,7 @@ class AudioLoader(Protocol):
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object.
|
||||
|
||||
Parameters
|
||||
@ -124,264 +118,41 @@ class AudioLoader(Protocol):
|
||||
|
||||
|
||||
class SpectrogramBuilder(Protocol):
|
||||
"""Defines the interface for a spectrogram generation component.
|
||||
"""Defines the interface for a spectrogram generation component."""
|
||||
|
||||
A SpectrogramBuilder takes a waveform (as numpy array or xarray DataArray)
|
||||
and produces a spectrogram (as an xarray DataArray) based on its internal
|
||||
configuration or implementation.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
wav: Union[np.ndarray, xr.DataArray],
|
||||
samplerate: Optional[int] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Generate a spectrogram from an audio waveform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : Union[np.ndarray, xr.DataArray]
|
||||
The input audio waveform. If a numpy array, `samplerate` must
|
||||
also be provided. If an xarray DataArray, it must have a 'time'
|
||||
coordinate from which the sample rate can be inferred.
|
||||
samplerate : int, optional
|
||||
The sample rate of the audio in Hz. Required if `wav` is a
|
||||
numpy array. If `wav` is an xarray DataArray, this parameter is
|
||||
ignored as the sample rate is derived from the coordinates.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed spectrogram as an xarray DataArray with 'time' and
|
||||
'frequency' coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `wav` is a numpy array and `samplerate` is not provided, or
|
||||
if `wav` is an xarray DataArray without a valid 'time' coordinate.
|
||||
"""
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate a spectrogram from an audio waveform."""
|
||||
...
|
||||
|
||||
|
||||
class PreprocessorProtocol(Protocol):
|
||||
"""Defines a high-level interface for the complete preprocessing pipeline.
|
||||
class AudioPipeline(Protocol):
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
A Preprocessor combines audio loading and spectrogram generation steps.
|
||||
It provides methods to go directly from source descriptions (file paths,
|
||||
Recording objects, Clip objects) to the final spectrogram representation
|
||||
needed by the model. It may also expose intermediate steps like audio
|
||||
loading or spectrogram computation from a waveform.
|
||||
"""
|
||||
|
||||
class SpectrogramPipeline(Protocol):
|
||||
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class PreprocessorProtocol(Protocol):
|
||||
"""Defines a high-level interface for the complete preprocessing pipeline."""
|
||||
|
||||
max_freq: float
|
||||
|
||||
min_freq: float
|
||||
|
||||
def preprocess_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio from a file and compute the final processed spectrogram.
|
||||
audio_pipeline: AudioPipeline
|
||||
|
||||
Performs the full pipeline:
|
||||
spectrogram_pipeline: SpectrogramPipeline
|
||||
|
||||
Load -> Preprocess Audio -> Compute Spectrogram.
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file cannot be found.
|
||||
Exception
|
||||
If any step in the loading or preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def preprocess_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio for a Recording and compute the processed spectrogram.
|
||||
|
||||
Performs the full pipeline for the entire duration of the recording.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file cannot be found.
|
||||
Exception
|
||||
If any step in the loading or preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def preprocess_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio for a Clip and compute the final processed spectrogram.
|
||||
|
||||
Performs the full pipeline for the specified clip segment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the audio segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file cannot be found.
|
||||
Exception
|
||||
If any step in the loading or preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_file_audio(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform from a file path.
|
||||
|
||||
Performs the initial audio loading and waveform processing steps
|
||||
(like resampling, scaling), but stops *before* spectrogram generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, Exception
|
||||
If audio loading/preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_recording_audio(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform for a Recording.
|
||||
|
||||
Performs the initial audio loading and waveform processing steps
|
||||
for the entire recording duration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, Exception
|
||||
If audio loading/preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_clip_audio(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform for a Clip.
|
||||
|
||||
Performs the initial audio loading and waveform processing steps
|
||||
for the specified clip segment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform segment.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, Exception
|
||||
If audio loading/preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def compute_spectrogram(
|
||||
self,
|
||||
wav: Union[xr.DataArray, np.ndarray],
|
||||
) -> xr.DataArray:
|
||||
"""Compute the spectrogram from a pre-loaded audio waveform.
|
||||
|
||||
Applies the spectrogram generation steps (STFT, scaling, etc.) defined
|
||||
by the `SpectrogramBuilder` component of the preprocessor to an
|
||||
already loaded (and potentially preprocessed) waveform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : Union[xr.DataArray, np.ndarray]
|
||||
The input audio waveform. If numpy array, `samplerate` is required.
|
||||
samplerate : int, optional
|
||||
Sample rate in Hz (required if `wav` is np.ndarray).
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError, Exception
|
||||
If waveform input is invalid or spectrogram computation fails.
|
||||
"""
|
||||
...
|
||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||
return self(torch.tensor(wav)).numpy()[0, 0]
|
||||
|
||||
@ -2,6 +2,62 @@ import numpy as np
|
||||
import xarray as xr
|
||||
|
||||
|
||||
def spec_to_xarray(
|
||||
spec: np.ndarray,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> xr.DataArray:
|
||||
if spec.ndim != 2:
|
||||
raise ValueError(
|
||||
"Input numpy spectrogram array should be 2-dimensional"
|
||||
)
|
||||
|
||||
height, width = spec.shape
|
||||
return xr.DataArray(
|
||||
data=spec,
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": np.linspace(
|
||||
min_freq,
|
||||
max_freq,
|
||||
height,
|
||||
endpoint=False,
|
||||
),
|
||||
"time": np.linspace(
|
||||
start_time,
|
||||
end_time,
|
||||
width,
|
||||
endpoint=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def audio_to_xarray(
|
||||
wav: np.ndarray,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
time_axis: str = "time",
|
||||
) -> xr.DataArray:
|
||||
if wav.ndim != 1:
|
||||
raise ValueError("Input numpy audio array should be 1-dimensional")
|
||||
|
||||
return xr.DataArray(
|
||||
data=wav,
|
||||
dims=[time_axis],
|
||||
coords={
|
||||
time_axis: np.linspace(
|
||||
start_time,
|
||||
end_time,
|
||||
len(wav),
|
||||
endpoint=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def extend_width(
|
||||
array: np.ndarray,
|
||||
extra: int,
|
||||
|
||||
@ -12,6 +12,7 @@ from soundevent import data, terms
|
||||
from batdetect2.data import DatasetConfig, load_dataset
|
||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
TermRegistry,
|
||||
@ -27,6 +28,7 @@ from batdetect2.typing import (
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -368,6 +370,11 @@ def sample_preprocessor() -> PreprocessorProtocol:
|
||||
return build_preprocessor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_audio_loader() -> AudioLoader:
|
||||
return build_audio_loader()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bat_tag() -> TagInfo:
|
||||
return TagInfo(key="class", value="bat")
|
||||
|
||||
@ -1,13 +1,10 @@
|
||||
import pathlib
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.arrays import Dimensions, create_time_dim_from_array
|
||||
|
||||
from batdetect2.preprocess import audio
|
||||
|
||||
@ -30,44 +27,6 @@ def create_dummy_wave(
|
||||
return wave.astype(dtype)
|
||||
|
||||
|
||||
def create_xr_wave(
|
||||
samplerate: int,
|
||||
duration: float,
|
||||
num_channels: int = 1,
|
||||
freq: float = 440.0,
|
||||
amplitude: float = 0.5,
|
||||
start_time: float = 0.0,
|
||||
) -> xr.DataArray:
|
||||
"""Generates a simple xarray waveform."""
|
||||
num_samples = int(samplerate * duration)
|
||||
times = np.linspace(
|
||||
start_time,
|
||||
start_time + duration,
|
||||
num_samples,
|
||||
endpoint=False,
|
||||
)
|
||||
coords = {
|
||||
Dimensions.time.value: create_time_dim_from_array(
|
||||
times, samplerate=samplerate, start_time=start_time
|
||||
)
|
||||
}
|
||||
dims = [Dimensions.time.value]
|
||||
|
||||
wave_data = amplitude * np.sin(2 * np.pi * freq * times)
|
||||
|
||||
if num_channels > 1:
|
||||
coords[Dimensions.channel.value] = np.arange(num_channels)
|
||||
dims = [Dimensions.channel.value] + dims
|
||||
wave_data = np.stack([wave_data] * num_channels, axis=0)
|
||||
|
||||
return xr.DataArray(
|
||||
wave_data.astype(np.float32),
|
||||
coords=coords,
|
||||
dims=dims,
|
||||
attrs={"samplerate": samplerate},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||
"""Creates a dummy WAV file and returns its path."""
|
||||
@ -99,408 +58,3 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
|
||||
@pytest.fixture
|
||||
def default_audio_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_resample_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig(resample=None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixed_duration_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig(duration=0.5)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scale_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig(scale=True, center=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_center_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig(center=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resample_fourier_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig(
|
||||
resample=audio.ResampleConfig(
|
||||
samplerate=audio.TARGET_SAMPLERATE_HZ // 2, method="fourier"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_resample_config_defaults():
|
||||
config = audio.ResampleConfig()
|
||||
assert config.samplerate == audio.TARGET_SAMPLERATE_HZ
|
||||
assert config.method == "poly"
|
||||
|
||||
|
||||
def test_audio_config_defaults():
|
||||
config = audio.AudioConfig()
|
||||
assert config.resample is not None
|
||||
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
|
||||
assert config.resample.method == "poly"
|
||||
assert config.scale == audio.SCALE_RAW_AUDIO
|
||||
assert config.center is False
|
||||
assert config.duration == audio.DEFAULT_DURATION
|
||||
|
||||
|
||||
def test_audio_config_override():
|
||||
resample_cfg = audio.ResampleConfig(samplerate=44100, method="fourier")
|
||||
config = audio.AudioConfig(
|
||||
resample=resample_cfg,
|
||||
scale=True,
|
||||
center=False,
|
||||
duration=1.0,
|
||||
)
|
||||
assert config.resample == resample_cfg
|
||||
assert config.scale is True
|
||||
assert config.center is False
|
||||
assert config.duration == 1.0
|
||||
|
||||
|
||||
def test_audio_config_no_resample():
|
||||
config = audio.AudioConfig(resample=None)
|
||||
assert config.resample is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"orig_sr, orig_dur, target_dur",
|
||||
[
|
||||
(256_000, 1.0, 0.5),
|
||||
(256_000, 0.5, 1.0),
|
||||
(256_000, 1.0, 1.0),
|
||||
],
|
||||
)
|
||||
def test_adjust_audio_duration(orig_sr, orig_dur, target_dur):
|
||||
wave = create_xr_wave(samplerate=orig_sr, duration=orig_dur)
|
||||
adjusted_wave = audio.adjust_audio_duration(wave, duration=target_dur)
|
||||
expected_samples = int(target_dur * orig_sr)
|
||||
assert adjusted_wave.sizes["time"] == expected_samples
|
||||
assert adjusted_wave.coords["time"].attrs["step"] == 1 / orig_sr
|
||||
assert adjusted_wave.dtype == wave.dtype
|
||||
if orig_dur > 0 and target_dur > orig_dur:
|
||||
padding_start_index = int(orig_dur * orig_sr) + 1
|
||||
assert np.all(adjusted_wave.values[padding_start_index:] == 0)
|
||||
|
||||
|
||||
def test_adjust_audio_duration_negative_target_raises():
|
||||
wave = create_xr_wave(1000, 1.0)
|
||||
with pytest.raises(ValueError):
|
||||
audio.adjust_audio_duration(wave, duration=-0.5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"orig_sr, target_sr, mode",
|
||||
[
|
||||
(48000, 96000, "poly"),
|
||||
(96000, 48000, "poly"),
|
||||
(48000, 96000, "fourier"),
|
||||
(96000, 48000, "fourier"),
|
||||
(48000, 44100, "poly"),
|
||||
(48000, 44100, "fourier"),
|
||||
],
|
||||
)
|
||||
def test_resample_audio(orig_sr, target_sr, mode):
|
||||
duration = 0.1
|
||||
wave = create_xr_wave(orig_sr, duration)
|
||||
resampled_wave = audio.resample_audio(
|
||||
wave, samplerate=target_sr, method=mode, dtype=np.float32
|
||||
)
|
||||
expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr))
|
||||
assert resampled_wave.sizes["time"] == expected_samples
|
||||
assert resampled_wave.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert np.isclose(
|
||||
resampled_wave.coords["time"].values[-1]
|
||||
- resampled_wave.coords["time"].values[0],
|
||||
duration,
|
||||
atol=2 / target_sr,
|
||||
)
|
||||
assert resampled_wave.dtype == np.float32
|
||||
|
||||
|
||||
def test_resample_audio_same_samplerate():
|
||||
sr = 48000
|
||||
duration = 0.1
|
||||
wave = create_xr_wave(sr, duration)
|
||||
resampled_wave = audio.resample_audio(
|
||||
wave, samplerate=sr, dtype=np.float64
|
||||
)
|
||||
xr.testing.assert_equal(wave.astype(np.float64), resampled_wave)
|
||||
|
||||
|
||||
def test_resample_audio_invalid_mode_raises():
|
||||
wave = create_xr_wave(48000, 0.1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
audio.resample_audio(wave, samplerate=96000, method="invalid_mode")
|
||||
|
||||
|
||||
def test_resample_audio_no_time_dim_raises():
|
||||
wave = xr.DataArray(np.random.rand(100), dims=["samples"])
|
||||
with pytest.raises(ValueError, match="Audio must have a time dimension"):
|
||||
audio.resample_audio(wave, samplerate=96000)
|
||||
|
||||
|
||||
def test_load_clip_audio_default_config(
|
||||
dummy_clip: data.Clip,
|
||||
default_audio_config: audio.AudioConfig,
|
||||
tmp_path: Path,
|
||||
):
|
||||
assert default_audio_config.resample is not None
|
||||
target_sr = default_audio_config.resample.samplerate
|
||||
orig_duration = dummy_clip.duration
|
||||
expected_samples = int(orig_duration * target_sr)
|
||||
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert isinstance(wav, xr.DataArray)
|
||||
assert wav.dims == ("time",)
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
||||
assert wav.dtype == np.float32
|
||||
|
||||
|
||||
def test_load_clip_audio_no_resample(
|
||||
dummy_clip: data.Clip,
|
||||
no_resample_config: audio.AudioConfig,
|
||||
tmp_path: Path,
|
||||
):
|
||||
orig_sr = dummy_clip.recording.samplerate
|
||||
orig_duration = dummy_clip.duration
|
||||
expected_samples = int(orig_duration * orig_sr)
|
||||
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip, config=no_resample_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert wav.coords["time"].attrs["step"] == 1 / orig_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
||||
|
||||
|
||||
def test_load_clip_audio_fixed_duration_crop(
|
||||
dummy_clip: data.Clip,
|
||||
fixed_duration_config: audio.AudioConfig,
|
||||
tmp_path: Path,
|
||||
):
|
||||
target_sr = audio.TARGET_SAMPLERATE_HZ
|
||||
target_duration = fixed_duration_config.duration
|
||||
assert target_duration is not None
|
||||
expected_samples = int(target_duration * target_sr)
|
||||
|
||||
assert dummy_clip.duration > target_duration
|
||||
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip, config=fixed_duration_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
|
||||
|
||||
def test_load_clip_audio_fixed_duration_pad(
|
||||
dummy_clip: data.Clip,
|
||||
tmp_path: Path,
|
||||
):
|
||||
target_duration = dummy_clip.duration * 2
|
||||
config = audio.AudioConfig(duration=target_duration)
|
||||
|
||||
assert config.resample is not None
|
||||
target_sr = config.resample.samplerate
|
||||
expected_samples = int(target_duration * target_sr)
|
||||
|
||||
wav = audio.load_clip_audio(dummy_clip, config=config, audio_dir=tmp_path)
|
||||
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
|
||||
original_samples_after_resample = int(dummy_clip.duration * target_sr)
|
||||
assert np.allclose(
|
||||
wav.values[original_samples_after_resample:], 0.0, atol=1e-6
|
||||
)
|
||||
|
||||
|
||||
def test_load_clip_audio_scale(
|
||||
dummy_clip: data.Clip, scale_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip,
|
||||
config=scale_config,
|
||||
audio_dir=tmp_path,
|
||||
)
|
||||
|
||||
assert np.isclose(np.max(np.abs(wav.values)), 1.0, atol=1e-5)
|
||||
|
||||
|
||||
def test_load_clip_audio_no_center(
|
||||
dummy_clip: data.Clip, no_center_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip, config=no_center_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
raw_wav, _ = sf.read(
|
||||
dummy_clip.recording.path,
|
||||
start=int(dummy_clip.start_time * dummy_clip.recording.samplerate),
|
||||
stop=int(dummy_clip.end_time * dummy_clip.recording.samplerate),
|
||||
dtype=np.float32, # type: ignore
|
||||
)
|
||||
raw_wav_mono = raw_wav[:, 0]
|
||||
|
||||
if not np.isclose(raw_wav_mono.mean(), 0.0, atol=1e-7):
|
||||
assert not np.isclose(wav.mean(), 0.0, atol=1e-6)
|
||||
|
||||
|
||||
def test_load_clip_audio_resample_fourier(
|
||||
dummy_clip: data.Clip, resample_fourier_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
assert resample_fourier_config.resample is not None
|
||||
target_sr = resample_fourier_config.resample.samplerate
|
||||
orig_duration = dummy_clip.duration
|
||||
expected_samples = int(orig_duration * target_sr)
|
||||
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip, config=resample_fourier_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
|
||||
|
||||
def test_load_clip_audio_dtype(
|
||||
dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip,
|
||||
config=default_audio_config,
|
||||
audio_dir=tmp_path,
|
||||
dtype=np.float64,
|
||||
)
|
||||
assert wav.dtype == np.float64
|
||||
|
||||
|
||||
def test_load_clip_audio_file_not_found(
|
||||
dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
non_existent_path = tmp_path / "not_a_real_file.wav"
|
||||
dummy_clip.recording = data.Recording(
|
||||
path=non_existent_path,
|
||||
duration=1,
|
||||
channels=1,
|
||||
samplerate=256000,
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
audio.load_clip_audio(
|
||||
dummy_clip, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
|
||||
def test_load_recording_audio(
|
||||
dummy_recording: data.Recording,
|
||||
default_audio_config: audio.AudioConfig,
|
||||
tmp_path,
|
||||
):
|
||||
assert default_audio_config.resample is not None
|
||||
target_sr = default_audio_config.resample.samplerate
|
||||
orig_duration = dummy_recording.duration
|
||||
expected_samples = int(orig_duration * target_sr)
|
||||
|
||||
wav = audio.load_recording_audio(
|
||||
dummy_recording, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert isinstance(wav, xr.DataArray)
|
||||
assert wav.dims == ("time",)
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
||||
assert wav.dtype == np.float32
|
||||
|
||||
|
||||
def test_load_recording_audio_file_not_found(
|
||||
dummy_recording: data.Recording,
|
||||
default_audio_config: audio.AudioConfig,
|
||||
tmp_path,
|
||||
):
|
||||
non_existent_path = tmp_path / "not_a_real_file.wav"
|
||||
dummy_recording = data.Recording(
|
||||
path=non_existent_path,
|
||||
duration=1,
|
||||
channels=1,
|
||||
samplerate=256000,
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
audio.load_recording_audio(
|
||||
dummy_recording, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
|
||||
def test_load_file_audio(
|
||||
dummy_wav_path: pathlib.Path,
|
||||
default_audio_config: audio.AudioConfig,
|
||||
tmp_path,
|
||||
):
|
||||
info = sf.info(dummy_wav_path)
|
||||
orig_duration = info.duration
|
||||
assert default_audio_config.resample is not None
|
||||
target_sr = default_audio_config.resample.samplerate
|
||||
expected_samples = int(orig_duration * target_sr)
|
||||
|
||||
wav = audio.load_file_audio(
|
||||
dummy_wav_path, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert isinstance(wav, xr.DataArray)
|
||||
assert wav.dims == ("time",)
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
||||
assert wav.dtype == np.float32
|
||||
|
||||
|
||||
def test_load_file_audio_file_not_found(
|
||||
default_audio_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
non_existent_path = tmp_path / "not_a_real_file.wav"
|
||||
with pytest.raises(FileNotFoundError):
|
||||
audio.load_file_audio(
|
||||
non_existent_path, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
|
||||
def test_build_audio_loader(default_audio_config: audio.AudioConfig):
|
||||
loader = audio.build_audio_loader(config=default_audio_config)
|
||||
assert isinstance(loader, audio.ConfigurableAudioLoader)
|
||||
assert loader.config == default_audio_config
|
||||
|
||||
|
||||
def test_configurable_audio_loader_methods(
|
||||
default_audio_config: audio.AudioConfig,
|
||||
dummy_wav_path: pathlib.Path,
|
||||
dummy_recording: data.Recording,
|
||||
dummy_clip: data.Clip,
|
||||
tmp_path,
|
||||
):
|
||||
loader = audio.build_audio_loader(config=default_audio_config)
|
||||
|
||||
expected_wav_file = audio.load_file_audio(
|
||||
dummy_wav_path, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
loaded_wav_file = loader.load_file(dummy_wav_path, audio_dir=tmp_path)
|
||||
xr.testing.assert_equal(expected_wav_file, loaded_wav_file)
|
||||
|
||||
expected_wav_rec = audio.load_recording_audio(
|
||||
dummy_recording, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
loaded_wav_rec = loader.load_recording(dummy_recording, audio_dir=tmp_path)
|
||||
xr.testing.assert_equal(expected_wav_rec, loaded_wav_rec)
|
||||
|
||||
expected_wav_clip = audio.load_clip_audio(
|
||||
dummy_clip, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
loaded_wav_clip = loader.load_clip(dummy_clip, audio_dir=tmp_path)
|
||||
xr.testing.assert_equal(expected_wav_clip, loaded_wav_clip)
|
||||
|
||||
@ -1,32 +1,7 @@
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
from soundevent import arrays
|
||||
|
||||
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
MAX_FREQ,
|
||||
MIN_FREQ,
|
||||
ConfigurableSpectrogramBuilder,
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
STFTConfig,
|
||||
apply_pcen,
|
||||
build_spectrogram_builder,
|
||||
compute_spectrogram,
|
||||
crop_spectrogram_frequencies,
|
||||
get_spectrogram_resolution,
|
||||
remove_spectral_mean,
|
||||
resize_spectrogram,
|
||||
scale_spectrogram,
|
||||
stft,
|
||||
)
|
||||
|
||||
SAMPLERATE = 250_000
|
||||
DURATION = 0.1
|
||||
@ -61,389 +36,3 @@ def constant_wave_xr() -> xr.DataArray:
|
||||
dims=["time"],
|
||||
attrs={"samplerate": SAMPLERATE},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_spec(sine_wave_xr: xr.DataArray) -> xr.DataArray:
|
||||
"""Generate a basic spectrogram for testing downstream functions."""
|
||||
config = SpectrogramConfig(
|
||||
stft=STFTConfig(window_duration=0.002, window_overlap=0.5),
|
||||
frequencies=FrequencyConfig(
|
||||
min_freq=0,
|
||||
max_freq=int(SAMPLERATE / 2),
|
||||
),
|
||||
size=None,
|
||||
pcen=None,
|
||||
spectral_mean_substraction=False,
|
||||
peak_normalize=False,
|
||||
scale="amplitude",
|
||||
)
|
||||
spec = stft(
|
||||
sine_wave_xr,
|
||||
window_duration=config.stft.window_duration,
|
||||
window_overlap=config.stft.window_overlap,
|
||||
window_fn=config.stft.window_fn,
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
def test_stft_config_defaults():
|
||||
config = STFTConfig()
|
||||
assert config.window_duration == 0.002
|
||||
assert config.window_overlap == 0.75
|
||||
assert config.window_fn == "hann"
|
||||
|
||||
|
||||
def test_frequency_config_defaults():
|
||||
config = FrequencyConfig()
|
||||
assert config.min_freq == MIN_FREQ
|
||||
assert config.max_freq == MAX_FREQ
|
||||
|
||||
|
||||
def test_spec_size_config_defaults():
|
||||
config = SpecSizeConfig()
|
||||
assert config.height == 128
|
||||
assert config.resize_factor == 0.5
|
||||
|
||||
|
||||
def test_pcen_config_defaults():
|
||||
config = PcenConfig()
|
||||
assert config.time_constant == 0.01
|
||||
assert config.gain == 0.98
|
||||
assert config.bias == 2
|
||||
assert config.power == 0.5
|
||||
|
||||
|
||||
def test_spectrogram_config_defaults():
|
||||
config = SpectrogramConfig()
|
||||
assert isinstance(config.stft, STFTConfig)
|
||||
assert isinstance(config.frequencies, FrequencyConfig)
|
||||
assert isinstance(config.pcen, PcenConfig)
|
||||
assert config.scale == "amplitude"
|
||||
assert isinstance(config.size, SpecSizeConfig)
|
||||
assert config.spectral_mean_substraction is True
|
||||
assert config.peak_normalize is False
|
||||
|
||||
|
||||
def test_stft_output_properties(sine_wave_xr: xr.DataArray):
|
||||
window_duration = 0.002
|
||||
window_overlap = 0.5
|
||||
samplerate = sine_wave_xr.attrs["samplerate"]
|
||||
nfft = int(window_duration * samplerate)
|
||||
hop_len = nfft - int(window_overlap * nfft)
|
||||
|
||||
spec = stft(
|
||||
sine_wave_xr,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
window_fn="hann",
|
||||
)
|
||||
|
||||
assert isinstance(spec, xr.DataArray)
|
||||
assert spec.dims == ("frequency", "time")
|
||||
assert spec.dtype == np.float32
|
||||
assert "frequency" in spec.coords
|
||||
assert "time" in spec.coords
|
||||
|
||||
time_step = arrays.get_dim_step(spec, "time")
|
||||
freq_step = arrays.get_dim_step(spec, "frequency")
|
||||
freq_start, freq_end = arrays.get_dim_range(spec, "frequency")
|
||||
assert np.isclose(freq_step, samplerate / nfft)
|
||||
assert np.isclose(time_step, hop_len / samplerate)
|
||||
assert spec.frequency.min() >= 0
|
||||
assert freq_start == 0
|
||||
assert np.isclose(freq_end, samplerate / 2, atol=freq_step / 2)
|
||||
assert np.isclose(spec.time.min(), 0)
|
||||
assert spec.time.max() < DURATION
|
||||
|
||||
assert spec.attrs["samplerate"] == samplerate
|
||||
assert spec.attrs["window_size"] == window_duration
|
||||
assert spec.attrs["hop_size"] == window_duration * (1 - window_overlap)
|
||||
|
||||
assert np.all(spec.data >= 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("window_fn", ["hann", "hamming"])
|
||||
def test_stft_window_fn(sine_wave_xr: xr.DataArray, window_fn: str):
|
||||
spec = stft(
|
||||
sine_wave_xr,
|
||||
window_duration=0.002,
|
||||
window_overlap=0.5,
|
||||
window_fn=window_fn,
|
||||
)
|
||||
assert isinstance(spec, xr.DataArray)
|
||||
assert np.all(spec.data >= 0)
|
||||
|
||||
|
||||
def test_crop_spectrogram_frequencies(sample_spec: xr.DataArray):
|
||||
min_f, max_f = 20_000, 80_000
|
||||
cropped_spec = crop_spectrogram_frequencies(
|
||||
sample_spec, min_freq=min_f, max_freq=max_f
|
||||
)
|
||||
|
||||
assert cropped_spec.dims == sample_spec.dims
|
||||
assert cropped_spec.dtype == sample_spec.dtype
|
||||
assert cropped_spec.sizes["time"] == sample_spec.sizes["time"]
|
||||
assert cropped_spec.sizes["frequency"] < sample_spec.sizes["frequency"]
|
||||
assert cropped_spec.coords["frequency"].min() >= min_f
|
||||
|
||||
assert np.isclose(cropped_spec.coords["frequency"].max(), max_f, rtol=0.1)
|
||||
|
||||
|
||||
def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
|
||||
samplerate = sample_spec.attrs["samplerate"]
|
||||
min_f, max_f = 0, samplerate / 2
|
||||
cropped_spec = crop_spectrogram_frequencies(
|
||||
sample_spec, min_freq=min_f, max_freq=max_f
|
||||
)
|
||||
|
||||
assert cropped_spec.sizes == sample_spec.sizes
|
||||
assert np.allclose(cropped_spec.data, sample_spec.data)
|
||||
|
||||
|
||||
def test_apply_pcen(sample_spec: xr.DataArray):
|
||||
pcen_config = PcenConfig()
|
||||
pcen_spec = apply_pcen(
|
||||
sample_spec,
|
||||
time_constant=pcen_config.time_constant,
|
||||
gain=pcen_config.gain,
|
||||
bias=pcen_config.bias,
|
||||
power=pcen_config.power,
|
||||
)
|
||||
|
||||
assert pcen_spec.dims == sample_spec.dims
|
||||
assert pcen_spec.sizes == sample_spec.sizes
|
||||
assert pcen_spec.dtype == sample_spec.dtype
|
||||
assert np.all(pcen_spec.data >= 0)
|
||||
|
||||
assert not np.allclose(pcen_spec.data, sample_spec.data)
|
||||
|
||||
|
||||
def test_scale_spectrogram_amplitude(sample_spec: xr.DataArray):
|
||||
scaled_spec = scale_spectrogram(sample_spec, scale="amplitude")
|
||||
assert np.allclose(scaled_spec.data, sample_spec.data)
|
||||
assert scaled_spec.dtype == sample_spec.dtype
|
||||
|
||||
|
||||
def test_scale_spectrogram_power(sample_spec: xr.DataArray):
|
||||
scaled_spec = scale_spectrogram(sample_spec, scale="power")
|
||||
assert np.allclose(scaled_spec.data, sample_spec.data**2)
|
||||
assert scaled_spec.dtype == sample_spec.dtype
|
||||
|
||||
|
||||
def test_scale_spectrogram_db(sample_spec: xr.DataArray):
|
||||
scaled_spec = scale_spectrogram(sample_spec, scale="dB")
|
||||
log_spec_expected = arrays.to_db(sample_spec)
|
||||
xr.testing.assert_allclose(scaled_spec, log_spec_expected)
|
||||
|
||||
|
||||
def test_remove_spectral_mean(sample_spec: xr.DataArray):
|
||||
spec_noisy = sample_spec.copy() + 0.1
|
||||
denoised_spec = remove_spectral_mean(spec_noisy)
|
||||
|
||||
assert denoised_spec.dims == spec_noisy.dims
|
||||
assert denoised_spec.sizes == spec_noisy.sizes
|
||||
assert denoised_spec.dtype == spec_noisy.dtype
|
||||
assert np.all(denoised_spec.data >= 0)
|
||||
|
||||
|
||||
def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
|
||||
const_spec = stft(constant_wave_xr, 0.002, 0.5)
|
||||
denoised_spec = remove_spectral_mean(const_spec)
|
||||
assert np.all(denoised_spec.data >= 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"height, resize_factor, expected_freq_size, expected_time_factor",
|
||||
[
|
||||
(128, 1.0, 128, 1.0),
|
||||
(64, 0.5, 64, 0.5),
|
||||
(256, None, 256, 1.0),
|
||||
(100, 2.0, 100, 2.0),
|
||||
],
|
||||
)
|
||||
def test_resize_spectrogram(
|
||||
sample_spec: xr.DataArray,
|
||||
height: int,
|
||||
resize_factor: Union[float, None],
|
||||
expected_freq_size: int,
|
||||
expected_time_factor: float,
|
||||
):
|
||||
original_time_size = sample_spec.sizes["time"]
|
||||
resized_spec = resize_spectrogram(
|
||||
sample_spec,
|
||||
height=height,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
|
||||
assert resized_spec.dims == ("frequency", "time")
|
||||
assert resized_spec.sizes["frequency"] == expected_freq_size
|
||||
expected_time_size = int(original_time_size * expected_time_factor)
|
||||
|
||||
assert abs(resized_spec.sizes["time"] - expected_time_size) <= 1
|
||||
|
||||
|
||||
def test_compute_spectrogram_defaults(sine_wave_xr: xr.DataArray):
|
||||
config = SpectrogramConfig()
|
||||
spec = compute_spectrogram(sine_wave_xr, config=config)
|
||||
|
||||
assert isinstance(spec, xr.DataArray)
|
||||
assert spec.dims == ("frequency", "time")
|
||||
assert spec.dtype == np.float32
|
||||
assert config.size is not None
|
||||
assert spec.sizes["frequency"] == config.size.height
|
||||
|
||||
temp_stft = stft(
|
||||
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
|
||||
)
|
||||
assert config.size.resize_factor is not None
|
||||
expected_time_size = int(
|
||||
temp_stft.sizes["time"] * config.size.resize_factor
|
||||
)
|
||||
assert abs(spec.sizes["time"] - expected_time_size) <= 1
|
||||
|
||||
assert spec.coords["frequency"].min() >= config.frequencies.min_freq
|
||||
assert np.isclose(
|
||||
spec.coords["frequency"].max(),
|
||||
config.frequencies.max_freq,
|
||||
rtol=0.1,
|
||||
)
|
||||
|
||||
|
||||
def test_compute_spectrogram_no_pcen_no_mean_sub_no_resize(
|
||||
sine_wave_xr: xr.DataArray,
|
||||
):
|
||||
config = SpectrogramConfig(
|
||||
pcen=None,
|
||||
spectral_mean_substraction=False,
|
||||
size=None,
|
||||
scale="power",
|
||||
frequencies=FrequencyConfig(min_freq=0, max_freq=int(SAMPLERATE / 2)),
|
||||
)
|
||||
spec = compute_spectrogram(sine_wave_xr, config=config)
|
||||
|
||||
stft_direct = stft(
|
||||
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
|
||||
)
|
||||
expected_spec = scale_spectrogram(stft_direct, scale="power")
|
||||
|
||||
assert spec.sizes == expected_spec.sizes
|
||||
assert np.allclose(spec.data, expected_spec.data)
|
||||
assert spec.dtype == expected_spec.dtype
|
||||
|
||||
|
||||
def test_compute_spectrogram_peak_normalize(sine_wave_xr: xr.DataArray):
|
||||
config = SpectrogramConfig(peak_normalize=True, pcen=None)
|
||||
spec = compute_spectrogram(sine_wave_xr, config=config)
|
||||
assert np.isclose(spec.data.max(), 1.0, atol=1e-6)
|
||||
|
||||
config = SpectrogramConfig(peak_normalize=False)
|
||||
spec_no_norm = compute_spectrogram(sine_wave_xr, config=config)
|
||||
assert not np.isclose(spec_no_norm.data.max(), 1.0, atol=1e-6)
|
||||
|
||||
|
||||
def test_get_spectrogram_resolution_calculation():
|
||||
config = SpectrogramConfig(
|
||||
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
|
||||
size=SpecSizeConfig(height=100, resize_factor=0.5),
|
||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=110_000),
|
||||
)
|
||||
|
||||
freq_res, time_res = get_spectrogram_resolution(config)
|
||||
|
||||
expected_freq_res = (110_000 - 10_000) / 100
|
||||
expected_hop_duration = 0.002 * (1 - 0.75)
|
||||
expected_time_res = expected_hop_duration / 0.5
|
||||
|
||||
assert np.isclose(freq_res, expected_freq_res)
|
||||
assert np.isclose(time_res, expected_time_res)
|
||||
|
||||
|
||||
def test_get_spectrogram_resolution_no_resize_factor():
|
||||
config = SpectrogramConfig(
|
||||
stft=STFTConfig(window_duration=0.004, window_overlap=0.5),
|
||||
size=SpecSizeConfig(height=200, resize_factor=None),
|
||||
frequencies=FrequencyConfig(min_freq=20_000, max_freq=120_000),
|
||||
)
|
||||
freq_res, time_res = get_spectrogram_resolution(config)
|
||||
expected_freq_res = (120_000 - 20_000) / 200
|
||||
expected_hop_duration = 0.004 * (1 - 0.5)
|
||||
expected_time_res = expected_hop_duration / 1.0
|
||||
|
||||
assert np.isclose(freq_res, expected_freq_res)
|
||||
assert np.isclose(time_res, expected_time_res)
|
||||
|
||||
|
||||
def test_get_spectrogram_resolution_no_size_config():
|
||||
config = SpectrogramConfig(size=None)
|
||||
with pytest.raises(
|
||||
ValueError, match="Spectrogram size configuration is required"
|
||||
):
|
||||
get_spectrogram_resolution(config)
|
||||
|
||||
|
||||
def test_configurable_spectrogram_builder_init():
|
||||
config = SpectrogramConfig()
|
||||
builder = ConfigurableSpectrogramBuilder(config=config, dtype=np.float16)
|
||||
assert builder.config is config
|
||||
assert builder.dtype == np.float16
|
||||
|
||||
|
||||
def test_configurable_spectrogram_builder_call_xr(sine_wave_xr: xr.DataArray):
|
||||
config = SpectrogramConfig()
|
||||
builder = ConfigurableSpectrogramBuilder(config=config)
|
||||
spec_builder = builder(sine_wave_xr)
|
||||
spec_direct = compute_spectrogram(sine_wave_xr, config=config)
|
||||
assert isinstance(spec_builder, xr.DataArray)
|
||||
assert np.allclose(spec_builder.data, spec_direct.data)
|
||||
assert spec_builder.dtype == spec_direct.dtype
|
||||
|
||||
|
||||
def test_configurable_spectrogram_builder_call_np_no_samplerate(
|
||||
sine_wave_xr: xr.DataArray,
|
||||
):
|
||||
config = SpectrogramConfig()
|
||||
builder = ConfigurableSpectrogramBuilder(config=config)
|
||||
wav_np = sine_wave_xr.data
|
||||
with pytest.raises(ValueError, match="Samplerate must be provided"):
|
||||
builder(wav_np, samplerate=None)
|
||||
|
||||
|
||||
def test_build_spectrogram_builder():
|
||||
config = SpectrogramConfig(peak_normalize=True)
|
||||
builder = build_spectrogram_builder(config=config, dtype=np.float64)
|
||||
assert isinstance(builder, ConfigurableSpectrogramBuilder)
|
||||
assert builder.config is config
|
||||
assert builder.dtype == np.float64
|
||||
|
||||
|
||||
def test_can_estimate_spectrogram_resolution(
|
||||
wav_factory: Callable[..., Path],
|
||||
):
|
||||
path = wav_factory(duration=0.2, samplerate=256_000)
|
||||
|
||||
audio_data = load_file_audio(
|
||||
path,
|
||||
config=AudioConfig(resample=None, duration=None),
|
||||
)
|
||||
|
||||
config = SpectrogramConfig(
|
||||
stft=STFTConfig(),
|
||||
size=SpecSizeConfig(height=256, resize_factor=0.5),
|
||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||
)
|
||||
|
||||
spec = compute_spectrogram(audio_data, config=config)
|
||||
|
||||
freq_res, time_res = get_spectrogram_resolution(config)
|
||||
|
||||
assert math.isclose(
|
||||
arrays.get_dim_step(spec, dim="frequency"),
|
||||
freq_res,
|
||||
rel_tol=0.1,
|
||||
)
|
||||
assert math.isclose(
|
||||
arrays.get_dim_step(spec, dim="time"),
|
||||
time_res,
|
||||
rel_tol=0.1,
|
||||
)
|
||||
|
||||
@ -3,7 +3,11 @@ import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessingConfig,
|
||||
build_preprocessor,
|
||||
)
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.targets.rois import (
|
||||
DEFAULT_ANCHOR,
|
||||
DEFAULT_FREQUENCY_SCALE,
|
||||
@ -275,6 +279,8 @@ def test_get_peak_energy_coordinates(generate_whistle):
|
||||
# Build a preprocessor (default config should be fine for this test)
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
# Define a region of interest that contains the whistle
|
||||
start_time = 0.2
|
||||
end_time = 0.7
|
||||
@ -285,6 +291,7 @@ def test_get_peak_energy_coordinates(generate_whistle):
|
||||
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||
recording=recording,
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
@ -356,6 +363,7 @@ def test_get_peak_energy_coordinates_with_two_whistles(generate_whistle):
|
||||
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||
recording=recording,
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=build_audio_loader(),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
@ -389,6 +397,7 @@ def test_get_peak_energy_coordinates_silent_region(create_recording):
|
||||
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||
recording=recording,
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=build_audio_loader(),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
@ -443,17 +452,11 @@ def test_peak_energy_bbox_mapper_encode(generate_whistle):
|
||||
|
||||
# Instantiate the mapper with a preprocessor
|
||||
preprocessor = build_preprocessor(
|
||||
PreprocessingConfig.model_validate(
|
||||
{
|
||||
"spectrogram": {
|
||||
"pcen": None,
|
||||
"spectral_mean_substraction": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
PreprocessingConfig.model_validate({"spectrogram": {"transforms": []}})
|
||||
)
|
||||
mapper = PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=build_audio_loader(),
|
||||
time_scale=time_scale,
|
||||
frequency_scale=freq_scale,
|
||||
)
|
||||
@ -493,6 +496,7 @@ def test_peak_energy_bbox_mapper_decode():
|
||||
|
||||
mapper = PeakEnergyBBoxMapper(
|
||||
preprocessor=build_preprocessor(),
|
||||
audio_loader=build_audio_loader(),
|
||||
time_scale=time_scale,
|
||||
frequency_scale=freq_scale,
|
||||
)
|
||||
@ -553,7 +557,11 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
|
||||
}
|
||||
)
|
||||
)
|
||||
mapper = PeakEnergyBBoxMapper(preprocessor=preprocessor)
|
||||
audio_loader = build_audio_loader()
|
||||
mapper = PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
)
|
||||
|
||||
# When
|
||||
# Encode the sound event, then immediately decode the result.
|
||||
|
||||
@ -11,11 +11,12 @@ from batdetect2.train.augmentations import (
|
||||
)
|
||||
from batdetect2.train.clips import select_subclip
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||
from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol
|
||||
|
||||
|
||||
def test_mix_examples(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
):
|
||||
@ -30,11 +31,13 @@ def test_mix_examples(
|
||||
|
||||
example1 = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
example2 = generate_train_example(
|
||||
clip_annotation_2,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
@ -51,6 +54,7 @@ def test_mix_examples(
|
||||
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
||||
def test_mix_examples_of_different_durations(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
duration1: float,
|
||||
@ -67,11 +71,13 @@ def test_mix_examples_of_different_durations(
|
||||
|
||||
example1 = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
example2 = generate_train_example(
|
||||
clip_annotation_2,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
@ -87,6 +93,7 @@ def test_mix_examples_of_different_durations(
|
||||
|
||||
def test_add_echo(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
):
|
||||
@ -96,6 +103,7 @@ def test_add_echo(
|
||||
|
||||
original = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
@ -109,6 +117,7 @@ def test_add_echo(
|
||||
|
||||
def test_selected_random_subclip_has_the_correct_width(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
):
|
||||
@ -118,6 +127,7 @@ def test_selected_random_subclip_has_the_correct_width(
|
||||
|
||||
original = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
@ -128,6 +138,7 @@ def test_selected_random_subclip_has_the_correct_width(
|
||||
|
||||
def test_add_echo_after_subclip(
|
||||
sample_preprocessor: PreprocessorProtocol,
|
||||
sample_audio_loader: AudioLoader,
|
||||
sample_labeller: ClipLabeller,
|
||||
create_recording: Callable[..., data.Recording],
|
||||
):
|
||||
@ -136,6 +147,7 @@ def test_add_echo_after_subclip(
|
||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||
original = generate_train_example(
|
||||
clip_annotation_1,
|
||||
audio_loader=sample_audio_loader,
|
||||
preprocessor=sample_preprocessor,
|
||||
labeller=sample_labeller,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user