mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 09:29:33 +01:00
Preprocessing in pytorch
This commit is contained in:
parent
61115d562c
commit
667b18a54d
@ -28,13 +28,12 @@ This module provides the primary interface:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import torch
|
||||||
import xarray as xr
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.preprocess.audio import (
|
from batdetect2.preprocess.audio import (
|
||||||
@ -44,28 +43,23 @@ from batdetect2.preprocess.audio import (
|
|||||||
AudioConfig,
|
AudioConfig,
|
||||||
ResampleConfig,
|
ResampleConfig,
|
||||||
build_audio_loader,
|
build_audio_loader,
|
||||||
|
build_audio_pipeline,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.spectrogram import (
|
from batdetect2.preprocess.spectrogram import (
|
||||||
MAX_FREQ,
|
MAX_FREQ,
|
||||||
MIN_FREQ,
|
MIN_FREQ,
|
||||||
ConfigurableSpectrogramBuilder,
|
|
||||||
FrequencyConfig,
|
FrequencyConfig,
|
||||||
PcenConfig,
|
PcenConfig,
|
||||||
SpecSizeConfig,
|
|
||||||
SpectrogramConfig,
|
SpectrogramConfig,
|
||||||
|
SpectrogramPipeline,
|
||||||
STFTConfig,
|
STFTConfig,
|
||||||
build_spectrogram_builder,
|
build_spectrogram_builder,
|
||||||
get_spectrogram_resolution,
|
build_spectrogram_pipeline,
|
||||||
)
|
|
||||||
from batdetect2.typing.preprocess import (
|
|
||||||
AudioLoader,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
SpectrogramBuilder,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AudioConfig",
|
"AudioConfig",
|
||||||
"ConfigurableSpectrogramBuilder",
|
|
||||||
"DEFAULT_DURATION",
|
"DEFAULT_DURATION",
|
||||||
"FrequencyConfig",
|
"FrequencyConfig",
|
||||||
"MAX_FREQ",
|
"MAX_FREQ",
|
||||||
@ -75,16 +69,11 @@ __all__ = [
|
|||||||
"ResampleConfig",
|
"ResampleConfig",
|
||||||
"SCALE_RAW_AUDIO",
|
"SCALE_RAW_AUDIO",
|
||||||
"STFTConfig",
|
"STFTConfig",
|
||||||
"SpecSizeConfig",
|
|
||||||
"SpectrogramConfig",
|
"SpectrogramConfig",
|
||||||
"StandardPreprocessor",
|
|
||||||
"TARGET_SAMPLERATE_HZ",
|
"TARGET_SAMPLERATE_HZ",
|
||||||
"build_audio_loader",
|
"build_audio_loader",
|
||||||
"build_preprocessor",
|
|
||||||
"build_spectrogram_builder",
|
"build_spectrogram_builder",
|
||||||
"get_spectrogram_resolution",
|
|
||||||
"load_preprocessing_config",
|
"load_preprocessing_config",
|
||||||
"get_default_preprocessor",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -110,343 +99,61 @@ class PreprocessingConfig(BaseConfig):
|
|||||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||||
|
|
||||||
|
|
||||||
class StandardPreprocessor(PreprocessorProtocol):
|
def load_preprocessing_config(
|
||||||
"""Standard implementation of the `Preprocessor` protocol.
|
path: PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> PreprocessingConfig:
|
||||||
|
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||||
|
|
||||||
Orchestrates the audio loading and spectrogram generation pipeline using
|
|
||||||
an `AudioLoader` and a `SpectrogramBuilder` internally, which are
|
|
||||||
configured according to a `PreprocessingConfig`.
|
|
||||||
|
|
||||||
This class is typically instantiated using the `build_preprocessor`
|
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||||
factory function.
|
"""Standard implementation of the `Preprocessor` protocol."""
|
||||||
|
|
||||||
Attributes
|
samplerate: int
|
||||||
----------
|
|
||||||
audio_loader : AudioLoader
|
|
||||||
The configured audio loader instance used for waveform loading and
|
|
||||||
initial processing.
|
|
||||||
spectrogram_builder : SpectrogramBuilder
|
|
||||||
The configured spectrogram builder instance used for generating
|
|
||||||
spectrograms from waveforms.
|
|
||||||
default_samplerate : int
|
|
||||||
The sample rate (in Hz) assumed for input waveforms when they are
|
|
||||||
provided as raw NumPy arrays without coordinate information (e.g.,
|
|
||||||
when calling `compute_spectrogram` directly with `np.ndarray`).
|
|
||||||
This value is derived from the `AudioConfig` (target resample rate
|
|
||||||
or default if resampling is off) and also serves as documentation
|
|
||||||
for the pipeline's intended operating sample rate. Note that when
|
|
||||||
processing `xr.DataArray` inputs that have coordinate information
|
|
||||||
(the standard internal workflow), the sample rate embedded in the
|
|
||||||
coordinates takes precedence over this default value during
|
|
||||||
spectrogram calculation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
audio_loader: AudioLoader
|
|
||||||
spectrogram_builder: SpectrogramBuilder
|
|
||||||
default_samplerate: int
|
|
||||||
max_freq: float
|
max_freq: float
|
||||||
min_freq: float
|
min_freq: float
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
audio_loader: AudioLoader,
|
audio_pipeline: torch.nn.Module,
|
||||||
spectrogram_builder: SpectrogramBuilder,
|
spectrogram_pipeline: SpectrogramPipeline,
|
||||||
default_samplerate: int,
|
samplerate: int,
|
||||||
max_freq: float,
|
max_freq: float,
|
||||||
min_freq: float,
|
min_freq: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the StandardPreprocessor.
|
super().__init__()
|
||||||
|
self.audio_pipeline = audio_pipeline
|
||||||
Parameters
|
self.spectrogram_pipeline = spectrogram_pipeline
|
||||||
----------
|
self.samplerate = samplerate
|
||||||
audio_loader : AudioLoader
|
|
||||||
An initialized audio loader conforming to the AudioLoader protocol.
|
|
||||||
spectrogram_builder : SpectrogramBuilder
|
|
||||||
An initialized spectrogram builder conforming to the
|
|
||||||
SpectrogramBuilder protocol.
|
|
||||||
default_samplerate : int
|
|
||||||
The sample rate to assume for NumPy array inputs and potentially
|
|
||||||
reflecting the target rate of the audio config.
|
|
||||||
"""
|
|
||||||
self.audio_loader = audio_loader
|
|
||||||
self.spectrogram_builder = spectrogram_builder
|
|
||||||
self.default_samplerate = default_samplerate
|
|
||||||
self.max_freq = max_freq
|
self.max_freq = max_freq
|
||||||
self.min_freq = min_freq
|
self.min_freq = min_freq
|
||||||
|
|
||||||
def load_file_audio(
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
self,
|
wav = self.audio_pipeline(wav)
|
||||||
path: data.PathLike,
|
return self.spectrogram_pipeline(wav)
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load and preprocess *only* the audio waveform from a file path.
|
|
||||||
|
|
||||||
Delegates to the internal `audio_loader`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the audio file.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
A directory prefix if `path` is relative.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The loaded and preprocessed audio waveform (typically first
|
|
||||||
channel).
|
|
||||||
"""
|
|
||||||
return self.audio_loader.load_file(
|
|
||||||
path,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_recording_audio(
|
|
||||||
self,
|
|
||||||
recording: data.Recording,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load and preprocess *only* the audio waveform for a Recording.
|
|
||||||
|
|
||||||
Delegates to the internal `audio_loader`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
recording : data.Recording
|
|
||||||
The Recording object.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The loaded and preprocessed audio waveform (typically first
|
|
||||||
channel).
|
|
||||||
"""
|
|
||||||
return self.audio_loader.load_recording(
|
|
||||||
recording,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_clip_audio(
|
|
||||||
self,
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load and preprocess *only* the audio waveform for a Clip.
|
|
||||||
|
|
||||||
Delegates to the internal `audio_loader`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip : data.Clip
|
|
||||||
The Clip object defining the segment.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The loaded and preprocessed audio waveform segment (typically first
|
|
||||||
channel).
|
|
||||||
"""
|
|
||||||
return self.audio_loader.load_clip(
|
|
||||||
clip,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
def preprocess_file(
|
|
||||||
self,
|
|
||||||
path: data.PathLike,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load audio from a file and compute the final processed spectrogram.
|
|
||||||
|
|
||||||
Performs the full pipeline:
|
|
||||||
|
|
||||||
Load -> Preprocess Audio -> Compute Spectrogram.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the audio file.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
A directory prefix if `path` is relative.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The final processed spectrogram.
|
|
||||||
"""
|
|
||||||
wav = self.load_file_audio(path, audio_dir=audio_dir)
|
|
||||||
return self.spectrogram_builder(
|
|
||||||
wav,
|
|
||||||
samplerate=self.default_samplerate,
|
|
||||||
)
|
|
||||||
|
|
||||||
def preprocess_recording(
|
|
||||||
self,
|
|
||||||
recording: data.Recording,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load audio for a Recording and compute the processed spectrogram.
|
|
||||||
|
|
||||||
Performs the full pipeline for the entire duration of the recording.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
recording : data.Recording
|
|
||||||
The Recording object.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The final processed spectrogram.
|
|
||||||
"""
|
|
||||||
wav = self.load_recording_audio(recording, audio_dir=audio_dir)
|
|
||||||
return self.spectrogram_builder(
|
|
||||||
wav,
|
|
||||||
samplerate=self.default_samplerate,
|
|
||||||
)
|
|
||||||
|
|
||||||
def preprocess_clip(
|
|
||||||
self,
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load audio for a Clip and compute the final processed spectrogram.
|
|
||||||
|
|
||||||
Performs the full pipeline for the specified clip segment.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip : data.Clip
|
|
||||||
The Clip object defining the audio segment.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The final processed spectrogram.
|
|
||||||
"""
|
|
||||||
wav = self.load_clip_audio(clip, audio_dir=audio_dir)
|
|
||||||
return self.spectrogram_builder(
|
|
||||||
wav,
|
|
||||||
samplerate=self.default_samplerate,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_spectrogram(
|
|
||||||
self, wav: Union[xr.DataArray, np.ndarray]
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Compute the spectrogram from a pre-loaded audio waveform.
|
|
||||||
|
|
||||||
Applies the configured spectrogram generation steps
|
|
||||||
(STFT, scaling, etc.) using the internal `spectrogram_builder`.
|
|
||||||
|
|
||||||
If `wav` is a NumPy array, the `default_samplerate` stored in this
|
|
||||||
preprocessor instance will be used. If `wav` is an xarray DataArray
|
|
||||||
with time coordinates, the sample rate derived from those coordinates
|
|
||||||
will take precedence over `default_samplerate`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wav : Union[xr.DataArray, np.ndarray]
|
|
||||||
The input audio waveform. If numpy array, `default_samplerate`
|
|
||||||
stored in this object will be assumed.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The computed spectrogram.
|
|
||||||
"""
|
|
||||||
return self.spectrogram_builder(
|
|
||||||
wav,
|
|
||||||
samplerate=self.default_samplerate,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_preprocessing_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> PreprocessingConfig:
|
|
||||||
"""Load the unified preprocessing configuration from a file.
|
|
||||||
|
|
||||||
Reads a configuration file (YAML) and validates it against the
|
|
||||||
`PreprocessingConfig` schema, potentially extracting data from a nested
|
|
||||||
field.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the configuration file.
|
|
||||||
field : str, optional
|
|
||||||
Dot-separated path to a nested section within the file containing the
|
|
||||||
preprocessing configuration (e.g., "train.preprocessing"). If None, the
|
|
||||||
entire file content is validated as the PreprocessingConfig.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
PreprocessingConfig
|
|
||||||
Loaded and validated preprocessing configuration object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
yaml.YAMLError
|
|
||||||
If the file content is not valid YAML.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the loaded config data does not conform to PreprocessingConfig.
|
|
||||||
KeyError, TypeError
|
|
||||||
If `field` specifies an invalid path.
|
|
||||||
"""
|
|
||||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def build_preprocessor(
|
def build_preprocessor(
|
||||||
config: Optional[PreprocessingConfig] = None,
|
config: Optional[PreprocessingConfig] = None,
|
||||||
) -> PreprocessorProtocol:
|
) -> PreprocessorProtocol:
|
||||||
"""Factory function to build the standard preprocessor from configuration.
|
"""Factory function to build the standard preprocessor from configuration."""
|
||||||
|
|
||||||
Creates instances of the required `AudioLoader` and `SpectrogramBuilder`
|
|
||||||
based on the provided `PreprocessingConfig` (or defaults if config is None),
|
|
||||||
determines the effective default sample rate, and initializes the
|
|
||||||
`StandardPreprocessor`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : PreprocessingConfig, optional
|
|
||||||
The unified preprocessing configuration object. If None, default
|
|
||||||
configurations for audio and spectrogram processing will be used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Preprocessor
|
|
||||||
An initialized `StandardPreprocessor` instance ready to process audio
|
|
||||||
according to the configuration.
|
|
||||||
"""
|
|
||||||
config = config or PreprocessingConfig()
|
config = config or PreprocessingConfig()
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building preprocessor with config: \n{}",
|
"Building preprocessor with config: \n{}",
|
||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
default_samplerate = (
|
samplerate = config.audio.samplerate
|
||||||
config.audio.resample.samplerate
|
|
||||||
if config.audio.resample
|
|
||||||
else TARGET_SAMPLERATE_HZ
|
|
||||||
)
|
|
||||||
|
|
||||||
min_freq = config.spectrogram.frequencies.min_freq
|
min_freq = config.spectrogram.frequencies.min_freq
|
||||||
max_freq = config.spectrogram.frequencies.max_freq
|
max_freq = config.spectrogram.frequencies.max_freq
|
||||||
|
|
||||||
return StandardPreprocessor(
|
return StandardPreprocessor(
|
||||||
audio_loader=build_audio_loader(config.audio),
|
audio_pipeline=build_audio_pipeline(config.audio),
|
||||||
spectrogram_builder=build_spectrogram_builder(config.spectrogram),
|
spectrogram_pipeline=build_spectrogram_pipeline(
|
||||||
default_samplerate=default_samplerate,
|
samplerate, config.spectrogram
|
||||||
|
),
|
||||||
|
samplerate=samplerate,
|
||||||
min_freq=min_freq,
|
min_freq=min_freq,
|
||||||
max_freq=max_freq,
|
max_freq=max_freq,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,53 +1,31 @@
|
|||||||
"""Handles loading and initial preprocessing of audio waveforms.
|
"""Handles loading and initial preprocessing of audio waveforms."""
|
||||||
|
|
||||||
This module provides components for loading audio data associated with
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
`soundevent` objects (Clips, Recordings, or raw files) and applying
|
|
||||||
fundamental waveform processing steps. These steps typically include:
|
|
||||||
|
|
||||||
1. Loading the raw audio data.
|
|
||||||
2. Adjusting the audio clip to a fixed duration (optional).
|
|
||||||
3. Resampling the audio to a target sample rate (optional).
|
|
||||||
4. Centering the waveform (DC offset removal) (optional).
|
|
||||||
5. Scaling the waveform amplitude (optional).
|
|
||||||
|
|
||||||
The processing pipeline is configurable via the `AudioConfig` data structure,
|
|
||||||
allowing for reproducible preprocessing consistent between model training and
|
|
||||||
inference. It uses the `soundevent` library for audio loading and basic array
|
|
||||||
operations, and `scipy` for resampling implementations.
|
|
||||||
|
|
||||||
The primary interface is the `AudioLoader` protocol, with
|
|
||||||
`ConfigurableAudioLoader` providing a concrete implementation driven by the
|
|
||||||
`AudioConfig`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import torch
|
||||||
from numpy.typing import DTypeLike
|
from numpy.typing import DTypeLike
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from scipy.signal import resample, resample_poly
|
from scipy.signal import resample, resample_poly
|
||||||
from soundevent import arrays, audio, data
|
from soundevent import audio, data
|
||||||
from soundevent.arrays import operations as ops
|
|
||||||
from soundfile import LibsndfileError
|
from soundfile import LibsndfileError
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
|
||||||
|
from batdetect2.typing import AudioLoader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ResampleConfig",
|
"ResampleConfig",
|
||||||
"AudioConfig",
|
"AudioConfig",
|
||||||
"ConfigurableAudioLoader",
|
"SoundEventAudioLoader",
|
||||||
"build_audio_loader",
|
"build_audio_loader",
|
||||||
"load_file_audio",
|
"load_file_audio",
|
||||||
"load_recording_audio",
|
"load_recording_audio",
|
||||||
"load_clip_audio",
|
"load_clip_audio",
|
||||||
"adjust_audio_duration",
|
|
||||||
"resample_audio",
|
"resample_audio",
|
||||||
"TARGET_SAMPLERATE_HZ",
|
"TARGET_SAMPLERATE_HZ",
|
||||||
"SCALE_RAW_AUDIO",
|
"SCALE_RAW_AUDIO",
|
||||||
"DEFAULT_DURATION",
|
"DEFAULT_DURATION",
|
||||||
"convert_to_xr",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256_000
|
TARGET_SAMPLERATE_HZ = 256_000
|
||||||
@ -76,192 +54,69 @@ class ResampleConfig(BaseConfig):
|
|||||||
resampling factors differently.
|
resampling factors differently.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
enabled: bool = True
|
||||||
method: str = "poly"
|
method: str = "poly"
|
||||||
|
|
||||||
|
|
||||||
class AudioConfig(BaseConfig):
|
class SoundEventAudioLoader:
|
||||||
"""Configuration for loading and initial audio preprocessing.
|
"""Concrete implementation of the `AudioLoader`."""
|
||||||
|
|
||||||
Defines the sequence of operations applied to raw audio waveforms after
|
|
||||||
loading, controlling steps like resampling, scaling, centering, and
|
|
||||||
duration adjustment.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
resample : ResampleConfig, optional
|
|
||||||
Configuration for resampling. If provided (or defaulted), audio will
|
|
||||||
be resampled to the specified `samplerate` using the specified
|
|
||||||
`method`. If set to `None` in the config file, resampling is skipped.
|
|
||||||
Defaults to a ResampleConfig instance with standard settings.
|
|
||||||
scale : bool, default=False
|
|
||||||
If True, scales the audio waveform using peak normalization so that
|
|
||||||
its maximum absolute amplitude is approximately 1.0. If False
|
|
||||||
(default), no amplitude scaling is applied.
|
|
||||||
center : bool, default=True
|
|
||||||
If True (default), centers the waveform by subtracting its mean
|
|
||||||
(DC offset removal). If False, the waveform is not centered.
|
|
||||||
duration : float, optional
|
|
||||||
If set to a float value (seconds), the loaded audio clip will be
|
|
||||||
adjusted (cropped or padded with zeros) to exactly this duration.
|
|
||||||
If None (default), the original duration is kept.
|
|
||||||
"""
|
|
||||||
|
|
||||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
|
||||||
scale: bool = SCALE_RAW_AUDIO
|
|
||||||
center: bool = False
|
|
||||||
duration: Optional[float] = DEFAULT_DURATION
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigurableAudioLoader:
|
|
||||||
"""Concrete implementation of the `AudioLoader` driven by `AudioConfig`.
|
|
||||||
|
|
||||||
This class loads audio and applies preprocessing steps (resampling,
|
|
||||||
scaling, centering, duration adjustment) based on the settings provided
|
|
||||||
in an `AudioConfig` object during initialization. It delegates the actual
|
|
||||||
work to module-level functions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: AudioConfig,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
config: Optional[ResampleConfig] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the ConfigurableAudioLoader.
|
self.samplerate = samplerate
|
||||||
|
self.config = config or ResampleConfig()
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : AudioConfig
|
|
||||||
The configuration object specifying the desired preprocessing steps
|
|
||||||
and parameters.
|
|
||||||
"""
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def load_file(
|
def load_file(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess audio directly from a file path.
|
"""Load and preprocess audio directly from a file path."""
|
||||||
|
return load_file_audio(
|
||||||
Implements the `AudioLoader.load_file` method by delegating to the
|
path,
|
||||||
`load_file_audio` function, passing the stored configuration.
|
samplerate=self.samplerate,
|
||||||
|
config=self.config,
|
||||||
Parameters
|
audio_dir=audio_dir,
|
||||||
----------
|
)
|
||||||
path : PathLike
|
|
||||||
Path to the audio file.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
A directory prefix if `path` is relative.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Loaded and preprocessed waveform (first channel).
|
|
||||||
"""
|
|
||||||
return load_file_audio(path, config=self.config, audio_dir=audio_dir)
|
|
||||||
|
|
||||||
def load_recording(
|
def load_recording(
|
||||||
self,
|
self,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the entire audio for a Recording object.
|
"""Load and preprocess the entire audio for a Recording object."""
|
||||||
|
|
||||||
Implements the `AudioLoader.load_recording` method by delegating to the
|
|
||||||
`load_recording_audio` function, passing the stored configuration.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
recording : data.Recording
|
|
||||||
The Recording object.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Loaded and preprocessed waveform (first channel).
|
|
||||||
"""
|
|
||||||
return load_recording_audio(
|
return load_recording_audio(
|
||||||
recording, config=self.config, audio_dir=audio_dir
|
recording,
|
||||||
|
samplerate=self.samplerate,
|
||||||
|
config=self.config,
|
||||||
|
audio_dir=audio_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_clip(
|
def load_clip(
|
||||||
self,
|
self,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the audio segment defined by a Clip object.
|
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||||
|
return load_clip_audio(
|
||||||
Implements the `AudioLoader.load_clip` method by delegating to the
|
clip,
|
||||||
`load_clip_audio` function, passing the stored configuration.
|
samplerate=self.samplerate,
|
||||||
|
config=self.config,
|
||||||
Parameters
|
audio_dir=audio_dir,
|
||||||
----------
|
)
|
||||||
clip : data.Clip
|
|
||||||
The Clip object specifying the segment.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Loaded and preprocessed waveform segment (first channel).
|
|
||||||
"""
|
|
||||||
return load_clip_audio(clip, config=self.config, audio_dir=audio_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def build_audio_loader(
|
|
||||||
config: AudioConfig,
|
|
||||||
) -> AudioLoader:
|
|
||||||
"""Factory function to create an AudioLoader based on configuration.
|
|
||||||
|
|
||||||
Instantiates and returns a `ConfigurableAudioLoader` initialized with
|
|
||||||
the provided `AudioConfig`. The return type is `AudioLoader`, adhering
|
|
||||||
to the protocol.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : AudioConfig
|
|
||||||
The configuration object specifying preprocessing steps.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
AudioLoader
|
|
||||||
An instance of `ConfigurableAudioLoader` ready to load and process audio
|
|
||||||
according to the configuration.
|
|
||||||
"""
|
|
||||||
return ConfigurableAudioLoader(config=config)
|
|
||||||
|
|
||||||
|
|
||||||
def load_file_audio(
|
def load_file_audio(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
config: Optional[AudioConfig] = None,
|
samplerate: Optional[int] = None,
|
||||||
|
config: Optional[ResampleConfig] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess audio from a file path using specified config.
|
"""Load and preprocess audio from a file path using specified config."""
|
||||||
|
|
||||||
Creates a `soundevent.data.Recording` object from the file path and then
|
|
||||||
delegates the loading and processing to `load_recording_audio`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the audio file.
|
|
||||||
config : AudioConfig, optional
|
|
||||||
Audio processing configuration. If None, default settings defined
|
|
||||||
in `AudioConfig` are used.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory prefix if `path` is relative.
|
|
||||||
dtype : DTypeLike, default=np.float32
|
|
||||||
Target NumPy data type for the loaded audio array.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Loaded and preprocessed waveform (first channel only).
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
recording = data.Recording.from_file(path)
|
recording = data.Recording.from_file(path)
|
||||||
except LibsndfileError as e:
|
except LibsndfileError as e:
|
||||||
@ -271,6 +126,7 @@ def load_file_audio(
|
|||||||
|
|
||||||
return load_recording_audio(
|
return load_recording_audio(
|
||||||
recording,
|
recording,
|
||||||
|
samplerate=samplerate,
|
||||||
config=config,
|
config=config,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
@ -279,33 +135,12 @@ def load_file_audio(
|
|||||||
|
|
||||||
def load_recording_audio(
|
def load_recording_audio(
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
config: Optional[AudioConfig] = None,
|
samplerate: Optional[int] = None,
|
||||||
|
config: Optional[ResampleConfig] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the entire audio content of a recording using config.
|
"""Load and preprocess the entire audio content of a recording using config."""
|
||||||
|
|
||||||
Creates a `soundevent.data.Clip` spanning the full duration of the
|
|
||||||
recording and then delegates the loading and processing to
|
|
||||||
`load_clip_audio`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
recording : data.Recording
|
|
||||||
The Recording object containing metadata and file path.
|
|
||||||
config : AudioConfig, optional
|
|
||||||
Audio processing configuration. If None, default settings are used.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file, used if the path in `recording`
|
|
||||||
is relative.
|
|
||||||
dtype : DTypeLike, default=np.float32
|
|
||||||
Target NumPy data type for the loaded audio array.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Loaded and preprocessed waveform (first channel only).
|
|
||||||
"""
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
start_time=0,
|
start_time=0,
|
||||||
@ -313,6 +148,7 @@ def load_recording_audio(
|
|||||||
)
|
)
|
||||||
return load_clip_audio(
|
return load_clip_audio(
|
||||||
clip,
|
clip,
|
||||||
|
samplerate=samplerate,
|
||||||
config=config,
|
config=config,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
@ -321,257 +157,66 @@ def load_recording_audio(
|
|||||||
|
|
||||||
def load_clip_audio(
|
def load_clip_audio(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
config: Optional[AudioConfig] = None,
|
samplerate: Optional[int] = None,
|
||||||
|
config: Optional[ResampleConfig] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> xr.DataArray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess a specific audio clip segment based on config.
|
"""Load and preprocess a specific audio clip segment based on config."""
|
||||||
|
try:
|
||||||
This is the core function performing the configured processing pipeline:
|
wav = (
|
||||||
1. Loads the specified clip segment using `soundevent.audio.load_clip`.
|
audio.load_clip(clip, audio_dir=audio_dir)
|
||||||
2. Selects the first audio channel.
|
.sel(channel=0)
|
||||||
3. Resamples if `config.resample` is configured.
|
.astype(dtype)
|
||||||
4. Centers (DC offset removal) if `config.center` is True.
|
|
||||||
5. Scales (peak normalization) if `config.scale` is True.
|
|
||||||
6. Adjusts duration (crop/pad) if `config.duration` is set.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip : data.Clip
|
|
||||||
The Clip object defining the audio segment and source recording.
|
|
||||||
config : AudioConfig, optional
|
|
||||||
Audio processing configuration. If None, a default `AudioConfig` is
|
|
||||||
used.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the source audio file specified in the clip's
|
|
||||||
recording.
|
|
||||||
dtype : DTypeLike, default=np.float32
|
|
||||||
Target NumPy data type for the processed audio array.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The loaded and preprocessed waveform segment as an xarray DataArray
|
|
||||||
with time coordinates.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the underlying audio file cannot be found.
|
|
||||||
Exception
|
|
||||||
If audio loading or processing fails for other reasons (e.g., invalid
|
|
||||||
format, resampling error).
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
- **Mono Processing:** This function currently loads and processes only the
|
|
||||||
**first channel** (channel 0) of the audio file. Any other channels
|
|
||||||
are ignored.
|
|
||||||
"""
|
|
||||||
config = config or AudioConfig()
|
|
||||||
|
|
||||||
with xr.set_options(keep_attrs=True):
|
|
||||||
try:
|
|
||||||
wav = (
|
|
||||||
audio.load_clip(clip, audio_dir=audio_dir)
|
|
||||||
.sel(channel=0)
|
|
||||||
.astype(dtype)
|
|
||||||
)
|
|
||||||
except LibsndfileError as e:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Could not load the recording at path: {clip.recording.path}. "
|
|
||||||
f"Error: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if config.resample:
|
|
||||||
wav = resample_audio(
|
|
||||||
wav,
|
|
||||||
samplerate=config.resample.samplerate,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.center:
|
|
||||||
wav = ops.center(wav)
|
|
||||||
|
|
||||||
if config.scale:
|
|
||||||
wav = scale_audio(wav)
|
|
||||||
|
|
||||||
if config.duration is not None:
|
|
||||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
|
||||||
|
|
||||||
return wav.astype(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def scale_audio(
|
|
||||||
wave: xr.DataArray,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""
|
|
||||||
Scale the audio waveform to have a maximum absolute value of 1.0.
|
|
||||||
|
|
||||||
This function normalizes the waveform by dividing it by its maximum
|
|
||||||
absolute value. If the maximum value is zero, the waveform is returned
|
|
||||||
unchanged. Also known as peak normalization, this process ensures that the
|
|
||||||
waveform's amplitude is within a standard range, which can be useful for
|
|
||||||
audio processing and analysis.
|
|
||||||
|
|
||||||
"""
|
|
||||||
max_val = np.max(np.abs(wave))
|
|
||||||
|
|
||||||
if max_val == 0:
|
|
||||||
return wave
|
|
||||||
|
|
||||||
return ops.scale(wave, 1 / max_val)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_audio_duration(
|
|
||||||
wave: xr.DataArray,
|
|
||||||
duration: float,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Adjust the duration of an audio waveform array via cropping or padding.
|
|
||||||
|
|
||||||
If the current duration is longer than the target, it crops the array
|
|
||||||
from the beginning. If shorter, it pads the array with zeros at the end
|
|
||||||
using `soundevent.arrays.extend_dim`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wave : xr.DataArray
|
|
||||||
The input audio waveform with a 'time' dimension and coordinates.
|
|
||||||
duration : float
|
|
||||||
The target duration in seconds.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The waveform adjusted to the target duration. Returns the input
|
|
||||||
unmodified if duration already matches or if the wave is empty.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `duration` is negative.
|
|
||||||
"""
|
|
||||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
|
||||||
step = arrays.get_dim_step(wave, dim="time")
|
|
||||||
current_duration = end_time - start_time + step
|
|
||||||
|
|
||||||
if current_duration == duration:
|
|
||||||
return wave
|
|
||||||
|
|
||||||
with xr.set_options(keep_attrs=True):
|
|
||||||
if current_duration > duration:
|
|
||||||
return arrays.crop_dim(
|
|
||||||
wave,
|
|
||||||
dim="time",
|
|
||||||
start=start_time,
|
|
||||||
stop=start_time + duration - step / 2,
|
|
||||||
right_closed=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return arrays.extend_dim(
|
|
||||||
wave,
|
|
||||||
dim="time",
|
|
||||||
start=start_time,
|
|
||||||
stop=start_time + duration - step / 2,
|
|
||||||
eps=0,
|
|
||||||
right_closed=True,
|
|
||||||
)
|
)
|
||||||
|
except LibsndfileError as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not load the recording at path: {clip.recording.path}. "
|
||||||
|
f"Error: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if not config or not config.enabled or samplerate is None:
|
||||||
|
return wav.data.astype(dtype)
|
||||||
|
|
||||||
|
sr = int(1 / wav.time.attrs["step"])
|
||||||
|
return resample_audio(
|
||||||
|
wav.data,
|
||||||
|
sr=sr,
|
||||||
|
samplerate=samplerate,
|
||||||
|
method=config.method,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
def resample_audio(
|
||||||
wav: xr.DataArray,
|
wav: np.ndarray,
|
||||||
|
sr: int,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
method: str = "poly",
|
method: str = "poly",
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
) -> np.ndarray:
|
||||||
) -> xr.DataArray:
|
"""Resample an audio waveform DataArray to a target sample rate."""
|
||||||
"""Resample an audio waveform DataArray to a target sample rate.
|
if sr == samplerate:
|
||||||
|
return wav
|
||||||
Updates the 'time' coordinate axis according to the new sample rate and
|
|
||||||
number of samples. Uses either polyphase (`scipy.signal.resample_poly`)
|
|
||||||
or Fourier method (`scipy.signal.resample`) based on the `method`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wav : xr.DataArray
|
|
||||||
Input audio waveform with 'time' dimension and coordinates.
|
|
||||||
samplerate : int, default=TARGET_SAMPLERATE_HZ
|
|
||||||
Target sample rate in Hz.
|
|
||||||
method : str, default="poly"
|
|
||||||
Resampling algorithm: "poly" or "fourier".
|
|
||||||
dtype : DTypeLike, default=np.float32
|
|
||||||
Target data type for the resampled array.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Resampled waveform with updated time coordinates. Returns the input
|
|
||||||
unmodified (but dtype cast) if the sample rate is already correct or
|
|
||||||
if the input array is empty.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `wav` lacks a 'time' dimension, the original sample rate cannot
|
|
||||||
be determined, `samplerate` is non-positive, or `method` is invalid.
|
|
||||||
"""
|
|
||||||
if "time" not in wav.dims:
|
|
||||||
raise ValueError("Audio must have a time dimension")
|
|
||||||
|
|
||||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
|
||||||
step = arrays.get_dim_step(wav, dim="time")
|
|
||||||
original_samplerate = int(1 / step)
|
|
||||||
|
|
||||||
if original_samplerate == samplerate:
|
|
||||||
return wav.astype(dtype).assign_attrs(original_samplerate=samplerate)
|
|
||||||
|
|
||||||
if method == "poly":
|
if method == "poly":
|
||||||
resampled = resample_audio_poly(
|
return resample_audio_poly(
|
||||||
wav,
|
wav,
|
||||||
sr_orig=original_samplerate,
|
sr_orig=sr,
|
||||||
sr_new=samplerate,
|
sr_new=samplerate,
|
||||||
axis=time_axis,
|
|
||||||
)
|
)
|
||||||
elif method == "fourier":
|
elif method == "fourier":
|
||||||
resampled = resample_audio_fourier(
|
return resample_audio_fourier(
|
||||||
wav,
|
wav,
|
||||||
sr_orig=original_samplerate,
|
sr_orig=sr,
|
||||||
sr_new=samplerate,
|
sr_new=samplerate,
|
||||||
axis=time_axis,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Resampling method '{method}' not implemented"
|
f"Resampling method '{method}' not implemented"
|
||||||
)
|
)
|
||||||
|
|
||||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
|
||||||
times = np.linspace(
|
|
||||||
start,
|
|
||||||
stop + step,
|
|
||||||
len(resampled),
|
|
||||||
endpoint=False,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
data=resampled.astype(dtype),
|
|
||||||
dims=wav.dims,
|
|
||||||
coords={
|
|
||||||
**wav.coords,
|
|
||||||
"time": arrays.create_time_dim_from_array(
|
|
||||||
times,
|
|
||||||
samplerate=samplerate,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
attrs={
|
|
||||||
**wav.attrs,
|
|
||||||
"samplerate": samplerate,
|
|
||||||
"original_samplerate": original_samplerate,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio_poly(
|
def resample_audio_poly(
|
||||||
array: xr.DataArray,
|
array: np.ndarray,
|
||||||
sr_orig: int,
|
sr_orig: int,
|
||||||
sr_new: int,
|
sr_new: int,
|
||||||
axis: int = -1,
|
axis: int = -1,
|
||||||
@ -605,7 +250,7 @@ def resample_audio_poly(
|
|||||||
"""
|
"""
|
||||||
gcd = np.gcd(sr_orig, sr_new)
|
gcd = np.gcd(sr_orig, sr_new)
|
||||||
return resample_poly(
|
return resample_poly(
|
||||||
array.values,
|
array,
|
||||||
sr_new // gcd,
|
sr_new // gcd,
|
||||||
sr_orig // gcd,
|
sr_orig // gcd,
|
||||||
axis=axis,
|
axis=axis,
|
||||||
@ -613,7 +258,7 @@ def resample_audio_poly(
|
|||||||
|
|
||||||
|
|
||||||
def resample_audio_fourier(
|
def resample_audio_fourier(
|
||||||
array: xr.DataArray,
|
array: np.ndarray,
|
||||||
sr_orig: int,
|
sr_orig: int,
|
||||||
sr_new: int,
|
sr_new: int,
|
||||||
axis: int = -1,
|
axis: int = -1,
|
||||||
@ -649,66 +294,89 @@ def resample_audio_fourier(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_xr(
|
class CenterAudioConfig(BaseConfig):
|
||||||
wav: np.ndarray,
|
name: Literal["center_audio"] = "center_audio"
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleAudioConfig(BaseConfig):
|
||||||
|
name: Literal["scale_audio"] = "scale_audio"
|
||||||
|
|
||||||
|
|
||||||
|
class FixDurationConfig(BaseConfig):
|
||||||
|
name: Literal["fix_duration"] = "fix_duration"
|
||||||
|
duration: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class FixDuration(torch.nn.Module):
|
||||||
|
def __init__(self, samplerate: int, duration: float):
|
||||||
|
super().__init__()
|
||||||
|
self.samplerate = samplerate
|
||||||
|
self.duration = duration
|
||||||
|
self.length = int(samplerate * duration)
|
||||||
|
|
||||||
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
length = wav.shape[-1]
|
||||||
|
|
||||||
|
if length == self.length:
|
||||||
|
return wav
|
||||||
|
|
||||||
|
if length > self.length:
|
||||||
|
return wav[: self.length]
|
||||||
|
|
||||||
|
return torch.nn.functional.pad(wav, (0, self.length - length))
|
||||||
|
|
||||||
|
|
||||||
|
AudioTransform = Annotated[
|
||||||
|
Union[
|
||||||
|
FixDurationConfig,
|
||||||
|
ScaleAudioConfig,
|
||||||
|
CenterAudioConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AudioConfig(BaseConfig):
|
||||||
|
"""Configuration for loading and initial audio preprocessing."""
|
||||||
|
|
||||||
|
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
|
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||||
|
transforms: List[AudioTransform] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def build_audio_loader(
|
||||||
|
config: Optional[AudioConfig] = None,
|
||||||
|
) -> AudioLoader:
|
||||||
|
"""Factory function to create an AudioLoader based on configuration."""
|
||||||
|
config = config or AudioConfig()
|
||||||
|
return SoundEventAudioLoader(
|
||||||
|
samplerate=config.samplerate,
|
||||||
|
config=config.resample,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_audio_transform_step(
|
||||||
|
config: AudioTransform,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
) -> torch.nn.Module:
|
||||||
) -> xr.DataArray:
|
if config.name == "fix_duration":
|
||||||
"""Convert a NumPy array to an xarray DataArray with time coordinates.
|
return FixDuration(samplerate=samplerate, duration=config.duration)
|
||||||
|
|
||||||
Parameters
|
if config.name == "scale_audio":
|
||||||
----------
|
return PeakNormalize()
|
||||||
wav : np.ndarray
|
|
||||||
The input waveform array. Expected to be 1D or 2D (with the first
|
|
||||||
axis as the channel dimension).
|
|
||||||
samplerate : int
|
|
||||||
The sample rate in Hz.
|
|
||||||
dtype : DTypeLike, default=np.float32
|
|
||||||
Target data type for the xarray DataArray.
|
|
||||||
|
|
||||||
Returns
|
if config.name == "center_audio":
|
||||||
-------
|
return CenterTensor()
|
||||||
xr.DataArray
|
|
||||||
The waveform as an xarray DataArray with time coordinates.
|
|
||||||
|
|
||||||
Raises
|
raise NotImplementedError(
|
||||||
------
|
f"Audio preprocessing step {config.name} not implemented"
|
||||||
ValueError
|
|
||||||
If the input array is not 1D or 2D, or if the sample rate is
|
|
||||||
non-positive. If the input array is empty.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if wav.ndim == 2:
|
|
||||||
wav = wav[0, :]
|
|
||||||
|
|
||||||
if wav.ndim != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Audio must be 1D array or 2D channel where the first "
|
|
||||||
"axis is the channel dimension"
|
|
||||||
)
|
|
||||||
|
|
||||||
if wav.size == 0:
|
|
||||||
raise ValueError("Audio array is empty")
|
|
||||||
|
|
||||||
if samplerate <= 0:
|
|
||||||
raise ValueError("Sample rate must be positive")
|
|
||||||
|
|
||||||
times = np.linspace(
|
|
||||||
0,
|
|
||||||
wav.shape[0] / samplerate,
|
|
||||||
wav.shape[0],
|
|
||||||
endpoint=False,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
data=wav.astype(dtype),
|
def build_audio_pipeline(config: AudioConfig) -> torch.nn.Module:
|
||||||
dims=["time"],
|
return torch.nn.Sequential(
|
||||||
coords={
|
*[
|
||||||
"time": arrays.create_time_dim_from_array(
|
build_audio_transform_step(step, samplerate=config.samplerate)
|
||||||
times,
|
for step in config.transforms
|
||||||
samplerate=samplerate,
|
]
|
||||||
),
|
|
||||||
},
|
|
||||||
attrs={"samplerate": samplerate},
|
|
||||||
)
|
)
|
||||||
|
|||||||
24
src/batdetect2/preprocess/common.py
Normal file
24
src/batdetect2/preprocess/common.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CenterTensor",
|
||||||
|
"PeakNormalize",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CenterTensor(torch.nn.Module):
|
||||||
|
def forward(self, wav: torch.Tensor):
|
||||||
|
return wav - wav.mean()
|
||||||
|
|
||||||
|
|
||||||
|
class PeakNormalize(torch.nn.Module):
|
||||||
|
def forward(self, wav: torch.Tensor):
|
||||||
|
max_value = wav.abs().min()
|
||||||
|
|
||||||
|
denominator = torch.where(
|
||||||
|
max_value == 0,
|
||||||
|
torch.tensor(1.0, device=wav.device, dtype=wav.dtype),
|
||||||
|
max_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
return wav / denominator
|
||||||
@ -1,48 +1,22 @@
|
|||||||
"""Computes spectrograms from audio waveforms with configurable parameters.
|
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||||
|
|
||||||
This module provides the functionality to convert preprocessed audio waveforms
|
from typing import Annotated, Callable, List, Literal, Optional, Union
|
||||||
(typically output from the `batdetect2.preprocessing.audio` module) into
|
|
||||||
spectrogram representations suitable for input into deep learning models like
|
|
||||||
BatDetect2.
|
|
||||||
|
|
||||||
It offers a configurable pipeline including:
|
|
||||||
1. Short-Time Fourier Transform (STFT) calculation to get magnitude.
|
|
||||||
2. Frequency axis cropping to a relevant range.
|
|
||||||
3. Per-Channel Energy Normalization (PCEN) (optional).
|
|
||||||
4. Amplitude scaling/representation (dB, power, or linear amplitude).
|
|
||||||
5. Simple spectral mean subtraction denoising (optional).
|
|
||||||
6. Resizing to target dimensions (optional).
|
|
||||||
7. Final peak normalization (optional).
|
|
||||||
|
|
||||||
Configuration is managed via the `SpectrogramConfig` class, allowing for
|
|
||||||
reproducible spectrogram generation consistent between training and inference.
|
|
||||||
The core computation is performed by `compute_spectrogram`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Callable, Literal, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import torch
|
||||||
from numpy.typing import DTypeLike
|
import torchaudio
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from scipy import signal
|
|
||||||
from soundevent import arrays, audio
|
|
||||||
from soundevent.arrays import operations as ops
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.preprocess.audio import convert_to_xr
|
from batdetect2.preprocess.common import PeakNormalize
|
||||||
from batdetect2.typing.preprocess import SpectrogramBuilder
|
from batdetect2.typing.preprocess import SpectrogramBuilder
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"STFTConfig",
|
"STFTConfig",
|
||||||
"FrequencyConfig",
|
"FrequencyConfig",
|
||||||
"SpecSizeConfig",
|
|
||||||
"PcenConfig",
|
"PcenConfig",
|
||||||
"SpectrogramConfig",
|
"SpectrogramConfig",
|
||||||
"ConfigurableSpectrogramBuilder",
|
|
||||||
"build_spectrogram_builder",
|
"build_spectrogram_builder",
|
||||||
"compute_spectrogram",
|
|
||||||
"get_spectrogram_resolution",
|
|
||||||
"MIN_FREQ",
|
"MIN_FREQ",
|
||||||
"MAX_FREQ",
|
"MAX_FREQ",
|
||||||
]
|
]
|
||||||
@ -79,6 +53,47 @@ class STFTConfig(BaseConfig):
|
|||||||
window_fn: str = "hann"
|
window_fn: str = "hann"
|
||||||
|
|
||||||
|
|
||||||
|
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||||
|
if name == "hann":
|
||||||
|
return torch.hann_window
|
||||||
|
|
||||||
|
if name == "hamming":
|
||||||
|
return torch.hamming_window
|
||||||
|
|
||||||
|
if name == "kaiser":
|
||||||
|
return torch.kaiser_window
|
||||||
|
|
||||||
|
if name == "blackman":
|
||||||
|
return torch.blackman_window
|
||||||
|
|
||||||
|
if name == "bartlett":
|
||||||
|
return torch.bartlett_window
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Spectrogram window function {name} not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _spec_params_from_config(samplerate: int, conf: STFTConfig):
|
||||||
|
n_fft = int(samplerate * conf.window_duration)
|
||||||
|
hop_length = int(n_fft * (1 - conf.window_overlap))
|
||||||
|
return n_fft, hop_length
|
||||||
|
|
||||||
|
|
||||||
|
def build_spectrogram_builder(
|
||||||
|
samplerate: int,
|
||||||
|
conf: STFTConfig,
|
||||||
|
) -> SpectrogramBuilder:
|
||||||
|
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
||||||
|
return torchaudio.transforms.Spectrogram(
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
window_fn=get_spectrogram_window(conf.window_fn),
|
||||||
|
center=False,
|
||||||
|
power=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FrequencyConfig(BaseConfig):
|
class FrequencyConfig(BaseConfig):
|
||||||
"""Configuration for frequency axis parameters.
|
"""Configuration for frequency axis parameters.
|
||||||
|
|
||||||
@ -96,644 +111,282 @@ class FrequencyConfig(BaseConfig):
|
|||||||
min_freq: int = Field(default=10_000, ge=0)
|
min_freq: int = Field(default=10_000, ge=0)
|
||||||
|
|
||||||
|
|
||||||
class SpecSizeConfig(BaseConfig):
|
def _frequency_to_index(
|
||||||
"""Configuration for the final size and shape of the spectrogram.
|
freq: float,
|
||||||
|
samplerate: int,
|
||||||
|
n_fft: int,
|
||||||
|
) -> Optional[int]:
|
||||||
|
alpha = freq * 2 / samplerate
|
||||||
|
height = np.floor(n_fft / 2) + 1
|
||||||
|
index = int(np.floor(alpha * height))
|
||||||
|
|
||||||
Attributes
|
if index <= 0:
|
||||||
----------
|
return None
|
||||||
height : int, default=128
|
|
||||||
Target height of the spectrogram in pixels (frequency bins). The
|
|
||||||
frequency axis will be resized (e.g., via interpolation) to match this
|
|
||||||
height after frequency cropping and amplitude scaling. Must be > 0.
|
|
||||||
resize_factor : float, optional
|
|
||||||
Factor by which to resize the spectrogram along the time axis *after*
|
|
||||||
STFT calculation. A value of 0.5 halves the number of time bins,
|
|
||||||
2.0 doubles it. If None (default), no resizing along the time axis is
|
|
||||||
performed relative to the STFT output width. Must be > 0 if provided.
|
|
||||||
"""
|
|
||||||
|
|
||||||
height: int = 128
|
if index >= height:
|
||||||
resize_factor: Optional[float] = 0.5
|
return None
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyClip(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
low_index: Optional[int] = None,
|
||||||
|
high_index: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.low_index = low_index
|
||||||
|
self.high_index = high_index
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
return spec[self.low_index : self.high_index]
|
||||||
|
|
||||||
|
|
||||||
class PcenConfig(BaseConfig):
|
class PcenConfig(BaseConfig):
|
||||||
"""Configuration for Per-Channel Energy Normalization (PCEN).
|
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
||||||
|
|
||||||
PCEN is an adaptive gain control method that can help emphasize transients
|
name: Literal["pcen"] = "pcen"
|
||||||
and suppress stationary noise. Applied after STFT and frequency cropping,
|
time_constant: float = 0.4
|
||||||
but before final amplitude scaling (dB, power, amplitude).
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
time_constant : float, default=0.4
|
|
||||||
Time constant (in seconds) for the PCEN smoothing filter. Controls
|
|
||||||
how quickly the normalization adapts to energy changes.
|
|
||||||
gain : float, default=0.98
|
|
||||||
Gain factor (alpha). Controls the adaptive gain component.
|
|
||||||
bias : float, default=2.0
|
|
||||||
Bias factor (delta). Added before the exponentiation.
|
|
||||||
power : float, default=0.5
|
|
||||||
Exponent (r). Controls the compression characteristic.
|
|
||||||
"""
|
|
||||||
|
|
||||||
time_constant: float = 0.01
|
|
||||||
gain: float = 0.98
|
gain: float = 0.98
|
||||||
bias: float = 2
|
bias: float = 2
|
||||||
power: float = 0.5
|
power: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramConfig(BaseConfig):
|
class PCEN(torch.nn.Module):
|
||||||
"""Unified configuration for spectrogram generation pipeline.
|
|
||||||
|
|
||||||
Aggregates settings for all steps involved in converting a preprocessed
|
|
||||||
audio waveform into a final spectrogram representation suitable for model
|
|
||||||
input.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
stft : STFTConfig
|
|
||||||
Configuration for the initial Short-Time Fourier Transform.
|
|
||||||
Defaults to standard settings via `STFTConfig`.
|
|
||||||
frequencies : FrequencyConfig
|
|
||||||
Configuration for cropping the frequency range after STFT.
|
|
||||||
Defaults to standard settings via `FrequencyConfig`.
|
|
||||||
pcen : PcenConfig, optional
|
|
||||||
Configuration for applying Per-Channel Energy Normalization (PCEN). If
|
|
||||||
provided, PCEN is applied after frequency cropping. If None or omitted
|
|
||||||
(default), PCEN is skipped.
|
|
||||||
scale : Literal["dB", "amplitude", "power"], default="amplitude"
|
|
||||||
Determines the final amplitude representation *after* optional PCEN.
|
|
||||||
- "amplitude": Use linear magnitude values (output of STFT or PCEN).
|
|
||||||
- "power": Use power values (magnitude squared).
|
|
||||||
- "dB": Use logarithmic (decibel-like) scaling applied to the magnitude
|
|
||||||
(or PCEN output if enabled). Calculated as `log1p(C * S)`.
|
|
||||||
size : SpecSizeConfig, optional
|
|
||||||
Configuration for resizing the spectrogram dimensions
|
|
||||||
(frequency height, optional time width factor). Applied after PCEN and
|
|
||||||
scaling. If None (default), no resizing is performed.
|
|
||||||
spectral_mean_substraction : bool, default=True
|
|
||||||
If True (default), applies simple spectral mean subtraction denoising
|
|
||||||
*after* PCEN and amplitude scaling, but *before* resizing.
|
|
||||||
peak_normalize : bool, default=False
|
|
||||||
If True, applies a final peak normalization to the spectrogram *after*
|
|
||||||
all other steps (including resizing), scaling the overall maximum value
|
|
||||||
to 1.0. If False (default), this final normalization is skipped.
|
|
||||||
"""
|
|
||||||
|
|
||||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
|
||||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
|
||||||
pcen: Optional[PcenConfig] = Field(default_factory=PcenConfig)
|
|
||||||
scale: Literal["dB", "amplitude", "power"] = "amplitude"
|
|
||||||
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
|
||||||
spectral_mean_substraction: bool = True
|
|
||||||
peak_normalize: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigurableSpectrogramBuilder(SpectrogramBuilder):
|
|
||||||
"""Implementation of `SpectrogramBuilder` driven by `SpectrogramConfig`.
|
|
||||||
|
|
||||||
This class computes spectrograms according to the parameters specified in a
|
|
||||||
`SpectrogramConfig` object provided during initialization. It handles both
|
|
||||||
numpy array and xarray DataArray inputs for the waveform.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: SpectrogramConfig,
|
smoothing_constant: float,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
gain: float = 0.98,
|
||||||
) -> None:
|
bias: float = 2.0,
|
||||||
"""Initialize the ConfigurableSpectrogramBuilder.
|
power: float = 0.5,
|
||||||
|
eps: float = 1e-6,
|
||||||
Parameters
|
dtype=torch.float64,
|
||||||
----------
|
):
|
||||||
config : SpectrogramConfig
|
super().__init__()
|
||||||
The configuration object specifying all spectrogram parameters.
|
self.smoothing_constant = smoothing_constant
|
||||||
dtype : DTypeLike, default=np.float32
|
self.gain = torch.tensor(gain, dtype=dtype)
|
||||||
The target NumPy data type for the computed spectrogram array.
|
self.bias = torch.tensor(bias, dtype=dtype)
|
||||||
"""
|
self.power = torch.tensor(power, dtype=dtype)
|
||||||
self.config = config
|
self.eps = torch.tensor(eps, dtype=dtype)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
def __call__(
|
self._b = torch.tensor([self.smoothing_constant, 0.0], dtype=dtype)
|
||||||
self,
|
self._a = torch.tensor(
|
||||||
wav: Union[np.ndarray, xr.DataArray],
|
[1.0, self.smoothing_constant - 1.0], dtype=dtype
|
||||||
samplerate: Optional[int] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Generate a spectrogram from an audio waveform using the config.
|
|
||||||
|
|
||||||
Implements the `SpectrogramBuilder` protocol. If the input `wav` is
|
|
||||||
a numpy array, `samplerate` must be provided; the array will be
|
|
||||||
converted to an xarray DataArray internally. If `wav` is already an
|
|
||||||
xarray DataArray with time coordinates, `samplerate` is ignored.
|
|
||||||
Delegates the main computation to `compute_spectrogram`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wav : Union[np.ndarray, xr.DataArray]
|
|
||||||
The input audio waveform.
|
|
||||||
samplerate : int, optional
|
|
||||||
The sample rate in Hz (required only if `wav` is np.ndarray).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The computed spectrogram.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `wav` is np.ndarray and `samplerate` is None.
|
|
||||||
"""
|
|
||||||
if isinstance(wav, np.ndarray):
|
|
||||||
if samplerate is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Samplerate must be provided when passing a numpy array."
|
|
||||||
)
|
|
||||||
wav = convert_to_xr(
|
|
||||||
wav,
|
|
||||||
samplerate=samplerate,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
return compute_spectrogram(
|
|
||||||
wav,
|
|
||||||
config=self.config,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
S = spec.to(self.dtype) * 2**31
|
||||||
|
|
||||||
def build_spectrogram_builder(
|
M = (
|
||||||
config: SpectrogramConfig,
|
torchaudio.functional.lfilter(
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
S,
|
||||||
) -> SpectrogramBuilder:
|
self._a,
|
||||||
"""Factory function to create a SpectrogramBuilder based on configuration.
|
self._b,
|
||||||
|
clamp=False,
|
||||||
|
)
|
||||||
|
).clamp(min=0)
|
||||||
|
|
||||||
Instantiates and returns a `ConfigurableSpectrogramBuilder` initialized
|
smooth = torch.exp(
|
||||||
with the provided `SpectrogramConfig`.
|
-self.gain * (torch.log(self.eps) + torch.log1p(M / self.eps))
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : SpectrogramConfig
|
|
||||||
The configuration object specifying spectrogram parameters.
|
|
||||||
dtype : DTypeLike, default=np.float32
|
|
||||||
The target NumPy data type for the computed spectrogram array.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
SpectrogramBuilder
|
|
||||||
An instance of `ConfigurableSpectrogramBuilder` ready to compute
|
|
||||||
spectrograms according to the configuration.
|
|
||||||
"""
|
|
||||||
return ConfigurableSpectrogramBuilder(config=config, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram(
|
|
||||||
wav: xr.DataArray,
|
|
||||||
config: Optional[SpectrogramConfig] = None,
|
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Compute a spectrogram from a waveform using specified configurations.
|
|
||||||
|
|
||||||
Applies a sequence of operations based on the `config`:
|
|
||||||
1. Compute STFT magnitude (`stft`).
|
|
||||||
2. Crop frequency axis (`crop_spectrogram_frequencies`).
|
|
||||||
3. Apply PCEN if configured (`apply_pcen`).
|
|
||||||
4. Apply final amplitude scaling (dB, power, amplitude)
|
|
||||||
(`scale_spectrogram`).
|
|
||||||
5. Apply spectral mean subtraction denoising if enabled.
|
|
||||||
6. Resize dimensions if specified (`resize_spectrogram`).
|
|
||||||
7. Apply final peak normalization if enabled.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wav : xr.DataArray
|
|
||||||
Input audio waveform with a 'time' dimension and coordinates from
|
|
||||||
which the sample rate can be inferred.
|
|
||||||
config : SpectrogramConfig, optional
|
|
||||||
Configuration object specifying spectrogram parameters. If None,
|
|
||||||
default settings from `SpectrogramConfig` are used.
|
|
||||||
dtype : DTypeLike, default=np.float32
|
|
||||||
Target NumPy data type for the final spectrogram array.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The computed and processed spectrogram with 'time' and 'frequency'
|
|
||||||
coordinates.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `wav` lacks necessary 'time' coordinates or dimensions.
|
|
||||||
"""
|
|
||||||
config = config or SpectrogramConfig()
|
|
||||||
|
|
||||||
with xr.set_options(keep_attrs=True):
|
|
||||||
spec = stft(
|
|
||||||
wav,
|
|
||||||
window_duration=config.stft.window_duration,
|
|
||||||
window_overlap=config.stft.window_overlap,
|
|
||||||
window_fn=config.stft.window_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
spec = crop_spectrogram_frequencies(
|
return (
|
||||||
spec,
|
(self.bias**self.power)
|
||||||
min_freq=config.frequencies.min_freq,
|
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
|
||||||
max_freq=config.frequencies.max_freq,
|
).to(spec.dtype)
|
||||||
)
|
|
||||||
|
|
||||||
if config.pcen:
|
|
||||||
spec = apply_pcen(
|
|
||||||
spec,
|
|
||||||
time_constant=config.pcen.time_constant,
|
|
||||||
gain=config.pcen.gain,
|
|
||||||
power=config.pcen.power,
|
|
||||||
bias=config.pcen.bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = scale_spectrogram(spec, scale=config.scale)
|
|
||||||
|
|
||||||
if config.spectral_mean_substraction:
|
|
||||||
spec = remove_spectral_mean(spec)
|
|
||||||
|
|
||||||
if config.size:
|
|
||||||
spec = resize_spectrogram(
|
|
||||||
spec,
|
|
||||||
height=config.size.height,
|
|
||||||
resize_factor=config.size.resize_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.peak_normalize:
|
|
||||||
spec = ops.normalize(spec)
|
|
||||||
|
|
||||||
return spec.astype(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def crop_spectrogram_frequencies(
|
def _compute_smoothing_constant(
|
||||||
spec: xr.DataArray,
|
samplerate: int,
|
||||||
min_freq: int = 10_000,
|
time_constant: float,
|
||||||
max_freq: int = 120_000,
|
) -> float:
|
||||||
) -> xr.DataArray:
|
# NOTE: These were taken to match the original implementation
|
||||||
"""Crop the frequency axis of a spectrogram to a specified range.
|
hop_length = 512
|
||||||
|
sr = samplerate / 10
|
||||||
Uses `soundevent.arrays.crop_dim` to select the frequency bins
|
time_constant = time_constant
|
||||||
corresponding to the range [`min_freq`, `max_freq`].
|
t_frames = time_constant * sr / float(hop_length)
|
||||||
|
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec : xr.DataArray
|
|
||||||
Input spectrogram with 'frequency' dimension and coordinates.
|
|
||||||
min_freq : int, default=MIN_FREQ
|
|
||||||
Minimum frequency (Hz) to keep.
|
|
||||||
max_freq : int, default=MAX_FREQ
|
|
||||||
Maximum frequency (Hz) to keep.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Spectrogram cropped along the frequency axis. Preserves dtype.
|
|
||||||
"""
|
|
||||||
start_freq, end_freq = arrays.get_dim_range(spec, dim="frequency")
|
|
||||||
|
|
||||||
return arrays.crop_dim(
|
|
||||||
spec,
|
|
||||||
dim="frequency",
|
|
||||||
start=min_freq if start_freq < min_freq else None,
|
|
||||||
stop=max_freq if end_freq > max_freq else None,
|
|
||||||
).astype(spec.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def stft(
|
class ScaleAmplitudeConfig(BaseConfig):
|
||||||
wave: xr.DataArray,
|
name: Literal["scale_amplitude"] = "scale_amplitude"
|
||||||
window_duration: float,
|
scale: Literal["power", "db"] = "db"
|
||||||
window_overlap: float,
|
|
||||||
window_fn: str = "hann",
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Compute the Short-Time Fourier Transform (STFT) magnitude spectrogram.
|
|
||||||
|
|
||||||
Calculates STFT parameters (N-FFT, hop length) based on the window
|
|
||||||
duration, overlap, and waveform sample rate. Returns an xarray DataArray
|
|
||||||
with correctly calculated 'time' and 'frequency' coordinates.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wave : xr.DataArray
|
|
||||||
Input audio waveform with 'time' coordinates.
|
|
||||||
window_duration : float
|
|
||||||
Duration of the STFT window in seconds.
|
|
||||||
window_overlap : float
|
|
||||||
Fractional overlap between consecutive windows.
|
|
||||||
window_fn : str, default="hann"
|
|
||||||
Name of the window function (e.g., "hann", "hamming").
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Magnitude spectrogram with 'time' and 'frequency' dimensions and
|
|
||||||
coordinates. STFT parameters are stored in the `attrs`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If sample rate cannot be determined from `wave` coordinates.
|
|
||||||
"""
|
|
||||||
if "channel" not in wave.coords:
|
|
||||||
wave = wave.assign_coords(channel=0)
|
|
||||||
|
|
||||||
return audio.compute_spectrogram(
|
|
||||||
wave,
|
|
||||||
window_size=window_duration,
|
|
||||||
hop_size=(1 - window_overlap) * window_duration,
|
|
||||||
window_type=window_fn,
|
|
||||||
scale="amplitude",
|
|
||||||
sort_dims=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_spectral_mean(spec: xr.DataArray) -> xr.DataArray:
|
class ToPower(torch.nn.Module):
|
||||||
"""Apply simple spectral mean subtraction for denoising.
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
Subtracts the mean value of each frequency bin (calculated across time)
|
|
||||||
from that bin, then clips negative values to zero.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec : xr.DataArray
|
|
||||||
Input spectrogram with 'time' and 'frequency' dimensions.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Denoised spectrogram with the same dimensions, coordinates, and dtype.
|
|
||||||
"""
|
|
||||||
return xr.DataArray(
|
|
||||||
data=(spec - spec.mean("time")).clip(0),
|
|
||||||
dims=spec.dims,
|
|
||||||
coords=spec.coords,
|
|
||||||
attrs=spec.attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def scale_spectrogram(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
scale: Literal["dB", "power", "amplitude"],
|
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Apply final amplitude scaling/representation to the spectrogram.
|
|
||||||
|
|
||||||
Converts the input magnitude spectrogram based on the `scale` type:
|
|
||||||
- "dB": Applies logarithmic scaling `log10(S)`.
|
|
||||||
- "power": Squares the magnitude values `S^2`.
|
|
||||||
- "amplitude": Returns the input magnitude values `S` unchanged.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec : xr.DataArray
|
|
||||||
Input magnitude spectrogram (potentially after PCEN).
|
|
||||||
scale : Literal["dB", "power", "amplitude"]
|
|
||||||
The target amplitude representation.
|
|
||||||
dtype : DTypeLike, default=np.float32
|
|
||||||
Target data type for the output scaled spectrogram.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Spectrogram with the specified amplitude scaling applied.
|
|
||||||
"""
|
|
||||||
if scale == "dB":
|
|
||||||
return arrays.to_db(spec).astype(dtype)
|
|
||||||
|
|
||||||
if scale == "power":
|
|
||||||
return spec**2
|
return spec**2
|
||||||
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
|
||||||
|
if conf.scale == "db":
|
||||||
|
return torchaudio.transforms.AmplitudeToDB()
|
||||||
|
|
||||||
def apply_pcen(
|
if conf.scale == "power":
|
||||||
spec: xr.DataArray,
|
return ToPower()
|
||||||
time_constant: float = 0.4,
|
|
||||||
gain: float = 0.98,
|
|
||||||
bias: float = 2,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
power: float = 0.5,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Apply Per-Channel Energy Normalization (PCEN) to a spectrogram.
|
|
||||||
|
|
||||||
Parameters
|
raise NotImplementedError(
|
||||||
----------
|
f"Amplitude scaling {conf.scale} not implemented"
|
||||||
spec : xr.DataArray
|
|
||||||
Input magnitude spectrogram with required attributes like
|
|
||||||
'processing_original_samplerate'.
|
|
||||||
time_constant : float, default=0.4
|
|
||||||
PCEN time constant in seconds.
|
|
||||||
gain : float, default=0.98
|
|
||||||
Gain factor (alpha).
|
|
||||||
bias : float, default=2.0
|
|
||||||
Bias factor (delta).
|
|
||||||
power : float, default=0.5
|
|
||||||
Exponent (r).
|
|
||||||
dtype : DTypeLike, default=np.float32
|
|
||||||
Target data type for the output spectrogram.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
PCEN-scaled spectrogram.
|
|
||||||
"""
|
|
||||||
samplerate = 1 / spec.time.attrs["step"]
|
|
||||||
hop_size = spec.attrs["hop_size"]
|
|
||||||
|
|
||||||
hop_length = int(hop_size * samplerate)
|
|
||||||
|
|
||||||
t_frames = time_constant * samplerate / hop_length
|
|
||||||
|
|
||||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
|
||||||
|
|
||||||
axis = spec.get_axis_num("time")
|
|
||||||
|
|
||||||
shape = tuple([1] * spec.ndim)
|
|
||||||
zi = np.empty(shape)
|
|
||||||
zi[:] = signal.lfilter_zi(
|
|
||||||
[smoothing_constant],
|
|
||||||
[1, smoothing_constant - 1],
|
|
||||||
)[:]
|
|
||||||
|
|
||||||
spec_data = spec.data * (2**31)
|
|
||||||
|
|
||||||
# Smooth the input array along the given axis
|
|
||||||
smoothed, _ = signal.lfilter(
|
|
||||||
[smoothing_constant],
|
|
||||||
[1, smoothing_constant - 1],
|
|
||||||
spec_data,
|
|
||||||
zi=zi,
|
|
||||||
axis=axis, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
smooth = np.exp(-gain * (np.log(eps) + np.log1p(smoothed / eps)))
|
|
||||||
data = (bias**power) * np.expm1(
|
|
||||||
power * np.log1p(spec_data * smooth / bias)
|
|
||||||
)
|
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
data.astype(spec.dtype),
|
|
||||||
dims=spec.dims,
|
|
||||||
coords=spec.coords,
|
|
||||||
attrs=spec.attrs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def scale_log(
|
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||||
spec: xr.DataArray,
|
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
|
||||||
ref: Union[float, Callable] = np.max,
|
|
||||||
amin: float = 1e-10,
|
|
||||||
top_db: Optional[float] = 80.0,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Apply logarithmic scaling to a magnitude spectrogram.
|
|
||||||
|
|
||||||
Calculates `log10(S)`, where S is the input magnitude spectrogram.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
spec : xr.DataArray
|
|
||||||
Input magnitude spectrogram with required attributes like
|
|
||||||
'processing_original_samplerate', 'processing_nfft'.
|
|
||||||
dtype : DTypeLike, default=np.float32
|
|
||||||
Target data type for the output spectrogram.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Log-scaled spectrogram.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If required attributes are missing from `spec.attrs`.
|
|
||||||
ValueError
|
|
||||||
If attributes are non-numeric or window function is invalid.
|
|
||||||
|
|
||||||
|
|
||||||
Notes
|
class SpectralMeanSubstraction(torch.nn.Module):
|
||||||
-----
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
Implementation mainly taken from librosa `power_to_db` function
|
mean = spec.mean(-1, keepdim=True)
|
||||||
"""
|
return (spec - mean).clamp(min=0)
|
||||||
|
|
||||||
if callable(ref):
|
|
||||||
ref_value = ref(spec)
|
|
||||||
else:
|
|
||||||
ref_value = np.abs(ref)
|
|
||||||
|
|
||||||
log_spec = 10.0 * np.log10(np.maximum(amin, spec)) - np.log10(
|
|
||||||
np.maximum(amin, ref_value)
|
|
||||||
)
|
|
||||||
|
|
||||||
if top_db is not None:
|
|
||||||
if top_db < 0:
|
|
||||||
raise ValueError("top_db must be non-negative")
|
|
||||||
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
|
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
data=log_spec.astype(dtype),
|
|
||||||
dims=spec.dims,
|
|
||||||
coords=spec.coords,
|
|
||||||
attrs=spec.attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resize_spectrogram(
|
class ResizeConfig(BaseConfig):
|
||||||
spec: xr.DataArray,
|
name: Literal["resize_spec"] = "resize_spec"
|
||||||
height: int = 128,
|
height: int = 128
|
||||||
resize_factor: Optional[float] = 0.5,
|
resize_factor: float = 0.5
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Resize a spectrogram to target dimensions using interpolation.
|
|
||||||
|
|
||||||
Resizes the frequency axis to `height` bins and optionally resizes the
|
|
||||||
time axis by `resize_factor`.
|
|
||||||
|
|
||||||
Parameters
|
class ResizeSpec(torch.nn.Module):
|
||||||
----------
|
def __init__(self, height: int, time_factor: float):
|
||||||
spec : xr.DataArray
|
super().__init__()
|
||||||
Input spectrogram with 'time' and 'frequency' dimensions.
|
self.height = height
|
||||||
height : int, default=128
|
self.time_factor = time_factor
|
||||||
Target number of frequency bins (vertical dimension).
|
|
||||||
resize_factor : float, optional
|
|
||||||
Factor to resize the time dimension. If 1.0 or None, time dimension
|
|
||||||
is unchanged. If 0.5, time dimension is halved, etc.
|
|
||||||
|
|
||||||
Returns
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
-------
|
current_length = spec.shape[-1]
|
||||||
xr.DataArray
|
target_length = int(self.time_factor * current_length)
|
||||||
Resized spectrogram. Coordinates are typically adjusted by the
|
return torch.nn.functional.interpolate(
|
||||||
underlying resize operation if implemented in `ops.resize`.
|
spec.unsqueeze(0).unsqueeze(0),
|
||||||
The dtype is currently hardcoded to float32 by ops.resize call.
|
size=(self.height, target_length),
|
||||||
"""
|
mode="bilinear",
|
||||||
resize_factor = resize_factor or 1
|
|
||||||
current_width = spec.sizes["time"]
|
|
||||||
|
|
||||||
target_sizes = {
|
|
||||||
"time": int(current_width * resize_factor),
|
|
||||||
"frequency": height,
|
|
||||||
}
|
|
||||||
|
|
||||||
new_coords = {}
|
|
||||||
for dim in ["time", "frequency"]:
|
|
||||||
step = arrays.get_dim_step(spec, dim)
|
|
||||||
start, stop = arrays.get_dim_range(spec, dim)
|
|
||||||
new_coords[dim] = arrays.create_range_dim(
|
|
||||||
name=dim,
|
|
||||||
start=start,
|
|
||||||
stop=stop + step,
|
|
||||||
size=target_sizes[dim],
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return spec.interp(
|
|
||||||
coords=new_coords, method="linear", kwargs=dict(fill_value=0)
|
class PeakNormalizeConfig(BaseConfig):
|
||||||
|
name: Literal["peak_normalize"] = "peak_normalize"
|
||||||
|
|
||||||
|
|
||||||
|
SpectrogramTransform = Annotated[
|
||||||
|
Union[
|
||||||
|
PcenConfig,
|
||||||
|
ScaleAmplitudeConfig,
|
||||||
|
SpectralMeanSubstractionConfig,
|
||||||
|
PeakNormalizeConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SpectrogramConfig(BaseConfig):
|
||||||
|
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||||
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
|
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||||
|
transforms: List[SpectrogramTransform] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
PcenConfig(),
|
||||||
|
SpectralMeanSubstractionConfig(),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_spectrogram_resolution(
|
def _build_spectrogram_transform_step(
|
||||||
config: SpectrogramConfig,
|
step: SpectrogramTransform,
|
||||||
) -> tuple[float, float]:
|
samplerate: int,
|
||||||
"""Calculate the approximate resolution of the final spectrogram.
|
) -> torch.nn.Module:
|
||||||
|
if step.name == "pcen":
|
||||||
|
return PCEN(
|
||||||
|
smoothing_constant=_compute_smoothing_constant(
|
||||||
|
samplerate=samplerate,
|
||||||
|
time_constant=step.time_constant,
|
||||||
|
),
|
||||||
|
gain=step.gain,
|
||||||
|
bias=step.bias,
|
||||||
|
power=step.power,
|
||||||
|
)
|
||||||
|
|
||||||
Computes the width of each frequency bin (Hz/bin) and the duration
|
if step.name == "scale_amplitude":
|
||||||
of each time bin (seconds/bin) based on the configuration parameters.
|
return _build_amplitude_scaler(step)
|
||||||
|
|
||||||
Parameters
|
if step.name == "spectral_mean_substraction":
|
||||||
----------
|
return SpectralMeanSubstraction()
|
||||||
config : SpectrogramConfig
|
|
||||||
The spectrogram configuration object.
|
|
||||||
samplerate : int, optional
|
|
||||||
The sample rate of the audio *before* STFT. Required if needed to
|
|
||||||
calculate hop duration accurately from STFT config, but the current
|
|
||||||
implementation calculates hop_duration directly from STFT config times.
|
|
||||||
|
|
||||||
Returns
|
if step.name == "peak_normalize":
|
||||||
-------
|
return PeakNormalize()
|
||||||
Tuple[float, float]
|
|
||||||
A tuple containing:
|
|
||||||
- frequency_resolution (float): Approximate Hz per frequency bin.
|
|
||||||
- time_resolution (float): Approximate seconds per time bin.
|
|
||||||
|
|
||||||
Raises
|
raise NotImplementedError(
|
||||||
------
|
f"Spectrogram preprocessing step {step.name} not implemented"
|
||||||
ValueError
|
)
|
||||||
If required configuration fields (like `config.size`) are missing
|
|
||||||
or invalid.
|
|
||||||
"""
|
def build_spectrogram_transform(
|
||||||
max_freq = config.frequencies.max_freq
|
samplerate: int,
|
||||||
min_freq = config.frequencies.min_freq
|
conf: SpectrogramConfig,
|
||||||
|
) -> torch.nn.Module:
|
||||||
if config.size is None:
|
return torch.nn.Sequential(
|
||||||
raise ValueError("Spectrogram size configuration is required.")
|
*[
|
||||||
|
_build_spectrogram_transform_step(step, samplerate=samplerate)
|
||||||
spec_height = config.size.height
|
for step in conf.transforms
|
||||||
resize_factor = config.size.resize_factor or 1
|
]
|
||||||
freq_bin_width = (max_freq - min_freq) / spec_height
|
)
|
||||||
hop_duration = config.stft.window_duration * (
|
|
||||||
1 - config.stft.window_overlap
|
|
||||||
|
class SpectrogramPipeline(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spec_builder: SpectrogramBuilder,
|
||||||
|
freq_cutter: torch.nn.Module,
|
||||||
|
transforms: torch.nn.Module,
|
||||||
|
resizer: torch.nn.Module,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.spec_builder = spec_builder
|
||||||
|
self.freq_cutter = freq_cutter
|
||||||
|
self.transforms = transforms
|
||||||
|
self.resizer = resizer
|
||||||
|
|
||||||
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
spec = self.spec_builder(wav)
|
||||||
|
spec = self.freq_cutter(spec)
|
||||||
|
spec = self.transforms(spec)
|
||||||
|
return self.resizer(spec)
|
||||||
|
|
||||||
|
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.spec_builder(wav)
|
||||||
|
|
||||||
|
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.freq_cutter(spec)
|
||||||
|
|
||||||
|
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.transforms(spec)
|
||||||
|
|
||||||
|
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.resizer(spec)
|
||||||
|
|
||||||
|
|
||||||
|
def build_spectrogram_pipeline(
|
||||||
|
samplerate: int,
|
||||||
|
conf: SpectrogramConfig,
|
||||||
|
) -> SpectrogramPipeline:
|
||||||
|
spec_builder = build_spectrogram_builder(samplerate, conf.stft)
|
||||||
|
n_fft, _ = _spec_params_from_config(samplerate, conf.stft)
|
||||||
|
cutter = FrequencyClip(
|
||||||
|
low_index=_frequency_to_index(
|
||||||
|
conf.frequencies.min_freq, samplerate, n_fft
|
||||||
|
),
|
||||||
|
high_index=_frequency_to_index(
|
||||||
|
conf.frequencies.max_freq, samplerate, n_fft
|
||||||
|
),
|
||||||
|
)
|
||||||
|
transforms = build_spectrogram_transform(samplerate, conf)
|
||||||
|
resizer = ResizeSpec(
|
||||||
|
height=conf.size.height,
|
||||||
|
time_factor=conf.size.resize_factor,
|
||||||
|
)
|
||||||
|
return SpectrogramPipeline(
|
||||||
|
spec_builder=spec_builder,
|
||||||
|
freq_cutter=cutter,
|
||||||
|
transforms=transforms,
|
||||||
|
resizer=resizer,
|
||||||
)
|
)
|
||||||
return freq_bin_width, hop_duration / resize_factor
|
|
||||||
|
|||||||
@ -28,8 +28,10 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
|
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import Position, Size
|
from batdetect2.typing.targets import Position, Size
|
||||||
|
from batdetect2.utils.arrays import spec_to_xarray
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Anchor",
|
"Anchor",
|
||||||
@ -365,6 +367,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
time_scale: float = DEFAULT_TIME_SCALE,
|
time_scale: float = DEFAULT_TIME_SCALE,
|
||||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||||
loading_buffer: float = 0.01,
|
loading_buffer: float = 0.01,
|
||||||
@ -383,6 +386,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
|||||||
Buffer in seconds to add when loading audio clips.
|
Buffer in seconds to add when loading audio clips.
|
||||||
"""
|
"""
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
|
self.audio_loader = audio_loader
|
||||||
self.time_scale = time_scale
|
self.time_scale = time_scale
|
||||||
self.frequency_scale = frequency_scale
|
self.frequency_scale = frequency_scale
|
||||||
self.loading_buffer = loading_buffer
|
self.loading_buffer = loading_buffer
|
||||||
@ -422,6 +426,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
|||||||
|
|
||||||
time, freq = get_peak_energy_coordinates(
|
time, freq = get_peak_energy_coordinates(
|
||||||
recording=sound_event.recording,
|
recording=sound_event.recording,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
@ -511,8 +516,10 @@ def build_roi_mapper(
|
|||||||
|
|
||||||
if config.name == "peak_energy_bbox":
|
if config.name == "peak_energy_bbox":
|
||||||
preprocessor = build_preprocessor(config.preprocessing)
|
preprocessor = build_preprocessor(config.preprocessing)
|
||||||
|
audio_loader = build_audio_loader(config.preprocessing.audio)
|
||||||
return PeakEnergyBBoxMapper(
|
return PeakEnergyBBoxMapper(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
audio_loader=audio_loader,
|
||||||
time_scale=config.time_scale,
|
time_scale=config.time_scale,
|
||||||
frequency_scale=config.frequency_scale,
|
frequency_scale=config.frequency_scale,
|
||||||
loading_buffer=config.loading_buffer,
|
loading_buffer=config.loading_buffer,
|
||||||
@ -617,6 +624,7 @@ def _build_bounding_box(
|
|||||||
|
|
||||||
def get_peak_energy_coordinates(
|
def get_peak_energy_coordinates(
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
end_time: Optional[float] = None,
|
end_time: Optional[float] = None,
|
||||||
@ -669,7 +677,15 @@ def get_peak_energy_coordinates(
|
|||||||
end_time=clip_end,
|
end_time=clip_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
spec = preprocessor.preprocess_clip(clip)
|
wav = audio_loader.load_clip(clip)
|
||||||
|
spec = preprocessor.process_numpy(wav)
|
||||||
|
spec = spec_to_xarray(
|
||||||
|
spec,
|
||||||
|
clip.start_time,
|
||||||
|
clip.end_time,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
low_freq = max(low_freq, preprocessor.min_freq)
|
low_freq = max(low_freq, preprocessor.min_freq)
|
||||||
high_freq = min(high_freq, preprocessor.max_freq)
|
high_freq = min(high_freq, preprocessor.max_freq)
|
||||||
selection = spec.sel(
|
selection = spec.sel(
|
||||||
|
|||||||
@ -129,9 +129,7 @@ def mix_examples(
|
|||||||
with xr.set_options(keep_attrs=True):
|
with xr.set_options(keep_attrs=True):
|
||||||
combined = weight * audio1 + (1 - weight) * audio2
|
combined = weight * audio1 + (1 - weight) * audio2
|
||||||
|
|
||||||
spectrogram = preprocessor.compute_spectrogram(
|
spectrogram = preprocessor.process_numpy(combined.data)
|
||||||
combined.rename({"audio_time": "time"})
|
|
||||||
).data
|
|
||||||
|
|
||||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||||
# spectrogram computed from the subclip's audio. This is due to a
|
# spectrogram computed from the subclip's audio. This is due to a
|
||||||
@ -241,9 +239,7 @@ def add_echo(
|
|||||||
with xr.set_options(keep_attrs=True):
|
with xr.set_options(keep_attrs=True):
|
||||||
audio = audio + weight * audio_delay
|
audio = audio + weight * audio_delay
|
||||||
|
|
||||||
spectrogram = preprocessor.compute_spectrogram(
|
spectrogram = preprocessor.process_numpy(audio.data)
|
||||||
audio.rename({"audio_time": "time"}),
|
|
||||||
).data
|
|
||||||
|
|
||||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||||
# spectrogram computed from the subclip's audio. This is due to a
|
# spectrogram computed from the subclip's audio. This is due to a
|
||||||
|
|||||||
@ -21,10 +21,12 @@ class ClipingConfig(BaseConfig):
|
|||||||
class Clipper(ClipperProtocol):
|
class Clipper(ClipperProtocol):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
samplerate: int,
|
||||||
duration: float = 0.5,
|
duration: float = 0.5,
|
||||||
max_empty: float = 0.2,
|
max_empty: float = 0.2,
|
||||||
random: bool = True,
|
random: bool = True,
|
||||||
):
|
):
|
||||||
|
self.samplerate = samplerate
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
self.random = random
|
self.random = random
|
||||||
self.max_empty = max_empty
|
self.max_empty = max_empty
|
||||||
|
|||||||
@ -25,6 +25,8 @@ from multiprocessing import Pool
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Sequence
|
from typing import Callable, Optional, Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -34,9 +36,12 @@ from tqdm.auto import tqdm
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||||
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
from batdetect2.utils.arrays import audio_to_xarray
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"preprocess_annotations",
|
"preprocess_annotations",
|
||||||
@ -76,6 +81,7 @@ def preprocess_dataset(
|
|||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
preprocessor = build_preprocessor(config=config.preprocess)
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
labeller = build_clip_labeler(targets, config=config.labels)
|
labeller = build_clip_labeler(targets, config=config.labels)
|
||||||
|
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||||
|
|
||||||
if not output.exists():
|
if not output.exists():
|
||||||
logger.debug("Creating directory {directory}", directory=output)
|
logger.debug("Creating directory {directory}", directory=output)
|
||||||
@ -84,6 +90,7 @@ def preprocess_dataset(
|
|||||||
preprocess_annotations(
|
preprocess_annotations(
|
||||||
dataset,
|
dataset,
|
||||||
output_dir=output,
|
output_dir=output,
|
||||||
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
replace=force,
|
replace=force,
|
||||||
@ -93,6 +100,7 @@ def preprocess_dataset(
|
|||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
@ -140,9 +148,15 @@ def generate_train_example(
|
|||||||
- The original `ClipAnnotation` metadata is stored as a JSON string in the
|
- The original `ClipAnnotation` metadata is stored as a JSON string in the
|
||||||
Dataset's attributes for provenance.
|
Dataset's attributes for provenance.
|
||||||
"""
|
"""
|
||||||
wave = preprocessor.load_clip_audio(clip_annotation.clip)
|
wave = audio_loader.load_clip(clip_annotation.clip)
|
||||||
|
|
||||||
spectrogram = preprocessor.compute_spectrogram(wave)
|
spectrogram = _spec_to_xr(
|
||||||
|
preprocessor(torch.tensor(wave)),
|
||||||
|
start_time=clip_annotation.clip.start_time,
|
||||||
|
end_time=clip_annotation.clip.end_time,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
heatmaps = labeller(clip_annotation, spectrogram)
|
heatmaps = labeller(clip_annotation, spectrogram)
|
||||||
|
|
||||||
@ -152,7 +166,12 @@ def generate_train_example(
|
|||||||
# the spectrogram time dimension, otherwise xarray will interpolate
|
# the spectrogram time dimension, otherwise xarray will interpolate
|
||||||
# the spectrogram and the heatmaps to the same temporal resolution
|
# the spectrogram and the heatmaps to the same temporal resolution
|
||||||
# as the waveform.
|
# as the waveform.
|
||||||
"audio": wave.rename({"time": "audio_time"}),
|
"audio": audio_to_xarray(
|
||||||
|
wave,
|
||||||
|
start_time=clip_annotation.clip.start_time,
|
||||||
|
end_time=clip_annotation.clip.end_time,
|
||||||
|
time_axis="audio_time",
|
||||||
|
),
|
||||||
"spectrogram": spectrogram,
|
"spectrogram": spectrogram,
|
||||||
"detection": heatmaps.detection,
|
"detection": heatmaps.detection,
|
||||||
"class": heatmaps.classes,
|
"class": heatmaps.classes,
|
||||||
@ -170,6 +189,32 @@ def generate_train_example(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _spec_to_xr(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
min_freq: float,
|
||||||
|
max_freq: float,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
data = spec.numpy()[0, 0]
|
||||||
|
|
||||||
|
height, width = data.shape
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=data,
|
||||||
|
dims=[
|
||||||
|
"frequency",
|
||||||
|
"time",
|
||||||
|
],
|
||||||
|
coords={
|
||||||
|
"frequency": np.linspace(
|
||||||
|
min_freq, max_freq, height, endpoint=False
|
||||||
|
),
|
||||||
|
"time": np.linspace(start_time, end_time, width, endpoint=False),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _save_xr_dataset_to_file(
|
def _save_xr_dataset_to_file(
|
||||||
dataset: xr.Dataset,
|
dataset: xr.Dataset,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
@ -206,6 +251,7 @@ def preprocess_annotations(
|
|||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
output_dir: data.PathLike,
|
output_dir: data.PathLike,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
@ -275,6 +321,7 @@ def preprocess_annotations(
|
|||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
filename_fn=filename_fn,
|
filename_fn=filename_fn,
|
||||||
replace=replace,
|
replace=replace,
|
||||||
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
),
|
),
|
||||||
@ -290,6 +337,7 @@ def preprocess_annotations(
|
|||||||
def preprocess_single_annotation(
|
def preprocess_single_annotation(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
output_dir: data.PathLike,
|
output_dir: data.PathLike,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
labeller: ClipLabeller,
|
labeller: ClipLabeller,
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
@ -335,6 +383,7 @@ def preprocess_single_annotation(
|
|||||||
try:
|
try:
|
||||||
sample = generate_train_example(
|
sample = generate_train_example(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,10 +10,10 @@ pipeline can interact consistently, regardless of the specific underlying
|
|||||||
implementation (e.g., different libraries or custom configurations).
|
implementation (e.g., different libraries or custom configurations).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Protocol, Union
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -36,7 +36,7 @@ class AudioLoader(Protocol):
|
|||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess audio directly from a file path.
|
"""Load and preprocess audio directly from a file path.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -46,12 +46,6 @@ class AudioLoader(Protocol):
|
|||||||
audio_dir : PathLike, optional
|
audio_dir : PathLike, optional
|
||||||
A directory prefix to prepend to the path if `path` is relative.
|
A directory prefix to prepend to the path if `path` is relative.
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The loaded and preprocessed audio waveform as an xarray DataArray
|
|
||||||
with time coordinates. Typically loads only the first channel.
|
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
FileNotFoundError
|
FileNotFoundError
|
||||||
@ -65,7 +59,7 @@ class AudioLoader(Protocol):
|
|||||||
self,
|
self,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the entire audio for a Recording object.
|
"""Load and preprocess the entire audio for a Recording object.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -95,7 +89,7 @@ class AudioLoader(Protocol):
|
|||||||
self,
|
self,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the audio segment defined by a Clip object.
|
"""Load and preprocess the audio segment defined by a Clip object.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -124,264 +118,41 @@ class AudioLoader(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class SpectrogramBuilder(Protocol):
|
class SpectrogramBuilder(Protocol):
|
||||||
"""Defines the interface for a spectrogram generation component.
|
"""Defines the interface for a spectrogram generation component."""
|
||||||
|
|
||||||
A SpectrogramBuilder takes a waveform (as numpy array or xarray DataArray)
|
def __call__(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
and produces a spectrogram (as an xarray DataArray) based on its internal
|
"""Generate a spectrogram from an audio waveform."""
|
||||||
configuration or implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
wav: Union[np.ndarray, xr.DataArray],
|
|
||||||
samplerate: Optional[int] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Generate a spectrogram from an audio waveform.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wav : Union[np.ndarray, xr.DataArray]
|
|
||||||
The input audio waveform. If a numpy array, `samplerate` must
|
|
||||||
also be provided. If an xarray DataArray, it must have a 'time'
|
|
||||||
coordinate from which the sample rate can be inferred.
|
|
||||||
samplerate : int, optional
|
|
||||||
The sample rate of the audio in Hz. Required if `wav` is a
|
|
||||||
numpy array. If `wav` is an xarray DataArray, this parameter is
|
|
||||||
ignored as the sample rate is derived from the coordinates.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The computed spectrogram as an xarray DataArray with 'time' and
|
|
||||||
'frequency' coordinates.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `wav` is a numpy array and `samplerate` is not provided, or
|
|
||||||
if `wav` is an xarray DataArray without a valid 'time' coordinate.
|
|
||||||
"""
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class PreprocessorProtocol(Protocol):
|
class AudioPipeline(Protocol):
|
||||||
"""Defines a high-level interface for the complete preprocessing pipeline.
|
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
A Preprocessor combines audio loading and spectrogram generation steps.
|
|
||||||
It provides methods to go directly from source descriptions (file paths,
|
class SpectrogramPipeline(Protocol):
|
||||||
Recording objects, Clip objects) to the final spectrogram representation
|
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
needed by the model. It may also expose intermediate steps like audio
|
|
||||||
loading or spectrogram computation from a waveform.
|
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||||
"""
|
|
||||||
|
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessorProtocol(Protocol):
|
||||||
|
"""Defines a high-level interface for the complete preprocessing pipeline."""
|
||||||
|
|
||||||
max_freq: float
|
max_freq: float
|
||||||
|
|
||||||
min_freq: float
|
min_freq: float
|
||||||
|
|
||||||
def preprocess_file(
|
audio_pipeline: AudioPipeline
|
||||||
self,
|
|
||||||
path: data.PathLike,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load audio from a file and compute the final processed spectrogram.
|
|
||||||
|
|
||||||
Performs the full pipeline:
|
spectrogram_pipeline: SpectrogramPipeline
|
||||||
|
|
||||||
Load -> Preprocess Audio -> Compute Spectrogram.
|
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
Parameters
|
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||||
----------
|
return self(torch.tensor(wav)).numpy()[0, 0]
|
||||||
path : PathLike
|
|
||||||
Path to the audio file.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
A directory prefix if `path` is relative.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The final processed spectrogram.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the audio file cannot be found.
|
|
||||||
Exception
|
|
||||||
If any step in the loading or preprocessing fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def preprocess_recording(
|
|
||||||
self,
|
|
||||||
recording: data.Recording,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load audio for a Recording and compute the processed spectrogram.
|
|
||||||
|
|
||||||
Performs the full pipeline for the entire duration of the recording.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
recording : data.Recording
|
|
||||||
The Recording object.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The final processed spectrogram.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the audio file cannot be found.
|
|
||||||
Exception
|
|
||||||
If any step in the loading or preprocessing fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def preprocess_clip(
|
|
||||||
self,
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load audio for a Clip and compute the final processed spectrogram.
|
|
||||||
|
|
||||||
Performs the full pipeline for the specified clip segment.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip : data.Clip
|
|
||||||
The Clip object defining the audio segment.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The final processed spectrogram.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the audio file cannot be found.
|
|
||||||
Exception
|
|
||||||
If any step in the loading or preprocessing fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def load_file_audio(
|
|
||||||
self,
|
|
||||||
path: data.PathLike,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load and preprocess *only* the audio waveform from a file path.
|
|
||||||
|
|
||||||
Performs the initial audio loading and waveform processing steps
|
|
||||||
(like resampling, scaling), but stops *before* spectrogram generation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the audio file.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
A directory prefix if `path` is relative.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The loaded and preprocessed audio waveform.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError, Exception
|
|
||||||
If audio loading/preprocessing fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def load_recording_audio(
|
|
||||||
self,
|
|
||||||
recording: data.Recording,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load and preprocess *only* the audio waveform for a Recording.
|
|
||||||
|
|
||||||
Performs the initial audio loading and waveform processing steps
|
|
||||||
for the entire recording duration.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
recording : data.Recording
|
|
||||||
The Recording object.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The loaded and preprocessed audio waveform.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError, Exception
|
|
||||||
If audio loading/preprocessing fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def load_clip_audio(
|
|
||||||
self,
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Load and preprocess *only* the audio waveform for a Clip.
|
|
||||||
|
|
||||||
Performs the initial audio loading and waveform processing steps
|
|
||||||
for the specified clip segment.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip : data.Clip
|
|
||||||
The Clip object defining the segment.
|
|
||||||
audio_dir : PathLike, optional
|
|
||||||
Directory containing the audio file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The loaded and preprocessed audio waveform segment.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError, Exception
|
|
||||||
If audio loading/preprocessing fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def compute_spectrogram(
|
|
||||||
self,
|
|
||||||
wav: Union[xr.DataArray, np.ndarray],
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Compute the spectrogram from a pre-loaded audio waveform.
|
|
||||||
|
|
||||||
Applies the spectrogram generation steps (STFT, scaling, etc.) defined
|
|
||||||
by the `SpectrogramBuilder` component of the preprocessor to an
|
|
||||||
already loaded (and potentially preprocessed) waveform.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wav : Union[xr.DataArray, np.ndarray]
|
|
||||||
The input audio waveform. If numpy array, `samplerate` is required.
|
|
||||||
samplerate : int, optional
|
|
||||||
Sample rate in Hz (required if `wav` is np.ndarray).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
The computed spectrogram.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError, Exception
|
|
||||||
If waveform input is invalid or spectrogram computation fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|||||||
@ -2,6 +2,62 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
|
||||||
|
|
||||||
|
def spec_to_xarray(
|
||||||
|
spec: np.ndarray,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
min_freq: float,
|
||||||
|
max_freq: float,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if spec.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Input numpy spectrogram array should be 2-dimensional"
|
||||||
|
)
|
||||||
|
|
||||||
|
height, width = spec.shape
|
||||||
|
return xr.DataArray(
|
||||||
|
data=spec,
|
||||||
|
dims=["frequency", "time"],
|
||||||
|
coords={
|
||||||
|
"frequency": np.linspace(
|
||||||
|
min_freq,
|
||||||
|
max_freq,
|
||||||
|
height,
|
||||||
|
endpoint=False,
|
||||||
|
),
|
||||||
|
"time": np.linspace(
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
width,
|
||||||
|
endpoint=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def audio_to_xarray(
|
||||||
|
wav: np.ndarray,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
time_axis: str = "time",
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if wav.ndim != 1:
|
||||||
|
raise ValueError("Input numpy audio array should be 1-dimensional")
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=wav,
|
||||||
|
dims=[time_axis],
|
||||||
|
coords={
|
||||||
|
time_axis: np.linspace(
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
len(wav),
|
||||||
|
endpoint=False,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extend_width(
|
def extend_width(
|
||||||
array: np.ndarray,
|
array: np.ndarray,
|
||||||
extra: int,
|
extra: int,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from soundevent import data, terms
|
|||||||
from batdetect2.data import DatasetConfig, load_dataset
|
from batdetect2.data import DatasetConfig, load_dataset
|
||||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
@ -27,6 +28,7 @@ from batdetect2.typing import (
|
|||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -368,6 +370,11 @@ def sample_preprocessor() -> PreprocessorProtocol:
|
|||||||
return build_preprocessor()
|
return build_preprocessor()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_audio_loader() -> AudioLoader:
|
||||||
|
return build_audio_loader()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def bat_tag() -> TagInfo:
|
def bat_tag() -> TagInfo:
|
||||||
return TagInfo(key="class", value="bat")
|
return TagInfo(key="class", value="bat")
|
||||||
|
|||||||
@ -1,13 +1,10 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import xarray as xr
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.arrays import Dimensions, create_time_dim_from_array
|
|
||||||
|
|
||||||
from batdetect2.preprocess import audio
|
from batdetect2.preprocess import audio
|
||||||
|
|
||||||
@ -30,44 +27,6 @@ def create_dummy_wave(
|
|||||||
return wave.astype(dtype)
|
return wave.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
def create_xr_wave(
|
|
||||||
samplerate: int,
|
|
||||||
duration: float,
|
|
||||||
num_channels: int = 1,
|
|
||||||
freq: float = 440.0,
|
|
||||||
amplitude: float = 0.5,
|
|
||||||
start_time: float = 0.0,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Generates a simple xarray waveform."""
|
|
||||||
num_samples = int(samplerate * duration)
|
|
||||||
times = np.linspace(
|
|
||||||
start_time,
|
|
||||||
start_time + duration,
|
|
||||||
num_samples,
|
|
||||||
endpoint=False,
|
|
||||||
)
|
|
||||||
coords = {
|
|
||||||
Dimensions.time.value: create_time_dim_from_array(
|
|
||||||
times, samplerate=samplerate, start_time=start_time
|
|
||||||
)
|
|
||||||
}
|
|
||||||
dims = [Dimensions.time.value]
|
|
||||||
|
|
||||||
wave_data = amplitude * np.sin(2 * np.pi * freq * times)
|
|
||||||
|
|
||||||
if num_channels > 1:
|
|
||||||
coords[Dimensions.channel.value] = np.arange(num_channels)
|
|
||||||
dims = [Dimensions.channel.value] + dims
|
|
||||||
wave_data = np.stack([wave_data] * num_channels, axis=0)
|
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
wave_data.astype(np.float32),
|
|
||||||
coords=coords,
|
|
||||||
dims=dims,
|
|
||||||
attrs={"samplerate": samplerate},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
|
def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||||
"""Creates a dummy WAV file and returns its path."""
|
"""Creates a dummy WAV file and returns its path."""
|
||||||
@ -99,408 +58,3 @@ def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_audio_config() -> audio.AudioConfig:
|
def default_audio_config() -> audio.AudioConfig:
|
||||||
return audio.AudioConfig()
|
return audio.AudioConfig()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def no_resample_config() -> audio.AudioConfig:
|
|
||||||
return audio.AudioConfig(resample=None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def fixed_duration_config() -> audio.AudioConfig:
|
|
||||||
return audio.AudioConfig(duration=0.5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def scale_config() -> audio.AudioConfig:
|
|
||||||
return audio.AudioConfig(scale=True, center=False)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def no_center_config() -> audio.AudioConfig:
|
|
||||||
return audio.AudioConfig(center=False)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def resample_fourier_config() -> audio.AudioConfig:
|
|
||||||
return audio.AudioConfig(
|
|
||||||
resample=audio.ResampleConfig(
|
|
||||||
samplerate=audio.TARGET_SAMPLERATE_HZ // 2, method="fourier"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_resample_config_defaults():
|
|
||||||
config = audio.ResampleConfig()
|
|
||||||
assert config.samplerate == audio.TARGET_SAMPLERATE_HZ
|
|
||||||
assert config.method == "poly"
|
|
||||||
|
|
||||||
|
|
||||||
def test_audio_config_defaults():
|
|
||||||
config = audio.AudioConfig()
|
|
||||||
assert config.resample is not None
|
|
||||||
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
|
|
||||||
assert config.resample.method == "poly"
|
|
||||||
assert config.scale == audio.SCALE_RAW_AUDIO
|
|
||||||
assert config.center is False
|
|
||||||
assert config.duration == audio.DEFAULT_DURATION
|
|
||||||
|
|
||||||
|
|
||||||
def test_audio_config_override():
|
|
||||||
resample_cfg = audio.ResampleConfig(samplerate=44100, method="fourier")
|
|
||||||
config = audio.AudioConfig(
|
|
||||||
resample=resample_cfg,
|
|
||||||
scale=True,
|
|
||||||
center=False,
|
|
||||||
duration=1.0,
|
|
||||||
)
|
|
||||||
assert config.resample == resample_cfg
|
|
||||||
assert config.scale is True
|
|
||||||
assert config.center is False
|
|
||||||
assert config.duration == 1.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_audio_config_no_resample():
|
|
||||||
config = audio.AudioConfig(resample=None)
|
|
||||||
assert config.resample is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"orig_sr, orig_dur, target_dur",
|
|
||||||
[
|
|
||||||
(256_000, 1.0, 0.5),
|
|
||||||
(256_000, 0.5, 1.0),
|
|
||||||
(256_000, 1.0, 1.0),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_adjust_audio_duration(orig_sr, orig_dur, target_dur):
|
|
||||||
wave = create_xr_wave(samplerate=orig_sr, duration=orig_dur)
|
|
||||||
adjusted_wave = audio.adjust_audio_duration(wave, duration=target_dur)
|
|
||||||
expected_samples = int(target_dur * orig_sr)
|
|
||||||
assert adjusted_wave.sizes["time"] == expected_samples
|
|
||||||
assert adjusted_wave.coords["time"].attrs["step"] == 1 / orig_sr
|
|
||||||
assert adjusted_wave.dtype == wave.dtype
|
|
||||||
if orig_dur > 0 and target_dur > orig_dur:
|
|
||||||
padding_start_index = int(orig_dur * orig_sr) + 1
|
|
||||||
assert np.all(adjusted_wave.values[padding_start_index:] == 0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_adjust_audio_duration_negative_target_raises():
|
|
||||||
wave = create_xr_wave(1000, 1.0)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
audio.adjust_audio_duration(wave, duration=-0.5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"orig_sr, target_sr, mode",
|
|
||||||
[
|
|
||||||
(48000, 96000, "poly"),
|
|
||||||
(96000, 48000, "poly"),
|
|
||||||
(48000, 96000, "fourier"),
|
|
||||||
(96000, 48000, "fourier"),
|
|
||||||
(48000, 44100, "poly"),
|
|
||||||
(48000, 44100, "fourier"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_resample_audio(orig_sr, target_sr, mode):
|
|
||||||
duration = 0.1
|
|
||||||
wave = create_xr_wave(orig_sr, duration)
|
|
||||||
resampled_wave = audio.resample_audio(
|
|
||||||
wave, samplerate=target_sr, method=mode, dtype=np.float32
|
|
||||||
)
|
|
||||||
expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr))
|
|
||||||
assert resampled_wave.sizes["time"] == expected_samples
|
|
||||||
assert resampled_wave.coords["time"].attrs["step"] == 1 / target_sr
|
|
||||||
assert np.isclose(
|
|
||||||
resampled_wave.coords["time"].values[-1]
|
|
||||||
- resampled_wave.coords["time"].values[0],
|
|
||||||
duration,
|
|
||||||
atol=2 / target_sr,
|
|
||||||
)
|
|
||||||
assert resampled_wave.dtype == np.float32
|
|
||||||
|
|
||||||
|
|
||||||
def test_resample_audio_same_samplerate():
|
|
||||||
sr = 48000
|
|
||||||
duration = 0.1
|
|
||||||
wave = create_xr_wave(sr, duration)
|
|
||||||
resampled_wave = audio.resample_audio(
|
|
||||||
wave, samplerate=sr, dtype=np.float64
|
|
||||||
)
|
|
||||||
xr.testing.assert_equal(wave.astype(np.float64), resampled_wave)
|
|
||||||
|
|
||||||
|
|
||||||
def test_resample_audio_invalid_mode_raises():
|
|
||||||
wave = create_xr_wave(48000, 0.1)
|
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
audio.resample_audio(wave, samplerate=96000, method="invalid_mode")
|
|
||||||
|
|
||||||
|
|
||||||
def test_resample_audio_no_time_dim_raises():
|
|
||||||
wave = xr.DataArray(np.random.rand(100), dims=["samples"])
|
|
||||||
with pytest.raises(ValueError, match="Audio must have a time dimension"):
|
|
||||||
audio.resample_audio(wave, samplerate=96000)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_clip_audio_default_config(
|
|
||||||
dummy_clip: data.Clip,
|
|
||||||
default_audio_config: audio.AudioConfig,
|
|
||||||
tmp_path: Path,
|
|
||||||
):
|
|
||||||
assert default_audio_config.resample is not None
|
|
||||||
target_sr = default_audio_config.resample.samplerate
|
|
||||||
orig_duration = dummy_clip.duration
|
|
||||||
expected_samples = int(orig_duration * target_sr)
|
|
||||||
|
|
||||||
wav = audio.load_clip_audio(
|
|
||||||
dummy_clip, config=default_audio_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(wav, xr.DataArray)
|
|
||||||
assert wav.dims == ("time",)
|
|
||||||
assert wav.sizes["time"] == expected_samples
|
|
||||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
|
||||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
|
||||||
assert wav.dtype == np.float32
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_clip_audio_no_resample(
|
|
||||||
dummy_clip: data.Clip,
|
|
||||||
no_resample_config: audio.AudioConfig,
|
|
||||||
tmp_path: Path,
|
|
||||||
):
|
|
||||||
orig_sr = dummy_clip.recording.samplerate
|
|
||||||
orig_duration = dummy_clip.duration
|
|
||||||
expected_samples = int(orig_duration * orig_sr)
|
|
||||||
|
|
||||||
wav = audio.load_clip_audio(
|
|
||||||
dummy_clip, config=no_resample_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert wav.coords["time"].attrs["step"] == 1 / orig_sr
|
|
||||||
assert wav.sizes["time"] == expected_samples
|
|
||||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_clip_audio_fixed_duration_crop(
|
|
||||||
dummy_clip: data.Clip,
|
|
||||||
fixed_duration_config: audio.AudioConfig,
|
|
||||||
tmp_path: Path,
|
|
||||||
):
|
|
||||||
target_sr = audio.TARGET_SAMPLERATE_HZ
|
|
||||||
target_duration = fixed_duration_config.duration
|
|
||||||
assert target_duration is not None
|
|
||||||
expected_samples = int(target_duration * target_sr)
|
|
||||||
|
|
||||||
assert dummy_clip.duration > target_duration
|
|
||||||
|
|
||||||
wav = audio.load_clip_audio(
|
|
||||||
dummy_clip, config=fixed_duration_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
|
||||||
assert wav.sizes["time"] == expected_samples
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_clip_audio_fixed_duration_pad(
|
|
||||||
dummy_clip: data.Clip,
|
|
||||||
tmp_path: Path,
|
|
||||||
):
|
|
||||||
target_duration = dummy_clip.duration * 2
|
|
||||||
config = audio.AudioConfig(duration=target_duration)
|
|
||||||
|
|
||||||
assert config.resample is not None
|
|
||||||
target_sr = config.resample.samplerate
|
|
||||||
expected_samples = int(target_duration * target_sr)
|
|
||||||
|
|
||||||
wav = audio.load_clip_audio(dummy_clip, config=config, audio_dir=tmp_path)
|
|
||||||
|
|
||||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
|
||||||
assert wav.sizes["time"] == expected_samples
|
|
||||||
|
|
||||||
original_samples_after_resample = int(dummy_clip.duration * target_sr)
|
|
||||||
assert np.allclose(
|
|
||||||
wav.values[original_samples_after_resample:], 0.0, atol=1e-6
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_clip_audio_scale(
|
|
||||||
dummy_clip: data.Clip, scale_config: audio.AudioConfig, tmp_path
|
|
||||||
):
|
|
||||||
wav = audio.load_clip_audio(
|
|
||||||
dummy_clip,
|
|
||||||
config=scale_config,
|
|
||||||
audio_dir=tmp_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert np.isclose(np.max(np.abs(wav.values)), 1.0, atol=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_clip_audio_no_center(
|
|
||||||
dummy_clip: data.Clip, no_center_config: audio.AudioConfig, tmp_path
|
|
||||||
):
|
|
||||||
wav = audio.load_clip_audio(
|
|
||||||
dummy_clip, config=no_center_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
raw_wav, _ = sf.read(
|
|
||||||
dummy_clip.recording.path,
|
|
||||||
start=int(dummy_clip.start_time * dummy_clip.recording.samplerate),
|
|
||||||
stop=int(dummy_clip.end_time * dummy_clip.recording.samplerate),
|
|
||||||
dtype=np.float32, # type: ignore
|
|
||||||
)
|
|
||||||
raw_wav_mono = raw_wav[:, 0]
|
|
||||||
|
|
||||||
if not np.isclose(raw_wav_mono.mean(), 0.0, atol=1e-7):
|
|
||||||
assert not np.isclose(wav.mean(), 0.0, atol=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_clip_audio_resample_fourier(
|
|
||||||
dummy_clip: data.Clip, resample_fourier_config: audio.AudioConfig, tmp_path
|
|
||||||
):
|
|
||||||
assert resample_fourier_config.resample is not None
|
|
||||||
target_sr = resample_fourier_config.resample.samplerate
|
|
||||||
orig_duration = dummy_clip.duration
|
|
||||||
expected_samples = int(orig_duration * target_sr)
|
|
||||||
|
|
||||||
wav = audio.load_clip_audio(
|
|
||||||
dummy_clip, config=resample_fourier_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
|
||||||
assert wav.sizes["time"] == expected_samples
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_clip_audio_dtype(
|
|
||||||
dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path
|
|
||||||
):
|
|
||||||
wav = audio.load_clip_audio(
|
|
||||||
dummy_clip,
|
|
||||||
config=default_audio_config,
|
|
||||||
audio_dir=tmp_path,
|
|
||||||
dtype=np.float64,
|
|
||||||
)
|
|
||||||
assert wav.dtype == np.float64
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_clip_audio_file_not_found(
|
|
||||||
dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path
|
|
||||||
):
|
|
||||||
non_existent_path = tmp_path / "not_a_real_file.wav"
|
|
||||||
dummy_clip.recording = data.Recording(
|
|
||||||
path=non_existent_path,
|
|
||||||
duration=1,
|
|
||||||
channels=1,
|
|
||||||
samplerate=256000,
|
|
||||||
)
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
audio.load_clip_audio(
|
|
||||||
dummy_clip, config=default_audio_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_recording_audio(
|
|
||||||
dummy_recording: data.Recording,
|
|
||||||
default_audio_config: audio.AudioConfig,
|
|
||||||
tmp_path,
|
|
||||||
):
|
|
||||||
assert default_audio_config.resample is not None
|
|
||||||
target_sr = default_audio_config.resample.samplerate
|
|
||||||
orig_duration = dummy_recording.duration
|
|
||||||
expected_samples = int(orig_duration * target_sr)
|
|
||||||
|
|
||||||
wav = audio.load_recording_audio(
|
|
||||||
dummy_recording, config=default_audio_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(wav, xr.DataArray)
|
|
||||||
assert wav.dims == ("time",)
|
|
||||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
|
||||||
assert wav.sizes["time"] == expected_samples
|
|
||||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
|
||||||
assert wav.dtype == np.float32
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_recording_audio_file_not_found(
|
|
||||||
dummy_recording: data.Recording,
|
|
||||||
default_audio_config: audio.AudioConfig,
|
|
||||||
tmp_path,
|
|
||||||
):
|
|
||||||
non_existent_path = tmp_path / "not_a_real_file.wav"
|
|
||||||
dummy_recording = data.Recording(
|
|
||||||
path=non_existent_path,
|
|
||||||
duration=1,
|
|
||||||
channels=1,
|
|
||||||
samplerate=256000,
|
|
||||||
)
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
audio.load_recording_audio(
|
|
||||||
dummy_recording, config=default_audio_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_file_audio(
|
|
||||||
dummy_wav_path: pathlib.Path,
|
|
||||||
default_audio_config: audio.AudioConfig,
|
|
||||||
tmp_path,
|
|
||||||
):
|
|
||||||
info = sf.info(dummy_wav_path)
|
|
||||||
orig_duration = info.duration
|
|
||||||
assert default_audio_config.resample is not None
|
|
||||||
target_sr = default_audio_config.resample.samplerate
|
|
||||||
expected_samples = int(orig_duration * target_sr)
|
|
||||||
|
|
||||||
wav = audio.load_file_audio(
|
|
||||||
dummy_wav_path, config=default_audio_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(wav, xr.DataArray)
|
|
||||||
assert wav.dims == ("time",)
|
|
||||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
|
||||||
assert wav.sizes["time"] == expected_samples
|
|
||||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
|
||||||
assert wav.dtype == np.float32
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_file_audio_file_not_found(
|
|
||||||
default_audio_config: audio.AudioConfig, tmp_path
|
|
||||||
):
|
|
||||||
non_existent_path = tmp_path / "not_a_real_file.wav"
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
audio.load_file_audio(
|
|
||||||
non_existent_path, config=default_audio_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_audio_loader(default_audio_config: audio.AudioConfig):
|
|
||||||
loader = audio.build_audio_loader(config=default_audio_config)
|
|
||||||
assert isinstance(loader, audio.ConfigurableAudioLoader)
|
|
||||||
assert loader.config == default_audio_config
|
|
||||||
|
|
||||||
|
|
||||||
def test_configurable_audio_loader_methods(
|
|
||||||
default_audio_config: audio.AudioConfig,
|
|
||||||
dummy_wav_path: pathlib.Path,
|
|
||||||
dummy_recording: data.Recording,
|
|
||||||
dummy_clip: data.Clip,
|
|
||||||
tmp_path,
|
|
||||||
):
|
|
||||||
loader = audio.build_audio_loader(config=default_audio_config)
|
|
||||||
|
|
||||||
expected_wav_file = audio.load_file_audio(
|
|
||||||
dummy_wav_path, config=default_audio_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
loaded_wav_file = loader.load_file(dummy_wav_path, audio_dir=tmp_path)
|
|
||||||
xr.testing.assert_equal(expected_wav_file, loaded_wav_file)
|
|
||||||
|
|
||||||
expected_wav_rec = audio.load_recording_audio(
|
|
||||||
dummy_recording, config=default_audio_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
loaded_wav_rec = loader.load_recording(dummy_recording, audio_dir=tmp_path)
|
|
||||||
xr.testing.assert_equal(expected_wav_rec, loaded_wav_rec)
|
|
||||||
|
|
||||||
expected_wav_clip = audio.load_clip_audio(
|
|
||||||
dummy_clip, config=default_audio_config, audio_dir=tmp_path
|
|
||||||
)
|
|
||||||
loaded_wav_clip = loader.load_clip(dummy_clip, audio_dir=tmp_path)
|
|
||||||
xr.testing.assert_equal(expected_wav_clip, loaded_wav_clip)
|
|
||||||
|
|||||||
@ -1,32 +1,7 @@
|
|||||||
import math
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import arrays
|
|
||||||
|
|
||||||
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
|
|
||||||
from batdetect2.preprocess.spectrogram import (
|
|
||||||
MAX_FREQ,
|
|
||||||
MIN_FREQ,
|
|
||||||
ConfigurableSpectrogramBuilder,
|
|
||||||
FrequencyConfig,
|
|
||||||
PcenConfig,
|
|
||||||
SpecSizeConfig,
|
|
||||||
SpectrogramConfig,
|
|
||||||
STFTConfig,
|
|
||||||
apply_pcen,
|
|
||||||
build_spectrogram_builder,
|
|
||||||
compute_spectrogram,
|
|
||||||
crop_spectrogram_frequencies,
|
|
||||||
get_spectrogram_resolution,
|
|
||||||
remove_spectral_mean,
|
|
||||||
resize_spectrogram,
|
|
||||||
scale_spectrogram,
|
|
||||||
stft,
|
|
||||||
)
|
|
||||||
|
|
||||||
SAMPLERATE = 250_000
|
SAMPLERATE = 250_000
|
||||||
DURATION = 0.1
|
DURATION = 0.1
|
||||||
@ -61,389 +36,3 @@ def constant_wave_xr() -> xr.DataArray:
|
|||||||
dims=["time"],
|
dims=["time"],
|
||||||
attrs={"samplerate": SAMPLERATE},
|
attrs={"samplerate": SAMPLERATE},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_spec(sine_wave_xr: xr.DataArray) -> xr.DataArray:
|
|
||||||
"""Generate a basic spectrogram for testing downstream functions."""
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
stft=STFTConfig(window_duration=0.002, window_overlap=0.5),
|
|
||||||
frequencies=FrequencyConfig(
|
|
||||||
min_freq=0,
|
|
||||||
max_freq=int(SAMPLERATE / 2),
|
|
||||||
),
|
|
||||||
size=None,
|
|
||||||
pcen=None,
|
|
||||||
spectral_mean_substraction=False,
|
|
||||||
peak_normalize=False,
|
|
||||||
scale="amplitude",
|
|
||||||
)
|
|
||||||
spec = stft(
|
|
||||||
sine_wave_xr,
|
|
||||||
window_duration=config.stft.window_duration,
|
|
||||||
window_overlap=config.stft.window_overlap,
|
|
||||||
window_fn=config.stft.window_fn,
|
|
||||||
)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def test_stft_config_defaults():
|
|
||||||
config = STFTConfig()
|
|
||||||
assert config.window_duration == 0.002
|
|
||||||
assert config.window_overlap == 0.75
|
|
||||||
assert config.window_fn == "hann"
|
|
||||||
|
|
||||||
|
|
||||||
def test_frequency_config_defaults():
|
|
||||||
config = FrequencyConfig()
|
|
||||||
assert config.min_freq == MIN_FREQ
|
|
||||||
assert config.max_freq == MAX_FREQ
|
|
||||||
|
|
||||||
|
|
||||||
def test_spec_size_config_defaults():
|
|
||||||
config = SpecSizeConfig()
|
|
||||||
assert config.height == 128
|
|
||||||
assert config.resize_factor == 0.5
|
|
||||||
|
|
||||||
|
|
||||||
def test_pcen_config_defaults():
|
|
||||||
config = PcenConfig()
|
|
||||||
assert config.time_constant == 0.01
|
|
||||||
assert config.gain == 0.98
|
|
||||||
assert config.bias == 2
|
|
||||||
assert config.power == 0.5
|
|
||||||
|
|
||||||
|
|
||||||
def test_spectrogram_config_defaults():
|
|
||||||
config = SpectrogramConfig()
|
|
||||||
assert isinstance(config.stft, STFTConfig)
|
|
||||||
assert isinstance(config.frequencies, FrequencyConfig)
|
|
||||||
assert isinstance(config.pcen, PcenConfig)
|
|
||||||
assert config.scale == "amplitude"
|
|
||||||
assert isinstance(config.size, SpecSizeConfig)
|
|
||||||
assert config.spectral_mean_substraction is True
|
|
||||||
assert config.peak_normalize is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_stft_output_properties(sine_wave_xr: xr.DataArray):
|
|
||||||
window_duration = 0.002
|
|
||||||
window_overlap = 0.5
|
|
||||||
samplerate = sine_wave_xr.attrs["samplerate"]
|
|
||||||
nfft = int(window_duration * samplerate)
|
|
||||||
hop_len = nfft - int(window_overlap * nfft)
|
|
||||||
|
|
||||||
spec = stft(
|
|
||||||
sine_wave_xr,
|
|
||||||
window_duration=window_duration,
|
|
||||||
window_overlap=window_overlap,
|
|
||||||
window_fn="hann",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(spec, xr.DataArray)
|
|
||||||
assert spec.dims == ("frequency", "time")
|
|
||||||
assert spec.dtype == np.float32
|
|
||||||
assert "frequency" in spec.coords
|
|
||||||
assert "time" in spec.coords
|
|
||||||
|
|
||||||
time_step = arrays.get_dim_step(spec, "time")
|
|
||||||
freq_step = arrays.get_dim_step(spec, "frequency")
|
|
||||||
freq_start, freq_end = arrays.get_dim_range(spec, "frequency")
|
|
||||||
assert np.isclose(freq_step, samplerate / nfft)
|
|
||||||
assert np.isclose(time_step, hop_len / samplerate)
|
|
||||||
assert spec.frequency.min() >= 0
|
|
||||||
assert freq_start == 0
|
|
||||||
assert np.isclose(freq_end, samplerate / 2, atol=freq_step / 2)
|
|
||||||
assert np.isclose(spec.time.min(), 0)
|
|
||||||
assert spec.time.max() < DURATION
|
|
||||||
|
|
||||||
assert spec.attrs["samplerate"] == samplerate
|
|
||||||
assert spec.attrs["window_size"] == window_duration
|
|
||||||
assert spec.attrs["hop_size"] == window_duration * (1 - window_overlap)
|
|
||||||
|
|
||||||
assert np.all(spec.data >= 0)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("window_fn", ["hann", "hamming"])
|
|
||||||
def test_stft_window_fn(sine_wave_xr: xr.DataArray, window_fn: str):
|
|
||||||
spec = stft(
|
|
||||||
sine_wave_xr,
|
|
||||||
window_duration=0.002,
|
|
||||||
window_overlap=0.5,
|
|
||||||
window_fn=window_fn,
|
|
||||||
)
|
|
||||||
assert isinstance(spec, xr.DataArray)
|
|
||||||
assert np.all(spec.data >= 0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_crop_spectrogram_frequencies(sample_spec: xr.DataArray):
|
|
||||||
min_f, max_f = 20_000, 80_000
|
|
||||||
cropped_spec = crop_spectrogram_frequencies(
|
|
||||||
sample_spec, min_freq=min_f, max_freq=max_f
|
|
||||||
)
|
|
||||||
|
|
||||||
assert cropped_spec.dims == sample_spec.dims
|
|
||||||
assert cropped_spec.dtype == sample_spec.dtype
|
|
||||||
assert cropped_spec.sizes["time"] == sample_spec.sizes["time"]
|
|
||||||
assert cropped_spec.sizes["frequency"] < sample_spec.sizes["frequency"]
|
|
||||||
assert cropped_spec.coords["frequency"].min() >= min_f
|
|
||||||
|
|
||||||
assert np.isclose(cropped_spec.coords["frequency"].max(), max_f, rtol=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
|
|
||||||
samplerate = sample_spec.attrs["samplerate"]
|
|
||||||
min_f, max_f = 0, samplerate / 2
|
|
||||||
cropped_spec = crop_spectrogram_frequencies(
|
|
||||||
sample_spec, min_freq=min_f, max_freq=max_f
|
|
||||||
)
|
|
||||||
|
|
||||||
assert cropped_spec.sizes == sample_spec.sizes
|
|
||||||
assert np.allclose(cropped_spec.data, sample_spec.data)
|
|
||||||
|
|
||||||
|
|
||||||
def test_apply_pcen(sample_spec: xr.DataArray):
|
|
||||||
pcen_config = PcenConfig()
|
|
||||||
pcen_spec = apply_pcen(
|
|
||||||
sample_spec,
|
|
||||||
time_constant=pcen_config.time_constant,
|
|
||||||
gain=pcen_config.gain,
|
|
||||||
bias=pcen_config.bias,
|
|
||||||
power=pcen_config.power,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert pcen_spec.dims == sample_spec.dims
|
|
||||||
assert pcen_spec.sizes == sample_spec.sizes
|
|
||||||
assert pcen_spec.dtype == sample_spec.dtype
|
|
||||||
assert np.all(pcen_spec.data >= 0)
|
|
||||||
|
|
||||||
assert not np.allclose(pcen_spec.data, sample_spec.data)
|
|
||||||
|
|
||||||
|
|
||||||
def test_scale_spectrogram_amplitude(sample_spec: xr.DataArray):
|
|
||||||
scaled_spec = scale_spectrogram(sample_spec, scale="amplitude")
|
|
||||||
assert np.allclose(scaled_spec.data, sample_spec.data)
|
|
||||||
assert scaled_spec.dtype == sample_spec.dtype
|
|
||||||
|
|
||||||
|
|
||||||
def test_scale_spectrogram_power(sample_spec: xr.DataArray):
|
|
||||||
scaled_spec = scale_spectrogram(sample_spec, scale="power")
|
|
||||||
assert np.allclose(scaled_spec.data, sample_spec.data**2)
|
|
||||||
assert scaled_spec.dtype == sample_spec.dtype
|
|
||||||
|
|
||||||
|
|
||||||
def test_scale_spectrogram_db(sample_spec: xr.DataArray):
|
|
||||||
scaled_spec = scale_spectrogram(sample_spec, scale="dB")
|
|
||||||
log_spec_expected = arrays.to_db(sample_spec)
|
|
||||||
xr.testing.assert_allclose(scaled_spec, log_spec_expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_remove_spectral_mean(sample_spec: xr.DataArray):
|
|
||||||
spec_noisy = sample_spec.copy() + 0.1
|
|
||||||
denoised_spec = remove_spectral_mean(spec_noisy)
|
|
||||||
|
|
||||||
assert denoised_spec.dims == spec_noisy.dims
|
|
||||||
assert denoised_spec.sizes == spec_noisy.sizes
|
|
||||||
assert denoised_spec.dtype == spec_noisy.dtype
|
|
||||||
assert np.all(denoised_spec.data >= 0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
|
|
||||||
const_spec = stft(constant_wave_xr, 0.002, 0.5)
|
|
||||||
denoised_spec = remove_spectral_mean(const_spec)
|
|
||||||
assert np.all(denoised_spec.data >= 0)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"height, resize_factor, expected_freq_size, expected_time_factor",
|
|
||||||
[
|
|
||||||
(128, 1.0, 128, 1.0),
|
|
||||||
(64, 0.5, 64, 0.5),
|
|
||||||
(256, None, 256, 1.0),
|
|
||||||
(100, 2.0, 100, 2.0),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_resize_spectrogram(
|
|
||||||
sample_spec: xr.DataArray,
|
|
||||||
height: int,
|
|
||||||
resize_factor: Union[float, None],
|
|
||||||
expected_freq_size: int,
|
|
||||||
expected_time_factor: float,
|
|
||||||
):
|
|
||||||
original_time_size = sample_spec.sizes["time"]
|
|
||||||
resized_spec = resize_spectrogram(
|
|
||||||
sample_spec,
|
|
||||||
height=height,
|
|
||||||
resize_factor=resize_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert resized_spec.dims == ("frequency", "time")
|
|
||||||
assert resized_spec.sizes["frequency"] == expected_freq_size
|
|
||||||
expected_time_size = int(original_time_size * expected_time_factor)
|
|
||||||
|
|
||||||
assert abs(resized_spec.sizes["time"] - expected_time_size) <= 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_spectrogram_defaults(sine_wave_xr: xr.DataArray):
|
|
||||||
config = SpectrogramConfig()
|
|
||||||
spec = compute_spectrogram(sine_wave_xr, config=config)
|
|
||||||
|
|
||||||
assert isinstance(spec, xr.DataArray)
|
|
||||||
assert spec.dims == ("frequency", "time")
|
|
||||||
assert spec.dtype == np.float32
|
|
||||||
assert config.size is not None
|
|
||||||
assert spec.sizes["frequency"] == config.size.height
|
|
||||||
|
|
||||||
temp_stft = stft(
|
|
||||||
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
|
|
||||||
)
|
|
||||||
assert config.size.resize_factor is not None
|
|
||||||
expected_time_size = int(
|
|
||||||
temp_stft.sizes["time"] * config.size.resize_factor
|
|
||||||
)
|
|
||||||
assert abs(spec.sizes["time"] - expected_time_size) <= 1
|
|
||||||
|
|
||||||
assert spec.coords["frequency"].min() >= config.frequencies.min_freq
|
|
||||||
assert np.isclose(
|
|
||||||
spec.coords["frequency"].max(),
|
|
||||||
config.frequencies.max_freq,
|
|
||||||
rtol=0.1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_spectrogram_no_pcen_no_mean_sub_no_resize(
|
|
||||||
sine_wave_xr: xr.DataArray,
|
|
||||||
):
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
pcen=None,
|
|
||||||
spectral_mean_substraction=False,
|
|
||||||
size=None,
|
|
||||||
scale="power",
|
|
||||||
frequencies=FrequencyConfig(min_freq=0, max_freq=int(SAMPLERATE / 2)),
|
|
||||||
)
|
|
||||||
spec = compute_spectrogram(sine_wave_xr, config=config)
|
|
||||||
|
|
||||||
stft_direct = stft(
|
|
||||||
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
|
|
||||||
)
|
|
||||||
expected_spec = scale_spectrogram(stft_direct, scale="power")
|
|
||||||
|
|
||||||
assert spec.sizes == expected_spec.sizes
|
|
||||||
assert np.allclose(spec.data, expected_spec.data)
|
|
||||||
assert spec.dtype == expected_spec.dtype
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_spectrogram_peak_normalize(sine_wave_xr: xr.DataArray):
|
|
||||||
config = SpectrogramConfig(peak_normalize=True, pcen=None)
|
|
||||||
spec = compute_spectrogram(sine_wave_xr, config=config)
|
|
||||||
assert np.isclose(spec.data.max(), 1.0, atol=1e-6)
|
|
||||||
|
|
||||||
config = SpectrogramConfig(peak_normalize=False)
|
|
||||||
spec_no_norm = compute_spectrogram(sine_wave_xr, config=config)
|
|
||||||
assert not np.isclose(spec_no_norm.data.max(), 1.0, atol=1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_spectrogram_resolution_calculation():
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
|
|
||||||
size=SpecSizeConfig(height=100, resize_factor=0.5),
|
|
||||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=110_000),
|
|
||||||
)
|
|
||||||
|
|
||||||
freq_res, time_res = get_spectrogram_resolution(config)
|
|
||||||
|
|
||||||
expected_freq_res = (110_000 - 10_000) / 100
|
|
||||||
expected_hop_duration = 0.002 * (1 - 0.75)
|
|
||||||
expected_time_res = expected_hop_duration / 0.5
|
|
||||||
|
|
||||||
assert np.isclose(freq_res, expected_freq_res)
|
|
||||||
assert np.isclose(time_res, expected_time_res)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_spectrogram_resolution_no_resize_factor():
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
stft=STFTConfig(window_duration=0.004, window_overlap=0.5),
|
|
||||||
size=SpecSizeConfig(height=200, resize_factor=None),
|
|
||||||
frequencies=FrequencyConfig(min_freq=20_000, max_freq=120_000),
|
|
||||||
)
|
|
||||||
freq_res, time_res = get_spectrogram_resolution(config)
|
|
||||||
expected_freq_res = (120_000 - 20_000) / 200
|
|
||||||
expected_hop_duration = 0.004 * (1 - 0.5)
|
|
||||||
expected_time_res = expected_hop_duration / 1.0
|
|
||||||
|
|
||||||
assert np.isclose(freq_res, expected_freq_res)
|
|
||||||
assert np.isclose(time_res, expected_time_res)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_spectrogram_resolution_no_size_config():
|
|
||||||
config = SpectrogramConfig(size=None)
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="Spectrogram size configuration is required"
|
|
||||||
):
|
|
||||||
get_spectrogram_resolution(config)
|
|
||||||
|
|
||||||
|
|
||||||
def test_configurable_spectrogram_builder_init():
|
|
||||||
config = SpectrogramConfig()
|
|
||||||
builder = ConfigurableSpectrogramBuilder(config=config, dtype=np.float16)
|
|
||||||
assert builder.config is config
|
|
||||||
assert builder.dtype == np.float16
|
|
||||||
|
|
||||||
|
|
||||||
def test_configurable_spectrogram_builder_call_xr(sine_wave_xr: xr.DataArray):
|
|
||||||
config = SpectrogramConfig()
|
|
||||||
builder = ConfigurableSpectrogramBuilder(config=config)
|
|
||||||
spec_builder = builder(sine_wave_xr)
|
|
||||||
spec_direct = compute_spectrogram(sine_wave_xr, config=config)
|
|
||||||
assert isinstance(spec_builder, xr.DataArray)
|
|
||||||
assert np.allclose(spec_builder.data, spec_direct.data)
|
|
||||||
assert spec_builder.dtype == spec_direct.dtype
|
|
||||||
|
|
||||||
|
|
||||||
def test_configurable_spectrogram_builder_call_np_no_samplerate(
|
|
||||||
sine_wave_xr: xr.DataArray,
|
|
||||||
):
|
|
||||||
config = SpectrogramConfig()
|
|
||||||
builder = ConfigurableSpectrogramBuilder(config=config)
|
|
||||||
wav_np = sine_wave_xr.data
|
|
||||||
with pytest.raises(ValueError, match="Samplerate must be provided"):
|
|
||||||
builder(wav_np, samplerate=None)
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_spectrogram_builder():
|
|
||||||
config = SpectrogramConfig(peak_normalize=True)
|
|
||||||
builder = build_spectrogram_builder(config=config, dtype=np.float64)
|
|
||||||
assert isinstance(builder, ConfigurableSpectrogramBuilder)
|
|
||||||
assert builder.config is config
|
|
||||||
assert builder.dtype == np.float64
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_estimate_spectrogram_resolution(
|
|
||||||
wav_factory: Callable[..., Path],
|
|
||||||
):
|
|
||||||
path = wav_factory(duration=0.2, samplerate=256_000)
|
|
||||||
|
|
||||||
audio_data = load_file_audio(
|
|
||||||
path,
|
|
||||||
config=AudioConfig(resample=None, duration=None),
|
|
||||||
)
|
|
||||||
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
stft=STFTConfig(),
|
|
||||||
size=SpecSizeConfig(height=256, resize_factor=0.5),
|
|
||||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = compute_spectrogram(audio_data, config=config)
|
|
||||||
|
|
||||||
freq_res, time_res = get_spectrogram_resolution(config)
|
|
||||||
|
|
||||||
assert math.isclose(
|
|
||||||
arrays.get_dim_step(spec, dim="frequency"),
|
|
||||||
freq_res,
|
|
||||||
rel_tol=0.1,
|
|
||||||
)
|
|
||||||
assert math.isclose(
|
|
||||||
arrays.get_dim_step(spec, dim="time"),
|
|
||||||
time_res,
|
|
||||||
rel_tol=0.1,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -3,7 +3,11 @@ import pytest
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import (
|
||||||
|
PreprocessingConfig,
|
||||||
|
build_preprocessor,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess.audio import build_audio_loader
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
DEFAULT_ANCHOR,
|
DEFAULT_ANCHOR,
|
||||||
DEFAULT_FREQUENCY_SCALE,
|
DEFAULT_FREQUENCY_SCALE,
|
||||||
@ -275,6 +279,8 @@ def test_get_peak_energy_coordinates(generate_whistle):
|
|||||||
# Build a preprocessor (default config should be fine for this test)
|
# Build a preprocessor (default config should be fine for this test)
|
||||||
preprocessor = build_preprocessor()
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
audio_loader = build_audio_loader()
|
||||||
|
|
||||||
# Define a region of interest that contains the whistle
|
# Define a region of interest that contains the whistle
|
||||||
start_time = 0.2
|
start_time = 0.2
|
||||||
end_time = 0.7
|
end_time = 0.7
|
||||||
@ -285,6 +291,7 @@ def test_get_peak_energy_coordinates(generate_whistle):
|
|||||||
peak_time, peak_freq = get_peak_energy_coordinates(
|
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
audio_loader=audio_loader,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
low_freq=low_freq,
|
low_freq=low_freq,
|
||||||
@ -356,6 +363,7 @@ def test_get_peak_energy_coordinates_with_two_whistles(generate_whistle):
|
|||||||
peak_time, peak_freq = get_peak_energy_coordinates(
|
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
audio_loader=build_audio_loader(),
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
low_freq=low_freq,
|
low_freq=low_freq,
|
||||||
@ -389,6 +397,7 @@ def test_get_peak_energy_coordinates_silent_region(create_recording):
|
|||||||
peak_time, peak_freq = get_peak_energy_coordinates(
|
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
audio_loader=build_audio_loader(),
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
low_freq=low_freq,
|
low_freq=low_freq,
|
||||||
@ -443,17 +452,11 @@ def test_peak_energy_bbox_mapper_encode(generate_whistle):
|
|||||||
|
|
||||||
# Instantiate the mapper with a preprocessor
|
# Instantiate the mapper with a preprocessor
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
PreprocessingConfig.model_validate(
|
PreprocessingConfig.model_validate({"spectrogram": {"transforms": []}})
|
||||||
{
|
|
||||||
"spectrogram": {
|
|
||||||
"pcen": None,
|
|
||||||
"spectral_mean_substraction": False,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
mapper = PeakEnergyBBoxMapper(
|
mapper = PeakEnergyBBoxMapper(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
audio_loader=build_audio_loader(),
|
||||||
time_scale=time_scale,
|
time_scale=time_scale,
|
||||||
frequency_scale=freq_scale,
|
frequency_scale=freq_scale,
|
||||||
)
|
)
|
||||||
@ -493,6 +496,7 @@ def test_peak_energy_bbox_mapper_decode():
|
|||||||
|
|
||||||
mapper = PeakEnergyBBoxMapper(
|
mapper = PeakEnergyBBoxMapper(
|
||||||
preprocessor=build_preprocessor(),
|
preprocessor=build_preprocessor(),
|
||||||
|
audio_loader=build_audio_loader(),
|
||||||
time_scale=time_scale,
|
time_scale=time_scale,
|
||||||
frequency_scale=freq_scale,
|
frequency_scale=freq_scale,
|
||||||
)
|
)
|
||||||
@ -553,7 +557,11 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mapper = PeakEnergyBBoxMapper(preprocessor=preprocessor)
|
audio_loader = build_audio_loader()
|
||||||
|
mapper = PeakEnergyBBoxMapper(
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
)
|
||||||
|
|
||||||
# When
|
# When
|
||||||
# Encode the sound event, then immediately decode the result.
|
# Encode the sound event, then immediately decode the result.
|
||||||
|
|||||||
@ -11,11 +11,12 @@ from batdetect2.train.augmentations import (
|
|||||||
)
|
)
|
||||||
from batdetect2.train.clips import select_subclip
|
from batdetect2.train.clips import select_subclip
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
from batdetect2.typing import AudioLoader, ClipLabeller, PreprocessorProtocol
|
||||||
|
|
||||||
|
|
||||||
def test_mix_examples(
|
def test_mix_examples(
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
sample_labeller: ClipLabeller,
|
sample_labeller: ClipLabeller,
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
):
|
):
|
||||||
@ -30,11 +31,13 @@ def test_mix_examples(
|
|||||||
|
|
||||||
example1 = generate_train_example(
|
example1 = generate_train_example(
|
||||||
clip_annotation_1,
|
clip_annotation_1,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
preprocessor=sample_preprocessor,
|
preprocessor=sample_preprocessor,
|
||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
example2 = generate_train_example(
|
example2 = generate_train_example(
|
||||||
clip_annotation_2,
|
clip_annotation_2,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
preprocessor=sample_preprocessor,
|
preprocessor=sample_preprocessor,
|
||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
@ -51,6 +54,7 @@ def test_mix_examples(
|
|||||||
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
|
||||||
def test_mix_examples_of_different_durations(
|
def test_mix_examples_of_different_durations(
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
sample_labeller: ClipLabeller,
|
sample_labeller: ClipLabeller,
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
duration1: float,
|
duration1: float,
|
||||||
@ -67,11 +71,13 @@ def test_mix_examples_of_different_durations(
|
|||||||
|
|
||||||
example1 = generate_train_example(
|
example1 = generate_train_example(
|
||||||
clip_annotation_1,
|
clip_annotation_1,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
preprocessor=sample_preprocessor,
|
preprocessor=sample_preprocessor,
|
||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
example2 = generate_train_example(
|
example2 = generate_train_example(
|
||||||
clip_annotation_2,
|
clip_annotation_2,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
preprocessor=sample_preprocessor,
|
preprocessor=sample_preprocessor,
|
||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
@ -87,6 +93,7 @@ def test_mix_examples_of_different_durations(
|
|||||||
|
|
||||||
def test_add_echo(
|
def test_add_echo(
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
sample_labeller: ClipLabeller,
|
sample_labeller: ClipLabeller,
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
):
|
):
|
||||||
@ -96,6 +103,7 @@ def test_add_echo(
|
|||||||
|
|
||||||
original = generate_train_example(
|
original = generate_train_example(
|
||||||
clip_annotation_1,
|
clip_annotation_1,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
preprocessor=sample_preprocessor,
|
preprocessor=sample_preprocessor,
|
||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
@ -109,6 +117,7 @@ def test_add_echo(
|
|||||||
|
|
||||||
def test_selected_random_subclip_has_the_correct_width(
|
def test_selected_random_subclip_has_the_correct_width(
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
sample_labeller: ClipLabeller,
|
sample_labeller: ClipLabeller,
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
):
|
):
|
||||||
@ -118,6 +127,7 @@ def test_selected_random_subclip_has_the_correct_width(
|
|||||||
|
|
||||||
original = generate_train_example(
|
original = generate_train_example(
|
||||||
clip_annotation_1,
|
clip_annotation_1,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
preprocessor=sample_preprocessor,
|
preprocessor=sample_preprocessor,
|
||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
@ -128,6 +138,7 @@ def test_selected_random_subclip_has_the_correct_width(
|
|||||||
|
|
||||||
def test_add_echo_after_subclip(
|
def test_add_echo_after_subclip(
|
||||||
sample_preprocessor: PreprocessorProtocol,
|
sample_preprocessor: PreprocessorProtocol,
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
sample_labeller: ClipLabeller,
|
sample_labeller: ClipLabeller,
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
):
|
):
|
||||||
@ -136,6 +147,7 @@ def test_add_echo_after_subclip(
|
|||||||
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
|
||||||
original = generate_train_example(
|
original = generate_train_example(
|
||||||
clip_annotation_1,
|
clip_annotation_1,
|
||||||
|
audio_loader=sample_audio_loader,
|
||||||
preprocessor=sample_preprocessor,
|
preprocessor=sample_preprocessor,
|
||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user