Add other roi tests

This commit is contained in:
mbsantiago 2025-06-21 23:51:07 +01:00
parent 0a0d6f7162
commit 3407e1b5f0
2 changed files with 233 additions and 3 deletions

View File

@ -324,7 +324,7 @@ class PeakEnergyBBoxMapperConfig(BaseConfig):
Scaling factor applied to the frequency dimensions. Scaling factor applied to the frequency dimensions.
""" """
name: Literal["peak_energy_bbox"] name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
preprocessing: PreprocessingConfig = Field( preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig default_factory=PreprocessingConfig
) )
@ -515,7 +515,7 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
) )
raise NotImplementedError( raise NotImplementedError(
f"No ROI mapper of name {config.name} is implemented" f"No ROI mapper of name '{config.name}' is implemented"
) )

View File

@ -3,7 +3,7 @@ import pytest
import soundfile as sf import soundfile as sf
from soundevent import data from soundevent import data
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets.rois import ( from batdetect2.targets.rois import (
DEFAULT_ANCHOR, DEFAULT_ANCHOR,
DEFAULT_FREQUENCY_SCALE, DEFAULT_FREQUENCY_SCALE,
@ -12,6 +12,8 @@ from batdetect2.targets.rois import (
SIZE_WIDTH, SIZE_WIDTH,
AnchorBBoxMapper, AnchorBBoxMapper,
BBoxAnchorMapperConfig, BBoxAnchorMapperConfig,
PeakEnergyBBoxMapper,
PeakEnergyBBoxMapperConfig,
_build_bounding_box, _build_bounding_box,
build_roi_mapper, build_roi_mapper,
get_peak_energy_coordinates, get_peak_energy_coordinates,
@ -401,3 +403,231 @@ def test_get_peak_energy_coordinates_silent_region(create_recording):
# Since there's no actual peak, the exact values might vary depending on # Since there's no actual peak, the exact values might vary depending on
# argmax behavior with all-zero or very low, uniform energy. We just need # argmax behavior with all-zero or very low, uniform energy. We just need
# to ensure they are within the search bounds. # to ensure they are within the search bounds.
def test_peak_energy_bbox_mapper_encode(generate_whistle):
"""
Tests the 'happy path' for PeakEnergyBBoxMapper.encode.
It verifies that the method correctly identifies a known peak within a
bounding box and calculates the four scaled distances to the box edges.
"""
# 1. SETUP
samplerate = 256_000
time_scale = 100.0
freq_scale = 0.1
bbox_start_time, bbox_low_freq = 1.0, 10000
bbox_end_time, bbox_high_freq = 2.0, 30000
bbox = data.BoundingBox(
coordinates=[
bbox_start_time,
bbox_low_freq,
bbox_end_time,
bbox_high_freq,
]
)
# Define the known location of the peak energy inside the bbox
peak_time, peak_freq = 1.6, 25000
# Create a recording with a whistle at the defined peak location
recording_path = generate_whistle(
time=peak_time,
frequency=peak_freq,
duration=3.0,
samplerate=samplerate,
)
recording = data.Recording.from_file(path=recording_path)
sound_event = data.SoundEvent(geometry=bbox, recording=recording)
# Instantiate the mapper with a preprocessor
preprocessor = build_preprocessor(
PreprocessingConfig.model_validate(
{
"spectrogram": {
"pcen": None,
"spectral_mean_substraction": False,
}
}
)
)
mapper = PeakEnergyBBoxMapper(
preprocessor=preprocessor,
time_scale=time_scale,
frequency_scale=freq_scale,
)
# Encode the sound event to get the position and size
actual_pos, actual_size = mapper.encode(sound_event)
# Then
assert actual_pos[0] == pytest.approx(peak_time, abs=0.01)
assert actual_pos[1] == pytest.approx(peak_freq, abs=1000)
# Assert that the calculated scaled distances are correct
identified_time, identified_freq = actual_pos
expected_left = (identified_time - bbox_start_time) * time_scale
expected_bottom = (identified_freq - bbox_low_freq) * freq_scale
expected_right = (bbox_end_time - identified_time) * time_scale
expected_top = (bbox_high_freq - identified_freq) * freq_scale
expected_size = np.array(
[expected_left, expected_bottom, expected_right, expected_top]
)
assert actual_size.shape == (4,)
np.testing.assert_allclose(actual_size, expected_size, rtol=1e-5)
def test_peak_energy_bbox_mapper_decode():
"""
Tests that PeakEnergyBBoxMapper.decode correctly reconstructs a BoundingBox.
"""
# Given
time_scale = 100.0
freq_scale = 0.1
# Define a known peak position and scaled distances.
peak_position = (1.5, 15000)
scaled_size = np.array([50.0, 500.0, 50.0, 500.0])
mapper = PeakEnergyBBoxMapper(
preprocessor=build_preprocessor(),
time_scale=time_scale,
frequency_scale=freq_scale,
)
# When
reconstructed_bbox = mapper.decode(peak_position, scaled_size)
# Then
# Calculate the expected coordinates based on the decode logic.
expected_start_time = peak_position[0] - scaled_size[0] / time_scale
expected_low_freq = peak_position[1] - scaled_size[1] / freq_scale
expected_end_time = peak_position[0] + scaled_size[2] / time_scale
expected_high_freq = peak_position[1] + scaled_size[3] / freq_scale
expected_coordinates = [
expected_start_time,
expected_low_freq,
expected_end_time,
expected_high_freq,
]
assert isinstance(reconstructed_bbox, data.BoundingBox)
np.testing.assert_allclose(
reconstructed_bbox.coordinates, expected_coordinates
)
def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
"""
Tests that encoding and then decoding a SoundEvent with the
PeakEnergyBBoxMapper results in the original BoundingBox.
"""
# Given
samplerate = 256_000
# Define the original geometry and the peak location within it.
original_bbox = data.BoundingBox(coordinates=[1.0, 10000, 2.0, 30000])
peak_time, peak_freq = 1.6, 25000
# Create the recording and sound event.
recording_path = generate_whistle(
time=peak_time,
frequency=peak_freq,
duration=3.0,
samplerate=samplerate,
)
recording = data.Recording.from_file(path=recording_path)
sound_event = data.SoundEvent(geometry=original_bbox, recording=recording)
# Instantiate the mapper.
preprocessor = build_preprocessor(
PreprocessingConfig.model_validate(
{
"spectrogram": {
"pcen": None,
"spectral_mean_substraction": False,
}
}
)
)
mapper = PeakEnergyBBoxMapper(preprocessor=preprocessor)
# When
# Encode the sound event, then immediately decode the result.
position, size = mapper.encode(sound_event)
reconstructed_bbox = mapper.decode(position, size)
# Then
# Verify the reconstructed bounding box is identical to the original.
np.testing.assert_allclose(
reconstructed_bbox.coordinates,
original_bbox.coordinates,
rtol=1e-5,
)
def test_build_roi_mapper_for_anchor_bbox():
# Given
config = BBoxAnchorMapperConfig(
anchor="center",
time_scale=123.0,
frequency_scale=456.0,
)
# When
mapper = build_roi_mapper(config)
# Then
assert isinstance(mapper, AnchorBBoxMapper)
assert mapper.anchor == "center"
assert mapper.time_scale == 123.0
assert mapper.frequency_scale == 456.0
def test_build_roi_mapper_for_peak_energy_bbox():
# Given
preproc_config = PreprocessingConfig.model_validate(
{
"spectrogram": {
"pcen": None,
"spectral_mean_substraction": True,
"scale": "dB",
}
}
)
config = PeakEnergyBBoxMapperConfig(
loading_buffer=0.99,
time_scale=789.0,
frequency_scale=123.0,
preprocessing=preproc_config,
)
# When
mapper = build_roi_mapper(config)
# Then
assert isinstance(mapper, PeakEnergyBBoxMapper)
assert mapper.loading_buffer == 0.99
assert mapper.time_scale == 789.0
assert mapper.frequency_scale == 123.0
def test_build_roi_mapper_raises_error_for_unknown_name():
"""
Tests that the factory raises a NotImplementedError when given a
config with an unrecognized mapper name.
"""
# Given
class DummyConfig:
name = "non_existent_mapper"
# Then
with pytest.raises(NotImplementedError) as excinfo:
build_roi_mapper(DummyConfig()) # type: ignore
# Check that the error message is informative.
assert "No ROI mapper of name 'non_existent_mapper'" in str(excinfo.value)