import math from pathlib import Path from typing import Callable, Union import numpy as np import pytest import xarray as xr from soundevent import arrays from batdetect2.preprocess.audio import AudioConfig, load_file_audio from batdetect2.preprocess.spectrogram import ( MAX_FREQ, MIN_FREQ, ConfigurableSpectrogramBuilder, FrequencyConfig, PcenConfig, SpecSizeConfig, SpectrogramConfig, STFTConfig, apply_pcen, build_spectrogram_builder, compute_spectrogram, crop_spectrogram_frequencies, get_spectrogram_resolution, remove_spectral_mean, resize_spectrogram, scale_log, scale_spectrogram, stft, ) SAMPLERATE = 250_000 DURATION = 0.1 TEST_FREQ = 30_000 N_SAMPLES = int(SAMPLERATE * DURATION) TIME_COORD = np.linspace( 0, DURATION, N_SAMPLES, endpoint=False, dtype=np.float32 ) @pytest.fixture def sine_wave_xr() -> xr.DataArray: """Generate a single sine wave as an xr.DataArray.""" t = TIME_COORD wav_data = np.sin(2 * np.pi * TEST_FREQ * t, dtype=np.float32) return xr.DataArray( wav_data, coords={"time": t}, dims=["time"], attrs={"samplerate": SAMPLERATE}, ) @pytest.fixture def constant_wave_xr() -> xr.DataArray: """Generate a constant signal as an xr.DataArray.""" t = TIME_COORD wav_data = np.ones(N_SAMPLES, dtype=np.float32) * 0.5 return xr.DataArray( wav_data, coords={"time": t}, dims=["time"], attrs={"samplerate": SAMPLERATE}, ) @pytest.fixture def sample_spec(sine_wave_xr: xr.DataArray) -> xr.DataArray: """Generate a basic spectrogram for testing downstream functions.""" config = SpectrogramConfig( stft=STFTConfig(window_duration=0.002, window_overlap=0.5), frequencies=FrequencyConfig( min_freq=0, max_freq=int(SAMPLERATE / 2), ), size=None, pcen=None, spectral_mean_substraction=False, peak_normalize=False, scale="amplitude", ) spec = stft( sine_wave_xr, window_duration=config.stft.window_duration, window_overlap=config.stft.window_overlap, window_fn=config.stft.window_fn, ) return spec def test_stft_config_defaults(): config = STFTConfig() assert config.window_duration == 0.002 assert config.window_overlap == 0.75 assert config.window_fn == "hann" def test_frequency_config_defaults(): config = FrequencyConfig() assert config.min_freq == MIN_FREQ assert config.max_freq == MAX_FREQ def test_spec_size_config_defaults(): config = SpecSizeConfig() assert config.height == 128 assert config.resize_factor == 0.5 def test_pcen_config_defaults(): config = PcenConfig() assert config.time_constant == 0.4 assert config.gain == 0.98 assert config.bias == 2 assert config.power == 0.5 def test_spectrogram_config_defaults(): config = SpectrogramConfig() assert isinstance(config.stft, STFTConfig) assert isinstance(config.frequencies, FrequencyConfig) assert isinstance(config.pcen, PcenConfig) assert config.scale == "amplitude" assert isinstance(config.size, SpecSizeConfig) assert config.spectral_mean_substraction is True assert config.peak_normalize is False def test_stft_output_properties(sine_wave_xr: xr.DataArray): window_duration = 0.002 window_overlap = 0.5 samplerate = sine_wave_xr.attrs["samplerate"] nfft = int(window_duration * samplerate) hop_len = nfft - int(window_overlap * nfft) spec = stft( sine_wave_xr, window_duration=window_duration, window_overlap=window_overlap, window_fn="hann", ) assert isinstance(spec, xr.DataArray) assert spec.dims == ("frequency", "time") assert spec.dtype == np.float32 assert "frequency" in spec.coords assert "time" in spec.coords time_step = arrays.get_dim_step(spec, "time") freq_step = arrays.get_dim_step(spec, "frequency") freq_start, freq_end = arrays.get_dim_range(spec, "frequency") assert np.isclose(freq_step, samplerate / nfft) assert np.isclose(time_step, hop_len / samplerate) assert spec.frequency.min() >= 0 assert freq_start == 0 assert np.isclose(freq_end + freq_step, samplerate / 2, atol=5) assert spec.time.min() >= 0 assert spec.time.max() < DURATION assert spec.attrs["original_samplerate"] == samplerate assert spec.attrs["nfft"] == nfft assert spec.attrs["noverlap"] == int(window_overlap * nfft) assert np.all(spec.data >= 0) @pytest.mark.parametrize("window_fn", ["hann", "hamming"]) def test_stft_window_fn(sine_wave_xr: xr.DataArray, window_fn: str): spec = stft( sine_wave_xr, window_duration=0.002, window_overlap=0.5, window_fn=window_fn, ) assert isinstance(spec, xr.DataArray) assert np.all(spec.data >= 0) def test_crop_spectrogram_frequencies(sample_spec: xr.DataArray): min_f, max_f = 20_000, 80_000 cropped_spec = crop_spectrogram_frequencies( sample_spec, min_freq=min_f, max_freq=max_f ) assert cropped_spec.dims == sample_spec.dims assert cropped_spec.dtype == sample_spec.dtype assert cropped_spec.sizes["time"] == sample_spec.sizes["time"] assert cropped_spec.sizes["frequency"] < sample_spec.sizes["frequency"] assert cropped_spec.coords["frequency"].min() >= min_f assert np.isclose(cropped_spec.coords["frequency"].max(), max_f, rtol=0.1) def test_crop_spectrogram_full_range(sample_spec: xr.DataArray): samplerate = sample_spec.attrs["original_samplerate"] min_f, max_f = 0, samplerate / 2 cropped_spec = crop_spectrogram_frequencies( sample_spec, min_freq=min_f, max_freq=max_f ) assert cropped_spec.sizes == sample_spec.sizes assert np.allclose(cropped_spec.data, sample_spec.data) def test_apply_pcen(sample_spec: xr.DataArray): if "original_samplerate" not in sample_spec.attrs: sample_spec.attrs["original_samplerate"] = SAMPLERATE if "nfft" not in sample_spec.attrs: sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE) if "noverlap" not in sample_spec.attrs: sample_spec.attrs["noverlap"] = int(0.5 * sample_spec.attrs["nfft"]) pcen_config = PcenConfig() pcen_spec = apply_pcen( sample_spec, time_constant=pcen_config.time_constant, gain=pcen_config.gain, bias=pcen_config.bias, power=pcen_config.power, ) assert pcen_spec.dims == sample_spec.dims assert pcen_spec.sizes == sample_spec.sizes assert pcen_spec.dtype == sample_spec.dtype assert np.all(pcen_spec.data >= 0) assert not np.allclose(pcen_spec.data, sample_spec.data) def test_scale_log(sample_spec: xr.DataArray): if "original_samplerate" not in sample_spec.attrs: sample_spec.attrs["original_samplerate"] = SAMPLERATE if "nfft" not in sample_spec.attrs: sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE) log_spec = scale_log(sample_spec, dtype=np.float32) assert log_spec.dims == sample_spec.dims assert log_spec.sizes == sample_spec.sizes assert log_spec.dtype == np.float32 assert np.all(log_spec.data >= 0) assert not np.allclose(log_spec.data, sample_spec.data) def test_scale_log_missing_attrs(sample_spec: xr.DataArray): spec_copy = sample_spec.copy() del spec_copy.attrs["original_samplerate"] with pytest.raises(KeyError): scale_log(spec_copy) spec_copy = sample_spec.copy() del spec_copy.attrs["nfft"] with pytest.raises(KeyError): scale_log(spec_copy) def test_scale_spectrogram_amplitude(sample_spec: xr.DataArray): scaled_spec = scale_spectrogram(sample_spec, scale="amplitude") assert np.allclose(scaled_spec.data, sample_spec.data) assert scaled_spec.dtype == sample_spec.dtype def test_scale_spectrogram_power(sample_spec: xr.DataArray): scaled_spec = scale_spectrogram(sample_spec, scale="power") assert np.allclose(scaled_spec.data, sample_spec.data**2) assert scaled_spec.dtype == sample_spec.dtype def test_scale_spectrogram_db(sample_spec: xr.DataArray): if "original_samplerate" not in sample_spec.attrs: sample_spec.attrs["original_samplerate"] = SAMPLERATE if "nfft" not in sample_spec.attrs: sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE) scaled_spec = scale_spectrogram(sample_spec, scale="dB", dtype=np.float64) log_spec_expected = scale_log(sample_spec, dtype=np.float64) assert scaled_spec.dtype == np.float64 assert np.allclose(scaled_spec.data, log_spec_expected.data) def test_remove_spectral_mean(sample_spec: xr.DataArray): spec_noisy = sample_spec.copy() + 0.1 denoised_spec = remove_spectral_mean(spec_noisy) assert denoised_spec.dims == spec_noisy.dims assert denoised_spec.sizes == spec_noisy.sizes assert denoised_spec.dtype == spec_noisy.dtype assert np.all(denoised_spec.data >= 0) def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray): const_spec = stft(constant_wave_xr, 0.002, 0.5) denoised_spec = remove_spectral_mean(const_spec) assert np.allclose(denoised_spec.data, 0, atol=1e-6) @pytest.mark.parametrize( "height, resize_factor, expected_freq_size, expected_time_factor", [ (128, 1.0, 128, 1.0), (64, 0.5, 64, 0.5), (256, None, 256, 1.0), (100, 2.0, 100, 2.0), ], ) def test_resize_spectrogram( sample_spec: xr.DataArray, height: int, resize_factor: Union[float, None], expected_freq_size: int, expected_time_factor: float, ): original_time_size = sample_spec.sizes["time"] resized_spec = resize_spectrogram( sample_spec, height=height, resize_factor=resize_factor, ) assert resized_spec.dims == ("frequency", "time") assert resized_spec.sizes["frequency"] == expected_freq_size expected_time_size = int(original_time_size * expected_time_factor) assert abs(resized_spec.sizes["time"] - expected_time_size) <= 1 assert resized_spec.dtype == np.float32 def test_compute_spectrogram_defaults(sine_wave_xr: xr.DataArray): config = SpectrogramConfig() spec = compute_spectrogram(sine_wave_xr, config=config) assert isinstance(spec, xr.DataArray) assert spec.dims == ("frequency", "time") assert spec.dtype == np.float32 assert config.size is not None assert spec.sizes["frequency"] == config.size.height temp_stft = stft( sine_wave_xr, config.stft.window_duration, config.stft.window_overlap ) assert config.size.resize_factor is not None expected_time_size = int( temp_stft.sizes["time"] * config.size.resize_factor ) assert abs(spec.sizes["time"] - expected_time_size) <= 1 assert spec.coords["frequency"].min() >= config.frequencies.min_freq assert np.isclose( spec.coords["frequency"].max(), config.frequencies.max_freq, rtol=0.1, ) def test_compute_spectrogram_no_pcen_no_mean_sub_no_resize( sine_wave_xr: xr.DataArray, ): config = SpectrogramConfig( pcen=None, spectral_mean_substraction=False, size=None, scale="power", frequencies=FrequencyConfig(min_freq=0, max_freq=int(SAMPLERATE / 2)), ) spec = compute_spectrogram(sine_wave_xr, config=config) stft_direct = stft( sine_wave_xr, config.stft.window_duration, config.stft.window_overlap ) expected_spec = scale_spectrogram(stft_direct, scale="power") assert spec.sizes == expected_spec.sizes assert np.allclose(spec.data, expected_spec.data) assert spec.dtype == expected_spec.dtype def test_compute_spectrogram_peak_normalize(sine_wave_xr: xr.DataArray): config = SpectrogramConfig(peak_normalize=True) spec = compute_spectrogram(sine_wave_xr, config=config) assert np.isclose(spec.data.max(), 1.0, atol=1e-6) config = SpectrogramConfig(peak_normalize=False) spec_no_norm = compute_spectrogram(sine_wave_xr, config=config) assert not np.isclose(spec_no_norm.data.max(), 1.0, atol=1e-6) def test_get_spectrogram_resolution_calculation(): config = SpectrogramConfig( stft=STFTConfig(window_duration=0.002, window_overlap=0.75), size=SpecSizeConfig(height=100, resize_factor=0.5), frequencies=FrequencyConfig(min_freq=10_000, max_freq=110_000), ) freq_res, time_res = get_spectrogram_resolution(config) expected_freq_res = (110_000 - 10_000) / 100 expected_hop_duration = 0.002 * (1 - 0.75) expected_time_res = expected_hop_duration / 0.5 assert np.isclose(freq_res, expected_freq_res) assert np.isclose(time_res, expected_time_res) def test_get_spectrogram_resolution_no_resize_factor(): config = SpectrogramConfig( stft=STFTConfig(window_duration=0.004, window_overlap=0.5), size=SpecSizeConfig(height=200, resize_factor=None), frequencies=FrequencyConfig(min_freq=20_000, max_freq=120_000), ) freq_res, time_res = get_spectrogram_resolution(config) expected_freq_res = (120_000 - 20_000) / 200 expected_hop_duration = 0.004 * (1 - 0.5) expected_time_res = expected_hop_duration / 1.0 assert np.isclose(freq_res, expected_freq_res) assert np.isclose(time_res, expected_time_res) def test_get_spectrogram_resolution_no_size_config(): config = SpectrogramConfig(size=None) with pytest.raises( ValueError, match="Spectrogram size configuration is required" ): get_spectrogram_resolution(config) def test_configurable_spectrogram_builder_init(): config = SpectrogramConfig() builder = ConfigurableSpectrogramBuilder(config=config, dtype=np.float16) assert builder.config is config assert builder.dtype == np.float16 def test_configurable_spectrogram_builder_call_xr(sine_wave_xr: xr.DataArray): config = SpectrogramConfig() builder = ConfigurableSpectrogramBuilder(config=config) spec_builder = builder(sine_wave_xr) spec_direct = compute_spectrogram(sine_wave_xr, config=config) assert isinstance(spec_builder, xr.DataArray) assert np.allclose(spec_builder.data, spec_direct.data) assert spec_builder.dtype == spec_direct.dtype def test_configurable_spectrogram_builder_call_np(sine_wave_xr: xr.DataArray): config = SpectrogramConfig() builder = ConfigurableSpectrogramBuilder(config=config) wav_np = sine_wave_xr.data samplerate = sine_wave_xr.attrs["samplerate"] spec_builder = builder(wav_np.astype(np.float32), samplerate=samplerate) spec_direct = compute_spectrogram(sine_wave_xr, config=config) assert isinstance(spec_builder, xr.DataArray) assert np.allclose(spec_builder.data, spec_direct.data, atol=1e-4) assert spec_builder.dtype == spec_direct.dtype def test_configurable_spectrogram_builder_call_np_no_samplerate( sine_wave_xr: xr.DataArray, ): config = SpectrogramConfig() builder = ConfigurableSpectrogramBuilder(config=config) wav_np = sine_wave_xr.data with pytest.raises(ValueError, match="Samplerate must be provided"): builder(wav_np, samplerate=None) def test_build_spectrogram_builder(): config = SpectrogramConfig(peak_normalize=True) builder = build_spectrogram_builder(config=config, dtype=np.float64) assert isinstance(builder, ConfigurableSpectrogramBuilder) assert builder.config is config assert builder.dtype == np.float64 def test_can_estimate_spectrogram_resolution( wav_factory: Callable[..., Path], ): path = wav_factory(duration=0.2, samplerate=256_000) audio_data = load_file_audio( path, config=AudioConfig(resample=None, duration=None), ) config = SpectrogramConfig( stft=STFTConfig(), size=SpecSizeConfig(height=256, resize_factor=0.5), frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000), ) spec = compute_spectrogram(audio_data, config=config) freq_res, time_res = get_spectrogram_resolution(config) assert math.isclose( arrays.get_dim_step(spec, dim="frequency"), freq_res, rel_tol=0.1, ) assert math.isclose( arrays.get_dim_step(spec, dim="time"), time_res, rel_tol=0.1, )