mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Add audio test suite
This commit is contained in:
parent
f5071d00a1
commit
aca0b58443
@ -29,6 +29,7 @@ from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import arrays, audio, data
|
||||
from soundevent.arrays import operations as ops
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
@ -360,7 +361,13 @@ def load_file_audio(
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform (first channel only).
|
||||
"""
|
||||
try:
|
||||
recording = data.Recording.from_file(path)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {path}. Error: {e}"
|
||||
) from e
|
||||
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
config=config,
|
||||
@ -421,10 +428,10 @@ def load_clip_audio(
|
||||
This is the core function performing the configured processing pipeline:
|
||||
1. Loads the specified clip segment using `soundevent.audio.load_clip`.
|
||||
2. Selects the first audio channel.
|
||||
3. Adjusts duration (crop/pad) if `config.duration` is set.
|
||||
4. Resamples if `config.resample` is configured.
|
||||
5. Centers (DC offset removal) if `config.center` is True.
|
||||
6. Scales (peak normalization) if `config.scale` is True.
|
||||
3. Resamples if `config.resample` is configured.
|
||||
4. Centers (DC offset removal) if `config.center` is True.
|
||||
5. Scales (peak normalization) if `config.scale` is True.
|
||||
6. Adjusts duration (crop/pad) if `config.duration` is set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -461,12 +468,17 @@ def load_clip_audio(
|
||||
"""
|
||||
config = config or AudioConfig()
|
||||
|
||||
try:
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype)
|
||||
audio.load_clip(clip, audio_dir=audio_dir)
|
||||
.sel(channel=0)
|
||||
.astype(dtype)
|
||||
)
|
||||
|
||||
if config.duration is not None:
|
||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {clip.recording.path}. "
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
if config.resample:
|
||||
wav = resample_audio(
|
||||
@ -479,11 +491,35 @@ def load_clip_audio(
|
||||
wav = ops.center(wav)
|
||||
|
||||
if config.scale:
|
||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||
wav = scale_audio(wav)
|
||||
|
||||
if config.duration is not None:
|
||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||
|
||||
return wav.astype(dtype)
|
||||
|
||||
|
||||
def scale_audio(
|
||||
wave: xr.DataArray,
|
||||
) -> xr.DataArray:
|
||||
"""
|
||||
Scale the audio waveform to have a maximum absolute value of 1.0.
|
||||
|
||||
This function normalizes the waveform by dividing it by its maximum
|
||||
absolute value. If the maximum value is zero, the waveform is returned
|
||||
unchanged. Also known as peak normalization, this process ensures that the
|
||||
waveform's amplitude is within a standard range, which can be useful for
|
||||
audio processing and analysis.
|
||||
|
||||
"""
|
||||
max_val = np.max(np.abs(wave))
|
||||
|
||||
if max_val == 0:
|
||||
return wave
|
||||
|
||||
return ops.scale(wave, 1 / max_val)
|
||||
|
||||
|
||||
def adjust_audio_duration(
|
||||
wave: xr.DataArray,
|
||||
duration: float,
|
||||
@ -513,24 +549,29 @@ def adjust_audio_duration(
|
||||
If `duration` is negative.
|
||||
"""
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
current_duration = end_time - start_time
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
current_duration = end_time - start_time + step
|
||||
|
||||
if current_duration == duration:
|
||||
return wave
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
if current_duration > duration:
|
||||
return arrays.crop_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
stop=start_time + duration - step / 2,
|
||||
right_closed=True,
|
||||
)
|
||||
|
||||
return arrays.extend_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
stop=start_time + duration - step / 2,
|
||||
eps=0,
|
||||
right_closed=True,
|
||||
)
|
||||
|
||||
|
||||
@ -616,7 +657,7 @@ def resample_audio(
|
||||
samplerate=samplerate,
|
||||
),
|
||||
},
|
||||
attrs=wav.attrs,
|
||||
attrs={**wav.attrs, "samplerate": samplerate},
|
||||
)
|
||||
|
||||
|
||||
|
@ -0,0 +1,506 @@
|
||||
import pathlib
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.arrays import Dimensions, create_time_dim_from_array
|
||||
|
||||
from batdetect2.preprocess import audio
|
||||
|
||||
|
||||
def create_dummy_wave(
|
||||
samplerate: int,
|
||||
duration: float,
|
||||
num_channels: int = 1,
|
||||
freq: float = 440.0,
|
||||
amplitude: float = 0.5,
|
||||
dtype: np.dtype = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""Generates a simple numpy waveform."""
|
||||
t = np.linspace(
|
||||
0.0, duration, int(samplerate * duration), endpoint=False, dtype=dtype
|
||||
)
|
||||
wave = amplitude * np.sin(2 * np.pi * freq * t)
|
||||
if num_channels > 1:
|
||||
wave = np.stack([wave] * num_channels, axis=0)
|
||||
return wave.astype(dtype)
|
||||
|
||||
|
||||
def create_xr_wave(
|
||||
samplerate: int,
|
||||
duration: float,
|
||||
num_channels: int = 1,
|
||||
freq: float = 440.0,
|
||||
amplitude: float = 0.5,
|
||||
start_time: float = 0.0,
|
||||
) -> xr.DataArray:
|
||||
"""Generates a simple xarray waveform."""
|
||||
num_samples = int(samplerate * duration)
|
||||
times = np.linspace(
|
||||
start_time,
|
||||
start_time + duration,
|
||||
num_samples,
|
||||
endpoint=False,
|
||||
)
|
||||
coords = {
|
||||
Dimensions.time.value: create_time_dim_from_array(
|
||||
times, samplerate=samplerate, start_time=start_time
|
||||
)
|
||||
}
|
||||
dims = [Dimensions.time.value]
|
||||
|
||||
wave_data = amplitude * np.sin(2 * np.pi * freq * times)
|
||||
|
||||
if num_channels > 1:
|
||||
coords[Dimensions.channel.value] = np.arange(num_channels)
|
||||
dims = [Dimensions.channel.value] + dims
|
||||
wave_data = np.stack([wave_data] * num_channels, axis=0)
|
||||
|
||||
return xr.DataArray(
|
||||
wave_data.astype(np.float32),
|
||||
coords=coords,
|
||||
dims=dims,
|
||||
attrs={"samplerate": samplerate},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_wav_path(tmp_path: pathlib.Path) -> pathlib.Path:
|
||||
"""Creates a dummy WAV file and returns its path."""
|
||||
samplerate = 48000
|
||||
duration = 2.0
|
||||
num_channels = 2
|
||||
wave_data = create_dummy_wave(samplerate, duration, num_channels)
|
||||
file_path = tmp_path / f"{uuid.uuid4()}.wav"
|
||||
sf.write(file_path, wave_data.T, samplerate, format="WAV", subtype="FLOAT")
|
||||
return file_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_recording(dummy_wav_path: pathlib.Path) -> data.Recording:
|
||||
"""Creates a Recording object pointing to the dummy WAV file."""
|
||||
return data.Recording.from_file(dummy_wav_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_clip(dummy_recording: data.Recording) -> data.Clip:
|
||||
"""Creates a Clip object from the dummy recording."""
|
||||
return data.Clip(
|
||||
recording=dummy_recording,
|
||||
start_time=0.5,
|
||||
end_time=1.5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_audio_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_resample_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig(resample=None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixed_duration_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig(duration=0.5)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scale_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig(scale=True, center=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_center_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig(center=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def resample_fourier_config() -> audio.AudioConfig:
|
||||
return audio.AudioConfig(
|
||||
resample=audio.ResampleConfig(
|
||||
samplerate=audio.TARGET_SAMPLERATE_HZ // 2, mode="fourier"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_resample_config_defaults():
|
||||
config = audio.ResampleConfig()
|
||||
assert config.samplerate == audio.TARGET_SAMPLERATE_HZ
|
||||
assert config.mode == "poly"
|
||||
|
||||
|
||||
def test_audio_config_defaults():
|
||||
config = audio.AudioConfig()
|
||||
assert config.resample is not None
|
||||
assert config.resample.samplerate == audio.TARGET_SAMPLERATE_HZ
|
||||
assert config.resample.mode == "poly"
|
||||
assert config.scale == audio.SCALE_RAW_AUDIO
|
||||
assert config.center is True
|
||||
assert config.duration == audio.DEFAULT_DURATION
|
||||
|
||||
|
||||
def test_audio_config_override():
|
||||
resample_cfg = audio.ResampleConfig(samplerate=44100, mode="fourier")
|
||||
config = audio.AudioConfig(
|
||||
resample=resample_cfg,
|
||||
scale=True,
|
||||
center=False,
|
||||
duration=1.0,
|
||||
)
|
||||
assert config.resample == resample_cfg
|
||||
assert config.scale is True
|
||||
assert config.center is False
|
||||
assert config.duration == 1.0
|
||||
|
||||
|
||||
def test_audio_config_no_resample():
|
||||
config = audio.AudioConfig(resample=None)
|
||||
assert config.resample is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"orig_sr, orig_dur, target_dur",
|
||||
[
|
||||
(256_000, 1.0, 0.5),
|
||||
(256_000, 0.5, 1.0),
|
||||
(256_000, 1.0, 1.0),
|
||||
],
|
||||
)
|
||||
def test_adjust_audio_duration(orig_sr, orig_dur, target_dur):
|
||||
wave = create_xr_wave(samplerate=orig_sr, duration=orig_dur)
|
||||
adjusted_wave = audio.adjust_audio_duration(wave, duration=target_dur)
|
||||
expected_samples = int(target_dur * orig_sr)
|
||||
assert adjusted_wave.sizes["time"] == expected_samples
|
||||
assert adjusted_wave.coords["time"].attrs["step"] == 1 / orig_sr
|
||||
assert adjusted_wave.dtype == wave.dtype
|
||||
if orig_dur > 0 and target_dur > orig_dur:
|
||||
padding_start_index = int(orig_dur * orig_sr) + 1
|
||||
assert np.all(adjusted_wave.values[padding_start_index:] == 0)
|
||||
|
||||
|
||||
def test_adjust_audio_duration_negative_target_raises():
|
||||
wave = create_xr_wave(1000, 1.0)
|
||||
with pytest.raises(ValueError):
|
||||
audio.adjust_audio_duration(wave, duration=-0.5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"orig_sr, target_sr, mode",
|
||||
[
|
||||
(48000, 96000, "poly"),
|
||||
(96000, 48000, "poly"),
|
||||
(48000, 96000, "fourier"),
|
||||
(96000, 48000, "fourier"),
|
||||
(48000, 44100, "poly"),
|
||||
(48000, 44100, "fourier"),
|
||||
],
|
||||
)
|
||||
def test_resample_audio(orig_sr, target_sr, mode):
|
||||
duration = 0.1
|
||||
wave = create_xr_wave(orig_sr, duration)
|
||||
resampled_wave = audio.resample_audio(
|
||||
wave, samplerate=target_sr, mode=mode, dtype=np.float32
|
||||
)
|
||||
expected_samples = int(wave.sizes["time"] * (target_sr / orig_sr))
|
||||
assert resampled_wave.sizes["time"] == expected_samples
|
||||
assert resampled_wave.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert np.isclose(
|
||||
resampled_wave.coords["time"].values[-1]
|
||||
- resampled_wave.coords["time"].values[0],
|
||||
duration,
|
||||
atol=2 / target_sr,
|
||||
)
|
||||
assert resampled_wave.dtype == np.float32
|
||||
|
||||
|
||||
def test_resample_audio_same_samplerate():
|
||||
sr = 48000
|
||||
duration = 0.1
|
||||
wave = create_xr_wave(sr, duration)
|
||||
resampled_wave = audio.resample_audio(
|
||||
wave, samplerate=sr, dtype=np.float64
|
||||
)
|
||||
xr.testing.assert_equal(wave.astype(np.float64), resampled_wave)
|
||||
|
||||
|
||||
def test_resample_audio_invalid_mode_raises():
|
||||
wave = create_xr_wave(48000, 0.1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
audio.resample_audio(wave, samplerate=96000, mode="invalid_mode")
|
||||
|
||||
|
||||
def test_resample_audio_no_time_dim_raises():
|
||||
wave = xr.DataArray(np.random.rand(100), dims=["samples"])
|
||||
with pytest.raises(ValueError, match="Audio must have a time dimension"):
|
||||
audio.resample_audio(wave, samplerate=96000)
|
||||
|
||||
|
||||
def test_load_clip_audio_default_config(
|
||||
dummy_clip: data.Clip,
|
||||
default_audio_config: audio.AudioConfig,
|
||||
tmp_path: Path,
|
||||
):
|
||||
assert default_audio_config.resample is not None
|
||||
target_sr = default_audio_config.resample.samplerate
|
||||
orig_duration = dummy_clip.duration
|
||||
expected_samples = int(orig_duration * target_sr)
|
||||
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert isinstance(wav, xr.DataArray)
|
||||
assert wav.dims == ("time",)
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
||||
assert wav.dtype == np.float32
|
||||
|
||||
|
||||
def test_load_clip_audio_no_resample(
|
||||
dummy_clip: data.Clip,
|
||||
no_resample_config: audio.AudioConfig,
|
||||
tmp_path: Path,
|
||||
):
|
||||
orig_sr = dummy_clip.recording.samplerate
|
||||
orig_duration = dummy_clip.duration
|
||||
expected_samples = int(orig_duration * orig_sr)
|
||||
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip, config=no_resample_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert wav.coords["time"].attrs["step"] == 1 / orig_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
||||
|
||||
|
||||
def test_load_clip_audio_fixed_duration_crop(
|
||||
dummy_clip: data.Clip,
|
||||
fixed_duration_config: audio.AudioConfig,
|
||||
tmp_path: Path,
|
||||
):
|
||||
target_sr = audio.TARGET_SAMPLERATE_HZ
|
||||
target_duration = fixed_duration_config.duration
|
||||
assert target_duration is not None
|
||||
expected_samples = int(target_duration * target_sr)
|
||||
|
||||
assert dummy_clip.duration > target_duration
|
||||
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip, config=fixed_duration_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
|
||||
|
||||
def test_load_clip_audio_fixed_duration_pad(
|
||||
dummy_clip: data.Clip,
|
||||
tmp_path: Path,
|
||||
):
|
||||
target_duration = dummy_clip.duration * 2
|
||||
config = audio.AudioConfig(duration=target_duration)
|
||||
|
||||
assert config.resample is not None
|
||||
target_sr = config.resample.samplerate
|
||||
expected_samples = int(target_duration * target_sr)
|
||||
|
||||
wav = audio.load_clip_audio(dummy_clip, config=config, audio_dir=tmp_path)
|
||||
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
|
||||
original_samples_after_resample = int(dummy_clip.duration * target_sr)
|
||||
assert np.allclose(
|
||||
wav.values[original_samples_after_resample:], 0.0, atol=1e-6
|
||||
)
|
||||
|
||||
|
||||
def test_load_clip_audio_scale(
|
||||
dummy_clip: data.Clip, scale_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip,
|
||||
config=scale_config,
|
||||
audio_dir=tmp_path,
|
||||
)
|
||||
|
||||
assert np.isclose(np.max(np.abs(wav.values)), 1.0, atol=1e-5)
|
||||
|
||||
|
||||
def test_load_clip_audio_no_center(
|
||||
dummy_clip: data.Clip, no_center_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip, config=no_center_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
raw_wav, _ = sf.read(
|
||||
dummy_clip.recording.path,
|
||||
start=int(dummy_clip.start_time * dummy_clip.recording.samplerate),
|
||||
stop=int(dummy_clip.end_time * dummy_clip.recording.samplerate),
|
||||
dtype=np.float32, # type: ignore
|
||||
)
|
||||
raw_wav_mono = raw_wav[:, 0]
|
||||
|
||||
if not np.isclose(raw_wav_mono.mean(), 0.0, atol=1e-7):
|
||||
assert not np.isclose(wav.mean(), 0.0, atol=1e-6)
|
||||
|
||||
|
||||
def test_load_clip_audio_resample_fourier(
|
||||
dummy_clip: data.Clip, resample_fourier_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
assert resample_fourier_config.resample is not None
|
||||
target_sr = resample_fourier_config.resample.samplerate
|
||||
orig_duration = dummy_clip.duration
|
||||
expected_samples = int(orig_duration * target_sr)
|
||||
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip, config=resample_fourier_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
|
||||
|
||||
def test_load_clip_audio_dtype(
|
||||
dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
wav = audio.load_clip_audio(
|
||||
dummy_clip,
|
||||
config=default_audio_config,
|
||||
audio_dir=tmp_path,
|
||||
dtype=np.float64,
|
||||
)
|
||||
assert wav.dtype == np.float64
|
||||
|
||||
|
||||
def test_load_clip_audio_file_not_found(
|
||||
dummy_clip: data.Clip, default_audio_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
non_existent_path = tmp_path / "not_a_real_file.wav"
|
||||
dummy_clip.recording = data.Recording(
|
||||
path=non_existent_path,
|
||||
duration=1,
|
||||
channels=1,
|
||||
samplerate=256000,
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
audio.load_clip_audio(
|
||||
dummy_clip, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
|
||||
def test_load_recording_audio(
|
||||
dummy_recording: data.Recording,
|
||||
default_audio_config: audio.AudioConfig,
|
||||
tmp_path,
|
||||
):
|
||||
assert default_audio_config.resample is not None
|
||||
target_sr = default_audio_config.resample.samplerate
|
||||
orig_duration = dummy_recording.duration
|
||||
expected_samples = int(orig_duration * target_sr)
|
||||
|
||||
wav = audio.load_recording_audio(
|
||||
dummy_recording, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert isinstance(wav, xr.DataArray)
|
||||
assert wav.dims == ("time",)
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
||||
assert wav.dtype == np.float32
|
||||
|
||||
|
||||
def test_load_recording_audio_file_not_found(
|
||||
dummy_recording: data.Recording,
|
||||
default_audio_config: audio.AudioConfig,
|
||||
tmp_path,
|
||||
):
|
||||
non_existent_path = tmp_path / "not_a_real_file.wav"
|
||||
dummy_recording = data.Recording(
|
||||
path=non_existent_path,
|
||||
duration=1,
|
||||
channels=1,
|
||||
samplerate=256000,
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
audio.load_recording_audio(
|
||||
dummy_recording, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
|
||||
def test_load_file_audio(
|
||||
dummy_wav_path: pathlib.Path,
|
||||
default_audio_config: audio.AudioConfig,
|
||||
tmp_path,
|
||||
):
|
||||
info = sf.info(dummy_wav_path)
|
||||
orig_duration = info.duration
|
||||
assert default_audio_config.resample is not None
|
||||
target_sr = default_audio_config.resample.samplerate
|
||||
expected_samples = int(orig_duration * target_sr)
|
||||
|
||||
wav = audio.load_file_audio(
|
||||
dummy_wav_path, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
assert isinstance(wav, xr.DataArray)
|
||||
assert wav.dims == ("time",)
|
||||
assert wav.coords["time"].attrs["step"] == 1 / target_sr
|
||||
assert wav.sizes["time"] == expected_samples
|
||||
assert np.isclose(wav.mean(), 0.0, atol=1e-6)
|
||||
assert wav.dtype == np.float32
|
||||
|
||||
|
||||
def test_load_file_audio_file_not_found(
|
||||
default_audio_config: audio.AudioConfig, tmp_path
|
||||
):
|
||||
non_existent_path = tmp_path / "not_a_real_file.wav"
|
||||
with pytest.raises(FileNotFoundError):
|
||||
audio.load_file_audio(
|
||||
non_existent_path, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
|
||||
|
||||
def test_build_audio_loader(default_audio_config: audio.AudioConfig):
|
||||
loader = audio.build_audio_loader(config=default_audio_config)
|
||||
assert isinstance(loader, audio.ConfigurableAudioLoader)
|
||||
assert loader.config == default_audio_config
|
||||
|
||||
|
||||
def test_configurable_audio_loader_methods(
|
||||
default_audio_config: audio.AudioConfig,
|
||||
dummy_wav_path: pathlib.Path,
|
||||
dummy_recording: data.Recording,
|
||||
dummy_clip: data.Clip,
|
||||
tmp_path,
|
||||
):
|
||||
loader = audio.build_audio_loader(config=default_audio_config)
|
||||
|
||||
expected_wav_file = audio.load_file_audio(
|
||||
dummy_wav_path, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
loaded_wav_file = loader.load_file(dummy_wav_path, audio_dir=tmp_path)
|
||||
xr.testing.assert_equal(expected_wav_file, loaded_wav_file)
|
||||
|
||||
expected_wav_rec = audio.load_recording_audio(
|
||||
dummy_recording, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
loaded_wav_rec = loader.load_recording(dummy_recording, audio_dir=tmp_path)
|
||||
xr.testing.assert_equal(expected_wav_rec, loaded_wav_rec)
|
||||
|
||||
expected_wav_clip = audio.load_clip_audio(
|
||||
dummy_clip, config=default_audio_config, audio_dir=tmp_path
|
||||
)
|
||||
loaded_wav_clip = loader.load_clip(dummy_clip, audio_dir=tmp_path)
|
||||
xr.testing.assert_equal(expected_wav_clip, loaded_wav_clip)
|
Loading…
Reference in New Issue
Block a user