mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Use soundevent.to_db
This commit is contained in:
parent
835fe1ccdf
commit
ebc89af4c6
@ -19,7 +19,7 @@ reproducible spectrogram generation consistent between training and inference.
|
|||||||
The core computation is performed by `compute_spectrogram`.
|
The core computation is performed by `compute_spectrogram`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Literal, Optional, Union
|
from typing import Callable, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
@ -327,44 +327,45 @@ def compute_spectrogram(
|
|||||||
"""
|
"""
|
||||||
config = config or SpectrogramConfig()
|
config = config or SpectrogramConfig()
|
||||||
|
|
||||||
spec = stft(
|
with xr.set_options(keep_attrs=True):
|
||||||
wav,
|
spec = stft(
|
||||||
window_duration=config.stft.window_duration,
|
wav,
|
||||||
window_overlap=config.stft.window_overlap,
|
window_duration=config.stft.window_duration,
|
||||||
window_fn=config.stft.window_fn,
|
window_overlap=config.stft.window_overlap,
|
||||||
)
|
window_fn=config.stft.window_fn,
|
||||||
|
|
||||||
spec = crop_spectrogram_frequencies(
|
|
||||||
spec,
|
|
||||||
min_freq=config.frequencies.min_freq,
|
|
||||||
max_freq=config.frequencies.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.pcen:
|
|
||||||
spec = apply_pcen(
|
|
||||||
spec,
|
|
||||||
time_constant=config.pcen.time_constant,
|
|
||||||
gain=config.pcen.gain,
|
|
||||||
power=config.pcen.power,
|
|
||||||
bias=config.pcen.bias,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
spec = scale_spectrogram(spec, scale=config.scale)
|
spec = crop_spectrogram_frequencies(
|
||||||
|
|
||||||
if config.spectral_mean_substraction:
|
|
||||||
spec = remove_spectral_mean(spec)
|
|
||||||
|
|
||||||
if config.size:
|
|
||||||
spec = resize_spectrogram(
|
|
||||||
spec,
|
spec,
|
||||||
height=config.size.height,
|
min_freq=config.frequencies.min_freq,
|
||||||
resize_factor=config.size.resize_factor,
|
max_freq=config.frequencies.max_freq,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.peak_normalize:
|
if config.pcen:
|
||||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
spec = apply_pcen(
|
||||||
|
spec,
|
||||||
|
time_constant=config.pcen.time_constant,
|
||||||
|
gain=config.pcen.gain,
|
||||||
|
power=config.pcen.power,
|
||||||
|
bias=config.pcen.bias,
|
||||||
|
)
|
||||||
|
|
||||||
return spec.astype(dtype)
|
spec = scale_spectrogram(spec, scale=config.scale)
|
||||||
|
|
||||||
|
if config.spectral_mean_substraction:
|
||||||
|
spec = remove_spectral_mean(spec)
|
||||||
|
|
||||||
|
if config.size:
|
||||||
|
spec = resize_spectrogram(
|
||||||
|
spec,
|
||||||
|
height=config.size.height,
|
||||||
|
resize_factor=config.size.resize_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.peak_normalize:
|
||||||
|
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||||
|
|
||||||
|
return spec.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
def crop_spectrogram_frequencies(
|
def crop_spectrogram_frequencies(
|
||||||
@ -477,7 +478,7 @@ def scale_spectrogram(
|
|||||||
"""Apply final amplitude scaling/representation to the spectrogram.
|
"""Apply final amplitude scaling/representation to the spectrogram.
|
||||||
|
|
||||||
Converts the input magnitude spectrogram based on the `scale` type:
|
Converts the input magnitude spectrogram based on the `scale` type:
|
||||||
- "dB": Applies logarithmic scaling `log1p(C * S)`.
|
- "dB": Applies logarithmic scaling `log10(S)`.
|
||||||
- "power": Squares the magnitude values `S^2`.
|
- "power": Squares the magnitude values `S^2`.
|
||||||
- "amplitude": Returns the input magnitude values `S` unchanged.
|
- "amplitude": Returns the input magnitude values `S` unchanged.
|
||||||
|
|
||||||
@ -496,7 +497,7 @@ def scale_spectrogram(
|
|||||||
Spectrogram with the specified amplitude scaling applied.
|
Spectrogram with the specified amplitude scaling applied.
|
||||||
"""
|
"""
|
||||||
if scale == "dB":
|
if scale == "dB":
|
||||||
return scale_log(spec, dtype=dtype)
|
return arrays.to_db(spec).astype(dtype)
|
||||||
|
|
||||||
if scale == "power":
|
if scale == "power":
|
||||||
return spec**2
|
return spec**2
|
||||||
@ -561,12 +562,13 @@ def apply_pcen(
|
|||||||
def scale_log(
|
def scale_log(
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
|
ref: Union[float, Callable] = np.max,
|
||||||
|
amin: float = 1e-10,
|
||||||
|
top_db: Optional[float] = 80.0,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
"""Apply logarithmic scaling to a magnitude spectrogram.
|
"""Apply logarithmic scaling to a magnitude spectrogram.
|
||||||
|
|
||||||
Calculates `log(1 + C * S)`, where S is the input magnitude spectrogram
|
Calculates `log10(S)`, where S is the input magnitude spectrogram.
|
||||||
and C is a scaling factor derived from the original STFT parameters
|
|
||||||
(sample rate, N-FFT, window function) stored in `spec.attrs`.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -587,12 +589,29 @@ def scale_log(
|
|||||||
If required attributes are missing from `spec.attrs`.
|
If required attributes are missing from `spec.attrs`.
|
||||||
ValueError
|
ValueError
|
||||||
If attributes are non-numeric or window function is invalid.
|
If attributes are non-numeric or window function is invalid.
|
||||||
|
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
Implementation mainly taken from librosa `power_to_db` function
|
||||||
"""
|
"""
|
||||||
samplerate = spec.attrs["original_samplerate"]
|
|
||||||
nfft = spec.attrs["nfft"]
|
if callable(ref):
|
||||||
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
ref_value = ref(spec)
|
||||||
|
else:
|
||||||
|
ref_value = np.abs(ref)
|
||||||
|
|
||||||
|
log_spec = 10.0 * np.log10(np.maximum(amin, spec)) - np.log10(
|
||||||
|
np.maximum(amin, ref_value)
|
||||||
|
)
|
||||||
|
|
||||||
|
if top_db is not None:
|
||||||
|
if top_db < 0:
|
||||||
|
raise ValueError("top_db must be non-negative")
|
||||||
|
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
|
||||||
|
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
data=log_spec.astype(dtype),
|
||||||
dims=spec.dims,
|
dims=spec.dims,
|
||||||
coords=spec.coords,
|
coords=spec.coords,
|
||||||
attrs=spec.attrs,
|
attrs=spec.attrs,
|
||||||
|
Loading…
Reference in New Issue
Block a user