diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index 8ecf886..c1787fc 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -324,7 +324,7 @@ class PeakEnergyBBoxMapperConfig(BaseConfig): Scaling factor applied to the frequency dimensions. """ - name: Literal["peak_energy_bbox"] + name: Literal["peak_energy_bbox"] = "peak_energy_bbox" preprocessing: PreprocessingConfig = Field( default_factory=PreprocessingConfig ) @@ -515,7 +515,7 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper: ) raise NotImplementedError( - f"No ROI mapper of name {config.name} is implemented" + f"No ROI mapper of name '{config.name}' is implemented" ) diff --git a/tests/test_targets/test_rois.py b/tests/test_targets/test_rois.py index 474c15a..c6afc24 100644 --- a/tests/test_targets/test_rois.py +++ b/tests/test_targets/test_rois.py @@ -3,7 +3,7 @@ import pytest import soundfile as sf from soundevent import data -from batdetect2.preprocess import build_preprocessor +from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.targets.rois import ( DEFAULT_ANCHOR, DEFAULT_FREQUENCY_SCALE, @@ -12,6 +12,8 @@ from batdetect2.targets.rois import ( SIZE_WIDTH, AnchorBBoxMapper, BBoxAnchorMapperConfig, + PeakEnergyBBoxMapper, + PeakEnergyBBoxMapperConfig, _build_bounding_box, build_roi_mapper, get_peak_energy_coordinates, @@ -401,3 +403,231 @@ def test_get_peak_energy_coordinates_silent_region(create_recording): # Since there's no actual peak, the exact values might vary depending on # argmax behavior with all-zero or very low, uniform energy. We just need # to ensure they are within the search bounds. + + +def test_peak_energy_bbox_mapper_encode(generate_whistle): + """ + Tests the 'happy path' for PeakEnergyBBoxMapper.encode. + + It verifies that the method correctly identifies a known peak within a + bounding box and calculates the four scaled distances to the box edges. + """ + # 1. SETUP + samplerate = 256_000 + time_scale = 100.0 + freq_scale = 0.1 + + bbox_start_time, bbox_low_freq = 1.0, 10000 + bbox_end_time, bbox_high_freq = 2.0, 30000 + bbox = data.BoundingBox( + coordinates=[ + bbox_start_time, + bbox_low_freq, + bbox_end_time, + bbox_high_freq, + ] + ) + + # Define the known location of the peak energy inside the bbox + peak_time, peak_freq = 1.6, 25000 + + # Create a recording with a whistle at the defined peak location + recording_path = generate_whistle( + time=peak_time, + frequency=peak_freq, + duration=3.0, + samplerate=samplerate, + ) + recording = data.Recording.from_file(path=recording_path) + sound_event = data.SoundEvent(geometry=bbox, recording=recording) + + # Instantiate the mapper with a preprocessor + preprocessor = build_preprocessor( + PreprocessingConfig.model_validate( + { + "spectrogram": { + "pcen": None, + "spectral_mean_substraction": False, + } + } + ) + ) + mapper = PeakEnergyBBoxMapper( + preprocessor=preprocessor, + time_scale=time_scale, + frequency_scale=freq_scale, + ) + + # Encode the sound event to get the position and size + actual_pos, actual_size = mapper.encode(sound_event) + + # Then + assert actual_pos[0] == pytest.approx(peak_time, abs=0.01) + assert actual_pos[1] == pytest.approx(peak_freq, abs=1000) + + # Assert that the calculated scaled distances are correct + identified_time, identified_freq = actual_pos + expected_left = (identified_time - bbox_start_time) * time_scale + expected_bottom = (identified_freq - bbox_low_freq) * freq_scale + expected_right = (bbox_end_time - identified_time) * time_scale + expected_top = (bbox_high_freq - identified_freq) * freq_scale + expected_size = np.array( + [expected_left, expected_bottom, expected_right, expected_top] + ) + + assert actual_size.shape == (4,) + np.testing.assert_allclose(actual_size, expected_size, rtol=1e-5) + + +def test_peak_energy_bbox_mapper_decode(): + """ + Tests that PeakEnergyBBoxMapper.decode correctly reconstructs a BoundingBox. + """ + # Given + time_scale = 100.0 + freq_scale = 0.1 + + # Define a known peak position and scaled distances. + peak_position = (1.5, 15000) + scaled_size = np.array([50.0, 500.0, 50.0, 500.0]) + + mapper = PeakEnergyBBoxMapper( + preprocessor=build_preprocessor(), + time_scale=time_scale, + frequency_scale=freq_scale, + ) + + # When + reconstructed_bbox = mapper.decode(peak_position, scaled_size) + + # Then + # Calculate the expected coordinates based on the decode logic. + expected_start_time = peak_position[0] - scaled_size[0] / time_scale + expected_low_freq = peak_position[1] - scaled_size[1] / freq_scale + expected_end_time = peak_position[0] + scaled_size[2] / time_scale + expected_high_freq = peak_position[1] + scaled_size[3] / freq_scale + + expected_coordinates = [ + expected_start_time, + expected_low_freq, + expected_end_time, + expected_high_freq, + ] + + assert isinstance(reconstructed_bbox, data.BoundingBox) + np.testing.assert_allclose( + reconstructed_bbox.coordinates, expected_coordinates + ) + + +def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle): + """ + Tests that encoding and then decoding a SoundEvent with the + PeakEnergyBBoxMapper results in the original BoundingBox. + """ + # Given + samplerate = 256_000 + + # Define the original geometry and the peak location within it. + original_bbox = data.BoundingBox(coordinates=[1.0, 10000, 2.0, 30000]) + peak_time, peak_freq = 1.6, 25000 + + # Create the recording and sound event. + recording_path = generate_whistle( + time=peak_time, + frequency=peak_freq, + duration=3.0, + samplerate=samplerate, + ) + recording = data.Recording.from_file(path=recording_path) + sound_event = data.SoundEvent(geometry=original_bbox, recording=recording) + + # Instantiate the mapper. + preprocessor = build_preprocessor( + PreprocessingConfig.model_validate( + { + "spectrogram": { + "pcen": None, + "spectral_mean_substraction": False, + } + } + ) + ) + mapper = PeakEnergyBBoxMapper(preprocessor=preprocessor) + + # When + # Encode the sound event, then immediately decode the result. + position, size = mapper.encode(sound_event) + reconstructed_bbox = mapper.decode(position, size) + + # Then + # Verify the reconstructed bounding box is identical to the original. + np.testing.assert_allclose( + reconstructed_bbox.coordinates, + original_bbox.coordinates, + rtol=1e-5, + ) + + +def test_build_roi_mapper_for_anchor_bbox(): + # Given + config = BBoxAnchorMapperConfig( + anchor="center", + time_scale=123.0, + frequency_scale=456.0, + ) + + # When + mapper = build_roi_mapper(config) + + # Then + assert isinstance(mapper, AnchorBBoxMapper) + assert mapper.anchor == "center" + assert mapper.time_scale == 123.0 + assert mapper.frequency_scale == 456.0 + + +def test_build_roi_mapper_for_peak_energy_bbox(): + # Given + preproc_config = PreprocessingConfig.model_validate( + { + "spectrogram": { + "pcen": None, + "spectral_mean_substraction": True, + "scale": "dB", + } + } + ) + config = PeakEnergyBBoxMapperConfig( + loading_buffer=0.99, + time_scale=789.0, + frequency_scale=123.0, + preprocessing=preproc_config, + ) + + # When + mapper = build_roi_mapper(config) + + # Then + assert isinstance(mapper, PeakEnergyBBoxMapper) + assert mapper.loading_buffer == 0.99 + assert mapper.time_scale == 789.0 + assert mapper.frequency_scale == 123.0 + + +def test_build_roi_mapper_raises_error_for_unknown_name(): + """ + Tests that the factory raises a NotImplementedError when given a + config with an unrecognized mapper name. + """ + + # Given + class DummyConfig: + name = "non_existent_mapper" + + # Then + with pytest.raises(NotImplementedError) as excinfo: + build_roi_mapper(DummyConfig()) # type: ignore + + # Check that the error message is informative. + assert "No ROI mapper of name 'non_existent_mapper'" in str(excinfo.value)