mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Add tests for peak energy function
This commit is contained in:
parent
3103630c26
commit
ad0f0bcb24
@ -5,6 +5,7 @@ from typing import Callable, List, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
from scipy import signal
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.data import DatasetConfig, load_dataset
|
from batdetect2.data import DatasetConfig, load_dataset
|
||||||
@ -127,6 +128,43 @@ def create_recording(wav_factory: Callable[..., Path]):
|
|||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def generate_whistle(tmp_path: Path):
|
||||||
|
"""
|
||||||
|
Pytest fixture that provides a factory for generating WAV audio files.
|
||||||
|
|
||||||
|
The factory creates a recording containing a "whistle" (a short,
|
||||||
|
frequency-specific pulse) positioned at a precise time, suitable for
|
||||||
|
testing audio analysis functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def factory(
|
||||||
|
time: float,
|
||||||
|
frequency: int,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
duration: float = 0.3,
|
||||||
|
samplerate: int = 441_000,
|
||||||
|
whistle_duration: float = 0.1,
|
||||||
|
) -> Path:
|
||||||
|
path = path or tmp_path / f"{uuid.uuid4()}.wav"
|
||||||
|
frames = int(samplerate * duration)
|
||||||
|
|
||||||
|
offset = int((time - duration / 2) * samplerate)
|
||||||
|
t = np.linspace(-duration / 2, duration / 2, frames, endpoint=False)
|
||||||
|
data = signal.gausspulse(
|
||||||
|
t,
|
||||||
|
fc=frequency,
|
||||||
|
bw=2 / (frequency * whistle_duration),
|
||||||
|
)
|
||||||
|
wave = (np.roll(data, offset) * np.iinfo(np.int16).max).astype(
|
||||||
|
np.int16
|
||||||
|
)
|
||||||
|
sf.write(str(path), wave, samplerate, subtype="PCM_16")
|
||||||
|
return path
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def recording(
|
def recording(
|
||||||
create_recording: Callable[..., data.Recording],
|
create_recording: Callable[..., data.Recording],
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import soundfile as sf
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
DEFAULT_ANCHOR,
|
DEFAULT_ANCHOR,
|
||||||
DEFAULT_FREQUENCY_SCALE,
|
DEFAULT_FREQUENCY_SCALE,
|
||||||
@ -12,6 +14,7 @@ from batdetect2.targets.rois import (
|
|||||||
BBoxAnchorMapperConfig,
|
BBoxAnchorMapperConfig,
|
||||||
_build_bounding_box,
|
_build_bounding_box,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
|
get_peak_energy_coordinates,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -247,3 +250,154 @@ def test_build_roi_mapper():
|
|||||||
assert mapper.anchor == config.anchor
|
assert mapper.anchor == config.anchor
|
||||||
assert mapper.time_scale == config.time_scale
|
assert mapper.time_scale == config.time_scale
|
||||||
assert mapper.frequency_scale == config.frequency_scale
|
assert mapper.frequency_scale == config.frequency_scale
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_peak_energy_coordinates(generate_whistle):
|
||||||
|
whistle_time = 0.5
|
||||||
|
whistle_frequency = 40_000
|
||||||
|
duration = 1.0
|
||||||
|
samplerate = 256_000
|
||||||
|
|
||||||
|
# Generate a WAV file with a whistle
|
||||||
|
whistle_path = generate_whistle(
|
||||||
|
time=whistle_time,
|
||||||
|
frequency=whistle_frequency,
|
||||||
|
duration=duration,
|
||||||
|
samplerate=samplerate,
|
||||||
|
whistle_duration=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a recording object from the generated WAV
|
||||||
|
recording = data.Recording.from_file(path=whistle_path)
|
||||||
|
|
||||||
|
# Build a preprocessor (default config should be fine for this test)
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
# Define a region of interest that contains the whistle
|
||||||
|
start_time = 0.2
|
||||||
|
end_time = 0.7
|
||||||
|
low_freq = 20_000
|
||||||
|
high_freq = 60_000
|
||||||
|
|
||||||
|
# Get the peak energy coordinates
|
||||||
|
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||||
|
recording=recording,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
|
loading_buffer=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that the peak coordinates are close to the expected values
|
||||||
|
assert peak_time == pytest.approx(whistle_time, abs=0.01)
|
||||||
|
assert peak_freq == pytest.approx(whistle_frequency, abs=1000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_peak_energy_coordinates_with_two_whistles(generate_whistle):
|
||||||
|
# Parameters for the first (stronger) whistle
|
||||||
|
strong_whistle_time = 0.2
|
||||||
|
strong_whistle_frequency = 30_000
|
||||||
|
strong_whistle_amplitude = 1.0 # Full amplitude
|
||||||
|
|
||||||
|
# Parameters for the second (weaker) whistle
|
||||||
|
weak_whistle_time = 0.8
|
||||||
|
weak_whistle_frequency = 50_000
|
||||||
|
weak_whistle_amplitude = 0.1 # Weaker amplitude
|
||||||
|
|
||||||
|
# Recording parameters
|
||||||
|
duration = 1.0
|
||||||
|
samplerate = 256_000
|
||||||
|
|
||||||
|
# Generate WAV files for each whistle
|
||||||
|
strong_whistle_path = generate_whistle(
|
||||||
|
time=strong_whistle_time,
|
||||||
|
frequency=strong_whistle_frequency,
|
||||||
|
duration=duration,
|
||||||
|
samplerate=samplerate,
|
||||||
|
whistle_duration=0.01,
|
||||||
|
)
|
||||||
|
weak_whistle_path = generate_whistle(
|
||||||
|
time=weak_whistle_time,
|
||||||
|
frequency=weak_whistle_frequency,
|
||||||
|
duration=duration,
|
||||||
|
samplerate=samplerate,
|
||||||
|
whistle_duration=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load audio data
|
||||||
|
strong_audio, _ = sf.read(strong_whistle_path)
|
||||||
|
weak_audio, _ = sf.read(weak_whistle_path)
|
||||||
|
|
||||||
|
# Mix the audio files
|
||||||
|
mixed_audio = (
|
||||||
|
strong_audio * strong_whistle_amplitude
|
||||||
|
+ weak_audio * weak_whistle_amplitude
|
||||||
|
)
|
||||||
|
mixed_audio_path = strong_whistle_path.parent / "mixed_whistles.wav"
|
||||||
|
sf.write(str(mixed_audio_path), mixed_audio, samplerate)
|
||||||
|
|
||||||
|
# Create a recording object from the mixed WAV
|
||||||
|
recording = data.Recording.from_file(path=mixed_audio_path)
|
||||||
|
|
||||||
|
# Build a preprocessor
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
# Define a region of interest that contains only the weaker whistle
|
||||||
|
start_time = 0.7
|
||||||
|
end_time = 0.9
|
||||||
|
low_freq = 45_000
|
||||||
|
high_freq = 55_000
|
||||||
|
|
||||||
|
# Get the peak energy coordinates within the bounding box
|
||||||
|
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||||
|
recording=recording,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
|
loading_buffer=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that the peak coordinates are close to the weaker whistle's values
|
||||||
|
assert peak_time == pytest.approx(weak_whistle_time, abs=0.01)
|
||||||
|
assert peak_freq == pytest.approx(weak_whistle_frequency, abs=1000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_peak_energy_coordinates_silent_region(create_recording):
|
||||||
|
# Parameters for a silent recording
|
||||||
|
duration = 2.0 # seconds
|
||||||
|
samplerate = 44_100 # Hz
|
||||||
|
|
||||||
|
# Create a silent recording
|
||||||
|
recording = create_recording(duration=duration, samplerate=samplerate)
|
||||||
|
|
||||||
|
# Build a preprocessor
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
# Define a region of interest within the silent recording
|
||||||
|
start_time = 0.5
|
||||||
|
end_time = 1.5
|
||||||
|
low_freq = 10_000
|
||||||
|
high_freq = 20_000
|
||||||
|
|
||||||
|
# Get the peak energy coordinates from the silent region
|
||||||
|
peak_time, peak_freq = get_peak_energy_coordinates(
|
||||||
|
recording=recording,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
|
loading_buffer=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that the peak coordinates are within the defined ROI bounds
|
||||||
|
assert start_time <= peak_time <= end_time
|
||||||
|
assert low_freq <= peak_freq <= high_freq
|
||||||
|
|
||||||
|
# Since there's no actual peak, the exact values might vary depending on
|
||||||
|
# argmax behavior with all-zero or very low, uniform energy. We just need
|
||||||
|
# to ensure they are within the search bounds.
|
||||||
|
Loading…
Reference in New Issue
Block a user