mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added min and max freq attributes to preprocessor protocol
This commit is contained in:
parent
6498b6ca37
commit
ac4bb8f023
@ -144,12 +144,16 @@ class StandardPreprocessor(PreprocessorProtocol):
|
||||
audio_loader: AudioLoader
|
||||
spectrogram_builder: SpectrogramBuilder
|
||||
default_samplerate: int
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_loader: AudioLoader,
|
||||
spectrogram_builder: SpectrogramBuilder,
|
||||
default_samplerate: int,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> None:
|
||||
"""Initialize the StandardPreprocessor.
|
||||
|
||||
@ -167,6 +171,8 @@ class StandardPreprocessor(PreprocessorProtocol):
|
||||
self.audio_loader = audio_loader
|
||||
self.spectrogram_builder = spectrogram_builder
|
||||
self.default_samplerate = default_samplerate
|
||||
self.max_freq = max_freq
|
||||
self.min_freq = min_freq
|
||||
|
||||
def load_file_audio(
|
||||
self,
|
||||
@ -429,8 +435,14 @@ def build_preprocessor(
|
||||
if config.audio.resample
|
||||
else TARGET_SAMPLERATE_HZ
|
||||
)
|
||||
|
||||
min_freq = config.spectrogram.frequencies.min_freq
|
||||
max_freq = config.spectrogram.frequencies.max_freq
|
||||
|
||||
return StandardPreprocessor(
|
||||
audio_loader=build_audio_loader(config.audio),
|
||||
spectrogram_builder=build_spectrogram_builder(config.spectrogram),
|
||||
default_samplerate=default_samplerate,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
@ -286,7 +286,8 @@ def load_recording_audio(
|
||||
"""Load and preprocess the entire audio content of a recording using config.
|
||||
|
||||
Creates a `soundevent.data.Clip` spanning the full duration of the
|
||||
recording and then delegates the loading and processing to `load_clip_audio`.
|
||||
recording and then delegates the loading and processing to
|
||||
`load_clip_audio`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -636,7 +637,11 @@ def resample_audio_fourier(
|
||||
If `num` is negative.
|
||||
"""
|
||||
ratio = sr_new / sr_orig
|
||||
return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore
|
||||
return resample( # type: ignore
|
||||
array,
|
||||
int(array.shape[axis] * ratio),
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_xr(
|
||||
@ -649,8 +654,8 @@ def convert_to_xr(
|
||||
Parameters
|
||||
----------
|
||||
wav : np.ndarray
|
||||
The input waveform array. Expected to be 1D or 2D (with the first axis as
|
||||
the channel dimension).
|
||||
The input waveform array. Expected to be 1D or 2D (with the first
|
||||
axis as the channel dimension).
|
||||
samplerate : int
|
||||
The sample rate in Hz.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
@ -673,7 +678,8 @@ def convert_to_xr(
|
||||
|
||||
if wav.ndim != 1:
|
||||
raise ValueError(
|
||||
"Audio must be 1D array or 2D channel where the first axis is the channel dimension"
|
||||
"Audio must be 1D array or 2D channel where the first "
|
||||
"axis is the channel dimension"
|
||||
)
|
||||
|
||||
if wav.size == 0:
|
||||
|
@ -21,8 +21,6 @@ The core computation is performed by `compute_spectrogram`.
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
@ -147,7 +145,8 @@ class SpectrogramConfig(BaseConfig):
|
||||
"""Unified configuration for spectrogram generation pipeline.
|
||||
|
||||
Aggregates settings for all steps involved in converting a preprocessed
|
||||
audio waveform into a final spectrogram representation suitable for model input.
|
||||
audio waveform into a final spectrogram representation suitable for model
|
||||
input.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@ -298,7 +297,8 @@ def compute_spectrogram(
|
||||
1. Compute STFT magnitude (`stft`).
|
||||
2. Crop frequency axis (`crop_spectrogram_frequencies`).
|
||||
3. Apply PCEN if configured (`apply_pcen`).
|
||||
4. Apply final amplitude scaling (dB, power, amplitude) (`scale_spectrogram`).
|
||||
4. Apply final amplitude scaling (dB, power, amplitude)
|
||||
(`scale_spectrogram`).
|
||||
5. Apply spectral mean subtraction denoising if enabled.
|
||||
6. Resize dimensions if specified (`resize_spectrogram`).
|
||||
7. Apply final peak normalization if enabled.
|
||||
@ -324,9 +324,6 @@ def compute_spectrogram(
|
||||
------
|
||||
ValueError
|
||||
If `wav` lacks necessary 'time' coordinates or dimensions.
|
||||
Exception
|
||||
Can re-raise exceptions from underlying libraries (e.g., librosa, numpy)
|
||||
if invalid parameters or data are encountered.
|
||||
"""
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
@ -335,7 +332,6 @@ def compute_spectrogram(
|
||||
window_duration=config.stft.window_duration,
|
||||
window_overlap=config.stft.window_overlap,
|
||||
window_fn=config.stft.window_fn,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
spec = crop_spectrogram_frequencies(
|
||||
@ -410,7 +406,6 @@ def stft(
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
window_fn: str = "hann",
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Compute the Short-Time Fourier Transform (STFT) magnitude spectrogram.
|
||||
|
||||
@ -425,11 +420,9 @@ def stft(
|
||||
window_duration : float
|
||||
Duration of the STFT window in seconds.
|
||||
window_overlap : float
|
||||
Fractional overlap between consecutive windows [0, 1).
|
||||
Fractional overlap between consecutive windows.
|
||||
window_fn : str, default="hann"
|
||||
Name of the window function (e.g., "hann", "hamming").
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the spectrogram array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -442,55 +435,13 @@ def stft(
|
||||
ValueError
|
||||
If sample rate cannot be determined from `wave` coordinates.
|
||||
"""
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
sampling_rate = 1 / step
|
||||
|
||||
nfft = int(window_duration * sampling_rate)
|
||||
noverlap = int(window_overlap * nfft)
|
||||
hop_len = nfft - noverlap
|
||||
hop_duration = hop_len / sampling_rate
|
||||
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=wave.data.astype(dtype),
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
center=False,
|
||||
window=window_fn,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=spec.astype(dtype),
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": arrays.create_frequency_dim_from_array(
|
||||
np.linspace(
|
||||
0,
|
||||
sampling_rate / 2,
|
||||
spec.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=sampling_rate / nfft,
|
||||
),
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
np.linspace(
|
||||
start_time,
|
||||
end_time - (window_duration - hop_duration),
|
||||
spec.shape[1],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=hop_duration,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
**wave.attrs,
|
||||
"original_samplerate": sampling_rate,
|
||||
"nfft": nfft,
|
||||
"noverlap": noverlap,
|
||||
},
|
||||
return audio.compute_spectrogram(
|
||||
wave,
|
||||
window_size=window_duration,
|
||||
hop_size=(1 - window_overlap) * window_duration,
|
||||
window_type=window_fn,
|
||||
scale="amplitude",
|
||||
sort_dims=False,
|
||||
)
|
||||
|
||||
|
||||
@ -592,8 +543,10 @@ def apply_pcen(
|
||||
verified against the specific `soundevent.audio.pcen` implementation
|
||||
details.
|
||||
"""
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
hop_length = spec.attrs["nfft"] - spec.attrs["noverlap"]
|
||||
samplerate = spec.attrs["samplerate"]
|
||||
hop_size = spec.attrs["hop_size"]
|
||||
|
||||
hop_length = int(hop_size * samplerate)
|
||||
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
return audio.pcen(
|
||||
|
@ -168,6 +168,10 @@ class PreprocessorProtocol(Protocol):
|
||||
loading or spectrogram computation from a waveform.
|
||||
"""
|
||||
|
||||
max_freq: float
|
||||
|
||||
min_freq: float
|
||||
|
||||
def preprocess_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
|
Loading…
Reference in New Issue
Block a user