mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Use soundevent.to_db
This commit is contained in:
parent
835fe1ccdf
commit
ebc89af4c6
@ -19,7 +19,7 @@ reproducible spectrogram generation consistent between training and inference.
|
||||
The core computation is performed by `compute_spectrogram`.
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
@ -327,6 +327,7 @@ def compute_spectrogram(
|
||||
"""
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
spec = stft(
|
||||
wav,
|
||||
window_duration=config.stft.window_duration,
|
||||
@ -477,7 +478,7 @@ def scale_spectrogram(
|
||||
"""Apply final amplitude scaling/representation to the spectrogram.
|
||||
|
||||
Converts the input magnitude spectrogram based on the `scale` type:
|
||||
- "dB": Applies logarithmic scaling `log1p(C * S)`.
|
||||
- "dB": Applies logarithmic scaling `log10(S)`.
|
||||
- "power": Squares the magnitude values `S^2`.
|
||||
- "amplitude": Returns the input magnitude values `S` unchanged.
|
||||
|
||||
@ -496,7 +497,7 @@ def scale_spectrogram(
|
||||
Spectrogram with the specified amplitude scaling applied.
|
||||
"""
|
||||
if scale == "dB":
|
||||
return scale_log(spec, dtype=dtype)
|
||||
return arrays.to_db(spec).astype(dtype)
|
||||
|
||||
if scale == "power":
|
||||
return spec**2
|
||||
@ -561,12 +562,13 @@ def apply_pcen(
|
||||
def scale_log(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
ref: Union[float, Callable] = np.max,
|
||||
amin: float = 1e-10,
|
||||
top_db: Optional[float] = 80.0,
|
||||
) -> xr.DataArray:
|
||||
"""Apply logarithmic scaling to a magnitude spectrogram.
|
||||
|
||||
Calculates `log(1 + C * S)`, where S is the input magnitude spectrogram
|
||||
and C is a scaling factor derived from the original STFT parameters
|
||||
(sample rate, N-FFT, window function) stored in `spec.attrs`.
|
||||
Calculates `log10(S)`, where S is the input magnitude spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -587,12 +589,29 @@ def scale_log(
|
||||
If required attributes are missing from `spec.attrs`.
|
||||
ValueError
|
||||
If attributes are non-numeric or window function is invalid.
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
Implementation mainly taken from librosa `power_to_db` function
|
||||
"""
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
nfft = spec.attrs["nfft"]
|
||||
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||
|
||||
if callable(ref):
|
||||
ref_value = ref(spec)
|
||||
else:
|
||||
ref_value = np.abs(ref)
|
||||
|
||||
log_spec = 10.0 * np.log10(np.maximum(amin, spec)) - np.log10(
|
||||
np.maximum(amin, ref_value)
|
||||
)
|
||||
|
||||
if top_db is not None:
|
||||
if top_db < 0:
|
||||
raise ValueError("top_db must be non-negative")
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
|
||||
|
||||
return xr.DataArray(
|
||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||
data=log_spec.astype(dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
|
Loading…
Reference in New Issue
Block a user