diff --git a/tests/conftest.py b/tests/conftest.py index 556540f..1bf855e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from typing import Callable, List, Optional import numpy as np import pytest import soundfile as sf +from scipy import signal from soundevent import data, terms from batdetect2.data import DatasetConfig, load_dataset @@ -127,6 +128,43 @@ def create_recording(wav_factory: Callable[..., Path]): return factory +@pytest.fixture +def generate_whistle(tmp_path: Path): + """ + Pytest fixture that provides a factory for generating WAV audio files. + + The factory creates a recording containing a "whistle" (a short, + frequency-specific pulse) positioned at a precise time, suitable for + testing audio analysis functions. + """ + + def factory( + time: float, + frequency: int, + path: Optional[Path] = None, + duration: float = 0.3, + samplerate: int = 441_000, + whistle_duration: float = 0.1, + ) -> Path: + path = path or tmp_path / f"{uuid.uuid4()}.wav" + frames = int(samplerate * duration) + + offset = int((time - duration / 2) * samplerate) + t = np.linspace(-duration / 2, duration / 2, frames, endpoint=False) + data = signal.gausspulse( + t, + fc=frequency, + bw=2 / (frequency * whistle_duration), + ) + wave = (np.roll(data, offset) * np.iinfo(np.int16).max).astype( + np.int16 + ) + sf.write(str(path), wave, samplerate, subtype="PCM_16") + return path + + return factory + + @pytest.fixture def recording( create_recording: Callable[..., data.Recording], diff --git a/tests/test_targets/test_rois.py b/tests/test_targets/test_rois.py index c665ade..474c15a 100644 --- a/tests/test_targets/test_rois.py +++ b/tests/test_targets/test_rois.py @@ -1,7 +1,9 @@ import numpy as np import pytest +import soundfile as sf from soundevent import data +from batdetect2.preprocess import build_preprocessor from batdetect2.targets.rois import ( DEFAULT_ANCHOR, DEFAULT_FREQUENCY_SCALE, @@ -12,6 +14,7 @@ from batdetect2.targets.rois import ( BBoxAnchorMapperConfig, _build_bounding_box, build_roi_mapper, + get_peak_energy_coordinates, ) @@ -247,3 +250,154 @@ def test_build_roi_mapper(): assert mapper.anchor == config.anchor assert mapper.time_scale == config.time_scale assert mapper.frequency_scale == config.frequency_scale + + +def test_get_peak_energy_coordinates(generate_whistle): + whistle_time = 0.5 + whistle_frequency = 40_000 + duration = 1.0 + samplerate = 256_000 + + # Generate a WAV file with a whistle + whistle_path = generate_whistle( + time=whistle_time, + frequency=whistle_frequency, + duration=duration, + samplerate=samplerate, + whistle_duration=0.01, + ) + + # Create a recording object from the generated WAV + recording = data.Recording.from_file(path=whistle_path) + + # Build a preprocessor (default config should be fine for this test) + preprocessor = build_preprocessor() + + # Define a region of interest that contains the whistle + start_time = 0.2 + end_time = 0.7 + low_freq = 20_000 + high_freq = 60_000 + + # Get the peak energy coordinates + peak_time, peak_freq = get_peak_energy_coordinates( + recording=recording, + preprocessor=preprocessor, + start_time=start_time, + end_time=end_time, + low_freq=low_freq, + high_freq=high_freq, + loading_buffer=0.05, + ) + + # Assert that the peak coordinates are close to the expected values + assert peak_time == pytest.approx(whistle_time, abs=0.01) + assert peak_freq == pytest.approx(whistle_frequency, abs=1000) + + +def test_get_peak_energy_coordinates_with_two_whistles(generate_whistle): + # Parameters for the first (stronger) whistle + strong_whistle_time = 0.2 + strong_whistle_frequency = 30_000 + strong_whistle_amplitude = 1.0 # Full amplitude + + # Parameters for the second (weaker) whistle + weak_whistle_time = 0.8 + weak_whistle_frequency = 50_000 + weak_whistle_amplitude = 0.1 # Weaker amplitude + + # Recording parameters + duration = 1.0 + samplerate = 256_000 + + # Generate WAV files for each whistle + strong_whistle_path = generate_whistle( + time=strong_whistle_time, + frequency=strong_whistle_frequency, + duration=duration, + samplerate=samplerate, + whistle_duration=0.01, + ) + weak_whistle_path = generate_whistle( + time=weak_whistle_time, + frequency=weak_whistle_frequency, + duration=duration, + samplerate=samplerate, + whistle_duration=0.01, + ) + + # Load audio data + strong_audio, _ = sf.read(strong_whistle_path) + weak_audio, _ = sf.read(weak_whistle_path) + + # Mix the audio files + mixed_audio = ( + strong_audio * strong_whistle_amplitude + + weak_audio * weak_whistle_amplitude + ) + mixed_audio_path = strong_whistle_path.parent / "mixed_whistles.wav" + sf.write(str(mixed_audio_path), mixed_audio, samplerate) + + # Create a recording object from the mixed WAV + recording = data.Recording.from_file(path=mixed_audio_path) + + # Build a preprocessor + preprocessor = build_preprocessor() + + # Define a region of interest that contains only the weaker whistle + start_time = 0.7 + end_time = 0.9 + low_freq = 45_000 + high_freq = 55_000 + + # Get the peak energy coordinates within the bounding box + peak_time, peak_freq = get_peak_energy_coordinates( + recording=recording, + preprocessor=preprocessor, + start_time=start_time, + end_time=end_time, + low_freq=low_freq, + high_freq=high_freq, + loading_buffer=0.05, + ) + + # Assert that the peak coordinates are close to the weaker whistle's values + assert peak_time == pytest.approx(weak_whistle_time, abs=0.01) + assert peak_freq == pytest.approx(weak_whistle_frequency, abs=1000) + + +def test_get_peak_energy_coordinates_silent_region(create_recording): + # Parameters for a silent recording + duration = 2.0 # seconds + samplerate = 44_100 # Hz + + # Create a silent recording + recording = create_recording(duration=duration, samplerate=samplerate) + + # Build a preprocessor + preprocessor = build_preprocessor() + + # Define a region of interest within the silent recording + start_time = 0.5 + end_time = 1.5 + low_freq = 10_000 + high_freq = 20_000 + + # Get the peak energy coordinates from the silent region + peak_time, peak_freq = get_peak_energy_coordinates( + recording=recording, + preprocessor=preprocessor, + start_time=start_time, + end_time=end_time, + low_freq=low_freq, + high_freq=high_freq, + loading_buffer=0.05, + ) + + # Assert that the peak coordinates are within the defined ROI bounds + assert start_time <= peak_time <= end_time + assert low_freq <= peak_freq <= high_freq + + # Since there's no actual peak, the exact values might vary depending on + # argmax behavior with all-zero or very low, uniform energy. We just need + # to ensure they are within the search bounds.