mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Documented the preprocessing module
This commit is contained in:
parent
19febf2216
commit
638f93fe92
@ -35,7 +35,7 @@ def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
|||||||
audio=AudioConfig(
|
audio=AudioConfig(
|
||||||
resample=ResampleConfig(
|
resample=ResampleConfig(
|
||||||
samplerate=params["target_samp_rate"],
|
samplerate=params["target_samp_rate"],
|
||||||
mode="poly",
|
method="poly",
|
||||||
),
|
),
|
||||||
scale=params["scale_raw_audio"],
|
scale=params["scale_raw_audio"],
|
||||||
center=params["scale_raw_audio"],
|
center=params["scale_raw_audio"],
|
||||||
|
@ -1,4 +1,32 @@
|
|||||||
"""Module containing functions for preprocessing audio clips."""
|
"""Main entry point for the BatDetect2 Preprocessing subsystem.
|
||||||
|
|
||||||
|
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
|
||||||
|
for converting raw audio input (from files or data objects) into processed
|
||||||
|
spectrograms suitable for input to BatDetect2 models. This ensures consistent
|
||||||
|
data handling between model training and inference.
|
||||||
|
|
||||||
|
The preprocessing pipeline consists of two main stages, configured via nested
|
||||||
|
data structures:
|
||||||
|
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
|
||||||
|
processing like resampling, duration adjustment, centering, and scaling.
|
||||||
|
Configured via `AudioConfig`.
|
||||||
|
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
|
||||||
|
the processed waveform using STFT, followed by frequency cropping, optional
|
||||||
|
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
|
||||||
|
resizing, and optional peak normalization. Configured via
|
||||||
|
`SpectrogramConfig`.
|
||||||
|
|
||||||
|
This module provides the primary interface:
|
||||||
|
|
||||||
|
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
|
||||||
|
and `SpectrogramConfig`.
|
||||||
|
- `load_preprocessing_config`: Function to load the unified configuration.
|
||||||
|
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
|
||||||
|
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
|
||||||
|
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
|
||||||
|
instance from a `PreprocessingConfig`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@ -14,13 +42,7 @@ from batdetect2.preprocess.audio import (
|
|||||||
TARGET_SAMPLERATE_HZ,
|
TARGET_SAMPLERATE_HZ,
|
||||||
AudioConfig,
|
AudioConfig,
|
||||||
ResampleConfig,
|
ResampleConfig,
|
||||||
adjust_audio_duration,
|
|
||||||
build_audio_loader,
|
build_audio_loader,
|
||||||
convert_to_xr,
|
|
||||||
load_clip_audio,
|
|
||||||
load_file_audio,
|
|
||||||
load_recording_audio,
|
|
||||||
resample_audio,
|
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.spectrogram import (
|
from batdetect2.preprocess.spectrogram import (
|
||||||
MAX_FREQ,
|
MAX_FREQ,
|
||||||
@ -32,7 +54,6 @@ from batdetect2.preprocess.spectrogram import (
|
|||||||
SpectrogramConfig,
|
SpectrogramConfig,
|
||||||
STFTConfig,
|
STFTConfig,
|
||||||
build_spectrogram_builder,
|
build_spectrogram_builder,
|
||||||
compute_spectrogram,
|
|
||||||
get_spectrogram_resolution,
|
get_spectrogram_resolution,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.types import (
|
from batdetect2.preprocess.types import (
|
||||||
@ -47,44 +68,79 @@ __all__ = [
|
|||||||
"ConfigurableSpectrogramBuilder",
|
"ConfigurableSpectrogramBuilder",
|
||||||
"DEFAULT_DURATION",
|
"DEFAULT_DURATION",
|
||||||
"FrequencyConfig",
|
"FrequencyConfig",
|
||||||
"FrequencyConfig",
|
|
||||||
"MAX_FREQ",
|
"MAX_FREQ",
|
||||||
"MIN_FREQ",
|
"MIN_FREQ",
|
||||||
"PcenConfig",
|
"PcenConfig",
|
||||||
"PcenConfig",
|
|
||||||
"PreprocessingConfig",
|
"PreprocessingConfig",
|
||||||
"ResampleConfig",
|
"ResampleConfig",
|
||||||
"SCALE_RAW_AUDIO",
|
"SCALE_RAW_AUDIO",
|
||||||
"STFTConfig",
|
"STFTConfig",
|
||||||
"STFTConfig",
|
|
||||||
"SpecSizeConfig",
|
|
||||||
"SpecSizeConfig",
|
"SpecSizeConfig",
|
||||||
"SpectrogramBuilder",
|
"SpectrogramBuilder",
|
||||||
"SpectrogramConfig",
|
"SpectrogramConfig",
|
||||||
"SpectrogramConfig",
|
"StandardPreprocessor",
|
||||||
"TARGET_SAMPLERATE_HZ",
|
"TARGET_SAMPLERATE_HZ",
|
||||||
"adjust_audio_duration",
|
|
||||||
"build_audio_loader",
|
"build_audio_loader",
|
||||||
|
"build_preprocessor",
|
||||||
"build_spectrogram_builder",
|
"build_spectrogram_builder",
|
||||||
"compute_spectrogram",
|
|
||||||
"convert_to_xr",
|
|
||||||
"get_spectrogram_resolution",
|
"get_spectrogram_resolution",
|
||||||
"load_clip_audio",
|
|
||||||
"load_file_audio",
|
|
||||||
"load_preprocessing_config",
|
"load_preprocessing_config",
|
||||||
"load_recording_audio",
|
|
||||||
"resample_audio",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingConfig(BaseConfig):
|
class PreprocessingConfig(BaseConfig):
|
||||||
"""Configuration for preprocessing data."""
|
"""Unified configuration for the audio preprocessing pipeline.
|
||||||
|
|
||||||
|
Aggregates the configuration for both the initial audio processing stage
|
||||||
|
and the subsequent spectrogram generation stage.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
audio : AudioConfig
|
||||||
|
Configuration settings for the audio loading and initial waveform
|
||||||
|
processing steps (e.g., resampling, duration adjustment, scaling).
|
||||||
|
Defaults to default `AudioConfig` settings if omitted.
|
||||||
|
spectrogram : SpectrogramConfig
|
||||||
|
Configuration settings for the spectrogram generation process
|
||||||
|
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
||||||
|
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||||
|
"""
|
||||||
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||||
|
|
||||||
|
|
||||||
class StandardPreprocessor(Preprocessor):
|
class StandardPreprocessor(Preprocessor):
|
||||||
|
"""Standard implementation of the `Preprocessor` protocol.
|
||||||
|
|
||||||
|
Orchestrates the audio loading and spectrogram generation pipeline using
|
||||||
|
an `AudioLoader` and a `SpectrogramBuilder` internally, which are
|
||||||
|
configured according to a `PreprocessingConfig`.
|
||||||
|
|
||||||
|
This class is typically instantiated using the `build_preprocessor`
|
||||||
|
factory function.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
audio_loader : AudioLoader
|
||||||
|
The configured audio loader instance used for waveform loading and
|
||||||
|
initial processing.
|
||||||
|
spectrogram_builder : SpectrogramBuilder
|
||||||
|
The configured spectrogram builder instance used for generating
|
||||||
|
spectrograms from waveforms.
|
||||||
|
default_samplerate : int
|
||||||
|
The sample rate (in Hz) assumed for input waveforms when they are
|
||||||
|
provided as raw NumPy arrays without coordinate information (e.g.,
|
||||||
|
when calling `compute_spectrogram` directly with `np.ndarray`).
|
||||||
|
This value is derived from the `AudioConfig` (target resample rate
|
||||||
|
or default if resampling is off) and also serves as documentation
|
||||||
|
for the pipeline's intended operating sample rate. Note that when
|
||||||
|
processing `xr.DataArray` inputs that have coordinate information
|
||||||
|
(the standard internal workflow), the sample rate embedded in the
|
||||||
|
coordinates takes precedence over this default value during
|
||||||
|
spectrogram calculation.
|
||||||
|
"""
|
||||||
|
|
||||||
audio_loader: AudioLoader
|
audio_loader: AudioLoader
|
||||||
spectrogram_builder: SpectrogramBuilder
|
spectrogram_builder: SpectrogramBuilder
|
||||||
default_samplerate: int
|
default_samplerate: int
|
||||||
@ -95,6 +151,19 @@ class StandardPreprocessor(Preprocessor):
|
|||||||
spectrogram_builder: SpectrogramBuilder,
|
spectrogram_builder: SpectrogramBuilder,
|
||||||
default_samplerate: int,
|
default_samplerate: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the StandardPreprocessor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_loader : AudioLoader
|
||||||
|
An initialized audio loader conforming to the AudioLoader protocol.
|
||||||
|
spectrogram_builder : SpectrogramBuilder
|
||||||
|
An initialized spectrogram builder conforming to the
|
||||||
|
SpectrogramBuilder protocol.
|
||||||
|
default_samplerate : int
|
||||||
|
The sample rate to assume for NumPy array inputs and potentially
|
||||||
|
reflecting the target rate of the audio config.
|
||||||
|
"""
|
||||||
self.audio_loader = audio_loader
|
self.audio_loader = audio_loader
|
||||||
self.spectrogram_builder = spectrogram_builder
|
self.spectrogram_builder = spectrogram_builder
|
||||||
self.default_samplerate = default_samplerate
|
self.default_samplerate = default_samplerate
|
||||||
@ -104,6 +173,23 @@ class StandardPreprocessor(Preprocessor):
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
"""Load and preprocess *only* the audio waveform from a file path.
|
||||||
|
|
||||||
|
Delegates to the internal `audio_loader`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the audio file.
|
||||||
|
audio_dir : PathLike, optional
|
||||||
|
A directory prefix if `path` is relative.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
The loaded and preprocessed audio waveform (typically first
|
||||||
|
channel).
|
||||||
|
"""
|
||||||
return self.audio_loader.load_file(
|
return self.audio_loader.load_file(
|
||||||
path,
|
path,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
@ -114,6 +200,23 @@ class StandardPreprocessor(Preprocessor):
|
|||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
"""Load and preprocess *only* the audio waveform for a Recording.
|
||||||
|
|
||||||
|
Delegates to the internal `audio_loader`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
recording : data.Recording
|
||||||
|
The Recording object.
|
||||||
|
audio_dir : PathLike, optional
|
||||||
|
Directory containing the audio file.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
The loaded and preprocessed audio waveform (typically first
|
||||||
|
channel).
|
||||||
|
"""
|
||||||
return self.audio_loader.load_recording(
|
return self.audio_loader.load_recording(
|
||||||
recording,
|
recording,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
@ -124,6 +227,23 @@ class StandardPreprocessor(Preprocessor):
|
|||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
"""Load and preprocess *only* the audio waveform for a Clip.
|
||||||
|
|
||||||
|
Delegates to the internal `audio_loader`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clip : data.Clip
|
||||||
|
The Clip object defining the segment.
|
||||||
|
audio_dir : PathLike, optional
|
||||||
|
Directory containing the audio file.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
The loaded and preprocessed audio waveform segment (typically first
|
||||||
|
channel).
|
||||||
|
"""
|
||||||
return self.audio_loader.load_clip(
|
return self.audio_loader.load_clip(
|
||||||
clip,
|
clip,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
@ -134,6 +254,24 @@ class StandardPreprocessor(Preprocessor):
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
"""Load audio from a file and compute the final processed spectrogram.
|
||||||
|
|
||||||
|
Performs the full pipeline:
|
||||||
|
|
||||||
|
Load -> Preprocess Audio -> Compute Spectrogram.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the audio file.
|
||||||
|
audio_dir : PathLike, optional
|
||||||
|
A directory prefix if `path` is relative.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
The final processed spectrogram.
|
||||||
|
"""
|
||||||
wav = self.load_file_audio(path, audio_dir=audio_dir)
|
wav = self.load_file_audio(path, audio_dir=audio_dir)
|
||||||
return self.spectrogram_builder(
|
return self.spectrogram_builder(
|
||||||
wav,
|
wav,
|
||||||
@ -145,6 +283,22 @@ class StandardPreprocessor(Preprocessor):
|
|||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
"""Load audio for a Recording and compute the processed spectrogram.
|
||||||
|
|
||||||
|
Performs the full pipeline for the entire duration of the recording.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
recording : data.Recording
|
||||||
|
The Recording object.
|
||||||
|
audio_dir : PathLike, optional
|
||||||
|
Directory containing the audio file.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
The final processed spectrogram.
|
||||||
|
"""
|
||||||
wav = self.load_recording_audio(recording, audio_dir=audio_dir)
|
wav = self.load_recording_audio(recording, audio_dir=audio_dir)
|
||||||
return self.spectrogram_builder(
|
return self.spectrogram_builder(
|
||||||
wav,
|
wav,
|
||||||
@ -156,6 +310,22 @@ class StandardPreprocessor(Preprocessor):
|
|||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
"""Load audio for a Clip and compute the final processed spectrogram.
|
||||||
|
|
||||||
|
Performs the full pipeline for the specified clip segment.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clip : data.Clip
|
||||||
|
The Clip object defining the audio segment.
|
||||||
|
audio_dir : PathLike, optional
|
||||||
|
Directory containing the audio file.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
The final processed spectrogram.
|
||||||
|
"""
|
||||||
wav = self.load_clip_audio(clip, audio_dir=audio_dir)
|
wav = self.load_clip_audio(clip, audio_dir=audio_dir)
|
||||||
return self.spectrogram_builder(
|
return self.spectrogram_builder(
|
||||||
wav,
|
wav,
|
||||||
@ -165,6 +335,27 @@ class StandardPreprocessor(Preprocessor):
|
|||||||
def compute_spectrogram(
|
def compute_spectrogram(
|
||||||
self, wav: Union[xr.DataArray, np.ndarray]
|
self, wav: Union[xr.DataArray, np.ndarray]
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
"""Compute the spectrogram from a pre-loaded audio waveform.
|
||||||
|
|
||||||
|
Applies the configured spectrogram generation steps
|
||||||
|
(STFT, scaling, etc.) using the internal `spectrogram_builder`.
|
||||||
|
|
||||||
|
If `wav` is a NumPy array, the `default_samplerate` stored in this
|
||||||
|
preprocessor instance will be used. If `wav` is an xarray DataArray
|
||||||
|
with time coordinates, the sample rate derived from those coordinates
|
||||||
|
will take precedence over `default_samplerate`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
wav : Union[xr.DataArray, np.ndarray]
|
||||||
|
The input audio waveform. If numpy array, `default_samplerate`
|
||||||
|
stored in this object will be assumed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
The computed spectrogram.
|
||||||
|
"""
|
||||||
return self.spectrogram_builder(
|
return self.spectrogram_builder(
|
||||||
wav,
|
wav,
|
||||||
samplerate=self.default_samplerate,
|
samplerate=self.default_samplerate,
|
||||||
@ -175,12 +366,64 @@ def load_preprocessing_config(
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
) -> PreprocessingConfig:
|
) -> PreprocessingConfig:
|
||||||
|
"""Load the unified preprocessing configuration from a file.
|
||||||
|
|
||||||
|
Reads a configuration file (YAML) and validates it against the
|
||||||
|
`PreprocessingConfig` schema, potentially extracting data from a nested
|
||||||
|
field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
preprocessing configuration (e.g., "train.preprocessing"). If None, the
|
||||||
|
entire file content is validated as the PreprocessingConfig.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
PreprocessingConfig
|
||||||
|
Loaded and validated preprocessing configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
yaml.YAMLError
|
||||||
|
If the file content is not valid YAML.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the loaded config data does not conform to PreprocessingConfig.
|
||||||
|
KeyError, TypeError
|
||||||
|
If `field` specifies an invalid path.
|
||||||
|
"""
|
||||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
def build_preprocessor_from_config(
|
def build_preprocessor(
|
||||||
config: PreprocessingConfig,
|
config: Optional[PreprocessingConfig] = None,
|
||||||
) -> Preprocessor:
|
) -> Preprocessor:
|
||||||
|
"""Factory function to build the standard preprocessor from configuration.
|
||||||
|
|
||||||
|
Creates instances of the required `AudioLoader` and `SpectrogramBuilder`
|
||||||
|
based on the provided `PreprocessingConfig` (or defaults if config is None),
|
||||||
|
determines the effective default sample rate, and initializes the
|
||||||
|
`StandardPreprocessor`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : PreprocessingConfig, optional
|
||||||
|
The unified preprocessing configuration object. If None, default
|
||||||
|
configurations for audio and spectrogram processing will be used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Preprocessor
|
||||||
|
An initialized `StandardPreprocessor` instance ready to process audio
|
||||||
|
according to the configuration.
|
||||||
|
"""
|
||||||
|
config = config or PreprocessingConfig()
|
||||||
|
|
||||||
default_samplerate = (
|
default_samplerate = (
|
||||||
config.audio.resample.samplerate
|
config.audio.resample.samplerate
|
||||||
if config.audio.resample
|
if config.audio.resample
|
||||||
|
@ -125,7 +125,7 @@ def no_center_config() -> audio.AudioConfig:
|
|||||||
def resample_fourier_config() -> audio.AudioConfig:
|
def resample_fourier_config() -> audio.AudioConfig:
|
||||||
return audio.AudioConfig(
|
return audio.AudioConfig(
|
||||||
resample=audio.ResampleConfig(
|
resample=audio.ResampleConfig(
|
||||||
samplerate=audio.TARGET_SAMPLERATE_HZ // 2, mode="fourier"
|
samplerate=audio.TARGET_SAMPLERATE_HZ // 2, method="fourier"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -133,21 +133,21 @@ def resample_fourier_config() -> audio.AudioConfig:
|
|||||||
def test_resample_config_defaults():
|
def test_resample_config_defaults():
|
||||||
config = audio.ResampleConfig()
|
config = audio.ResampleConfig()
|
||||||
assert config.samplerate == audio.TARGET_SAMPLERATE_HZ
|
assert config.samplerate == audio.TARGET_SAMPLERATE_HZ
|
||||||
assert config.mode == "poly"
|
assert config.method == "poly"
|
||||||
|
|
||||||
|
|
||||||
def test_audio_config_defaults():
|
def test_audio_config_defaults():
|
||||||
config = audio.AudioConfig()
|
config = audio.AudioConfig()
|
||||||
assert config.resample is not None
|
assert config.resample is not None
|
||||||
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
|
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
|
||||||
assert config.resample.mode == "poly"
|
assert config.resample.method == "poly"
|
||||||
assert config.scale == audio.SCALE_RAW_AUDIO
|
assert config.scale == audio.SCALE_RAW_AUDIO
|
||||||
assert config.center is True
|
assert config.center is True
|
||||||
assert config.duration == audio.DEFAULT_DURATION
|
assert config.duration == audio.DEFAULT_DURATION
|
||||||
|
|
||||||
|
|
||||||
def test_audio_config_override():
|
def test_audio_config_override():
|
||||||
resample_cfg = audio.ResampleConfig(samplerate=44100, mode="fourier")
|
resample_cfg = audio.ResampleConfig(samplerate=44100, method="fourier")
|
||||||
config = audio.AudioConfig(
|
config = audio.AudioConfig(
|
||||||
resample=resample_cfg,
|
resample=resample_cfg,
|
||||||
scale=True,
|
scale=True,
|
||||||
@ -206,7 +206,7 @@ def test_resample_audio(orig_sr, target_sr, mode):
|
|||||||
duration = 0.1
|
duration = 0.1
|
||||||
wave = create_xr_wave(orig_sr, duration)
|
wave = create_xr_wave(orig_sr, duration)
|
||||||
resampled_wave = audio.resample_audio(
|
resampled_wave = audio.resample_audio(
|
||||||
wave, samplerate=target_sr, mode=mode, dtype=np.float32
|
wave, samplerate=target_sr, method=mode, dtype=np.float32
|
||||||
)
|
)
|
||||||
expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr))
|
expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr))
|
||||||
assert resampled_wave.sizes["time"] == expected_samples
|
assert resampled_wave.sizes["time"] == expected_samples
|
||||||
@ -233,7 +233,7 @@ def test_resample_audio_same_samplerate():
|
|||||||
def test_resample_audio_invalid_mode_raises():
|
def test_resample_audio_invalid_mode_raises():
|
||||||
wave = create_xr_wave(48000, 0.1)
|
wave = create_xr_wave(48000, 0.1)
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
audio.resample_audio(wave, samplerate=96000, mode="invalid_mode")
|
audio.resample_audio(wave, samplerate=96000, method="invalid_mode")
|
||||||
|
|
||||||
|
|
||||||
def test_resample_audio_no_time_dim_raises():
|
def test_resample_audio_no_time_dim_raises():
|
||||||
|
Loading…
Reference in New Issue
Block a user