diff --git a/batdetect2/preprocess/audio.py b/batdetect2/preprocess/audio.py index ff1e751..b80e903 100644 --- a/batdetect2/preprocess/audio.py +++ b/batdetect2/preprocess/audio.py @@ -29,6 +29,7 @@ from pydantic import Field from scipy.signal import resample, resample_poly from soundevent import arrays, audio, data from soundevent.arrays import operations as ops +from soundfile import LibsndfileError from batdetect2.configs import BaseConfig @@ -360,7 +361,13 @@ def load_file_audio( xr.DataArray Loaded and preprocessed waveform (first channel only). """ - recording = data.Recording.from_file(path) + try: + recording = data.Recording.from_file(path) + except LibsndfileError as e: + raise FileNotFoundError( + f"Could not load the recording at path: {path}. Error: {e}" + ) from e + return load_recording_audio( recording, config=config, @@ -421,10 +428,10 @@ def load_clip_audio( This is the core function performing the configured processing pipeline: 1. Loads the specified clip segment using `soundevent.audio.load_clip`. 2. Selects the first audio channel. - 3. Adjusts duration (crop/pad) if `config.duration` is set. - 4. Resamples if `config.resample` is configured. - 5. Centers (DC offset removal) if `config.center` is True. - 6. Scales (peak normalization) if `config.scale` is True. + 3. Resamples if `config.resample` is configured. + 4. Centers (DC offset removal) if `config.center` is True. + 5. Scales (peak normalization) if `config.scale` is True. + 6. Adjusts duration (crop/pad) if `config.duration` is set. Parameters ---------- @@ -461,12 +468,17 @@ def load_clip_audio( """ config = config or AudioConfig() - wav = ( - audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype) - ) - - if config.duration is not None: - wav = adjust_audio_duration(wav, duration=config.duration) + try: + wav = ( + audio.load_clip(clip, audio_dir=audio_dir) + .sel(channel=0) + .astype(dtype) + ) + except LibsndfileError as e: + raise FileNotFoundError( + f"Could not load the recording at path: {clip.recording.path}. " + f"Error: {e}" + ) from e if config.resample: wav = resample_audio( @@ -479,11 +491,35 @@ def load_clip_audio( wav = ops.center(wav) if config.scale: - wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav)))) + wav = scale_audio(wav) + + if config.duration is not None: + wav = adjust_audio_duration(wav, duration=config.duration) return wav.astype(dtype) +def scale_audio( + wave: xr.DataArray, +) -> xr.DataArray: + """ + Scale the audio waveform to have a maximum absolute value of 1.0. + + This function normalizes the waveform by dividing it by its maximum + absolute value. If the maximum value is zero, the waveform is returned + unchanged. Also known as peak normalization, this process ensures that the + waveform's amplitude is within a standard range, which can be useful for + audio processing and analysis. + + """ + max_val = np.max(np.abs(wave)) + + if max_val == 0: + return wave + + return ops.scale(wave, 1 / max_val) + + def adjust_audio_duration( wave: xr.DataArray, duration: float, @@ -513,26 +549,31 @@ def adjust_audio_duration( If `duration` is negative. """ start_time, end_time = arrays.get_dim_range(wave, dim="time") - current_duration = end_time - start_time + step = arrays.get_dim_step(wave, dim="time") + current_duration = end_time - start_time + step if current_duration == duration: return wave - if current_duration > duration: - return arrays.crop_dim( + with xr.set_options(keep_attrs=True): + if current_duration > duration: + return arrays.crop_dim( + wave, + dim="time", + start=start_time, + stop=start_time + duration - step / 2, + right_closed=True, + ) + + return arrays.extend_dim( wave, dim="time", start=start_time, - stop=start_time + duration, + stop=start_time + duration - step / 2, + eps=0, + right_closed=True, ) - return arrays.extend_dim( - wave, - dim="time", - start=start_time, - stop=start_time + duration, - ) - def resample_audio( wav: xr.DataArray, @@ -616,7 +657,7 @@ def resample_audio( samplerate=samplerate, ), }, - attrs=wav.attrs, + attrs={**wav.attrs, "samplerate": samplerate}, ) diff --git a/tests/test_preprocessing/test_audio.py b/tests/test_preprocessing/test_audio.py index e69de29..4a96c32 100644 --- a/tests/test_preprocessing/test_audio.py +++ b/tests/test_preprocessing/test_audio.py @@ -0,0 +1,506 @@ +import pathlib +import uuid +from pathlib import Path + +import numpy as np +import pytest +import soundfile as sf +import xarray as xr +from soundevent import data +from soundevent.arrays import Dimensions, create_time_dim_from_array + +from batdetect2.preprocess import audio + + +def create_dummy_wave( + samplerate: int, + duration: float, + num_channels: int = 1, + freq: float = 440.0, + amplitude: float = 0.5, + dtype: np.dtype = np.float32, +) -> np.ndarray: + """Generates a simple numpy waveform.""" + t = np.linspace( + 0.0, duration, int(samplerate * duration), endpoint=False, dtype=dtype + ) + wave = amplitude * np.sin(2 * np.pi * freq * t) + if num_channels > 1: + wave = np.stack([wave] * num_channels, axis=0) + return wave.astype(dtype) + + +def create_xr_wave( + samplerate: int, + duration: float, + num_channels: int = 1, + freq: float = 440.0, + amplitude: float = 0.5, + start_time: float = 0.0, +) -> xr.DataArray: + """Generates a simple xarray waveform.""" + num_samples = int(samplerate * duration) + times = np.linspace( + start_time, + start_time + duration, + num_samples, + endpoint=False, + ) + coords = { + Dimensions.time.value: create_time_dim_from_array( + times, samplerate=samplerate, start_time=start_time + ) + } + dims = [Dimensions.time.value] + + wave_data = amplitude * np.sin(2 * np.pi * freq * times) + + if num_channels > 1: + coords[Dimensions.channel.value] = np.arange(num_channels) + dims = [Dimensions.channel.value] + dims + wave_data = np.stack([wave_data] * num_channels, axis=0) + + return xr.DataArray( + wave_data.astype(np.float32), + coords=coords, + dims=dims, + attrs={"samplerate": samplerate}, + ) + + +@pytest.fixture +def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path: + """Creates a dummy WAV file and returns its path.""" + samplerate = 48000 + duration = 2.0 + num_channels = 2 + wave_data = create_dummy_wave(samplerate, duration, num_channels) + file_path = tmp_path / f"{uuid.uuid4()}.wav" + sf.write(file_path, wave_data.T, samplerate, format="WAV", subtype="FLOAT") + return file_path + + +@pytest.fixture +def dummy_recording(dummy_wav_path: pathlib.Path) -> data.Recording: + """Creates a Recording object pointing to the dummy WAV file.""" + return data.Recording.from_file(dummy_wav_path) + + +@pytest.fixture +def dummy_clip(dummy_recording: data.Recording) -> data.Clip: + """Creates a Clip object from the dummy recording.""" + return data.Clip( + recording=dummy_recording, + start_time=0.5, + end_time=1.5, + ) + + +@pytest.fixture +def default_audio_config() -> audio.AudioConfig: + return audio.AudioConfig() + + +@pytest.fixture +def no_resample_config() -> audio.AudioConfig: + return audio.AudioConfig(resample=None) + + +@pytest.fixture +def fixed_duration_config() -> audio.AudioConfig: + return audio.AudioConfig(duration=0.5) + + +@pytest.fixture +def scale_config() -> audio.AudioConfig: + return audio.AudioConfig(scale=True, center=False) + + +@pytest.fixture +def no_center_config() -> audio.AudioConfig: + return audio.AudioConfig(center=False) + + +@pytest.fixture +def resample_fourier_config() -> audio.AudioConfig: + return audio.AudioConfig( + resample=audio.ResampleConfig( + samplerate=audio.TARGET_SAMPLERATE_HZ // 2, mode="fourier" + ) + ) + + +def test_resample_config_defaults(): + config = audio.ResampleConfig() + assert config.samplerate == audio.TARGET_SAMPLERATE_HZ + assert config.mode == "poly" + + +def test_audio_config_defaults(): + config = audio.AudioConfig() + assert config.resample is not None + assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ + assert config.resample.mode == "poly" + assert config.scale == audio.SCALE_RAW_AUDIO + assert config.center is True + assert config.duration == audio.DEFAULT_DURATION + + +def test_audio_config_override(): + resample_cfg = audio.ResampleConfig(samplerate=44100, mode="fourier") + config = audio.AudioConfig( + resample=resample_cfg, + scale=True, + center=False, + duration=1.0, + ) + assert config.resample == resample_cfg + assert config.scale is True + assert config.center is False + assert config.duration == 1.0 + + +def test_audio_config_no_resample(): + config = audio.AudioConfig(resample=None) + assert config.resample is None + + +@pytest.mark.parametrize( + "orig_sr, orig_dur, target_dur", + [ + (256_000, 1.0, 0.5), + (256_000, 0.5, 1.0), + (256_000, 1.0, 1.0), + ], +) +def test_adjust_audio_duration(orig_sr, orig_dur, target_dur): + wave = create_xr_wave(samplerate=orig_sr, duration=orig_dur) + adjusted_wave = audio.adjust_audio_duration(wave, duration=target_dur) + expected_samples = int(target_dur * orig_sr) + assert adjusted_wave.sizes["time"] == expected_samples + assert adjusted_wave.coords["time"].attrs["step"] == 1 / orig_sr + assert adjusted_wave.dtype == wave.dtype + if orig_dur > 0 and target_dur > orig_dur: + padding_start_index = int(orig_dur * orig_sr) + 1 + assert np.all(adjusted_wave.values[padding_start_index:] == 0) + + +def test_adjust_audio_duration_negative_target_raises(): + wave = create_xr_wave(1000, 1.0) + with pytest.raises(ValueError): + audio.adjust_audio_duration(wave, duration=-0.5) + + +@pytest.mark.parametrize( + "orig_sr, target_sr, mode", + [ + (48000, 96000, "poly"), + (96000, 48000, "poly"), + (48000, 96000, "fourier"), + (96000, 48000, "fourier"), + (48000, 44100, "poly"), + (48000, 44100, "fourier"), + ], +) +def test_resample_audio(orig_sr, target_sr, mode): + duration = 0.1 + wave = create_xr_wave(orig_sr, duration) + resampled_wave = audio.resample_audio( + wave, samplerate=target_sr, mode=mode, dtype=np.float32 + ) + expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr)) + assert resampled_wave.sizes["time"] == expected_samples + assert resampled_wave.coords["time"].attrs["step"] == 1 / target_sr + assert np.isclose( + resampled_wave.coords["time"].values[-1] + - resampled_wave.coords["time"].values[0], + duration, + atol=2 / target_sr, + ) + assert resampled_wave.dtype == np.float32 + + +def test_resample_audio_same_samplerate(): + sr = 48000 + duration = 0.1 + wave = create_xr_wave(sr, duration) + resampled_wave = audio.resample_audio( + wave, samplerate=sr, dtype=np.float64 + ) + xr.testing.assert_equal(wave.astype(np.float64), resampled_wave) + + +def test_resample_audio_invalid_mode_raises(): + wave = create_xr_wave(48000, 0.1) + with pytest.raises(NotImplementedError): + audio.resample_audio(wave, samplerate=96000, mode="invalid_mode") + + +def test_resample_audio_no_time_dim_raises(): + wave = xr.DataArray(np.random.rand(100), dims=["samples"]) + with pytest.raises(ValueError, match="Audio must have a time dimension"): + audio.resample_audio(wave, samplerate=96000) + + +def test_load_clip_audio_default_config( + dummy_clip: data.Clip, + default_audio_config: audio.AudioConfig, + tmp_path: Path, +): + assert default_audio_config.resample is not None + target_sr = default_audio_config.resample.samplerate + orig_duration = dummy_clip.duration + expected_samples = int(orig_duration * target_sr) + + wav = audio.load_clip_audio( + dummy_clip, config=default_audio_config, audio_dir=tmp_path + ) + + assert isinstance(wav, xr.DataArray) + assert wav.dims == ("time",) + assert wav.sizes["time"] == expected_samples + assert wav.coords["time"].attrs["step"] == 1 / target_sr + assert np.isclose(wav.mean(), 0.0, atol=1e-6) + assert wav.dtype == np.float32 + + +def test_load_clip_audio_no_resample( + dummy_clip: data.Clip, + no_resample_config: audio.AudioConfig, + tmp_path: Path, +): + orig_sr = dummy_clip.recording.samplerate + orig_duration = dummy_clip.duration + expected_samples = int(orig_duration * orig_sr) + + wav = audio.load_clip_audio( + dummy_clip, config=no_resample_config, audio_dir=tmp_path + ) + + assert wav.coords["time"].attrs["step"] == 1 / orig_sr + assert wav.sizes["time"] == expected_samples + assert np.isclose(wav.mean(), 0.0, atol=1e-6) + + +def test_load_clip_audio_fixed_duration_crop( + dummy_clip: data.Clip, + fixed_duration_config: audio.AudioConfig, + tmp_path: Path, +): + target_sr = audio.TARGET_SAMPLERATE_HZ + target_duration = fixed_duration_config.duration + assert target_duration is not None + expected_samples = int(target_duration * target_sr) + + assert dummy_clip.duration > target_duration + + wav = audio.load_clip_audio( + dummy_clip, config=fixed_duration_config, audio_dir=tmp_path + ) + + assert wav.coords["time"].attrs["step"] == 1 / target_sr + assert wav.sizes["time"] == expected_samples + + +def test_load_clip_audio_fixed_duration_pad( + dummy_clip: data.Clip, + tmp_path: Path, +): + target_duration = dummy_clip.duration * 2 + config = audio.AudioConfig(duration=target_duration) + + assert config.resample is not None + target_sr = config.resample.samplerate + expected_samples = int(target_duration * target_sr) + + wav = audio.load_clip_audio(dummy_clip, config=config, audio_dir=tmp_path) + + assert wav.coords["time"].attrs["step"] == 1 / target_sr + assert wav.sizes["time"] == expected_samples + + original_samples_after_resample = int(dummy_clip.duration * target_sr) + assert np.allclose( + wav.values[original_samples_after_resample:], 0.0, atol=1e-6 + ) + + +def test_load_clip_audio_scale( + dummy_clip: data.Clip, scale_config: audio.AudioConfig, tmp_path +): + wav = audio.load_clip_audio( + dummy_clip, + config=scale_config, + audio_dir=tmp_path, + ) + + assert np.isclose(np.max(np.abs(wav.values)), 1.0, atol=1e-5) + + +def test_load_clip_audio_no_center( + dummy_clip: data.Clip, no_center_config: audio.AudioConfig, tmp_path +): + wav = audio.load_clip_audio( + dummy_clip, config=no_center_config, audio_dir=tmp_path + ) + + raw_wav, _ = sf.read( + dummy_clip.recording.path, + start=int(dummy_clip.start_time * dummy_clip.recording.samplerate), + stop=int(dummy_clip.end_time * dummy_clip.recording.samplerate), + dtype=np.float32, # type: ignore + ) + raw_wav_mono = raw_wav[:, 0] + + if not np.isclose(raw_wav_mono.mean(), 0.0, atol=1e-7): + assert not np.isclose(wav.mean(), 0.0, atol=1e-6) + + +def test_load_clip_audio_resample_fourier( + dummy_clip: data.Clip, resample_fourier_config: audio.AudioConfig, tmp_path +): + assert resample_fourier_config.resample is not None + target_sr = resample_fourier_config.resample.samplerate + orig_duration = dummy_clip.duration + expected_samples = int(orig_duration * target_sr) + + wav = audio.load_clip_audio( + dummy_clip, config=resample_fourier_config, audio_dir=tmp_path + ) + + assert wav.coords["time"].attrs["step"] == 1 / target_sr + assert wav.sizes["time"] == expected_samples + + +def test_load_clip_audio_dtype( + dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path +): + wav = audio.load_clip_audio( + dummy_clip, + config=default_audio_config, + audio_dir=tmp_path, + dtype=np.float64, + ) + assert wav.dtype == np.float64 + + +def test_load_clip_audio_file_not_found( + dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path +): + non_existent_path = tmp_path / "not_a_real_file.wav" + dummy_clip.recording = data.Recording( + path=non_existent_path, + duration=1, + channels=1, + samplerate=256000, + ) + with pytest.raises(FileNotFoundError): + audio.load_clip_audio( + dummy_clip, config=default_audio_config, audio_dir=tmp_path + ) + + +def test_load_recording_audio( + dummy_recording: data.Recording, + default_audio_config: audio.AudioConfig, + tmp_path, +): + assert default_audio_config.resample is not None + target_sr = default_audio_config.resample.samplerate + orig_duration = dummy_recording.duration + expected_samples = int(orig_duration * target_sr) + + wav = audio.load_recording_audio( + dummy_recording, config=default_audio_config, audio_dir=tmp_path + ) + + assert isinstance(wav, xr.DataArray) + assert wav.dims == ("time",) + assert wav.coords["time"].attrs["step"] == 1 / target_sr + assert wav.sizes["time"] == expected_samples + assert np.isclose(wav.mean(), 0.0, atol=1e-6) + assert wav.dtype == np.float32 + + +def test_load_recording_audio_file_not_found( + dummy_recording: data.Recording, + default_audio_config: audio.AudioConfig, + tmp_path, +): + non_existent_path = tmp_path / "not_a_real_file.wav" + dummy_recording = data.Recording( + path=non_existent_path, + duration=1, + channels=1, + samplerate=256000, + ) + with pytest.raises(FileNotFoundError): + audio.load_recording_audio( + dummy_recording, config=default_audio_config, audio_dir=tmp_path + ) + + +def test_load_file_audio( + dummy_wav_path: pathlib.Path, + default_audio_config: audio.AudioConfig, + tmp_path, +): + info = sf.info(dummy_wav_path) + orig_duration = info.duration + assert default_audio_config.resample is not None + target_sr = default_audio_config.resample.samplerate + expected_samples = int(orig_duration * target_sr) + + wav = audio.load_file_audio( + dummy_wav_path, config=default_audio_config, audio_dir=tmp_path + ) + + assert isinstance(wav, xr.DataArray) + assert wav.dims == ("time",) + assert wav.coords["time"].attrs["step"] == 1 / target_sr + assert wav.sizes["time"] == expected_samples + assert np.isclose(wav.mean(), 0.0, atol=1e-6) + assert wav.dtype == np.float32 + + +def test_load_file_audio_file_not_found( + default_audio_config: audio.AudioConfig, tmp_path +): + non_existent_path = tmp_path / "not_a_real_file.wav" + with pytest.raises(FileNotFoundError): + audio.load_file_audio( + non_existent_path, config=default_audio_config, audio_dir=tmp_path + ) + + +def test_build_audio_loader(default_audio_config: audio.AudioConfig): + loader = audio.build_audio_loader(config=default_audio_config) + assert isinstance(loader, audio.ConfigurableAudioLoader) + assert loader.config == default_audio_config + + +def test_configurable_audio_loader_methods( + default_audio_config: audio.AudioConfig, + dummy_wav_path: pathlib.Path, + dummy_recording: data.Recording, + dummy_clip: data.Clip, + tmp_path, +): + loader = audio.build_audio_loader(config=default_audio_config) + + expected_wav_file = audio.load_file_audio( + dummy_wav_path, config=default_audio_config, audio_dir=tmp_path + ) + loaded_wav_file = loader.load_file(dummy_wav_path, audio_dir=tmp_path) + xr.testing.assert_equal(expected_wav_file, loaded_wav_file) + + expected_wav_rec = audio.load_recording_audio( + dummy_recording, config=default_audio_config, audio_dir=tmp_path + ) + loaded_wav_rec = loader.load_recording(dummy_recording, audio_dir=tmp_path) + xr.testing.assert_equal(expected_wav_rec, loaded_wav_rec) + + expected_wav_clip = audio.load_clip_audio( + dummy_clip, config=default_audio_config, audio_dir=tmp_path + ) + loaded_wav_clip = loader.load_clip(dummy_clip, audio_dir=tmp_path) + xr.testing.assert_equal(expected_wav_clip, loaded_wav_clip)