mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added unit tests for spectrogram preprocessing
This commit is contained in:
parent
f314942628
commit
fd7f2b0081
@ -90,11 +90,11 @@ class FrequencyConfig(BaseConfig):
|
|||||||
Frequencies above this value will be cropped. Must be > 0.
|
Frequencies above this value will be cropped. Must be > 0.
|
||||||
min_freq : int, default=10000
|
min_freq : int, default=10000
|
||||||
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
||||||
Frequencies below this value will be cropped. Must be > 0.
|
Frequencies below this value will be cropped. Must be >= 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_freq: int = Field(default=120_000, gt=0)
|
max_freq: int = Field(default=120_000, ge=0)
|
||||||
min_freq: int = Field(default=10_000, gt=0)
|
min_freq: int = Field(default=10_000, ge=0)
|
||||||
|
|
||||||
|
|
||||||
class SpecSizeConfig(BaseConfig):
|
class SpecSizeConfig(BaseConfig):
|
||||||
@ -395,11 +395,13 @@ def crop_spectrogram_frequencies(
|
|||||||
xr.DataArray
|
xr.DataArray
|
||||||
Spectrogram cropped along the frequency axis. Preserves dtype.
|
Spectrogram cropped along the frequency axis. Preserves dtype.
|
||||||
"""
|
"""
|
||||||
|
start_freq, end_freq = arrays.get_dim_range(spec, dim="frequency")
|
||||||
|
|
||||||
return arrays.crop_dim(
|
return arrays.crop_dim(
|
||||||
spec,
|
spec,
|
||||||
dim="frequency",
|
dim="frequency",
|
||||||
start=min_freq,
|
start=min_freq if start_freq < min_freq else None,
|
||||||
stop=max_freq,
|
stop=max_freq if end_freq > max_freq else None,
|
||||||
).astype(spec.dtype)
|
).astype(spec.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,69 +2,477 @@ import math
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from hypothesis import HealthCheck, given, settings
|
import numpy as np
|
||||||
from hypothesis import strategies as st
|
import pytest
|
||||||
|
import xarray as xr
|
||||||
from soundevent import arrays
|
from soundevent import arrays
|
||||||
|
|
||||||
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
|
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
|
||||||
from batdetect2.preprocess.spectrogram import (
|
from batdetect2.preprocess.spectrogram import (
|
||||||
STFTConfig,
|
MAX_FREQ,
|
||||||
|
MIN_FREQ,
|
||||||
|
ConfigurableSpectrogramBuilder,
|
||||||
FrequencyConfig,
|
FrequencyConfig,
|
||||||
|
PcenConfig,
|
||||||
SpecSizeConfig,
|
SpecSizeConfig,
|
||||||
SpectrogramConfig,
|
SpectrogramConfig,
|
||||||
|
STFTConfig,
|
||||||
|
apply_pcen,
|
||||||
|
build_spectrogram_builder,
|
||||||
compute_spectrogram,
|
compute_spectrogram,
|
||||||
duration_to_spec_width,
|
crop_spectrogram_frequencies,
|
||||||
get_spectrogram_resolution,
|
get_spectrogram_resolution,
|
||||||
spec_width_to_samples,
|
remove_spectral_mean,
|
||||||
|
resize_spectrogram,
|
||||||
|
scale_log,
|
||||||
|
scale_spectrogram,
|
||||||
stft,
|
stft,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
SAMPLERATE = 250_000
|
||||||
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
|
DURATION = 0.1
|
||||||
@given(
|
TEST_FREQ = 30_000
|
||||||
duration=st.floats(min_value=0.1, max_value=1.0),
|
N_SAMPLES = int(SAMPLERATE * DURATION)
|
||||||
window_duration=st.floats(min_value=0.001, max_value=0.01),
|
TIME_COORD = np.linspace(
|
||||||
window_overlap=st.floats(min_value=0.2, max_value=0.9),
|
0, DURATION, N_SAMPLES, endpoint=False, dtype=np.float32
|
||||||
samplerate=st.integers(min_value=256_000, max_value=512_000),
|
|
||||||
)
|
)
|
||||||
def test_can_estimate_correctly_spectrogram_width_from_duration(
|
|
||||||
duration: float,
|
|
||||||
window_duration: float,
|
|
||||||
window_overlap: float,
|
|
||||||
samplerate: int,
|
|
||||||
wav_factory: Callable[..., Path],
|
|
||||||
):
|
|
||||||
path = wav_factory(duration=duration, samplerate=samplerate)
|
|
||||||
audio = load_file_audio(
|
|
||||||
path,
|
|
||||||
# NOTE: Dont resample nor adjust duration to test if the width
|
|
||||||
# estimation works on all scenarios
|
|
||||||
config=AudioConfig(resample=None, duration=None),
|
|
||||||
)
|
|
||||||
spectrogram = stft(audio, window_duration, window_overlap)
|
|
||||||
|
|
||||||
spec_width = duration_to_spec_width(
|
|
||||||
duration,
|
@pytest.fixture
|
||||||
samplerate=samplerate,
|
def sine_wave_xr() -> xr.DataArray:
|
||||||
|
"""Generate a single sine wave as an xr.DataArray."""
|
||||||
|
t = TIME_COORD
|
||||||
|
wav_data = np.sin(2 * np.pi * TEST_FREQ * t, dtype=np.float32)
|
||||||
|
return xr.DataArray(
|
||||||
|
wav_data,
|
||||||
|
coords={"time": t},
|
||||||
|
dims=["time"],
|
||||||
|
attrs={"samplerate": SAMPLERATE},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def constant_wave_xr() -> xr.DataArray:
|
||||||
|
"""Generate a constant signal as an xr.DataArray."""
|
||||||
|
t = TIME_COORD
|
||||||
|
wav_data = np.ones(N_SAMPLES, dtype=np.float32) * 0.5
|
||||||
|
return xr.DataArray(
|
||||||
|
wav_data,
|
||||||
|
coords={"time": t},
|
||||||
|
dims=["time"],
|
||||||
|
attrs={"samplerate": SAMPLERATE},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_spec(sine_wave_xr: xr.DataArray) -> xr.DataArray:
|
||||||
|
"""Generate a basic spectrogram for testing downstream functions."""
|
||||||
|
config = SpectrogramConfig(
|
||||||
|
stft=STFTConfig(window_duration=0.002, window_overlap=0.5),
|
||||||
|
frequencies=FrequencyConfig(
|
||||||
|
min_freq=0,
|
||||||
|
max_freq=int(SAMPLERATE / 2),
|
||||||
|
),
|
||||||
|
size=None,
|
||||||
|
pcen=None,
|
||||||
|
spectral_mean_substraction=False,
|
||||||
|
peak_normalize=False,
|
||||||
|
scale="amplitude",
|
||||||
|
)
|
||||||
|
spec = stft(
|
||||||
|
sine_wave_xr,
|
||||||
|
window_duration=config.stft.window_duration,
|
||||||
|
window_overlap=config.stft.window_overlap,
|
||||||
|
window_fn=config.stft.window_fn,
|
||||||
|
)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def test_stft_config_defaults():
|
||||||
|
config = STFTConfig()
|
||||||
|
assert config.window_duration == 0.002
|
||||||
|
assert config.window_overlap == 0.75
|
||||||
|
assert config.window_fn == "hann"
|
||||||
|
|
||||||
|
|
||||||
|
def test_frequency_config_defaults():
|
||||||
|
config = FrequencyConfig()
|
||||||
|
assert config.min_freq == MIN_FREQ
|
||||||
|
assert config.max_freq == MAX_FREQ
|
||||||
|
|
||||||
|
|
||||||
|
def test_spec_size_config_defaults():
|
||||||
|
config = SpecSizeConfig()
|
||||||
|
assert config.height == 128
|
||||||
|
assert config.resize_factor == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_pcen_config_defaults():
|
||||||
|
config = PcenConfig()
|
||||||
|
assert config.time_constant == 0.4
|
||||||
|
assert config.gain == 0.98
|
||||||
|
assert config.bias == 2
|
||||||
|
assert config.power == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_spectrogram_config_defaults():
|
||||||
|
config = SpectrogramConfig()
|
||||||
|
assert isinstance(config.stft, STFTConfig)
|
||||||
|
assert isinstance(config.frequencies, FrequencyConfig)
|
||||||
|
assert isinstance(config.pcen, PcenConfig)
|
||||||
|
assert config.scale == "amplitude"
|
||||||
|
assert isinstance(config.size, SpecSizeConfig)
|
||||||
|
assert config.spectral_mean_substraction is True
|
||||||
|
assert config.peak_normalize is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_stft_output_properties(sine_wave_xr: xr.DataArray):
|
||||||
|
window_duration = 0.002
|
||||||
|
window_overlap = 0.5
|
||||||
|
samplerate = sine_wave_xr.attrs["samplerate"]
|
||||||
|
nfft = int(window_duration * samplerate)
|
||||||
|
hop_len = nfft - int(window_overlap * nfft)
|
||||||
|
|
||||||
|
spec = stft(
|
||||||
|
sine_wave_xr,
|
||||||
window_duration=window_duration,
|
window_duration=window_duration,
|
||||||
window_overlap=window_overlap,
|
window_overlap=window_overlap,
|
||||||
)
|
window_fn="hann",
|
||||||
assert spectrogram.sizes["time"] == spec_width
|
|
||||||
|
|
||||||
rebuilt_duration = (
|
|
||||||
spec_width_to_samples(
|
|
||||||
spec_width,
|
|
||||||
samplerate=samplerate,
|
|
||||||
window_duration=window_duration,
|
|
||||||
window_overlap=window_overlap,
|
|
||||||
)
|
|
||||||
/ samplerate
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert isinstance(spec, xr.DataArray)
|
||||||
abs(duration - rebuilt_duration)
|
assert spec.dims == ("frequency", "time")
|
||||||
< (1 - window_overlap) * window_duration
|
assert spec.dtype == np.float32
|
||||||
|
assert "frequency" in spec.coords
|
||||||
|
assert "time" in spec.coords
|
||||||
|
|
||||||
|
time_step = arrays.get_dim_step(spec, "time")
|
||||||
|
freq_step = arrays.get_dim_step(spec, "frequency")
|
||||||
|
freq_start, freq_end = arrays.get_dim_range(spec, "frequency")
|
||||||
|
assert np.isclose(freq_step, samplerate / nfft)
|
||||||
|
assert np.isclose(time_step, hop_len / samplerate)
|
||||||
|
assert spec.frequency.min() >= 0
|
||||||
|
assert freq_start == 0
|
||||||
|
assert np.isclose(freq_end + freq_step, samplerate / 2, atol=5)
|
||||||
|
assert spec.time.min() >= 0
|
||||||
|
assert spec.time.max() < DURATION
|
||||||
|
|
||||||
|
assert spec.attrs["original_samplerate"] == samplerate
|
||||||
|
assert spec.attrs["nfft"] == nfft
|
||||||
|
assert spec.attrs["noverlap"] == int(window_overlap * nfft)
|
||||||
|
|
||||||
|
assert np.all(spec.data >= 0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("window_fn", ["hann", "hamming"])
|
||||||
|
def test_stft_window_fn(sine_wave_xr: xr.DataArray, window_fn: str):
|
||||||
|
spec = stft(
|
||||||
|
sine_wave_xr,
|
||||||
|
window_duration=0.002,
|
||||||
|
window_overlap=0.5,
|
||||||
|
window_fn=window_fn,
|
||||||
)
|
)
|
||||||
|
assert isinstance(spec, xr.DataArray)
|
||||||
|
assert np.all(spec.data >= 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_crop_spectrogram_frequencies(sample_spec: xr.DataArray):
|
||||||
|
min_f, max_f = 20_000, 80_000
|
||||||
|
cropped_spec = crop_spectrogram_frequencies(
|
||||||
|
sample_spec, min_freq=min_f, max_freq=max_f
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cropped_spec.dims == sample_spec.dims
|
||||||
|
assert cropped_spec.dtype == sample_spec.dtype
|
||||||
|
assert cropped_spec.sizes["time"] == sample_spec.sizes["time"]
|
||||||
|
assert cropped_spec.sizes["frequency"] < sample_spec.sizes["frequency"]
|
||||||
|
assert cropped_spec.coords["frequency"].min() >= min_f
|
||||||
|
|
||||||
|
assert np.isclose(cropped_spec.coords["frequency"].max(), max_f, rtol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
|
||||||
|
samplerate = sample_spec.attrs["original_samplerate"]
|
||||||
|
min_f, max_f = 0, samplerate / 2
|
||||||
|
cropped_spec = crop_spectrogram_frequencies(
|
||||||
|
sample_spec, min_freq=min_f, max_freq=max_f
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cropped_spec.sizes == sample_spec.sizes
|
||||||
|
assert np.allclose(cropped_spec.data, sample_spec.data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_pcen(sample_spec: xr.DataArray):
|
||||||
|
if "original_samplerate" not in sample_spec.attrs:
|
||||||
|
sample_spec.attrs["original_samplerate"] = SAMPLERATE
|
||||||
|
if "nfft" not in sample_spec.attrs:
|
||||||
|
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
|
||||||
|
if "noverlap" not in sample_spec.attrs:
|
||||||
|
sample_spec.attrs["noverlap"] = int(0.5 * sample_spec.attrs["nfft"])
|
||||||
|
|
||||||
|
pcen_config = PcenConfig()
|
||||||
|
pcen_spec = apply_pcen(
|
||||||
|
sample_spec,
|
||||||
|
time_constant=pcen_config.time_constant,
|
||||||
|
gain=pcen_config.gain,
|
||||||
|
bias=pcen_config.bias,
|
||||||
|
power=pcen_config.power,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert pcen_spec.dims == sample_spec.dims
|
||||||
|
assert pcen_spec.sizes == sample_spec.sizes
|
||||||
|
assert pcen_spec.dtype == sample_spec.dtype
|
||||||
|
assert np.all(pcen_spec.data >= 0)
|
||||||
|
|
||||||
|
assert not np.allclose(pcen_spec.data, sample_spec.data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_log(sample_spec: xr.DataArray):
|
||||||
|
if "original_samplerate" not in sample_spec.attrs:
|
||||||
|
sample_spec.attrs["original_samplerate"] = SAMPLERATE
|
||||||
|
if "nfft" not in sample_spec.attrs:
|
||||||
|
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
|
||||||
|
|
||||||
|
log_spec = scale_log(sample_spec, dtype=np.float32)
|
||||||
|
|
||||||
|
assert log_spec.dims == sample_spec.dims
|
||||||
|
assert log_spec.sizes == sample_spec.sizes
|
||||||
|
assert log_spec.dtype == np.float32
|
||||||
|
assert np.all(log_spec.data >= 0)
|
||||||
|
assert not np.allclose(log_spec.data, sample_spec.data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_log_missing_attrs(sample_spec: xr.DataArray):
|
||||||
|
spec_copy = sample_spec.copy()
|
||||||
|
del spec_copy.attrs["original_samplerate"]
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
scale_log(spec_copy)
|
||||||
|
|
||||||
|
spec_copy = sample_spec.copy()
|
||||||
|
del spec_copy.attrs["nfft"]
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
scale_log(spec_copy)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_spectrogram_amplitude(sample_spec: xr.DataArray):
|
||||||
|
scaled_spec = scale_spectrogram(sample_spec, scale="amplitude")
|
||||||
|
assert np.allclose(scaled_spec.data, sample_spec.data)
|
||||||
|
assert scaled_spec.dtype == sample_spec.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_spectrogram_power(sample_spec: xr.DataArray):
|
||||||
|
scaled_spec = scale_spectrogram(sample_spec, scale="power")
|
||||||
|
assert np.allclose(scaled_spec.data, sample_spec.data**2)
|
||||||
|
assert scaled_spec.dtype == sample_spec.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def test_scale_spectrogram_db(sample_spec: xr.DataArray):
|
||||||
|
if "original_samplerate" not in sample_spec.attrs:
|
||||||
|
sample_spec.attrs["original_samplerate"] = SAMPLERATE
|
||||||
|
if "nfft" not in sample_spec.attrs:
|
||||||
|
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
|
||||||
|
|
||||||
|
scaled_spec = scale_spectrogram(sample_spec, scale="dB", dtype=np.float64)
|
||||||
|
log_spec_expected = scale_log(sample_spec, dtype=np.float64)
|
||||||
|
assert scaled_spec.dtype == np.float64
|
||||||
|
assert np.allclose(scaled_spec.data, log_spec_expected.data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_spectral_mean(sample_spec: xr.DataArray):
|
||||||
|
spec_noisy = sample_spec.copy() + 0.1
|
||||||
|
denoised_spec = remove_spectral_mean(spec_noisy)
|
||||||
|
|
||||||
|
assert denoised_spec.dims == spec_noisy.dims
|
||||||
|
assert denoised_spec.sizes == spec_noisy.sizes
|
||||||
|
assert denoised_spec.dtype == spec_noisy.dtype
|
||||||
|
assert np.all(denoised_spec.data >= 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
|
||||||
|
const_spec = stft(constant_wave_xr, 0.002, 0.5)
|
||||||
|
denoised_spec = remove_spectral_mean(const_spec)
|
||||||
|
|
||||||
|
assert np.allclose(denoised_spec.data, 0, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"height, resize_factor, expected_freq_size, expected_time_factor",
|
||||||
|
[
|
||||||
|
(128, 1.0, 128, 1.0),
|
||||||
|
(64, 0.5, 64, 0.5),
|
||||||
|
(256, None, 256, 1.0),
|
||||||
|
(100, 2.0, 100, 2.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_resize_spectrogram(
|
||||||
|
sample_spec: xr.DataArray,
|
||||||
|
height: int,
|
||||||
|
resize_factor: float | None,
|
||||||
|
expected_freq_size: int,
|
||||||
|
expected_time_factor: float,
|
||||||
|
):
|
||||||
|
original_time_size = sample_spec.sizes["time"]
|
||||||
|
resized_spec = resize_spectrogram(
|
||||||
|
sample_spec,
|
||||||
|
height=height,
|
||||||
|
resize_factor=resize_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resized_spec.dims == ("frequency", "time")
|
||||||
|
assert resized_spec.sizes["frequency"] == expected_freq_size
|
||||||
|
expected_time_size = int(original_time_size * expected_time_factor)
|
||||||
|
|
||||||
|
assert abs(resized_spec.sizes["time"] - expected_time_size) <= 1
|
||||||
|
|
||||||
|
assert resized_spec.dtype == np.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_spectrogram_defaults(sine_wave_xr: xr.DataArray):
|
||||||
|
config = SpectrogramConfig()
|
||||||
|
spec = compute_spectrogram(sine_wave_xr, config=config)
|
||||||
|
|
||||||
|
assert isinstance(spec, xr.DataArray)
|
||||||
|
assert spec.dims == ("frequency", "time")
|
||||||
|
assert spec.dtype == np.float32
|
||||||
|
assert config.size is not None
|
||||||
|
assert spec.sizes["frequency"] == config.size.height
|
||||||
|
|
||||||
|
temp_stft = stft(
|
||||||
|
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
|
||||||
|
)
|
||||||
|
assert config.size.resize_factor is not None
|
||||||
|
expected_time_size = int(
|
||||||
|
temp_stft.sizes["time"] * config.size.resize_factor
|
||||||
|
)
|
||||||
|
assert abs(spec.sizes["time"] - expected_time_size) <= 1
|
||||||
|
|
||||||
|
assert spec.coords["frequency"].min() >= config.frequencies.min_freq
|
||||||
|
assert np.isclose(
|
||||||
|
spec.coords["frequency"].max(),
|
||||||
|
config.frequencies.max_freq,
|
||||||
|
rtol=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_spectrogram_no_pcen_no_mean_sub_no_resize(
|
||||||
|
sine_wave_xr: xr.DataArray,
|
||||||
|
):
|
||||||
|
config = SpectrogramConfig(
|
||||||
|
pcen=None,
|
||||||
|
spectral_mean_substraction=False,
|
||||||
|
size=None,
|
||||||
|
scale="power",
|
||||||
|
frequencies=FrequencyConfig(min_freq=0, max_freq=int(SAMPLERATE / 2)),
|
||||||
|
)
|
||||||
|
spec = compute_spectrogram(sine_wave_xr, config=config)
|
||||||
|
|
||||||
|
stft_direct = stft(
|
||||||
|
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
|
||||||
|
)
|
||||||
|
expected_spec = scale_spectrogram(stft_direct, scale="power")
|
||||||
|
|
||||||
|
assert spec.sizes == expected_spec.sizes
|
||||||
|
assert np.allclose(spec.data, expected_spec.data)
|
||||||
|
assert spec.dtype == expected_spec.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def test_compute_spectrogram_peak_normalize(sine_wave_xr: xr.DataArray):
|
||||||
|
config = SpectrogramConfig(peak_normalize=True)
|
||||||
|
spec = compute_spectrogram(sine_wave_xr, config=config)
|
||||||
|
assert np.isclose(spec.data.max(), 1.0, atol=1e-6)
|
||||||
|
|
||||||
|
config = SpectrogramConfig(peak_normalize=False)
|
||||||
|
spec_no_norm = compute_spectrogram(sine_wave_xr, config=config)
|
||||||
|
assert not np.isclose(spec_no_norm.data.max(), 1.0, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_spectrogram_resolution_calculation():
|
||||||
|
config = SpectrogramConfig(
|
||||||
|
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
|
||||||
|
size=SpecSizeConfig(height=100, resize_factor=0.5),
|
||||||
|
frequencies=FrequencyConfig(min_freq=10_000, max_freq=110_000),
|
||||||
|
)
|
||||||
|
|
||||||
|
freq_res, time_res = get_spectrogram_resolution(config)
|
||||||
|
|
||||||
|
expected_freq_res = (110_000 - 10_000) / 100
|
||||||
|
expected_hop_duration = 0.002 * (1 - 0.75)
|
||||||
|
expected_time_res = expected_hop_duration / 0.5
|
||||||
|
|
||||||
|
assert np.isclose(freq_res, expected_freq_res)
|
||||||
|
assert np.isclose(time_res, expected_time_res)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_spectrogram_resolution_no_resize_factor():
|
||||||
|
config = SpectrogramConfig(
|
||||||
|
stft=STFTConfig(window_duration=0.004, window_overlap=0.5),
|
||||||
|
size=SpecSizeConfig(height=200, resize_factor=None),
|
||||||
|
frequencies=FrequencyConfig(min_freq=20_000, max_freq=120_000),
|
||||||
|
)
|
||||||
|
freq_res, time_res = get_spectrogram_resolution(config)
|
||||||
|
expected_freq_res = (120_000 - 20_000) / 200
|
||||||
|
expected_hop_duration = 0.004 * (1 - 0.5)
|
||||||
|
expected_time_res = expected_hop_duration / 1.0
|
||||||
|
|
||||||
|
assert np.isclose(freq_res, expected_freq_res)
|
||||||
|
assert np.isclose(time_res, expected_time_res)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_spectrogram_resolution_no_size_config():
|
||||||
|
config = SpectrogramConfig(size=None)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Spectrogram size configuration is required"
|
||||||
|
):
|
||||||
|
get_spectrogram_resolution(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_configurable_spectrogram_builder_init():
|
||||||
|
config = SpectrogramConfig()
|
||||||
|
builder = ConfigurableSpectrogramBuilder(config=config, dtype=np.float16)
|
||||||
|
assert builder.config is config
|
||||||
|
assert builder.dtype == np.float16
|
||||||
|
|
||||||
|
|
||||||
|
def test_configurable_spectrogram_builder_call_xr(sine_wave_xr: xr.DataArray):
|
||||||
|
config = SpectrogramConfig()
|
||||||
|
builder = ConfigurableSpectrogramBuilder(config=config)
|
||||||
|
spec_builder = builder(sine_wave_xr)
|
||||||
|
spec_direct = compute_spectrogram(sine_wave_xr, config=config)
|
||||||
|
assert isinstance(spec_builder, xr.DataArray)
|
||||||
|
assert np.allclose(spec_builder.data, spec_direct.data)
|
||||||
|
assert spec_builder.dtype == spec_direct.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def test_configurable_spectrogram_builder_call_np(sine_wave_xr: xr.DataArray):
|
||||||
|
config = SpectrogramConfig()
|
||||||
|
builder = ConfigurableSpectrogramBuilder(config=config)
|
||||||
|
wav_np = sine_wave_xr.data
|
||||||
|
samplerate = sine_wave_xr.attrs["samplerate"]
|
||||||
|
|
||||||
|
spec_builder = builder(wav_np.astype(np.float32), samplerate=samplerate)
|
||||||
|
spec_direct = compute_spectrogram(sine_wave_xr, config=config)
|
||||||
|
|
||||||
|
assert isinstance(spec_builder, xr.DataArray)
|
||||||
|
assert np.allclose(spec_builder.data, spec_direct.data, atol=1e-4)
|
||||||
|
assert spec_builder.dtype == spec_direct.dtype
|
||||||
|
|
||||||
|
|
||||||
|
def test_configurable_spectrogram_builder_call_np_no_samplerate(
|
||||||
|
sine_wave_xr: xr.DataArray,
|
||||||
|
):
|
||||||
|
config = SpectrogramConfig()
|
||||||
|
builder = ConfigurableSpectrogramBuilder(config=config)
|
||||||
|
wav_np = sine_wave_xr.data
|
||||||
|
with pytest.raises(ValueError, match="Samplerate must be provided"):
|
||||||
|
builder(wav_np, samplerate=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_spectrogram_builder():
|
||||||
|
config = SpectrogramConfig(peak_normalize=True)
|
||||||
|
builder = build_spectrogram_builder(config=config, dtype=np.float64)
|
||||||
|
assert isinstance(builder, ConfigurableSpectrogramBuilder)
|
||||||
|
assert builder.config is config
|
||||||
|
assert builder.dtype == np.float64
|
||||||
|
|
||||||
|
|
||||||
def test_can_estimate_spectrogram_resolution(
|
def test_can_estimate_spectrogram_resolution(
|
||||||
@ -72,10 +480,8 @@ def test_can_estimate_spectrogram_resolution(
|
|||||||
):
|
):
|
||||||
path = wav_factory(duration=0.2, samplerate=256_000)
|
path = wav_factory(duration=0.2, samplerate=256_000)
|
||||||
|
|
||||||
audio = load_file_audio(
|
audio_data = load_file_audio(
|
||||||
path,
|
path,
|
||||||
# NOTE: Dont resample nor adjust duration to test if the width
|
|
||||||
# estimation works on all scenarios
|
|
||||||
config=AudioConfig(resample=None, duration=None),
|
config=AudioConfig(resample=None, duration=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -85,7 +491,7 @@ def test_can_estimate_spectrogram_resolution(
|
|||||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||||
)
|
)
|
||||||
|
|
||||||
spec = compute_spectrogram(audio, config=config)
|
spec = compute_spectrogram(audio_data, config=config)
|
||||||
|
|
||||||
freq_res, time_res = get_spectrogram_resolution(config)
|
freq_res, time_res = get_spectrogram_resolution(config)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user