Added unit tests for spectrogram preprocessing

This commit is contained in:
mbsantiago 2025-04-17 18:31:24 +01:00
parent f314942628
commit fd7f2b0081
2 changed files with 461 additions and 53 deletions

View File

@ -90,11 +90,11 @@ class FrequencyConfig(BaseConfig):
Frequencies above this value will be cropped. Must be > 0.
min_freq : int, default=10000
Minimum frequency in Hz to retain in the spectrogram after STFT.
Frequencies below this value will be cropped. Must be > 0.
Frequencies below this value will be cropped. Must be >= 0.
"""
max_freq: int = Field(default=120_000, gt=0)
min_freq: int = Field(default=10_000, gt=0)
max_freq: int = Field(default=120_000, ge=0)
min_freq: int = Field(default=10_000, ge=0)
class SpecSizeConfig(BaseConfig):
@ -395,11 +395,13 @@ def crop_spectrogram_frequencies(
xr.DataArray
Spectrogram cropped along the frequency axis. Preserves dtype.
"""
start_freq, end_freq = arrays.get_dim_range(spec, dim="frequency")
return arrays.crop_dim(
spec,
dim="frequency",
start=min_freq,
stop=max_freq,
start=min_freq if start_freq < min_freq else None,
stop=max_freq if end_freq > max_freq else None,
).astype(spec.dtype)

View File

@ -2,69 +2,477 @@ import math
from pathlib import Path
from typing import Callable
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
import numpy as np
import pytest
import xarray as xr
from soundevent import arrays
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
from batdetect2.preprocess.spectrogram import (
STFTConfig,
MAX_FREQ,
MIN_FREQ,
ConfigurableSpectrogramBuilder,
FrequencyConfig,
PcenConfig,
SpecSizeConfig,
SpectrogramConfig,
STFTConfig,
apply_pcen,
build_spectrogram_builder,
compute_spectrogram,
duration_to_spec_width,
crop_spectrogram_frequencies,
get_spectrogram_resolution,
spec_width_to_samples,
remove_spectral_mean,
resize_spectrogram,
scale_log,
scale_spectrogram,
stft,
)
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
@given(
duration=st.floats(min_value=0.1, max_value=1.0),
window_duration=st.floats(min_value=0.001, max_value=0.01),
window_overlap=st.floats(min_value=0.2, max_value=0.9),
samplerate=st.integers(min_value=256_000, max_value=512_000),
SAMPLERATE = 250_000
DURATION = 0.1
TEST_FREQ = 30_000
N_SAMPLES = int(SAMPLERATE * DURATION)
TIME_COORD = np.linspace(
0, DURATION, N_SAMPLES, endpoint=False, dtype=np.float32
)
def test_can_estimate_correctly_spectrogram_width_from_duration(
duration: float,
window_duration: float,
window_overlap: float,
samplerate: int,
wav_factory: Callable[..., Path],
@pytest.fixture
def sine_wave_xr() -> xr.DataArray:
"""Generate a single sine wave as an xr.DataArray."""
t = TIME_COORD
wav_data = np.sin(2 * np.pi * TEST_FREQ * t, dtype=np.float32)
return xr.DataArray(
wav_data,
coords={"time": t},
dims=["time"],
attrs={"samplerate": SAMPLERATE},
)
@pytest.fixture
def constant_wave_xr() -> xr.DataArray:
"""Generate a constant signal as an xr.DataArray."""
t = TIME_COORD
wav_data = np.ones(N_SAMPLES, dtype=np.float32) * 0.5
return xr.DataArray(
wav_data,
coords={"time": t},
dims=["time"],
attrs={"samplerate": SAMPLERATE},
)
@pytest.fixture
def sample_spec(sine_wave_xr: xr.DataArray) -> xr.DataArray:
"""Generate a basic spectrogram for testing downstream functions."""
config = SpectrogramConfig(
stft=STFTConfig(window_duration=0.002, window_overlap=0.5),
frequencies=FrequencyConfig(
min_freq=0,
max_freq=int(SAMPLERATE / 2),
),
size=None,
pcen=None,
spectral_mean_substraction=False,
peak_normalize=False,
scale="amplitude",
)
spec = stft(
sine_wave_xr,
window_duration=config.stft.window_duration,
window_overlap=config.stft.window_overlap,
window_fn=config.stft.window_fn,
)
return spec
def test_stft_config_defaults():
config = STFTConfig()
assert config.window_duration == 0.002
assert config.window_overlap == 0.75
assert config.window_fn == "hann"
def test_frequency_config_defaults():
config = FrequencyConfig()
assert config.min_freq == MIN_FREQ
assert config.max_freq == MAX_FREQ
def test_spec_size_config_defaults():
config = SpecSizeConfig()
assert config.height == 128
assert config.resize_factor == 0.5
def test_pcen_config_defaults():
config = PcenConfig()
assert config.time_constant == 0.4
assert config.gain == 0.98
assert config.bias == 2
assert config.power == 0.5
def test_spectrogram_config_defaults():
config = SpectrogramConfig()
assert isinstance(config.stft, STFTConfig)
assert isinstance(config.frequencies, FrequencyConfig)
assert isinstance(config.pcen, PcenConfig)
assert config.scale == "amplitude"
assert isinstance(config.size, SpecSizeConfig)
assert config.spectral_mean_substraction is True
assert config.peak_normalize is False
def test_stft_output_properties(sine_wave_xr: xr.DataArray):
window_duration = 0.002
window_overlap = 0.5
samplerate = sine_wave_xr.attrs["samplerate"]
nfft = int(window_duration * samplerate)
hop_len = nfft - int(window_overlap * nfft)
spec = stft(
sine_wave_xr,
window_duration=window_duration,
window_overlap=window_overlap,
window_fn="hann",
)
assert isinstance(spec, xr.DataArray)
assert spec.dims == ("frequency", "time")
assert spec.dtype == np.float32
assert "frequency" in spec.coords
assert "time" in spec.coords
time_step = arrays.get_dim_step(spec, "time")
freq_step = arrays.get_dim_step(spec, "frequency")
freq_start, freq_end = arrays.get_dim_range(spec, "frequency")
assert np.isclose(freq_step, samplerate / nfft)
assert np.isclose(time_step, hop_len / samplerate)
assert spec.frequency.min() >= 0
assert freq_start == 0
assert np.isclose(freq_end + freq_step, samplerate / 2, atol=5)
assert spec.time.min() >= 0
assert spec.time.max() < DURATION
assert spec.attrs["original_samplerate"] == samplerate
assert spec.attrs["nfft"] == nfft
assert spec.attrs["noverlap"] == int(window_overlap * nfft)
assert np.all(spec.data >= 0)
@pytest.mark.parametrize("window_fn", ["hann", "hamming"])
def test_stft_window_fn(sine_wave_xr: xr.DataArray, window_fn: str):
spec = stft(
sine_wave_xr,
window_duration=0.002,
window_overlap=0.5,
window_fn=window_fn,
)
assert isinstance(spec, xr.DataArray)
assert np.all(spec.data >= 0)
def test_crop_spectrogram_frequencies(sample_spec: xr.DataArray):
min_f, max_f = 20_000, 80_000
cropped_spec = crop_spectrogram_frequencies(
sample_spec, min_freq=min_f, max_freq=max_f
)
assert cropped_spec.dims == sample_spec.dims
assert cropped_spec.dtype == sample_spec.dtype
assert cropped_spec.sizes["time"] == sample_spec.sizes["time"]
assert cropped_spec.sizes["frequency"] < sample_spec.sizes["frequency"]
assert cropped_spec.coords["frequency"].min() >= min_f
assert np.isclose(cropped_spec.coords["frequency"].max(), max_f, rtol=0.1)
def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
samplerate = sample_spec.attrs["original_samplerate"]
min_f, max_f = 0, samplerate / 2
cropped_spec = crop_spectrogram_frequencies(
sample_spec, min_freq=min_f, max_freq=max_f
)
assert cropped_spec.sizes == sample_spec.sizes
assert np.allclose(cropped_spec.data, sample_spec.data)
def test_apply_pcen(sample_spec: xr.DataArray):
if "original_samplerate" not in sample_spec.attrs:
sample_spec.attrs["original_samplerate"] = SAMPLERATE
if "nfft" not in sample_spec.attrs:
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
if "noverlap" not in sample_spec.attrs:
sample_spec.attrs["noverlap"] = int(0.5 * sample_spec.attrs["nfft"])
pcen_config = PcenConfig()
pcen_spec = apply_pcen(
sample_spec,
time_constant=pcen_config.time_constant,
gain=pcen_config.gain,
bias=pcen_config.bias,
power=pcen_config.power,
)
assert pcen_spec.dims == sample_spec.dims
assert pcen_spec.sizes == sample_spec.sizes
assert pcen_spec.dtype == sample_spec.dtype
assert np.all(pcen_spec.data >= 0)
assert not np.allclose(pcen_spec.data, sample_spec.data)
def test_scale_log(sample_spec: xr.DataArray):
if "original_samplerate" not in sample_spec.attrs:
sample_spec.attrs["original_samplerate"] = SAMPLERATE
if "nfft" not in sample_spec.attrs:
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
log_spec = scale_log(sample_spec, dtype=np.float32)
assert log_spec.dims == sample_spec.dims
assert log_spec.sizes == sample_spec.sizes
assert log_spec.dtype == np.float32
assert np.all(log_spec.data >= 0)
assert not np.allclose(log_spec.data, sample_spec.data)
def test_scale_log_missing_attrs(sample_spec: xr.DataArray):
spec_copy = sample_spec.copy()
del spec_copy.attrs["original_samplerate"]
with pytest.raises(KeyError):
scale_log(spec_copy)
spec_copy = sample_spec.copy()
del spec_copy.attrs["nfft"]
with pytest.raises(KeyError):
scale_log(spec_copy)
def test_scale_spectrogram_amplitude(sample_spec: xr.DataArray):
scaled_spec = scale_spectrogram(sample_spec, scale="amplitude")
assert np.allclose(scaled_spec.data, sample_spec.data)
assert scaled_spec.dtype == sample_spec.dtype
def test_scale_spectrogram_power(sample_spec: xr.DataArray):
scaled_spec = scale_spectrogram(sample_spec, scale="power")
assert np.allclose(scaled_spec.data, sample_spec.data**2)
assert scaled_spec.dtype == sample_spec.dtype
def test_scale_spectrogram_db(sample_spec: xr.DataArray):
if "original_samplerate" not in sample_spec.attrs:
sample_spec.attrs["original_samplerate"] = SAMPLERATE
if "nfft" not in sample_spec.attrs:
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
scaled_spec = scale_spectrogram(sample_spec, scale="dB", dtype=np.float64)
log_spec_expected = scale_log(sample_spec, dtype=np.float64)
assert scaled_spec.dtype == np.float64
assert np.allclose(scaled_spec.data, log_spec_expected.data)
def test_remove_spectral_mean(sample_spec: xr.DataArray):
spec_noisy = sample_spec.copy() + 0.1
denoised_spec = remove_spectral_mean(spec_noisy)
assert denoised_spec.dims == spec_noisy.dims
assert denoised_spec.sizes == spec_noisy.sizes
assert denoised_spec.dtype == spec_noisy.dtype
assert np.all(denoised_spec.data >= 0)
def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
const_spec = stft(constant_wave_xr, 0.002, 0.5)
denoised_spec = remove_spectral_mean(const_spec)
assert np.allclose(denoised_spec.data, 0, atol=1e-6)
@pytest.mark.parametrize(
"height, resize_factor, expected_freq_size, expected_time_factor",
[
(128, 1.0, 128, 1.0),
(64, 0.5, 64, 0.5),
(256, None, 256, 1.0),
(100, 2.0, 100, 2.0),
],
)
def test_resize_spectrogram(
sample_spec: xr.DataArray,
height: int,
resize_factor: float | None,
expected_freq_size: int,
expected_time_factor: float,
):
path = wav_factory(duration=duration, samplerate=samplerate)
audio = load_file_audio(
path,
# NOTE: Dont resample nor adjust duration to test if the width
# estimation works on all scenarios
config=AudioConfig(resample=None, duration=None),
)
spectrogram = stft(audio, window_duration, window_overlap)
spec_width = duration_to_spec_width(
duration,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
)
assert spectrogram.sizes["time"] == spec_width
rebuilt_duration = (
spec_width_to_samples(
spec_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
)
/ samplerate
original_time_size = sample_spec.sizes["time"]
resized_spec = resize_spectrogram(
sample_spec,
height=height,
resize_factor=resize_factor,
)
assert (
abs(duration - rebuilt_duration)
< (1 - window_overlap) * window_duration
assert resized_spec.dims == ("frequency", "time")
assert resized_spec.sizes["frequency"] == expected_freq_size
expected_time_size = int(original_time_size * expected_time_factor)
assert abs(resized_spec.sizes["time"] - expected_time_size) <= 1
assert resized_spec.dtype == np.float32
def test_compute_spectrogram_defaults(sine_wave_xr: xr.DataArray):
config = SpectrogramConfig()
spec = compute_spectrogram(sine_wave_xr, config=config)
assert isinstance(spec, xr.DataArray)
assert spec.dims == ("frequency", "time")
assert spec.dtype == np.float32
assert config.size is not None
assert spec.sizes["frequency"] == config.size.height
temp_stft = stft(
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
)
assert config.size.resize_factor is not None
expected_time_size = int(
temp_stft.sizes["time"] * config.size.resize_factor
)
assert abs(spec.sizes["time"] - expected_time_size) <= 1
assert spec.coords["frequency"].min() >= config.frequencies.min_freq
assert np.isclose(
spec.coords["frequency"].max(),
config.frequencies.max_freq,
rtol=0.1,
)
def test_compute_spectrogram_no_pcen_no_mean_sub_no_resize(
sine_wave_xr: xr.DataArray,
):
config = SpectrogramConfig(
pcen=None,
spectral_mean_substraction=False,
size=None,
scale="power",
frequencies=FrequencyConfig(min_freq=0, max_freq=int(SAMPLERATE / 2)),
)
spec = compute_spectrogram(sine_wave_xr, config=config)
stft_direct = stft(
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
)
expected_spec = scale_spectrogram(stft_direct, scale="power")
assert spec.sizes == expected_spec.sizes
assert np.allclose(spec.data, expected_spec.data)
assert spec.dtype == expected_spec.dtype
def test_compute_spectrogram_peak_normalize(sine_wave_xr: xr.DataArray):
config = SpectrogramConfig(peak_normalize=True)
spec = compute_spectrogram(sine_wave_xr, config=config)
assert np.isclose(spec.data.max(), 1.0, atol=1e-6)
config = SpectrogramConfig(peak_normalize=False)
spec_no_norm = compute_spectrogram(sine_wave_xr, config=config)
assert not np.isclose(spec_no_norm.data.max(), 1.0, atol=1e-6)
def test_get_spectrogram_resolution_calculation():
config = SpectrogramConfig(
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
size=SpecSizeConfig(height=100, resize_factor=0.5),
frequencies=FrequencyConfig(min_freq=10_000, max_freq=110_000),
)
freq_res, time_res = get_spectrogram_resolution(config)
expected_freq_res = (110_000 - 10_000) / 100
expected_hop_duration = 0.002 * (1 - 0.75)
expected_time_res = expected_hop_duration / 0.5
assert np.isclose(freq_res, expected_freq_res)
assert np.isclose(time_res, expected_time_res)
def test_get_spectrogram_resolution_no_resize_factor():
config = SpectrogramConfig(
stft=STFTConfig(window_duration=0.004, window_overlap=0.5),
size=SpecSizeConfig(height=200, resize_factor=None),
frequencies=FrequencyConfig(min_freq=20_000, max_freq=120_000),
)
freq_res, time_res = get_spectrogram_resolution(config)
expected_freq_res = (120_000 - 20_000) / 200
expected_hop_duration = 0.004 * (1 - 0.5)
expected_time_res = expected_hop_duration / 1.0
assert np.isclose(freq_res, expected_freq_res)
assert np.isclose(time_res, expected_time_res)
def test_get_spectrogram_resolution_no_size_config():
config = SpectrogramConfig(size=None)
with pytest.raises(
ValueError, match="Spectrogram size configuration is required"
):
get_spectrogram_resolution(config)
def test_configurable_spectrogram_builder_init():
config = SpectrogramConfig()
builder = ConfigurableSpectrogramBuilder(config=config, dtype=np.float16)
assert builder.config is config
assert builder.dtype == np.float16
def test_configurable_spectrogram_builder_call_xr(sine_wave_xr: xr.DataArray):
config = SpectrogramConfig()
builder = ConfigurableSpectrogramBuilder(config=config)
spec_builder = builder(sine_wave_xr)
spec_direct = compute_spectrogram(sine_wave_xr, config=config)
assert isinstance(spec_builder, xr.DataArray)
assert np.allclose(spec_builder.data, spec_direct.data)
assert spec_builder.dtype == spec_direct.dtype
def test_configurable_spectrogram_builder_call_np(sine_wave_xr: xr.DataArray):
config = SpectrogramConfig()
builder = ConfigurableSpectrogramBuilder(config=config)
wav_np = sine_wave_xr.data
samplerate = sine_wave_xr.attrs["samplerate"]
spec_builder = builder(wav_np.astype(np.float32), samplerate=samplerate)
spec_direct = compute_spectrogram(sine_wave_xr, config=config)
assert isinstance(spec_builder, xr.DataArray)
assert np.allclose(spec_builder.data, spec_direct.data, atol=1e-4)
assert spec_builder.dtype == spec_direct.dtype
def test_configurable_spectrogram_builder_call_np_no_samplerate(
sine_wave_xr: xr.DataArray,
):
config = SpectrogramConfig()
builder = ConfigurableSpectrogramBuilder(config=config)
wav_np = sine_wave_xr.data
with pytest.raises(ValueError, match="Samplerate must be provided"):
builder(wav_np, samplerate=None)
def test_build_spectrogram_builder():
config = SpectrogramConfig(peak_normalize=True)
builder = build_spectrogram_builder(config=config, dtype=np.float64)
assert isinstance(builder, ConfigurableSpectrogramBuilder)
assert builder.config is config
assert builder.dtype == np.float64
def test_can_estimate_spectrogram_resolution(
@ -72,10 +480,8 @@ def test_can_estimate_spectrogram_resolution(
):
path = wav_factory(duration=0.2, samplerate=256_000)
audio = load_file_audio(
audio_data = load_file_audio(
path,
# NOTE: Dont resample nor adjust duration to test if the width
# estimation works on all scenarios
config=AudioConfig(resample=None, duration=None),
)
@ -85,7 +491,7 @@ def test_can_estimate_spectrogram_resolution(
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
)
spec = compute_spectrogram(audio, config=config)
spec = compute_spectrogram(audio_data, config=config)
freq_res, time_res = get_spectrogram_resolution(config)