mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Add other roi tests
This commit is contained in:
parent
0a0d6f7162
commit
3407e1b5f0
@ -324,7 +324,7 @@ class PeakEnergyBBoxMapperConfig(BaseConfig):
|
||||
Scaling factor applied to the frequency dimensions.
|
||||
"""
|
||||
|
||||
name: Literal["peak_energy_bbox"]
|
||||
name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
@ -515,7 +515,7 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"No ROI mapper of name {config.name} is implemented"
|
||||
f"No ROI mapper of name '{config.name}' is implemented"
|
||||
)
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@ import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.targets.rois import (
|
||||
DEFAULT_ANCHOR,
|
||||
DEFAULT_FREQUENCY_SCALE,
|
||||
@ -12,6 +12,8 @@ from batdetect2.targets.rois import (
|
||||
SIZE_WIDTH,
|
||||
AnchorBBoxMapper,
|
||||
BBoxAnchorMapperConfig,
|
||||
PeakEnergyBBoxMapper,
|
||||
PeakEnergyBBoxMapperConfig,
|
||||
_build_bounding_box,
|
||||
build_roi_mapper,
|
||||
get_peak_energy_coordinates,
|
||||
@ -401,3 +403,231 @@ def test_get_peak_energy_coordinates_silent_region(create_recording):
|
||||
# Since there's no actual peak, the exact values might vary depending on
|
||||
# argmax behavior with all-zero or very low, uniform energy. We just need
|
||||
# to ensure they are within the search bounds.
|
||||
|
||||
|
||||
def test_peak_energy_bbox_mapper_encode(generate_whistle):
|
||||
"""
|
||||
Tests the 'happy path' for PeakEnergyBBoxMapper.encode.
|
||||
|
||||
It verifies that the method correctly identifies a known peak within a
|
||||
bounding box and calculates the four scaled distances to the box edges.
|
||||
"""
|
||||
# 1. SETUP
|
||||
samplerate = 256_000
|
||||
time_scale = 100.0
|
||||
freq_scale = 0.1
|
||||
|
||||
bbox_start_time, bbox_low_freq = 1.0, 10000
|
||||
bbox_end_time, bbox_high_freq = 2.0, 30000
|
||||
bbox = data.BoundingBox(
|
||||
coordinates=[
|
||||
bbox_start_time,
|
||||
bbox_low_freq,
|
||||
bbox_end_time,
|
||||
bbox_high_freq,
|
||||
]
|
||||
)
|
||||
|
||||
# Define the known location of the peak energy inside the bbox
|
||||
peak_time, peak_freq = 1.6, 25000
|
||||
|
||||
# Create a recording with a whistle at the defined peak location
|
||||
recording_path = generate_whistle(
|
||||
time=peak_time,
|
||||
frequency=peak_freq,
|
||||
duration=3.0,
|
||||
samplerate=samplerate,
|
||||
)
|
||||
recording = data.Recording.from_file(path=recording_path)
|
||||
sound_event = data.SoundEvent(geometry=bbox, recording=recording)
|
||||
|
||||
# Instantiate the mapper with a preprocessor
|
||||
preprocessor = build_preprocessor(
|
||||
PreprocessingConfig.model_validate(
|
||||
{
|
||||
"spectrogram": {
|
||||
"pcen": None,
|
||||
"spectral_mean_substraction": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
mapper = PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
time_scale=time_scale,
|
||||
frequency_scale=freq_scale,
|
||||
)
|
||||
|
||||
# Encode the sound event to get the position and size
|
||||
actual_pos, actual_size = mapper.encode(sound_event)
|
||||
|
||||
# Then
|
||||
assert actual_pos[0] == pytest.approx(peak_time, abs=0.01)
|
||||
assert actual_pos[1] == pytest.approx(peak_freq, abs=1000)
|
||||
|
||||
# Assert that the calculated scaled distances are correct
|
||||
identified_time, identified_freq = actual_pos
|
||||
expected_left = (identified_time - bbox_start_time) * time_scale
|
||||
expected_bottom = (identified_freq - bbox_low_freq) * freq_scale
|
||||
expected_right = (bbox_end_time - identified_time) * time_scale
|
||||
expected_top = (bbox_high_freq - identified_freq) * freq_scale
|
||||
expected_size = np.array(
|
||||
[expected_left, expected_bottom, expected_right, expected_top]
|
||||
)
|
||||
|
||||
assert actual_size.shape == (4,)
|
||||
np.testing.assert_allclose(actual_size, expected_size, rtol=1e-5)
|
||||
|
||||
|
||||
def test_peak_energy_bbox_mapper_decode():
|
||||
"""
|
||||
Tests that PeakEnergyBBoxMapper.decode correctly reconstructs a BoundingBox.
|
||||
"""
|
||||
# Given
|
||||
time_scale = 100.0
|
||||
freq_scale = 0.1
|
||||
|
||||
# Define a known peak position and scaled distances.
|
||||
peak_position = (1.5, 15000)
|
||||
scaled_size = np.array([50.0, 500.0, 50.0, 500.0])
|
||||
|
||||
mapper = PeakEnergyBBoxMapper(
|
||||
preprocessor=build_preprocessor(),
|
||||
time_scale=time_scale,
|
||||
frequency_scale=freq_scale,
|
||||
)
|
||||
|
||||
# When
|
||||
reconstructed_bbox = mapper.decode(peak_position, scaled_size)
|
||||
|
||||
# Then
|
||||
# Calculate the expected coordinates based on the decode logic.
|
||||
expected_start_time = peak_position[0] - scaled_size[0] / time_scale
|
||||
expected_low_freq = peak_position[1] - scaled_size[1] / freq_scale
|
||||
expected_end_time = peak_position[0] + scaled_size[2] / time_scale
|
||||
expected_high_freq = peak_position[1] + scaled_size[3] / freq_scale
|
||||
|
||||
expected_coordinates = [
|
||||
expected_start_time,
|
||||
expected_low_freq,
|
||||
expected_end_time,
|
||||
expected_high_freq,
|
||||
]
|
||||
|
||||
assert isinstance(reconstructed_bbox, data.BoundingBox)
|
||||
np.testing.assert_allclose(
|
||||
reconstructed_bbox.coordinates, expected_coordinates
|
||||
)
|
||||
|
||||
|
||||
def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
|
||||
"""
|
||||
Tests that encoding and then decoding a SoundEvent with the
|
||||
PeakEnergyBBoxMapper results in the original BoundingBox.
|
||||
"""
|
||||
# Given
|
||||
samplerate = 256_000
|
||||
|
||||
# Define the original geometry and the peak location within it.
|
||||
original_bbox = data.BoundingBox(coordinates=[1.0, 10000, 2.0, 30000])
|
||||
peak_time, peak_freq = 1.6, 25000
|
||||
|
||||
# Create the recording and sound event.
|
||||
recording_path = generate_whistle(
|
||||
time=peak_time,
|
||||
frequency=peak_freq,
|
||||
duration=3.0,
|
||||
samplerate=samplerate,
|
||||
)
|
||||
recording = data.Recording.from_file(path=recording_path)
|
||||
sound_event = data.SoundEvent(geometry=original_bbox, recording=recording)
|
||||
|
||||
# Instantiate the mapper.
|
||||
preprocessor = build_preprocessor(
|
||||
PreprocessingConfig.model_validate(
|
||||
{
|
||||
"spectrogram": {
|
||||
"pcen": None,
|
||||
"spectral_mean_substraction": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
mapper = PeakEnergyBBoxMapper(preprocessor=preprocessor)
|
||||
|
||||
# When
|
||||
# Encode the sound event, then immediately decode the result.
|
||||
position, size = mapper.encode(sound_event)
|
||||
reconstructed_bbox = mapper.decode(position, size)
|
||||
|
||||
# Then
|
||||
# Verify the reconstructed bounding box is identical to the original.
|
||||
np.testing.assert_allclose(
|
||||
reconstructed_bbox.coordinates,
|
||||
original_bbox.coordinates,
|
||||
rtol=1e-5,
|
||||
)
|
||||
|
||||
|
||||
def test_build_roi_mapper_for_anchor_bbox():
|
||||
# Given
|
||||
config = BBoxAnchorMapperConfig(
|
||||
anchor="center",
|
||||
time_scale=123.0,
|
||||
frequency_scale=456.0,
|
||||
)
|
||||
|
||||
# When
|
||||
mapper = build_roi_mapper(config)
|
||||
|
||||
# Then
|
||||
assert isinstance(mapper, AnchorBBoxMapper)
|
||||
assert mapper.anchor == "center"
|
||||
assert mapper.time_scale == 123.0
|
||||
assert mapper.frequency_scale == 456.0
|
||||
|
||||
|
||||
def test_build_roi_mapper_for_peak_energy_bbox():
|
||||
# Given
|
||||
preproc_config = PreprocessingConfig.model_validate(
|
||||
{
|
||||
"spectrogram": {
|
||||
"pcen": None,
|
||||
"spectral_mean_substraction": True,
|
||||
"scale": "dB",
|
||||
}
|
||||
}
|
||||
)
|
||||
config = PeakEnergyBBoxMapperConfig(
|
||||
loading_buffer=0.99,
|
||||
time_scale=789.0,
|
||||
frequency_scale=123.0,
|
||||
preprocessing=preproc_config,
|
||||
)
|
||||
|
||||
# When
|
||||
mapper = build_roi_mapper(config)
|
||||
|
||||
# Then
|
||||
assert isinstance(mapper, PeakEnergyBBoxMapper)
|
||||
assert mapper.loading_buffer == 0.99
|
||||
assert mapper.time_scale == 789.0
|
||||
assert mapper.frequency_scale == 123.0
|
||||
|
||||
|
||||
def test_build_roi_mapper_raises_error_for_unknown_name():
|
||||
"""
|
||||
Tests that the factory raises a NotImplementedError when given a
|
||||
config with an unrecognized mapper name.
|
||||
"""
|
||||
|
||||
# Given
|
||||
class DummyConfig:
|
||||
name = "non_existent_mapper"
|
||||
|
||||
# Then
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
build_roi_mapper(DummyConfig()) # type: ignore
|
||||
|
||||
# Check that the error message is informative.
|
||||
assert "No ROI mapper of name 'non_existent_mapper'" in str(excinfo.value)
|
||||
|
Loading…
Reference in New Issue
Block a user