diff --git a/batdetect2/preprocess/spectrogram.py b/batdetect2/preprocess/spectrogram.py index 89b85ac..5d4fe17 100644 --- a/batdetect2/preprocess/spectrogram.py +++ b/batdetect2/preprocess/spectrogram.py @@ -19,7 +19,7 @@ reproducible spectrogram generation consistent between training and inference. The core computation is performed by `compute_spectrogram`. """ -from typing import Literal, Optional, Union +from typing import Callable, Literal, Optional, Union import numpy as np import xarray as xr @@ -327,44 +327,45 @@ def compute_spectrogram( """ config = config or SpectrogramConfig() - spec = stft( - wav, - window_duration=config.stft.window_duration, - window_overlap=config.stft.window_overlap, - window_fn=config.stft.window_fn, - ) - - spec = crop_spectrogram_frequencies( - spec, - min_freq=config.frequencies.min_freq, - max_freq=config.frequencies.max_freq, - ) - - if config.pcen: - spec = apply_pcen( - spec, - time_constant=config.pcen.time_constant, - gain=config.pcen.gain, - power=config.pcen.power, - bias=config.pcen.bias, + with xr.set_options(keep_attrs=True): + spec = stft( + wav, + window_duration=config.stft.window_duration, + window_overlap=config.stft.window_overlap, + window_fn=config.stft.window_fn, ) - spec = scale_spectrogram(spec, scale=config.scale) - - if config.spectral_mean_substraction: - spec = remove_spectral_mean(spec) - - if config.size: - spec = resize_spectrogram( + spec = crop_spectrogram_frequencies( spec, - height=config.size.height, - resize_factor=config.size.resize_factor, + min_freq=config.frequencies.min_freq, + max_freq=config.frequencies.max_freq, ) - if config.peak_normalize: - spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) + if config.pcen: + spec = apply_pcen( + spec, + time_constant=config.pcen.time_constant, + gain=config.pcen.gain, + power=config.pcen.power, + bias=config.pcen.bias, + ) - return spec.astype(dtype) + spec = scale_spectrogram(spec, scale=config.scale) + + if config.spectral_mean_substraction: + spec = remove_spectral_mean(spec) + + if config.size: + spec = resize_spectrogram( + spec, + height=config.size.height, + resize_factor=config.size.resize_factor, + ) + + if config.peak_normalize: + spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) + + return spec.astype(dtype) def crop_spectrogram_frequencies( @@ -477,7 +478,7 @@ def scale_spectrogram( """Apply final amplitude scaling/representation to the spectrogram. Converts the input magnitude spectrogram based on the `scale` type: - - "dB": Applies logarithmic scaling `log1p(C * S)`. + - "dB": Applies logarithmic scaling `log10(S)`. - "power": Squares the magnitude values `S^2`. - "amplitude": Returns the input magnitude values `S` unchanged. @@ -496,7 +497,7 @@ def scale_spectrogram( Spectrogram with the specified amplitude scaling applied. """ if scale == "dB": - return scale_log(spec, dtype=dtype) + return arrays.to_db(spec).astype(dtype) if scale == "power": return spec**2 @@ -561,12 +562,13 @@ def apply_pcen( def scale_log( spec: xr.DataArray, dtype: DTypeLike = np.float32, # type: ignore + ref: Union[float, Callable] = np.max, + amin: float = 1e-10, + top_db: Optional[float] = 80.0, ) -> xr.DataArray: """Apply logarithmic scaling to a magnitude spectrogram. - Calculates `log(1 + C * S)`, where S is the input magnitude spectrogram - and C is a scaling factor derived from the original STFT parameters - (sample rate, N-FFT, window function) stored in `spec.attrs`. + Calculates `log10(S)`, where S is the input magnitude spectrogram. Parameters ---------- @@ -587,12 +589,29 @@ def scale_log( If required attributes are missing from `spec.attrs`. ValueError If attributes are non-numeric or window function is invalid. + + + Notes + ----- + Implementation mainly taken from librosa `power_to_db` function """ - samplerate = spec.attrs["original_samplerate"] - nfft = spec.attrs["nfft"] - log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum()) + + if callable(ref): + ref_value = ref(spec) + else: + ref_value = np.abs(ref) + + log_spec = 10.0 * np.log10(np.maximum(amin, spec)) - np.log10( + np.maximum(amin, ref_value) + ) + + if top_db is not None: + if top_db < 0: + raise ValueError("top_db must be non-negative") + log_spec = np.maximum(log_spec, log_spec.max() - top_db) + return xr.DataArray( - data=np.log1p(log_scaling * spec).astype(dtype), + data=log_spec.astype(dtype), dims=spec.dims, coords=spec.coords, attrs=spec.attrs,