mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added unit tests for spectrogram preprocessing
This commit is contained in:
parent
f314942628
commit
fd7f2b0081
@ -90,11 +90,11 @@ class FrequencyConfig(BaseConfig):
|
||||
Frequencies above this value will be cropped. Must be > 0.
|
||||
min_freq : int, default=10000
|
||||
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies below this value will be cropped. Must be > 0.
|
||||
Frequencies below this value will be cropped. Must be >= 0.
|
||||
"""
|
||||
|
||||
max_freq: int = Field(default=120_000, gt=0)
|
||||
min_freq: int = Field(default=10_000, gt=0)
|
||||
max_freq: int = Field(default=120_000, ge=0)
|
||||
min_freq: int = Field(default=10_000, ge=0)
|
||||
|
||||
|
||||
class SpecSizeConfig(BaseConfig):
|
||||
@ -395,11 +395,13 @@ def crop_spectrogram_frequencies(
|
||||
xr.DataArray
|
||||
Spectrogram cropped along the frequency axis. Preserves dtype.
|
||||
"""
|
||||
start_freq, end_freq = arrays.get_dim_range(spec, dim="frequency")
|
||||
|
||||
return arrays.crop_dim(
|
||||
spec,
|
||||
dim="frequency",
|
||||
start=min_freq,
|
||||
stop=max_freq,
|
||||
start=min_freq if start_freq < min_freq else None,
|
||||
stop=max_freq if end_freq > max_freq else None,
|
||||
).astype(spec.dtype)
|
||||
|
||||
|
||||
|
@ -2,69 +2,477 @@ import math
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
from soundevent import arrays
|
||||
|
||||
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
STFTConfig,
|
||||
MAX_FREQ,
|
||||
MIN_FREQ,
|
||||
ConfigurableSpectrogramBuilder,
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
STFTConfig,
|
||||
apply_pcen,
|
||||
build_spectrogram_builder,
|
||||
compute_spectrogram,
|
||||
duration_to_spec_width,
|
||||
crop_spectrogram_frequencies,
|
||||
get_spectrogram_resolution,
|
||||
spec_width_to_samples,
|
||||
remove_spectral_mean,
|
||||
resize_spectrogram,
|
||||
scale_log,
|
||||
scale_spectrogram,
|
||||
stft,
|
||||
)
|
||||
|
||||
|
||||
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@given(
|
||||
duration=st.floats(min_value=0.1, max_value=1.0),
|
||||
window_duration=st.floats(min_value=0.001, max_value=0.01),
|
||||
window_overlap=st.floats(min_value=0.2, max_value=0.9),
|
||||
samplerate=st.integers(min_value=256_000, max_value=512_000),
|
||||
SAMPLERATE = 250_000
|
||||
DURATION = 0.1
|
||||
TEST_FREQ = 30_000
|
||||
N_SAMPLES = int(SAMPLERATE * DURATION)
|
||||
TIME_COORD = np.linspace(
|
||||
0, DURATION, N_SAMPLES, endpoint=False, dtype=np.float32
|
||||
)
|
||||
def test_can_estimate_correctly_spectrogram_width_from_duration(
|
||||
duration: float,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
samplerate: int,
|
||||
wav_factory: Callable[..., Path],
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sine_wave_xr() -> xr.DataArray:
|
||||
"""Generate a single sine wave as an xr.DataArray."""
|
||||
t = TIME_COORD
|
||||
wav_data = np.sin(2 * np.pi * TEST_FREQ * t, dtype=np.float32)
|
||||
return xr.DataArray(
|
||||
wav_data,
|
||||
coords={"time": t},
|
||||
dims=["time"],
|
||||
attrs={"samplerate": SAMPLERATE},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def constant_wave_xr() -> xr.DataArray:
|
||||
"""Generate a constant signal as an xr.DataArray."""
|
||||
t = TIME_COORD
|
||||
wav_data = np.ones(N_SAMPLES, dtype=np.float32) * 0.5
|
||||
return xr.DataArray(
|
||||
wav_data,
|
||||
coords={"time": t},
|
||||
dims=["time"],
|
||||
attrs={"samplerate": SAMPLERATE},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_spec(sine_wave_xr: xr.DataArray) -> xr.DataArray:
|
||||
"""Generate a basic spectrogram for testing downstream functions."""
|
||||
config = SpectrogramConfig(
|
||||
stft=STFTConfig(window_duration=0.002, window_overlap=0.5),
|
||||
frequencies=FrequencyConfig(
|
||||
min_freq=0,
|
||||
max_freq=int(SAMPLERATE / 2),
|
||||
),
|
||||
size=None,
|
||||
pcen=None,
|
||||
spectral_mean_substraction=False,
|
||||
peak_normalize=False,
|
||||
scale="amplitude",
|
||||
)
|
||||
spec = stft(
|
||||
sine_wave_xr,
|
||||
window_duration=config.stft.window_duration,
|
||||
window_overlap=config.stft.window_overlap,
|
||||
window_fn=config.stft.window_fn,
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
def test_stft_config_defaults():
|
||||
config = STFTConfig()
|
||||
assert config.window_duration == 0.002
|
||||
assert config.window_overlap == 0.75
|
||||
assert config.window_fn == "hann"
|
||||
|
||||
|
||||
def test_frequency_config_defaults():
|
||||
config = FrequencyConfig()
|
||||
assert config.min_freq == MIN_FREQ
|
||||
assert config.max_freq == MAX_FREQ
|
||||
|
||||
|
||||
def test_spec_size_config_defaults():
|
||||
config = SpecSizeConfig()
|
||||
assert config.height == 128
|
||||
assert config.resize_factor == 0.5
|
||||
|
||||
|
||||
def test_pcen_config_defaults():
|
||||
config = PcenConfig()
|
||||
assert config.time_constant == 0.4
|
||||
assert config.gain == 0.98
|
||||
assert config.bias == 2
|
||||
assert config.power == 0.5
|
||||
|
||||
|
||||
def test_spectrogram_config_defaults():
|
||||
config = SpectrogramConfig()
|
||||
assert isinstance(config.stft, STFTConfig)
|
||||
assert isinstance(config.frequencies, FrequencyConfig)
|
||||
assert isinstance(config.pcen, PcenConfig)
|
||||
assert config.scale == "amplitude"
|
||||
assert isinstance(config.size, SpecSizeConfig)
|
||||
assert config.spectral_mean_substraction is True
|
||||
assert config.peak_normalize is False
|
||||
|
||||
|
||||
def test_stft_output_properties(sine_wave_xr: xr.DataArray):
|
||||
window_duration = 0.002
|
||||
window_overlap = 0.5
|
||||
samplerate = sine_wave_xr.attrs["samplerate"]
|
||||
nfft = int(window_duration * samplerate)
|
||||
hop_len = nfft - int(window_overlap * nfft)
|
||||
|
||||
spec = stft(
|
||||
sine_wave_xr,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
window_fn="hann",
|
||||
)
|
||||
|
||||
assert isinstance(spec, xr.DataArray)
|
||||
assert spec.dims == ("frequency", "time")
|
||||
assert spec.dtype == np.float32
|
||||
assert "frequency" in spec.coords
|
||||
assert "time" in spec.coords
|
||||
|
||||
time_step = arrays.get_dim_step(spec, "time")
|
||||
freq_step = arrays.get_dim_step(spec, "frequency")
|
||||
freq_start, freq_end = arrays.get_dim_range(spec, "frequency")
|
||||
assert np.isclose(freq_step, samplerate / nfft)
|
||||
assert np.isclose(time_step, hop_len / samplerate)
|
||||
assert spec.frequency.min() >= 0
|
||||
assert freq_start == 0
|
||||
assert np.isclose(freq_end + freq_step, samplerate / 2, atol=5)
|
||||
assert spec.time.min() >= 0
|
||||
assert spec.time.max() < DURATION
|
||||
|
||||
assert spec.attrs["original_samplerate"] == samplerate
|
||||
assert spec.attrs["nfft"] == nfft
|
||||
assert spec.attrs["noverlap"] == int(window_overlap * nfft)
|
||||
|
||||
assert np.all(spec.data >= 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("window_fn", ["hann", "hamming"])
|
||||
def test_stft_window_fn(sine_wave_xr: xr.DataArray, window_fn: str):
|
||||
spec = stft(
|
||||
sine_wave_xr,
|
||||
window_duration=0.002,
|
||||
window_overlap=0.5,
|
||||
window_fn=window_fn,
|
||||
)
|
||||
assert isinstance(spec, xr.DataArray)
|
||||
assert np.all(spec.data >= 0)
|
||||
|
||||
|
||||
def test_crop_spectrogram_frequencies(sample_spec: xr.DataArray):
|
||||
min_f, max_f = 20_000, 80_000
|
||||
cropped_spec = crop_spectrogram_frequencies(
|
||||
sample_spec, min_freq=min_f, max_freq=max_f
|
||||
)
|
||||
|
||||
assert cropped_spec.dims == sample_spec.dims
|
||||
assert cropped_spec.dtype == sample_spec.dtype
|
||||
assert cropped_spec.sizes["time"] == sample_spec.sizes["time"]
|
||||
assert cropped_spec.sizes["frequency"] < sample_spec.sizes["frequency"]
|
||||
assert cropped_spec.coords["frequency"].min() >= min_f
|
||||
|
||||
assert np.isclose(cropped_spec.coords["frequency"].max(), max_f, rtol=0.1)
|
||||
|
||||
|
||||
def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
|
||||
samplerate = sample_spec.attrs["original_samplerate"]
|
||||
min_f, max_f = 0, samplerate / 2
|
||||
cropped_spec = crop_spectrogram_frequencies(
|
||||
sample_spec, min_freq=min_f, max_freq=max_f
|
||||
)
|
||||
|
||||
assert cropped_spec.sizes == sample_spec.sizes
|
||||
assert np.allclose(cropped_spec.data, sample_spec.data)
|
||||
|
||||
|
||||
def test_apply_pcen(sample_spec: xr.DataArray):
|
||||
if "original_samplerate" not in sample_spec.attrs:
|
||||
sample_spec.attrs["original_samplerate"] = SAMPLERATE
|
||||
if "nfft" not in sample_spec.attrs:
|
||||
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
|
||||
if "noverlap" not in sample_spec.attrs:
|
||||
sample_spec.attrs["noverlap"] = int(0.5 * sample_spec.attrs["nfft"])
|
||||
|
||||
pcen_config = PcenConfig()
|
||||
pcen_spec = apply_pcen(
|
||||
sample_spec,
|
||||
time_constant=pcen_config.time_constant,
|
||||
gain=pcen_config.gain,
|
||||
bias=pcen_config.bias,
|
||||
power=pcen_config.power,
|
||||
)
|
||||
|
||||
assert pcen_spec.dims == sample_spec.dims
|
||||
assert pcen_spec.sizes == sample_spec.sizes
|
||||
assert pcen_spec.dtype == sample_spec.dtype
|
||||
assert np.all(pcen_spec.data >= 0)
|
||||
|
||||
assert not np.allclose(pcen_spec.data, sample_spec.data)
|
||||
|
||||
|
||||
def test_scale_log(sample_spec: xr.DataArray):
|
||||
if "original_samplerate" not in sample_spec.attrs:
|
||||
sample_spec.attrs["original_samplerate"] = SAMPLERATE
|
||||
if "nfft" not in sample_spec.attrs:
|
||||
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
|
||||
|
||||
log_spec = scale_log(sample_spec, dtype=np.float32)
|
||||
|
||||
assert log_spec.dims == sample_spec.dims
|
||||
assert log_spec.sizes == sample_spec.sizes
|
||||
assert log_spec.dtype == np.float32
|
||||
assert np.all(log_spec.data >= 0)
|
||||
assert not np.allclose(log_spec.data, sample_spec.data)
|
||||
|
||||
|
||||
def test_scale_log_missing_attrs(sample_spec: xr.DataArray):
|
||||
spec_copy = sample_spec.copy()
|
||||
del spec_copy.attrs["original_samplerate"]
|
||||
with pytest.raises(KeyError):
|
||||
scale_log(spec_copy)
|
||||
|
||||
spec_copy = sample_spec.copy()
|
||||
del spec_copy.attrs["nfft"]
|
||||
with pytest.raises(KeyError):
|
||||
scale_log(spec_copy)
|
||||
|
||||
|
||||
def test_scale_spectrogram_amplitude(sample_spec: xr.DataArray):
|
||||
scaled_spec = scale_spectrogram(sample_spec, scale="amplitude")
|
||||
assert np.allclose(scaled_spec.data, sample_spec.data)
|
||||
assert scaled_spec.dtype == sample_spec.dtype
|
||||
|
||||
|
||||
def test_scale_spectrogram_power(sample_spec: xr.DataArray):
|
||||
scaled_spec = scale_spectrogram(sample_spec, scale="power")
|
||||
assert np.allclose(scaled_spec.data, sample_spec.data**2)
|
||||
assert scaled_spec.dtype == sample_spec.dtype
|
||||
|
||||
|
||||
def test_scale_spectrogram_db(sample_spec: xr.DataArray):
|
||||
if "original_samplerate" not in sample_spec.attrs:
|
||||
sample_spec.attrs["original_samplerate"] = SAMPLERATE
|
||||
if "nfft" not in sample_spec.attrs:
|
||||
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
|
||||
|
||||
scaled_spec = scale_spectrogram(sample_spec, scale="dB", dtype=np.float64)
|
||||
log_spec_expected = scale_log(sample_spec, dtype=np.float64)
|
||||
assert scaled_spec.dtype == np.float64
|
||||
assert np.allclose(scaled_spec.data, log_spec_expected.data)
|
||||
|
||||
|
||||
def test_remove_spectral_mean(sample_spec: xr.DataArray):
|
||||
spec_noisy = sample_spec.copy() + 0.1
|
||||
denoised_spec = remove_spectral_mean(spec_noisy)
|
||||
|
||||
assert denoised_spec.dims == spec_noisy.dims
|
||||
assert denoised_spec.sizes == spec_noisy.sizes
|
||||
assert denoised_spec.dtype == spec_noisy.dtype
|
||||
assert np.all(denoised_spec.data >= 0)
|
||||
|
||||
|
||||
def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
|
||||
const_spec = stft(constant_wave_xr, 0.002, 0.5)
|
||||
denoised_spec = remove_spectral_mean(const_spec)
|
||||
|
||||
assert np.allclose(denoised_spec.data, 0, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"height, resize_factor, expected_freq_size, expected_time_factor",
|
||||
[
|
||||
(128, 1.0, 128, 1.0),
|
||||
(64, 0.5, 64, 0.5),
|
||||
(256, None, 256, 1.0),
|
||||
(100, 2.0, 100, 2.0),
|
||||
],
|
||||
)
|
||||
def test_resize_spectrogram(
|
||||
sample_spec: xr.DataArray,
|
||||
height: int,
|
||||
resize_factor: float | None,
|
||||
expected_freq_size: int,
|
||||
expected_time_factor: float,
|
||||
):
|
||||
path = wav_factory(duration=duration, samplerate=samplerate)
|
||||
audio = load_file_audio(
|
||||
path,
|
||||
# NOTE: Dont resample nor adjust duration to test if the width
|
||||
# estimation works on all scenarios
|
||||
config=AudioConfig(resample=None, duration=None),
|
||||
)
|
||||
spectrogram = stft(audio, window_duration, window_overlap)
|
||||
|
||||
spec_width = duration_to_spec_width(
|
||||
duration,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
)
|
||||
assert spectrogram.sizes["time"] == spec_width
|
||||
|
||||
rebuilt_duration = (
|
||||
spec_width_to_samples(
|
||||
spec_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
)
|
||||
/ samplerate
|
||||
original_time_size = sample_spec.sizes["time"]
|
||||
resized_spec = resize_spectrogram(
|
||||
sample_spec,
|
||||
height=height,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
|
||||
assert (
|
||||
abs(duration - rebuilt_duration)
|
||||
< (1 - window_overlap) * window_duration
|
||||
assert resized_spec.dims == ("frequency", "time")
|
||||
assert resized_spec.sizes["frequency"] == expected_freq_size
|
||||
expected_time_size = int(original_time_size * expected_time_factor)
|
||||
|
||||
assert abs(resized_spec.sizes["time"] - expected_time_size) <= 1
|
||||
|
||||
assert resized_spec.dtype == np.float32
|
||||
|
||||
|
||||
def test_compute_spectrogram_defaults(sine_wave_xr: xr.DataArray):
|
||||
config = SpectrogramConfig()
|
||||
spec = compute_spectrogram(sine_wave_xr, config=config)
|
||||
|
||||
assert isinstance(spec, xr.DataArray)
|
||||
assert spec.dims == ("frequency", "time")
|
||||
assert spec.dtype == np.float32
|
||||
assert config.size is not None
|
||||
assert spec.sizes["frequency"] == config.size.height
|
||||
|
||||
temp_stft = stft(
|
||||
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
|
||||
)
|
||||
assert config.size.resize_factor is not None
|
||||
expected_time_size = int(
|
||||
temp_stft.sizes["time"] * config.size.resize_factor
|
||||
)
|
||||
assert abs(spec.sizes["time"] - expected_time_size) <= 1
|
||||
|
||||
assert spec.coords["frequency"].min() >= config.frequencies.min_freq
|
||||
assert np.isclose(
|
||||
spec.coords["frequency"].max(),
|
||||
config.frequencies.max_freq,
|
||||
rtol=0.1,
|
||||
)
|
||||
|
||||
|
||||
def test_compute_spectrogram_no_pcen_no_mean_sub_no_resize(
|
||||
sine_wave_xr: xr.DataArray,
|
||||
):
|
||||
config = SpectrogramConfig(
|
||||
pcen=None,
|
||||
spectral_mean_substraction=False,
|
||||
size=None,
|
||||
scale="power",
|
||||
frequencies=FrequencyConfig(min_freq=0, max_freq=int(SAMPLERATE / 2)),
|
||||
)
|
||||
spec = compute_spectrogram(sine_wave_xr, config=config)
|
||||
|
||||
stft_direct = stft(
|
||||
sine_wave_xr, config.stft.window_duration, config.stft.window_overlap
|
||||
)
|
||||
expected_spec = scale_spectrogram(stft_direct, scale="power")
|
||||
|
||||
assert spec.sizes == expected_spec.sizes
|
||||
assert np.allclose(spec.data, expected_spec.data)
|
||||
assert spec.dtype == expected_spec.dtype
|
||||
|
||||
|
||||
def test_compute_spectrogram_peak_normalize(sine_wave_xr: xr.DataArray):
|
||||
config = SpectrogramConfig(peak_normalize=True)
|
||||
spec = compute_spectrogram(sine_wave_xr, config=config)
|
||||
assert np.isclose(spec.data.max(), 1.0, atol=1e-6)
|
||||
|
||||
config = SpectrogramConfig(peak_normalize=False)
|
||||
spec_no_norm = compute_spectrogram(sine_wave_xr, config=config)
|
||||
assert not np.isclose(spec_no_norm.data.max(), 1.0, atol=1e-6)
|
||||
|
||||
|
||||
def test_get_spectrogram_resolution_calculation():
|
||||
config = SpectrogramConfig(
|
||||
stft=STFTConfig(window_duration=0.002, window_overlap=0.75),
|
||||
size=SpecSizeConfig(height=100, resize_factor=0.5),
|
||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=110_000),
|
||||
)
|
||||
|
||||
freq_res, time_res = get_spectrogram_resolution(config)
|
||||
|
||||
expected_freq_res = (110_000 - 10_000) / 100
|
||||
expected_hop_duration = 0.002 * (1 - 0.75)
|
||||
expected_time_res = expected_hop_duration / 0.5
|
||||
|
||||
assert np.isclose(freq_res, expected_freq_res)
|
||||
assert np.isclose(time_res, expected_time_res)
|
||||
|
||||
|
||||
def test_get_spectrogram_resolution_no_resize_factor():
|
||||
config = SpectrogramConfig(
|
||||
stft=STFTConfig(window_duration=0.004, window_overlap=0.5),
|
||||
size=SpecSizeConfig(height=200, resize_factor=None),
|
||||
frequencies=FrequencyConfig(min_freq=20_000, max_freq=120_000),
|
||||
)
|
||||
freq_res, time_res = get_spectrogram_resolution(config)
|
||||
expected_freq_res = (120_000 - 20_000) / 200
|
||||
expected_hop_duration = 0.004 * (1 - 0.5)
|
||||
expected_time_res = expected_hop_duration / 1.0
|
||||
|
||||
assert np.isclose(freq_res, expected_freq_res)
|
||||
assert np.isclose(time_res, expected_time_res)
|
||||
|
||||
|
||||
def test_get_spectrogram_resolution_no_size_config():
|
||||
config = SpectrogramConfig(size=None)
|
||||
with pytest.raises(
|
||||
ValueError, match="Spectrogram size configuration is required"
|
||||
):
|
||||
get_spectrogram_resolution(config)
|
||||
|
||||
|
||||
def test_configurable_spectrogram_builder_init():
|
||||
config = SpectrogramConfig()
|
||||
builder = ConfigurableSpectrogramBuilder(config=config, dtype=np.float16)
|
||||
assert builder.config is config
|
||||
assert builder.dtype == np.float16
|
||||
|
||||
|
||||
def test_configurable_spectrogram_builder_call_xr(sine_wave_xr: xr.DataArray):
|
||||
config = SpectrogramConfig()
|
||||
builder = ConfigurableSpectrogramBuilder(config=config)
|
||||
spec_builder = builder(sine_wave_xr)
|
||||
spec_direct = compute_spectrogram(sine_wave_xr, config=config)
|
||||
assert isinstance(spec_builder, xr.DataArray)
|
||||
assert np.allclose(spec_builder.data, spec_direct.data)
|
||||
assert spec_builder.dtype == spec_direct.dtype
|
||||
|
||||
|
||||
def test_configurable_spectrogram_builder_call_np(sine_wave_xr: xr.DataArray):
|
||||
config = SpectrogramConfig()
|
||||
builder = ConfigurableSpectrogramBuilder(config=config)
|
||||
wav_np = sine_wave_xr.data
|
||||
samplerate = sine_wave_xr.attrs["samplerate"]
|
||||
|
||||
spec_builder = builder(wav_np.astype(np.float32), samplerate=samplerate)
|
||||
spec_direct = compute_spectrogram(sine_wave_xr, config=config)
|
||||
|
||||
assert isinstance(spec_builder, xr.DataArray)
|
||||
assert np.allclose(spec_builder.data, spec_direct.data, atol=1e-4)
|
||||
assert spec_builder.dtype == spec_direct.dtype
|
||||
|
||||
|
||||
def test_configurable_spectrogram_builder_call_np_no_samplerate(
|
||||
sine_wave_xr: xr.DataArray,
|
||||
):
|
||||
config = SpectrogramConfig()
|
||||
builder = ConfigurableSpectrogramBuilder(config=config)
|
||||
wav_np = sine_wave_xr.data
|
||||
with pytest.raises(ValueError, match="Samplerate must be provided"):
|
||||
builder(wav_np, samplerate=None)
|
||||
|
||||
|
||||
def test_build_spectrogram_builder():
|
||||
config = SpectrogramConfig(peak_normalize=True)
|
||||
builder = build_spectrogram_builder(config=config, dtype=np.float64)
|
||||
assert isinstance(builder, ConfigurableSpectrogramBuilder)
|
||||
assert builder.config is config
|
||||
assert builder.dtype == np.float64
|
||||
|
||||
|
||||
def test_can_estimate_spectrogram_resolution(
|
||||
@ -72,10 +480,8 @@ def test_can_estimate_spectrogram_resolution(
|
||||
):
|
||||
path = wav_factory(duration=0.2, samplerate=256_000)
|
||||
|
||||
audio = load_file_audio(
|
||||
audio_data = load_file_audio(
|
||||
path,
|
||||
# NOTE: Dont resample nor adjust duration to test if the width
|
||||
# estimation works on all scenarios
|
||||
config=AudioConfig(resample=None, duration=None),
|
||||
)
|
||||
|
||||
@ -85,7 +491,7 @@ def test_can_estimate_spectrogram_resolution(
|
||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||
)
|
||||
|
||||
spec = compute_spectrogram(audio, config=config)
|
||||
spec = compute_spectrogram(audio_data, config=config)
|
||||
|
||||
freq_res, time_res = get_spectrogram_resolution(config)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user