mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Add tests for peak energy function
This commit is contained in:
parent
3103630c26
commit
ad0f0bcb24
@ -5,6 +5,7 @@ from typing import Callable, List, Optional
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
from scipy import signal
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.data import DatasetConfig, load_dataset
|
||||
@ -127,6 +128,43 @@ def create_recording(wav_factory: Callable[..., Path]):
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generate_whistle(tmp_path: Path):
|
||||
"""
|
||||
Pytest fixture that provides a factory for generating WAV audio files.
|
||||
|
||||
The factory creates a recording containing a "whistle" (a short,
|
||||
frequency-specific pulse) positioned at a precise time, suitable for
|
||||
testing audio analysis functions.
|
||||
"""
|
||||
|
||||
def factory(
|
||||
time: float,
|
||||
frequency: int,
|
||||
path: Optional[Path] = None,
|
||||
duration: float = 0.3,
|
||||
samplerate: int = 441_000,
|
||||
whistle_duration: float = 0.1,
|
||||
) -> Path:
|
||||
path = path or tmp_path / f"{uuid.uuid4()}.wav"
|
||||
frames = int(samplerate * duration)
|
||||
|
||||
offset = int((time - duration / 2) * samplerate)
|
||||
t = np.linspace(-duration / 2, duration / 2, frames, endpoint=False)
|
||||
data = signal.gausspulse(
|
||||
t,
|
||||
fc=frequency,
|
||||
bw=2 / (frequency * whistle_duration),
|
||||
)
|
||||
wave = (np.roll(data, offset) * np.iinfo(np.int16).max).astype(
|
||||
np.int16
|
||||
)
|
||||
sf.write(str(path), wave, samplerate, subtype="PCM_16")
|
||||
return path
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recording(
|
||||
create_recording: Callable[..., data.Recording],
|
||||
|
@ -1,7 +1,9 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets.rois import (
|
||||
DEFAULT_ANCHOR,
|
||||
DEFAULT_FREQUENCY_SCALE,
|
||||
@ -12,6 +14,7 @@ from batdetect2.targets.rois import (
|
||||
BBoxAnchorMapperConfig,
|
||||
_build_bounding_box,
|
||||
build_roi_mapper,
|
||||
get_peak_energy_coordinates,
|
||||
)
|
||||
|
||||
|
||||
@ -247,3 +250,154 @@ def test_build_roi_mapper():
|
||||
assert mapper.anchor == config.anchor
|
||||
assert mapper.time_scale == config.time_scale
|
||||
assert mapper.frequency_scale == config.frequency_scale
|
||||
|
||||
|
||||
def test_get_peak_energy_coordinates(generate_whistle):
|
||||
whistle_time = 0.5
|
||||
whistle_frequency = 40_000
|
||||
duration = 1.0
|
||||
samplerate = 256_000
|
||||
|
||||
# Generate a WAV file with a whistle
|
||||
whistle_path = generate_whistle(
|
||||
time=whistle_time,
|
||||
frequency=whistle_frequency,
|
||||
duration=duration,
|
||||
samplerate=samplerate,
|
||||
whistle_duration=0.01,
|
||||
)
|
||||
|
||||
# Create a recording object from the generated WAV
|
||||
recording = data.Recording.from_file(path=whistle_path)
|
||||
|
||||
# Build a preprocessor (default config should be fine for this test)
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
# Define a region of interest that contains the whistle
|
||||
start_time = 0.2
|
||||
end_time = 0.7
|
||||
low_freq = 20_000
|
||||
high_freq = 60_000
|
||||
|
||||
# Get the peak energy coordinates
|
||||
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||
recording=recording,
|
||||
preprocessor=preprocessor,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
loading_buffer=0.05,
|
||||
)
|
||||
|
||||
# Assert that the peak coordinates are close to the expected values
|
||||
assert peak_time == pytest.approx(whistle_time, abs=0.01)
|
||||
assert peak_freq == pytest.approx(whistle_frequency, abs=1000)
|
||||
|
||||
|
||||
def test_get_peak_energy_coordinates_with_two_whistles(generate_whistle):
|
||||
# Parameters for the first (stronger) whistle
|
||||
strong_whistle_time = 0.2
|
||||
strong_whistle_frequency = 30_000
|
||||
strong_whistle_amplitude = 1.0 # Full amplitude
|
||||
|
||||
# Parameters for the second (weaker) whistle
|
||||
weak_whistle_time = 0.8
|
||||
weak_whistle_frequency = 50_000
|
||||
weak_whistle_amplitude = 0.1 # Weaker amplitude
|
||||
|
||||
# Recording parameters
|
||||
duration = 1.0
|
||||
samplerate = 256_000
|
||||
|
||||
# Generate WAV files for each whistle
|
||||
strong_whistle_path = generate_whistle(
|
||||
time=strong_whistle_time,
|
||||
frequency=strong_whistle_frequency,
|
||||
duration=duration,
|
||||
samplerate=samplerate,
|
||||
whistle_duration=0.01,
|
||||
)
|
||||
weak_whistle_path = generate_whistle(
|
||||
time=weak_whistle_time,
|
||||
frequency=weak_whistle_frequency,
|
||||
duration=duration,
|
||||
samplerate=samplerate,
|
||||
whistle_duration=0.01,
|
||||
)
|
||||
|
||||
# Load audio data
|
||||
strong_audio, _ = sf.read(strong_whistle_path)
|
||||
weak_audio, _ = sf.read(weak_whistle_path)
|
||||
|
||||
# Mix the audio files
|
||||
mixed_audio = (
|
||||
strong_audio * strong_whistle_amplitude
|
||||
+ weak_audio * weak_whistle_amplitude
|
||||
)
|
||||
mixed_audio_path = strong_whistle_path.parent / "mixed_whistles.wav"
|
||||
sf.write(str(mixed_audio_path), mixed_audio, samplerate)
|
||||
|
||||
# Create a recording object from the mixed WAV
|
||||
recording = data.Recording.from_file(path=mixed_audio_path)
|
||||
|
||||
# Build a preprocessor
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
# Define a region of interest that contains only the weaker whistle
|
||||
start_time = 0.7
|
||||
end_time = 0.9
|
||||
low_freq = 45_000
|
||||
high_freq = 55_000
|
||||
|
||||
# Get the peak energy coordinates within the bounding box
|
||||
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||
recording=recording,
|
||||
preprocessor=preprocessor,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
loading_buffer=0.05,
|
||||
)
|
||||
|
||||
# Assert that the peak coordinates are close to the weaker whistle's values
|
||||
assert peak_time == pytest.approx(weak_whistle_time, abs=0.01)
|
||||
assert peak_freq == pytest.approx(weak_whistle_frequency, abs=1000)
|
||||
|
||||
|
||||
def test_get_peak_energy_coordinates_silent_region(create_recording):
|
||||
# Parameters for a silent recording
|
||||
duration = 2.0 # seconds
|
||||
samplerate = 44_100 # Hz
|
||||
|
||||
# Create a silent recording
|
||||
recording = create_recording(duration=duration, samplerate=samplerate)
|
||||
|
||||
# Build a preprocessor
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
# Define a region of interest within the silent recording
|
||||
start_time = 0.5
|
||||
end_time = 1.5
|
||||
low_freq = 10_000
|
||||
high_freq = 20_000
|
||||
|
||||
# Get the peak energy coordinates from the silent region
|
||||
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||
recording=recording,
|
||||
preprocessor=preprocessor,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
loading_buffer=0.05,
|
||||
)
|
||||
|
||||
# Assert that the peak coordinates are within the defined ROI bounds
|
||||
assert start_time <= peak_time <= end_time
|
||||
assert low_freq <= peak_freq <= high_freq
|
||||
|
||||
# Since there's no actual peak, the exact values might vary depending on
|
||||
# argmax behavior with all-zero or very low, uniform energy. We just need
|
||||
# to ensure they are within the search bounds.
|
||||
|
Loading…
Reference in New Issue
Block a user