From c559bcc682102cfbc0498df7c487e75a4928d965 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 21 Jun 2025 11:44:15 +0100 Subject: [PATCH] Changed ROIMapper protocol to only have encoder/decoder methods --- batdetect2/targets/__init__.py | 32 +-- batdetect2/targets/classes.py | 23 +++ batdetect2/targets/rois.py | 345 ++++++++++++++++++++------------ tests/test_targets/test_rois.py | 58 +++--- tests/test_train/test_labels.py | 4 +- 5 files changed, 278 insertions(+), 184 deletions(-) diff --git a/batdetect2/targets/__init__.py b/batdetect2/targets/__init__.py index 9a5da1e..7d1ab3e 100644 --- a/batdetect2/targets/__init__.py +++ b/batdetect2/targets/__init__.py @@ -50,7 +50,7 @@ from batdetect2.targets.filtering import ( load_filter_from_config, ) from batdetect2.targets.rois import ( - ROIConfig, + BBoxAnchorMapperConfig, ROITargetMapper, build_roi_mapper, ) @@ -88,7 +88,7 @@ __all__ = [ "FilterConfig", "FilterRule", "MapValueRule", - "ROIConfig", + "BBoxAnchorMapperConfig", "ROITargetMapper", "ReplaceRule", "SoundEventDecoder", @@ -156,12 +156,12 @@ class TargetConfig(BaseConfig): omitted, default ROI mapping settings are used. """ - filtering: Optional[FilterConfig] = None - transforms: Optional[TransformConfig] = None + filtering: FilterConfig = Field(default_factory=FilterConfig) + transforms: TransformConfig = Field(default_factory=TransformConfig) classes: ClassesConfig = Field( default_factory=lambda: DEFAULT_CLASSES_CONFIG ) - roi: Optional[ROIConfig] = None + roi: Optional[BBoxAnchorMapperConfig] = None def load_target_config( @@ -374,14 +374,7 @@ class Targets(TargetProtocol): ValueError If the annotation lacks geometry. """ - geom = sound_event.sound_event.geometry - - if geom is None: - raise ValueError( - "Sound event has no geometry, cannot get its position." - ) - - return self._roi_mapper.get_roi_position(geom) + return self._roi_mapper.encode_position(sound_event.sound_event) def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray: """Calculate the target size dimensions from the annotation's geometry. @@ -405,14 +398,7 @@ class Targets(TargetProtocol): ValueError If the annotation lacks geometry. """ - geom = sound_event.sound_event.geometry - - if geom is None: - raise ValueError( - "Sound event has no geometry, cannot get its size." - ) - - return self._roi_mapper.get_roi_size(geom) + return self._roi_mapper.encode_size(sound_event.sound_event) def recover_roi( self, @@ -438,7 +424,7 @@ class Targets(TargetProtocol): data.Geometry The reconstructed geometry (typically `BoundingBox`). """ - return self._roi_mapper.recover_roi(pos, dims) + return self._roi_mapper.decode(pos, dims) DEFAULT_CLASSES = [ @@ -606,7 +592,7 @@ def build_targets( if config.transforms else None ) - roi_mapper = build_roi_mapper(config.roi or ROIConfig()) + roi_mapper = build_roi_mapper(config.roi or BBoxAnchorMapperConfig()) class_names = get_class_names_from_config(config.classes) generic_class_tags = build_generic_class_tags( config.classes, diff --git a/batdetect2/targets/classes.py b/batdetect2/targets/classes.py index 54d0d81..0f223cb 100644 --- a/batdetect2/targets/classes.py +++ b/batdetect2/targets/classes.py @@ -27,8 +27,29 @@ __all__ = [ "build_generic_class_tags", "get_class_names_from_config", "DEFAULT_SPECIES_LIST", + "PositionMethod", + "CornerPosition", + "SizeMethod", + "BoundingBoxSize", ] +class PositionMethod(BaseConfig): + """Base class for defining how to select a position from a geometry.""" + method_type: str + +class CornerPosition(PositionMethod): + """Selects a position based on a corner or center of the bounding box.""" + method_type: Literal["corner"] = "corner" + corner: Literal["upper_left", "lower_left", "center"] = "lower_left" + +class SizeMethod(BaseConfig): + """Base class for defining how to select a size from a geometry.""" + method_type: str + +class BoundingBoxSize(SizeMethod): + """Uses the width and height of the bounding box as the size.""" + method_type: Literal["bounding_box"] = "bounding_box" + SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]] """Type alias for a sound event class encoder function. @@ -113,6 +134,8 @@ class TargetClass(BaseConfig): tags: List[TagInfo] = Field(min_length=1) match_type: Literal["all", "any"] = Field(default="all") output_tags: Optional[List[TagInfo]] = None + position_method: PositionMethod = Field(default_factory=lambda: CornerPosition(corner="lower_left")) + size_method: SizeMethod = Field(default_factory=BoundingBoxSize) def _get_default_classes() -> List[TargetClass]: diff --git a/batdetect2/targets/rois.py b/batdetect2/targets/rois.py index bd05397..484506a 100644 --- a/batdetect2/targets/rois.py +++ b/batdetect2/targets/rois.py @@ -20,14 +20,31 @@ scaling factors) is managed by the `ROIConfig`. This module separates the handled in `batdetect2.targets.classes`. """ -from typing import List, Literal, Optional, Protocol, Tuple +from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union import numpy as np +from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config +from batdetect2.preprocess import PreprocessingConfig, build_preprocessor +from batdetect2.preprocess.types import PreprocessorProtocol -Positions = Literal[ +__all__ = [ + "ROITargetMapper", + "BBoxAnchorMapperConfig", + "AnchorBBoxMapper", + "build_roi_mapper", + "load_roi_mapper", + "DEFAULT_ANCHOR", + "SIZE_WIDTH", + "SIZE_HEIGHT", + "SIZE_ORDER", + "DEFAULT_TIME_SCALE", + "DEFAULT_FREQUENCY_SCALE", +] + +Anchor = Literal[ "bottom-left", "bottom-right", "top-left", @@ -41,20 +58,6 @@ Positions = Literal[ "point_on_surface", ] -__all__ = [ - "ROITargetMapper", - "ROIConfig", - "BBoxEncoder", - "build_roi_mapper", - "load_roi_mapper", - "DEFAULT_POSITION", - "SIZE_WIDTH", - "SIZE_HEIGHT", - "SIZE_ORDER", - "DEFAULT_TIME_SCALE", - "DEFAULT_FREQUENCY_SCALE", -] - SIZE_WIDTH = "width" """Standard name for the width/time dimension component ('width').""" @@ -71,10 +74,15 @@ DEFAULT_FREQUENCY_SCALE = 1 / 859.375 """Default scaling factor for frequency bandwidth.""" -DEFAULT_POSITION = "bottom-left" +DEFAULT_ANCHOR = "bottom-left" """Default reference position within the geometry ('bottom-left' corner).""" +Position = tuple[float, float] + +Size = np.ndarray + + class ROITargetMapper(Protocol): """Protocol defining the interface for ROI-to-target mapping. @@ -93,21 +101,13 @@ class ROITargetMapper(Protocol): dimension_names: List[str] - def get_roi_position( - self, - geom: data.Geometry, - position: Optional[Positions] = None, - ) -> tuple[float, float]: + def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]: """Extract the reference position from a geometry. Parameters ---------- geom : soundevent.data.Geometry The input geometry (e.g., BoundingBox, Polygon). - position : Positions, optional - Overrides the default `position` configured for the mapper. - If provided, this position will be used instead of the mapper's - internal default. Returns ------- @@ -124,36 +124,7 @@ class ROITargetMapper(Protocol): """ ... - def get_roi_size(self, geom: data.Geometry) -> np.ndarray: - """Calculate the scaled target dimensions from a geometry. - - Computes the relevant size measures. - - Parameters - ---------- - geom : soundevent.data.Geometry - The input geometry. - - Returns - ------- - np.ndarray - A NumPy array containing the scaled dimensions corresponding to - `dimension_names`. For bounding boxes, typically contains - `[scaled_width, scaled_height]`. - - Raises - ------ - TypeError, ValueError - If the size cannot be computed for the given geometry type. - """ - ... - - def recover_roi( - self, - pos: tuple[float, float], - dims: np.ndarray, - position: Optional[Positions] = None, - ) -> data.Geometry: + def decode(self, position: Position, size: Size) -> data.Geometry: """Recover an approximate ROI from a position and target dimensions. Performs the inverse mapping: takes a reference position and the @@ -161,15 +132,11 @@ class ROITargetMapper(Protocol): Parameters ---------- - pos : Tuple[float, float] + position : Tuple[float, float] The reference position (time, frequency). - dims : np.ndarray + size : np.ndarray NumPy array containing the dimensions, matching the order specified by `dimension_names`. - position : Positions, optional - Overrides the default `position` configured for the mapper. - If provided, this position will be used instead of the mapper's - internal default when reconstructing the roi geometry. Returns ------- @@ -185,7 +152,7 @@ class ROITargetMapper(Protocol): ... -class ROIConfig(BaseConfig): +class BBoxAnchorMapperConfig(BaseConfig): """Configuration for mapping Regions of Interest (ROIs). Defines parameters controlling how geometric ROIs are converted into @@ -193,10 +160,9 @@ class ROIConfig(BaseConfig): Attributes ---------- - position : Positions, default="bottom-left" + anchor : Anchor, default="bottom-left" Specifies the reference point within the geometry (e.g., bounding box) to use as the target location (e.g., "center", "bottom-left"). - See `soundevent.geometry.operations.Positions`. time_scale : float, default=1000.0 Scaling factor applied to the time duration (width) of the ROI when calculating the target size representation. Must match model @@ -207,12 +173,13 @@ class ROIConfig(BaseConfig): expectations. """ - position: Positions = DEFAULT_POSITION + name: Literal["anchor_bbox"] = "anchor_bbox" + anchor: Anchor = DEFAULT_ANCHOR time_scale: float = DEFAULT_TIME_SCALE frequency_scale: float = DEFAULT_FREQUENCY_SCALE -class BBoxEncoder(ROITargetMapper): +class AnchorBBoxMapper(ROITargetMapper): """Concrete implementation of `ROITargetMapper` focused on Bounding Boxes. This class implements the ROI mapping protocol primarily for @@ -224,7 +191,7 @@ class BBoxEncoder(ROITargetMapper): ---------- dimension_names : List[str] Specifies the output dimension names as ['width', 'height']. - position : Positions + anchor : Anchor The configured reference point type (e.g., "center", "bottom-left"). time_scale : float The configured scaling factor for the time dimension (width). @@ -236,7 +203,7 @@ class BBoxEncoder(ROITargetMapper): def __init__( self, - position: Positions = DEFAULT_POSITION, + anchor: Anchor = DEFAULT_ANCHOR, time_scale: float = DEFAULT_TIME_SCALE, frequency_scale: float = DEFAULT_FREQUENCY_SCALE, ): @@ -244,22 +211,18 @@ class BBoxEncoder(ROITargetMapper): Parameters ---------- - position : Positions, default="bottom-left" + anchor : Anchor, default="bottom-left" Reference point type within the bounding box. time_scale : float, default=1000.0 Scaling factor for time duration (width). frequency_scale : float, default=1/859.375 Scaling factor for frequency bandwidth (height). """ - self.position: Positions = position + self.anchor: Anchor = anchor self.time_scale = time_scale self.frequency_scale = frequency_scale - def get_roi_position( - self, - geom: data.Geometry, - position: Optional[Positions] = None, - ) -> Tuple[float, float]: + def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]: """Extract the configured reference position from the geometry. Uses `soundevent.geometry.get_geometry_point`. @@ -268,9 +231,6 @@ class BBoxEncoder(ROITargetMapper): ---------- geom : soundevent.data.Geometry Input geometry (e.g., BoundingBox). - position : Positions, optional - Overrides the default `position` configured for the encoder. - If provided, this position will be used instead of `self.position`. Returns ------- @@ -279,42 +239,33 @@ class BBoxEncoder(ROITargetMapper): """ from soundevent import geometry - position = position or self.position - return geometry.get_geometry_point(geom, position=position) + geom = sound_event.geometry - def get_roi_size(self, geom: data.Geometry) -> np.ndarray: - """Calculate the scaled [width, height] from the geometry's bounds. + if geom is None: + raise ValueError( + "Cannot encode the geometry of a sound event without geometry." + f" Sound event: {sound_event}" + ) - Computes the bounding box, extracts duration and bandwidth, and applies - the configured `time_scale` and `frequency_scale`. - - Parameters - ---------- - geom : soundevent.data.Geometry - Input geometry. - - Returns - ------- - np.ndarray - A 1D NumPy array: `[scaled_width, scaled_height]`. - """ - from soundevent import geometry + position = geometry.get_geometry_point(geom, position=self.anchor) start_time, low_freq, end_time, high_freq = geometry.compute_bounds( geom ) - return np.array( + + size = np.array( [ (end_time - start_time) * self.time_scale, (high_freq - low_freq) * self.frequency_scale, ] ) - def recover_roi( + return position, size + + def decode( self, - pos: tuple[float, float], - dims: np.ndarray, - position: Optional[Positions] = None, + position: Position, + size: Size, ) -> data.Geometry: """Recover a BoundingBox from a position and scaled dimensions. @@ -329,10 +280,6 @@ class BBoxEncoder(ROITargetMapper): dims : np.ndarray NumPy array containing the *scaled* dimensions, expected order is [scaled_width, scaled_height]. - position : Positions, optional - Overrides the default `position` configured for the encoder. - If provided, this position will be used instead of `self.position` - when reconstructing the bounding box. Returns ------- @@ -344,28 +291,113 @@ class BBoxEncoder(ROITargetMapper): ValueError If `dims` does not have the expected shape (length 2). """ - position = position or self.position - if dims.ndim != 1 or dims.shape[0] != 2: + if size.ndim != 1 or size.shape[0] != 2: raise ValueError( "Dimension array does not have the expected shape. " - f"({dims.shape = }) != ([2])" + f"({size.shape = }) != ([2])" ) - width, height = dims + width, height = size return _build_bounding_box( - pos, + position, duration=float(width) / self.time_scale, bandwidth=float(height) / self.frequency_scale, - position=self.position, + anchor=self.anchor, ) -def build_roi_mapper(config: ROIConfig) -> ROITargetMapper: - """Factory function to create an ROITargetMapper from configuration. +class PeakEnergyBBoxMapperConfig(BaseConfig): + name: Literal["peak_energy_bbox"] + preprocessing: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + loading_buffer: float = 0.01 + time_scale: float = DEFAULT_TIME_SCALE + frequency_scale: float = DEFAULT_FREQUENCY_SCALE - Currently creates a `BBoxEncoder` instance based on the provided - `ROIConfig`. + +class PeakEnergyBBoxMapper(ROITargetMapper): + """ + Encodes the ROI using the location of the peak energy within the bounding box + as the 'position' and the distances from that point to the box edges as the 'size'. + """ + + dimension_names = ["left", "bottom", "right", "top"] + + def __init__( + self, + preprocessor: PreprocessorProtocol, + time_scale: float = DEFAULT_TIME_SCALE, + frequency_scale: float = DEFAULT_FREQUENCY_SCALE, + loading_buffer: float = 0.01, + ): + self.preprocessor = preprocessor + self.time_scale = time_scale + self.frequency_scale = frequency_scale + self.loading_buffer = loading_buffer + + def encode( + self, + sound_event: data.SoundEvent, + ) -> tuple[Position, Size]: + from soundevent import geometry + + geom = sound_event.geometry + + if geom is None: + raise ValueError( + "Cannot encode the geometry of a sound event without geometry." + f" Sound event: {sound_event}" + ) + + start_time, low_freq, end_time, high_freq = geometry.compute_bounds( + geom + ) + + time, freq = get_peak_energy_coordinates( + recording=sound_event.recording, + preprocessor=self.preprocessor, + start_time=start_time, + end_time=end_time, + low_freq=low_freq, + high_freq=high_freq, + loading_buffer=self.loading_buffer, + ) + + size = np.array( + [ + (time - start_time) * self.time_scale, + (freq - low_freq) * self.frequency_scale, + (end_time - time) * self.time_scale, + (high_freq - freq) * self.frequency_scale, + ] + ) + + return (time, freq), size + + def decode(self, position: Position, size: Size) -> data.Geometry: + time, freq = position + left, bottom, right, top = size + + return data.BoundingBox( + coordinates=[ + time - max(0, float(left)) / self.time_scale, + freq - max(0, float(bottom)) / self.frequency_scale, + time + max(0, float(right)) / self.time_scale, + freq + max(0, float(top)) / self.frequency_scale, + ] + ) + + +ROIMapperConfig = Annotated[ + Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig], + Field(discriminator="name"), +] + + +def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper: + """Factory function to create an ROITargetMapper from configuration. Parameters ---------- @@ -378,10 +410,24 @@ def build_roi_mapper(config: ROIConfig) -> ROITargetMapper: An initialized `BBoxEncoder` instance configured with the settings from `config`. """ - return BBoxEncoder( - position=config.position, - time_scale=config.time_scale, - frequency_scale=config.frequency_scale, + if config.name == "anchor_bbox": + return AnchorBBoxMapper( + anchor=config.anchor, + time_scale=config.time_scale, + frequency_scale=config.frequency_scale, + ) + + if config.name == "peak_energy_bbox": + preprocessor = build_preprocessor(config.preprocessing) + return PeakEnergyBBoxMapper( + preprocessor=preprocessor, + time_scale=config.time_scale, + frequency_scale=config.frequency_scale, + loading_buffer=config.loading_buffer, + ) + + raise NotImplementedError( + f"No ROI mapper of name {config.name} is implemented" ) @@ -414,11 +460,11 @@ def load_roi_mapper( If the configuration file cannot be found, parsed, validated, or if the specified `field` is invalid. """ - config = load_config(path=path, schema=ROIConfig, field=field) + config = load_config(path=path, schema=BBoxAnchorMapperConfig, field=field) return build_roi_mapper(config) -VALID_POSITIONS = [ +VALID_ANCHORS = [ "bottom-left", "bottom-right", "top-left", @@ -437,7 +483,7 @@ def _build_bounding_box( pos: tuple[float, float], duration: float, bandwidth: float, - position: Positions = DEFAULT_POSITION, + anchor: Anchor = DEFAULT_ANCHOR, ) -> data.BoundingBox: """Construct a BoundingBox from a reference point, size, and position type. @@ -455,7 +501,7 @@ def _build_bounding_box( bandwidth : float The required *unscaled* frequency bandwidth (height) of the bounding box. - position : Positions, default="bottom-left" + anchor : Anchor, default="bottom-left" Specifies which part of the bounding box the input `pos` corresponds to. Returns @@ -466,12 +512,12 @@ def _build_bounding_box( Raises ------ ValueError - If `position` is not a recognized value or format. + If `anchor` is not a recognized value or format. """ time, freq = map(float, pos) duration = max(0, duration) bandwidth = max(0, bandwidth) - if position in ["center", "centroid", "point_on_surface"]: + if anchor in ["center", "centroid", "point_on_surface"]: return data.BoundingBox( coordinates=[ max(time - duration / 2, 0), @@ -481,13 +527,12 @@ def _build_bounding_box( ] ) - if position not in VALID_POSITIONS: + if anchor not in VALID_ANCHORS: raise ValueError( - f"Invalid position: {position}. " - f"Valid options are: {VALID_POSITIONS}" + f"Invalid anchor: {anchor}. Valid options are: {VALID_ANCHORS}" ) - y, x = position.split("-") + y, x = anchor.split("-") start_time = { "left": time, @@ -509,3 +554,43 @@ def _build_bounding_box( max(0, low_freq + bandwidth), ] ) + + +def get_peak_energy_coordinates( + recording: data.Recording, + preprocessor: PreprocessorProtocol, + start_time: float = 0, + end_time: Optional[float] = None, + low_freq: float = 0, + high_freq: Optional[float] = None, + loading_buffer: float = 0.05, +) -> Position: + if end_time is None: + end_time = recording.duration + end_time = min(end_time, recording.duration) + + if high_freq is None: + high_freq = recording.samplerate / 2 + + clip_start = max(0, start_time - loading_buffer) + clip_end = min(recording.duration, end_time + loading_buffer) + + clip = data.Clip( + recording=recording, + start_time=clip_start, + end_time=clip_end, + ) + + spec = preprocessor.preprocess_clip(clip) + low_freq = max(low_freq, preprocessor.min_freq) + high_freq = min(high_freq, preprocessor.max_freq) + selection = spec.sel( + time=slice(start_time, end_time), + frequency=slice(low_freq, high_freq), + ) + + index = selection.argmax(dim=["time", "frequency"]) + point = selection.isel(index) # type: ignore + peak_time: float = point.time.item() + peak_freq: float = point.frequency.item() + return peak_time, peak_freq diff --git a/tests/test_targets/test_rois.py b/tests/test_targets/test_rois.py index 49858f1..79b5c45 100644 --- a/tests/test_targets/test_rois.py +++ b/tests/test_targets/test_rois.py @@ -4,12 +4,12 @@ from soundevent import data from batdetect2.targets.rois import ( DEFAULT_FREQUENCY_SCALE, - DEFAULT_POSITION, + DEFAULT_ANCHOR, DEFAULT_TIME_SCALE, SIZE_HEIGHT, SIZE_WIDTH, - BBoxEncoder, - ROIConfig, + AnchorBBoxMapper, + BBoxAnchorMapperConfig, _build_bounding_box, build_roi_mapper, load_roi_mapper, @@ -29,36 +29,36 @@ def zero_bbox() -> data.BoundingBox: @pytest.fixture -def default_encoder() -> BBoxEncoder: +def default_encoder() -> AnchorBBoxMapper: """A BBoxEncoder with default settings.""" - return BBoxEncoder() + return AnchorBBoxMapper() @pytest.fixture -def custom_encoder() -> BBoxEncoder: +def custom_encoder() -> AnchorBBoxMapper: """A BBoxEncoder with custom settings.""" - return BBoxEncoder(position="center", time_scale=1.0, frequency_scale=10.0) + return AnchorBBoxMapper(anchor="center", time_scale=1.0, frequency_scale=10.0) def test_roi_config_defaults(): """Test ROIConfig default values.""" - config = ROIConfig() - assert config.position == DEFAULT_POSITION + config = BBoxAnchorMapperConfig() + assert config.anchor == DEFAULT_ANCHOR assert config.time_scale == DEFAULT_TIME_SCALE assert config.frequency_scale == DEFAULT_FREQUENCY_SCALE def test_roi_config_custom(): """Test creating ROIConfig with custom values.""" - config = ROIConfig(position="center", time_scale=1.0, frequency_scale=10.0) - assert config.position == "center" + config = BBoxAnchorMapperConfig(anchor="center", time_scale=1.0, frequency_scale=10.0) + assert config.anchor == "center" assert config.time_scale == 1.0 assert config.frequency_scale == 10.0 def test_bbox_encoder_init_defaults(default_encoder): """Test BBoxEncoder initialization with default arguments.""" - assert default_encoder.position == DEFAULT_POSITION + assert default_encoder.position == DEFAULT_ANCHOR assert default_encoder.time_scale == DEFAULT_TIME_SCALE assert default_encoder.frequency_scale == DEFAULT_FREQUENCY_SCALE assert default_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT] @@ -92,15 +92,15 @@ def test_bbox_encoder_get_roi_position( sample_bbox, position_type, expected_pos ): """Test get_roi_position for various position types.""" - encoder = BBoxEncoder(position=position_type) - actual_pos = encoder.get_roi_position(sample_bbox) + encoder = AnchorBBoxMapper(anchor=position_type) + actual_pos = encoder.encode_position(sample_bbox) assert actual_pos == pytest.approx(expected_pos) def test_bbox_encoder_get_roi_position_zero_box(zero_bbox): """Test get_roi_position for a zero-sized box.""" - encoder = BBoxEncoder(position="center") - assert encoder.get_roi_position(zero_bbox) == pytest.approx((15.0, 150.0)) + encoder = AnchorBBoxMapper(anchor="center") + assert encoder.encode_position(zero_bbox) == pytest.approx((15.0, 150.0)) def test_bbox_encoder_get_roi_size_defaults(sample_bbox, default_encoder): @@ -160,7 +160,7 @@ def test_build_bounding_box(position_type, expected_coords): duration = 10.0 bandwidth = 100.0 bbox = _build_bounding_box( - ref_pos, duration, bandwidth, position=position_type + ref_pos, duration, bandwidth, anchor=position_type ) assert isinstance(bbox, data.BoundingBox) np.testing.assert_allclose(bbox.coordinates, expected_coords) @@ -173,17 +173,17 @@ def test_build_bounding_box_invalid_position(): (0, 0), 1, 1, - position="invalid-spot", # type: ignore + anchor="invalid-spot", # type: ignore ) @pytest.mark.parametrize("position_type, ref_pos", POSITION_TEST_CASES) def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos): """Test recover_roi correctly reconstructs the original bbox.""" - encoder = BBoxEncoder(position=position_type) - scaled_dims = encoder.get_roi_size(sample_bbox) + encoder = AnchorBBoxMapper(anchor=position_type) + scaled_dims = encoder.encode_size(sample_bbox) - recovered_bbox = encoder.recover_roi(ref_pos, scaled_dims) + recovered_bbox = encoder.decode(ref_pos, scaled_dims) assert isinstance(recovered_bbox, data.BoundingBox) np.testing.assert_allclose( @@ -227,13 +227,13 @@ def test_bbox_encoder_recover_roi_invalid_dims_shape(default_encoder): def test_build_roi_mapper(): """Test build_roi_mapper creates a configured BBoxEncoder.""" - config = ROIConfig( - position="top-right", time_scale=2.0, frequency_scale=20.0 + config = BBoxAnchorMapperConfig( + anchor="top-right", time_scale=2.0, frequency_scale=20.0 ) mapper = build_roi_mapper(config) - assert isinstance(mapper, BBoxEncoder) - assert mapper.position == config.position + assert isinstance(mapper, AnchorBBoxMapper) + assert mapper.anchor == config.anchor assert mapper.time_scale == config.time_scale assert mapper.frequency_scale == config.frequency_scale @@ -270,8 +270,8 @@ def test_load_roi_mapper_simple(tmp_path, sample_config_yaml_content): mapper = load_roi_mapper(config_path) - assert isinstance(mapper, BBoxEncoder) - assert mapper.position == "center" + assert isinstance(mapper, AnchorBBoxMapper) + assert mapper.anchor == "center" assert mapper.time_scale == 500.0 assert mapper.frequency_scale == pytest.approx(1 / 1000.0) @@ -283,8 +283,8 @@ def test_load_roi_mapper_nested(tmp_path, nested_config_yaml_content): mapper = load_roi_mapper(config_path, field="model_settings.roi_mapping") - assert isinstance(mapper, BBoxEncoder) - assert mapper.position == "bottom-right" + assert isinstance(mapper, AnchorBBoxMapper) + assert mapper.anchor == "bottom-right" assert mapper.time_scale == DEFAULT_TIME_SCALE assert mapper.frequency_scale == 0.01 diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index 91000e4..27f7a1d 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -5,7 +5,7 @@ import xarray as xr from soundevent import data from batdetect2.targets import TargetConfig, TargetProtocol, build_targets -from batdetect2.targets.rois import ROIConfig +from batdetect2.targets.rois import BBoxAnchorMapperConfig from batdetect2.targets.terms import TagInfo, TermRegistry from batdetect2.train.labels import generate_heatmaps @@ -85,7 +85,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions( ): config = sample_target_config.model_copy( update=dict( - roi=ROIConfig( + roi=BBoxAnchorMapperConfig( time_scale=1, frequency_scale=1, )