Added min and max freq attributes to preprocessor protocol

This commit is contained in:
mbsantiago 2025-04-23 23:14:31 +01:00
parent 6498b6ca37
commit ac4bb8f023
4 changed files with 43 additions and 68 deletions

View File

@ -144,12 +144,16 @@ class StandardPreprocessor(PreprocessorProtocol):
audio_loader: AudioLoader audio_loader: AudioLoader
spectrogram_builder: SpectrogramBuilder spectrogram_builder: SpectrogramBuilder
default_samplerate: int default_samplerate: int
max_freq: float
min_freq: float
def __init__( def __init__(
self, self,
audio_loader: AudioLoader, audio_loader: AudioLoader,
spectrogram_builder: SpectrogramBuilder, spectrogram_builder: SpectrogramBuilder,
default_samplerate: int, default_samplerate: int,
max_freq: float,
min_freq: float,
) -> None: ) -> None:
"""Initialize the StandardPreprocessor. """Initialize the StandardPreprocessor.
@ -167,6 +171,8 @@ class StandardPreprocessor(PreprocessorProtocol):
self.audio_loader = audio_loader self.audio_loader = audio_loader
self.spectrogram_builder = spectrogram_builder self.spectrogram_builder = spectrogram_builder
self.default_samplerate = default_samplerate self.default_samplerate = default_samplerate
self.max_freq = max_freq
self.min_freq = min_freq
def load_file_audio( def load_file_audio(
self, self,
@ -429,8 +435,14 @@ def build_preprocessor(
if config.audio.resample if config.audio.resample
else TARGET_SAMPLERATE_HZ else TARGET_SAMPLERATE_HZ
) )
min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq
return StandardPreprocessor( return StandardPreprocessor(
audio_loader=build_audio_loader(config.audio), audio_loader=build_audio_loader(config.audio),
spectrogram_builder=build_spectrogram_builder(config.spectrogram), spectrogram_builder=build_spectrogram_builder(config.spectrogram),
default_samplerate=default_samplerate, default_samplerate=default_samplerate,
min_freq=min_freq,
max_freq=max_freq,
) )

View File

@ -286,7 +286,8 @@ def load_recording_audio(
"""Load and preprocess the entire audio content of a recording using config. """Load and preprocess the entire audio content of a recording using config.
Creates a `soundevent.data.Clip` spanning the full duration of the Creates a `soundevent.data.Clip` spanning the full duration of the
recording and then delegates the loading and processing to `load_clip_audio`. recording and then delegates the loading and processing to
`load_clip_audio`.
Parameters Parameters
---------- ----------
@ -636,7 +637,11 @@ def resample_audio_fourier(
If `num` is negative. If `num` is negative.
""" """
ratio = sr_new / sr_orig ratio = sr_new / sr_orig
return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore return resample( # type: ignore
array,
int(array.shape[axis] * ratio),
axis=axis,
)
def convert_to_xr( def convert_to_xr(
@ -649,8 +654,8 @@ def convert_to_xr(
Parameters Parameters
---------- ----------
wav : np.ndarray wav : np.ndarray
The input waveform array. Expected to be 1D or 2D (with the first axis as The input waveform array. Expected to be 1D or 2D (with the first
the channel dimension). axis as the channel dimension).
samplerate : int samplerate : int
The sample rate in Hz. The sample rate in Hz.
dtype : DTypeLike, default=np.float32 dtype : DTypeLike, default=np.float32
@ -673,7 +678,8 @@ def convert_to_xr(
if wav.ndim != 1: if wav.ndim != 1:
raise ValueError( raise ValueError(
"Audio must be 1D array or 2D channel where the first axis is the channel dimension" "Audio must be 1D array or 2D channel where the first "
"axis is the channel dimension"
) )
if wav.size == 0: if wav.size == 0:

View File

@ -21,8 +21,6 @@ The core computation is performed by `compute_spectrogram`.
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import librosa
import librosa.core.spectrum
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from numpy.typing import DTypeLike from numpy.typing import DTypeLike
@ -147,7 +145,8 @@ class SpectrogramConfig(BaseConfig):
"""Unified configuration for spectrogram generation pipeline. """Unified configuration for spectrogram generation pipeline.
Aggregates settings for all steps involved in converting a preprocessed Aggregates settings for all steps involved in converting a preprocessed
audio waveform into a final spectrogram representation suitable for model input. audio waveform into a final spectrogram representation suitable for model
input.
Attributes Attributes
---------- ----------
@ -298,7 +297,8 @@ def compute_spectrogram(
1. Compute STFT magnitude (`stft`). 1. Compute STFT magnitude (`stft`).
2. Crop frequency axis (`crop_spectrogram_frequencies`). 2. Crop frequency axis (`crop_spectrogram_frequencies`).
3. Apply PCEN if configured (`apply_pcen`). 3. Apply PCEN if configured (`apply_pcen`).
4. Apply final amplitude scaling (dB, power, amplitude) (`scale_spectrogram`). 4. Apply final amplitude scaling (dB, power, amplitude)
(`scale_spectrogram`).
5. Apply spectral mean subtraction denoising if enabled. 5. Apply spectral mean subtraction denoising if enabled.
6. Resize dimensions if specified (`resize_spectrogram`). 6. Resize dimensions if specified (`resize_spectrogram`).
7. Apply final peak normalization if enabled. 7. Apply final peak normalization if enabled.
@ -324,9 +324,6 @@ def compute_spectrogram(
------ ------
ValueError ValueError
If `wav` lacks necessary 'time' coordinates or dimensions. If `wav` lacks necessary 'time' coordinates or dimensions.
Exception
Can re-raise exceptions from underlying libraries (e.g., librosa, numpy)
if invalid parameters or data are encountered.
""" """
config = config or SpectrogramConfig() config = config or SpectrogramConfig()
@ -335,7 +332,6 @@ def compute_spectrogram(
window_duration=config.stft.window_duration, window_duration=config.stft.window_duration,
window_overlap=config.stft.window_overlap, window_overlap=config.stft.window_overlap,
window_fn=config.stft.window_fn, window_fn=config.stft.window_fn,
dtype=dtype,
) )
spec = crop_spectrogram_frequencies( spec = crop_spectrogram_frequencies(
@ -410,7 +406,6 @@ def stft(
window_duration: float, window_duration: float,
window_overlap: float, window_overlap: float,
window_fn: str = "hann", window_fn: str = "hann",
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray: ) -> xr.DataArray:
"""Compute the Short-Time Fourier Transform (STFT) magnitude spectrogram. """Compute the Short-Time Fourier Transform (STFT) magnitude spectrogram.
@ -425,11 +420,9 @@ def stft(
window_duration : float window_duration : float
Duration of the STFT window in seconds. Duration of the STFT window in seconds.
window_overlap : float window_overlap : float
Fractional overlap between consecutive windows [0, 1). Fractional overlap between consecutive windows.
window_fn : str, default="hann" window_fn : str, default="hann"
Name of the window function (e.g., "hann", "hamming"). Name of the window function (e.g., "hann", "hamming").
dtype : DTypeLike, default=np.float32
Target data type for the spectrogram array.
Returns Returns
------- -------
@ -442,55 +435,13 @@ def stft(
ValueError ValueError
If sample rate cannot be determined from `wave` coordinates. If sample rate cannot be determined from `wave` coordinates.
""" """
start_time, end_time = arrays.get_dim_range(wave, dim="time") return audio.compute_spectrogram(
step = arrays.get_dim_step(wave, dim="time") wave,
sampling_rate = 1 / step window_size=window_duration,
hop_size=(1 - window_overlap) * window_duration,
nfft = int(window_duration * sampling_rate) window_type=window_fn,
noverlap = int(window_overlap * nfft) scale="amplitude",
hop_len = nfft - noverlap sort_dims=False,
hop_duration = hop_len / sampling_rate
spec, _ = librosa.core.spectrum._spectrogram(
y=wave.data.astype(dtype),
power=1,
n_fft=nfft,
hop_length=nfft - noverlap,
center=False,
window=window_fn,
)
return xr.DataArray(
data=spec.astype(dtype),
dims=["frequency", "time"],
coords={
"frequency": arrays.create_frequency_dim_from_array(
np.linspace(
0,
sampling_rate / 2,
spec.shape[0],
endpoint=False,
dtype=dtype,
),
step=sampling_rate / nfft,
),
"time": arrays.create_time_dim_from_array(
np.linspace(
start_time,
end_time - (window_duration - hop_duration),
spec.shape[1],
endpoint=False,
dtype=dtype,
),
step=hop_duration,
),
},
attrs={
**wave.attrs,
"original_samplerate": sampling_rate,
"nfft": nfft,
"noverlap": noverlap,
},
) )
@ -592,8 +543,10 @@ def apply_pcen(
verified against the specific `soundevent.audio.pcen` implementation verified against the specific `soundevent.audio.pcen` implementation
details. details.
""" """
samplerate = spec.attrs["original_samplerate"] samplerate = spec.attrs["samplerate"]
hop_length = spec.attrs["nfft"] - spec.attrs["noverlap"] hop_size = spec.attrs["hop_size"]
hop_length = int(hop_size * samplerate)
t_frames = time_constant * samplerate / (float(hop_length) * 10) t_frames = time_constant * samplerate / (float(hop_length) * 10)
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
return audio.pcen( return audio.pcen(

View File

@ -168,6 +168,10 @@ class PreprocessorProtocol(Protocol):
loading or spectrogram computation from a waveform. loading or spectrogram computation from a waveform.
""" """
max_freq: float
min_freq: float
def preprocess_file( def preprocess_file(
self, self,
path: data.PathLike, path: data.PathLike,