Add tests for peak energy function

This commit is contained in:
mbsantiago 2025-06-21 23:01:08 +01:00
parent 3103630c26
commit ad0f0bcb24
2 changed files with 192 additions and 0 deletions

View File

@ -5,6 +5,7 @@ from typing import Callable, List, Optional
import numpy as np import numpy as np
import pytest import pytest
import soundfile as sf import soundfile as sf
from scipy import signal
from soundevent import data, terms from soundevent import data, terms
from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data import DatasetConfig, load_dataset
@ -127,6 +128,43 @@ def create_recording(wav_factory: Callable[..., Path]):
return factory return factory
@pytest.fixture
def generate_whistle(tmp_path: Path):
"""
Pytest fixture that provides a factory for generating WAV audio files.
The factory creates a recording containing a "whistle" (a short,
frequency-specific pulse) positioned at a precise time, suitable for
testing audio analysis functions.
"""
def factory(
time: float,
frequency: int,
path: Optional[Path] = None,
duration: float = 0.3,
samplerate: int = 441_000,
whistle_duration: float = 0.1,
) -> Path:
path = path or tmp_path / f"{uuid.uuid4()}.wav"
frames = int(samplerate * duration)
offset = int((time - duration / 2) * samplerate)
t = np.linspace(-duration / 2, duration / 2, frames, endpoint=False)
data = signal.gausspulse(
t,
fc=frequency,
bw=2 / (frequency * whistle_duration),
)
wave = (np.roll(data, offset) * np.iinfo(np.int16).max).astype(
np.int16
)
sf.write(str(path), wave, samplerate, subtype="PCM_16")
return path
return factory
@pytest.fixture @pytest.fixture
def recording( def recording(
create_recording: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],

View File

@ -1,7 +1,9 @@
import numpy as np import numpy as np
import pytest import pytest
import soundfile as sf
from soundevent import data from soundevent import data
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets.rois import ( from batdetect2.targets.rois import (
DEFAULT_ANCHOR, DEFAULT_ANCHOR,
DEFAULT_FREQUENCY_SCALE, DEFAULT_FREQUENCY_SCALE,
@ -12,6 +14,7 @@ from batdetect2.targets.rois import (
BBoxAnchorMapperConfig, BBoxAnchorMapperConfig,
_build_bounding_box, _build_bounding_box,
build_roi_mapper, build_roi_mapper,
get_peak_energy_coordinates,
) )
@ -247,3 +250,154 @@ def test_build_roi_mapper():
assert mapper.anchor == config.anchor assert mapper.anchor == config.anchor
assert mapper.time_scale == config.time_scale assert mapper.time_scale == config.time_scale
assert mapper.frequency_scale == config.frequency_scale assert mapper.frequency_scale == config.frequency_scale
def test_get_peak_energy_coordinates(generate_whistle):
whistle_time = 0.5
whistle_frequency = 40_000
duration = 1.0
samplerate = 256_000
# Generate a WAV file with a whistle
whistle_path = generate_whistle(
time=whistle_time,
frequency=whistle_frequency,
duration=duration,
samplerate=samplerate,
whistle_duration=0.01,
)
# Create a recording object from the generated WAV
recording = data.Recording.from_file(path=whistle_path)
# Build a preprocessor (default config should be fine for this test)
preprocessor = build_preprocessor()
# Define a region of interest that contains the whistle
start_time = 0.2
end_time = 0.7
low_freq = 20_000
high_freq = 60_000
# Get the peak energy coordinates
peak_time, peak_freq = get_peak_energy_coordinates(
recording=recording,
preprocessor=preprocessor,
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
loading_buffer=0.05,
)
# Assert that the peak coordinates are close to the expected values
assert peak_time == pytest.approx(whistle_time, abs=0.01)
assert peak_freq == pytest.approx(whistle_frequency, abs=1000)
def test_get_peak_energy_coordinates_with_two_whistles(generate_whistle):
# Parameters for the first (stronger) whistle
strong_whistle_time = 0.2
strong_whistle_frequency = 30_000
strong_whistle_amplitude = 1.0 # Full amplitude
# Parameters for the second (weaker) whistle
weak_whistle_time = 0.8
weak_whistle_frequency = 50_000
weak_whistle_amplitude = 0.1 # Weaker amplitude
# Recording parameters
duration = 1.0
samplerate = 256_000
# Generate WAV files for each whistle
strong_whistle_path = generate_whistle(
time=strong_whistle_time,
frequency=strong_whistle_frequency,
duration=duration,
samplerate=samplerate,
whistle_duration=0.01,
)
weak_whistle_path = generate_whistle(
time=weak_whistle_time,
frequency=weak_whistle_frequency,
duration=duration,
samplerate=samplerate,
whistle_duration=0.01,
)
# Load audio data
strong_audio, _ = sf.read(strong_whistle_path)
weak_audio, _ = sf.read(weak_whistle_path)
# Mix the audio files
mixed_audio = (
strong_audio * strong_whistle_amplitude
+ weak_audio * weak_whistle_amplitude
)
mixed_audio_path = strong_whistle_path.parent / "mixed_whistles.wav"
sf.write(str(mixed_audio_path), mixed_audio, samplerate)
# Create a recording object from the mixed WAV
recording = data.Recording.from_file(path=mixed_audio_path)
# Build a preprocessor
preprocessor = build_preprocessor()
# Define a region of interest that contains only the weaker whistle
start_time = 0.7
end_time = 0.9
low_freq = 45_000
high_freq = 55_000
# Get the peak energy coordinates within the bounding box
peak_time, peak_freq = get_peak_energy_coordinates(
recording=recording,
preprocessor=preprocessor,
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
loading_buffer=0.05,
)
# Assert that the peak coordinates are close to the weaker whistle's values
assert peak_time == pytest.approx(weak_whistle_time, abs=0.01)
assert peak_freq == pytest.approx(weak_whistle_frequency, abs=1000)
def test_get_peak_energy_coordinates_silent_region(create_recording):
# Parameters for a silent recording
duration = 2.0 # seconds
samplerate = 44_100 # Hz
# Create a silent recording
recording = create_recording(duration=duration, samplerate=samplerate)
# Build a preprocessor
preprocessor = build_preprocessor()
# Define a region of interest within the silent recording
start_time = 0.5
end_time = 1.5
low_freq = 10_000
high_freq = 20_000
# Get the peak energy coordinates from the silent region
peak_time, peak_freq = get_peak_energy_coordinates(
recording=recording,
preprocessor=preprocessor,
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
loading_buffer=0.05,
)
# Assert that the peak coordinates are within the defined ROI bounds
assert start_time <= peak_time <= end_time
assert low_freq <= peak_freq <= high_freq
# Since there's no actual peak, the exact values might vary depending on
# argmax behavior with all-zero or very low, uniform energy. We just need
# to ensure they are within the search bounds.