Use soundevent.to_db

This commit is contained in:
mbsantiago 2025-06-19 00:28:01 +01:00
parent 835fe1ccdf
commit ebc89af4c6

View File

@ -19,7 +19,7 @@ reproducible spectrogram generation consistent between training and inference.
The core computation is performed by `compute_spectrogram`.
"""
from typing import Literal, Optional, Union
from typing import Callable, Literal, Optional, Union
import numpy as np
import xarray as xr
@ -327,44 +327,45 @@ def compute_spectrogram(
"""
config = config or SpectrogramConfig()
spec = stft(
wav,
window_duration=config.stft.window_duration,
window_overlap=config.stft.window_overlap,
window_fn=config.stft.window_fn,
)
spec = crop_spectrogram_frequencies(
spec,
min_freq=config.frequencies.min_freq,
max_freq=config.frequencies.max_freq,
)
if config.pcen:
spec = apply_pcen(
spec,
time_constant=config.pcen.time_constant,
gain=config.pcen.gain,
power=config.pcen.power,
bias=config.pcen.bias,
with xr.set_options(keep_attrs=True):
spec = stft(
wav,
window_duration=config.stft.window_duration,
window_overlap=config.stft.window_overlap,
window_fn=config.stft.window_fn,
)
spec = scale_spectrogram(spec, scale=config.scale)
if config.spectral_mean_substraction:
spec = remove_spectral_mean(spec)
if config.size:
spec = resize_spectrogram(
spec = crop_spectrogram_frequencies(
spec,
height=config.size.height,
resize_factor=config.size.resize_factor,
min_freq=config.frequencies.min_freq,
max_freq=config.frequencies.max_freq,
)
if config.peak_normalize:
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
if config.pcen:
spec = apply_pcen(
spec,
time_constant=config.pcen.time_constant,
gain=config.pcen.gain,
power=config.pcen.power,
bias=config.pcen.bias,
)
return spec.astype(dtype)
spec = scale_spectrogram(spec, scale=config.scale)
if config.spectral_mean_substraction:
spec = remove_spectral_mean(spec)
if config.size:
spec = resize_spectrogram(
spec,
height=config.size.height,
resize_factor=config.size.resize_factor,
)
if config.peak_normalize:
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
return spec.astype(dtype)
def crop_spectrogram_frequencies(
@ -477,7 +478,7 @@ def scale_spectrogram(
"""Apply final amplitude scaling/representation to the spectrogram.
Converts the input magnitude spectrogram based on the `scale` type:
- "dB": Applies logarithmic scaling `log1p(C * S)`.
- "dB": Applies logarithmic scaling `log10(S)`.
- "power": Squares the magnitude values `S^2`.
- "amplitude": Returns the input magnitude values `S` unchanged.
@ -496,7 +497,7 @@ def scale_spectrogram(
Spectrogram with the specified amplitude scaling applied.
"""
if scale == "dB":
return scale_log(spec, dtype=dtype)
return arrays.to_db(spec).astype(dtype)
if scale == "power":
return spec**2
@ -561,12 +562,13 @@ def apply_pcen(
def scale_log(
spec: xr.DataArray,
dtype: DTypeLike = np.float32, # type: ignore
ref: Union[float, Callable] = np.max,
amin: float = 1e-10,
top_db: Optional[float] = 80.0,
) -> xr.DataArray:
"""Apply logarithmic scaling to a magnitude spectrogram.
Calculates `log(1 + C * S)`, where S is the input magnitude spectrogram
and C is a scaling factor derived from the original STFT parameters
(sample rate, N-FFT, window function) stored in `spec.attrs`.
Calculates `log10(S)`, where S is the input magnitude spectrogram.
Parameters
----------
@ -587,12 +589,29 @@ def scale_log(
If required attributes are missing from `spec.attrs`.
ValueError
If attributes are non-numeric or window function is invalid.
Notes
-----
Implementation mainly taken from librosa `power_to_db` function
"""
samplerate = spec.attrs["original_samplerate"]
nfft = spec.attrs["nfft"]
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
if callable(ref):
ref_value = ref(spec)
else:
ref_value = np.abs(ref)
log_spec = 10.0 * np.log10(np.maximum(amin, spec)) - np.log10(
np.maximum(amin, ref_value)
)
if top_db is not None:
if top_db < 0:
raise ValueError("top_db must be non-negative")
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
return xr.DataArray(
data=np.log1p(log_scaling * spec).astype(dtype),
data=log_spec.astype(dtype),
dims=spec.dims,
coords=spec.coords,
attrs=spec.attrs,