Changed ROIMapper protocol to only have encoder/decoder methods

This commit is contained in:
mbsantiago 2025-06-21 11:44:15 +01:00
parent ebad489cb1
commit c559bcc682
5 changed files with 278 additions and 184 deletions

View File

@ -50,7 +50,7 @@ from batdetect2.targets.filtering import (
load_filter_from_config,
)
from batdetect2.targets.rois import (
ROIConfig,
BBoxAnchorMapperConfig,
ROITargetMapper,
build_roi_mapper,
)
@ -88,7 +88,7 @@ __all__ = [
"FilterConfig",
"FilterRule",
"MapValueRule",
"ROIConfig",
"BBoxAnchorMapperConfig",
"ROITargetMapper",
"ReplaceRule",
"SoundEventDecoder",
@ -156,12 +156,12 @@ class TargetConfig(BaseConfig):
omitted, default ROI mapping settings are used.
"""
filtering: Optional[FilterConfig] = None
transforms: Optional[TransformConfig] = None
filtering: FilterConfig = Field(default_factory=FilterConfig)
transforms: TransformConfig = Field(default_factory=TransformConfig)
classes: ClassesConfig = Field(
default_factory=lambda: DEFAULT_CLASSES_CONFIG
)
roi: Optional[ROIConfig] = None
roi: Optional[BBoxAnchorMapperConfig] = None
def load_target_config(
@ -374,14 +374,7 @@ class Targets(TargetProtocol):
ValueError
If the annotation lacks geometry.
"""
geom = sound_event.sound_event.geometry
if geom is None:
raise ValueError(
"Sound event has no geometry, cannot get its position."
)
return self._roi_mapper.get_roi_position(geom)
return self._roi_mapper.encode_position(sound_event.sound_event)
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
"""Calculate the target size dimensions from the annotation's geometry.
@ -405,14 +398,7 @@ class Targets(TargetProtocol):
ValueError
If the annotation lacks geometry.
"""
geom = sound_event.sound_event.geometry
if geom is None:
raise ValueError(
"Sound event has no geometry, cannot get its size."
)
return self._roi_mapper.get_roi_size(geom)
return self._roi_mapper.encode_size(sound_event.sound_event)
def recover_roi(
self,
@ -438,7 +424,7 @@ class Targets(TargetProtocol):
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
return self._roi_mapper.recover_roi(pos, dims)
return self._roi_mapper.decode(pos, dims)
DEFAULT_CLASSES = [
@ -606,7 +592,7 @@ def build_targets(
if config.transforms
else None
)
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
roi_mapper = build_roi_mapper(config.roi or BBoxAnchorMapperConfig())
class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags(
config.classes,

View File

@ -27,8 +27,29 @@ __all__ = [
"build_generic_class_tags",
"get_class_names_from_config",
"DEFAULT_SPECIES_LIST",
"PositionMethod",
"CornerPosition",
"SizeMethod",
"BoundingBoxSize",
]
class PositionMethod(BaseConfig):
"""Base class for defining how to select a position from a geometry."""
method_type: str
class CornerPosition(PositionMethod):
"""Selects a position based on a corner or center of the bounding box."""
method_type: Literal["corner"] = "corner"
corner: Literal["upper_left", "lower_left", "center"] = "lower_left"
class SizeMethod(BaseConfig):
"""Base class for defining how to select a size from a geometry."""
method_type: str
class BoundingBoxSize(SizeMethod):
"""Uses the width and height of the bounding box as the size."""
method_type: Literal["bounding_box"] = "bounding_box"
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Type alias for a sound event class encoder function.
@ -113,6 +134,8 @@ class TargetClass(BaseConfig):
tags: List[TagInfo] = Field(min_length=1)
match_type: Literal["all", "any"] = Field(default="all")
output_tags: Optional[List[TagInfo]] = None
position_method: PositionMethod = Field(default_factory=lambda: CornerPosition(corner="lower_left"))
size_method: SizeMethod = Field(default_factory=BoundingBoxSize)
def _get_default_classes() -> List[TargetClass]:

View File

@ -20,14 +20,31 @@ scaling factors) is managed by the `ROIConfig`. This module separates the
handled in `batdetect2.targets.classes`.
"""
from typing import List, Literal, Optional, Protocol, Tuple
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
import numpy as np
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
Positions = Literal[
__all__ = [
"ROITargetMapper",
"BBoxAnchorMapperConfig",
"AnchorBBoxMapper",
"build_roi_mapper",
"load_roi_mapper",
"DEFAULT_ANCHOR",
"SIZE_WIDTH",
"SIZE_HEIGHT",
"SIZE_ORDER",
"DEFAULT_TIME_SCALE",
"DEFAULT_FREQUENCY_SCALE",
]
Anchor = Literal[
"bottom-left",
"bottom-right",
"top-left",
@ -41,20 +58,6 @@ Positions = Literal[
"point_on_surface",
]
__all__ = [
"ROITargetMapper",
"ROIConfig",
"BBoxEncoder",
"build_roi_mapper",
"load_roi_mapper",
"DEFAULT_POSITION",
"SIZE_WIDTH",
"SIZE_HEIGHT",
"SIZE_ORDER",
"DEFAULT_TIME_SCALE",
"DEFAULT_FREQUENCY_SCALE",
]
SIZE_WIDTH = "width"
"""Standard name for the width/time dimension component ('width')."""
@ -71,10 +74,15 @@ DEFAULT_FREQUENCY_SCALE = 1 / 859.375
"""Default scaling factor for frequency bandwidth."""
DEFAULT_POSITION = "bottom-left"
DEFAULT_ANCHOR = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner)."""
Position = tuple[float, float]
Size = np.ndarray
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
@ -93,21 +101,13 @@ class ROITargetMapper(Protocol):
dimension_names: List[str]
def get_roi_position(
self,
geom: data.Geometry,
position: Optional[Positions] = None,
) -> tuple[float, float]:
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Extract the reference position from a geometry.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry (e.g., BoundingBox, Polygon).
position : Positions, optional
Overrides the default `position` configured for the mapper.
If provided, this position will be used instead of the mapper's
internal default.
Returns
-------
@ -124,36 +124,7 @@ class ROITargetMapper(Protocol):
"""
...
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
"""Calculate the scaled target dimensions from a geometry.
Computes the relevant size measures.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry.
Returns
-------
np.ndarray
A NumPy array containing the scaled dimensions corresponding to
`dimension_names`. For bounding boxes, typically contains
`[scaled_width, scaled_height]`.
Raises
------
TypeError, ValueError
If the size cannot be computed for the given geometry type.
"""
...
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
position: Optional[Positions] = None,
) -> data.Geometry:
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Recover an approximate ROI from a position and target dimensions.
Performs the inverse mapping: takes a reference position and the
@ -161,15 +132,11 @@ class ROITargetMapper(Protocol):
Parameters
----------
pos : Tuple[float, float]
position : Tuple[float, float]
The reference position (time, frequency).
dims : np.ndarray
size : np.ndarray
NumPy array containing the dimensions, matching the order
specified by `dimension_names`.
position : Positions, optional
Overrides the default `position` configured for the mapper.
If provided, this position will be used instead of the mapper's
internal default when reconstructing the roi geometry.
Returns
-------
@ -185,7 +152,7 @@ class ROITargetMapper(Protocol):
...
class ROIConfig(BaseConfig):
class BBoxAnchorMapperConfig(BaseConfig):
"""Configuration for mapping Regions of Interest (ROIs).
Defines parameters controlling how geometric ROIs are converted into
@ -193,10 +160,9 @@ class ROIConfig(BaseConfig):
Attributes
----------
position : Positions, default="bottom-left"
anchor : Anchor, default="bottom-left"
Specifies the reference point within the geometry (e.g., bounding box)
to use as the target location (e.g., "center", "bottom-left").
See `soundevent.geometry.operations.Positions`.
time_scale : float, default=1000.0
Scaling factor applied to the time duration (width) of the ROI
when calculating the target size representation. Must match model
@ -207,12 +173,13 @@ class ROIConfig(BaseConfig):
expectations.
"""
position: Positions = DEFAULT_POSITION
name: Literal["anchor_bbox"] = "anchor_bbox"
anchor: Anchor = DEFAULT_ANCHOR
time_scale: float = DEFAULT_TIME_SCALE
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
class BBoxEncoder(ROITargetMapper):
class AnchorBBoxMapper(ROITargetMapper):
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
This class implements the ROI mapping protocol primarily for
@ -224,7 +191,7 @@ class BBoxEncoder(ROITargetMapper):
----------
dimension_names : List[str]
Specifies the output dimension names as ['width', 'height'].
position : Positions
anchor : Anchor
The configured reference point type (e.g., "center", "bottom-left").
time_scale : float
The configured scaling factor for the time dimension (width).
@ -236,7 +203,7 @@ class BBoxEncoder(ROITargetMapper):
def __init__(
self,
position: Positions = DEFAULT_POSITION,
anchor: Anchor = DEFAULT_ANCHOR,
time_scale: float = DEFAULT_TIME_SCALE,
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
):
@ -244,22 +211,18 @@ class BBoxEncoder(ROITargetMapper):
Parameters
----------
position : Positions, default="bottom-left"
anchor : Anchor, default="bottom-left"
Reference point type within the bounding box.
time_scale : float, default=1000.0
Scaling factor for time duration (width).
frequency_scale : float, default=1/859.375
Scaling factor for frequency bandwidth (height).
"""
self.position: Positions = position
self.anchor: Anchor = anchor
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def get_roi_position(
self,
geom: data.Geometry,
position: Optional[Positions] = None,
) -> Tuple[float, float]:
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
"""Extract the configured reference position from the geometry.
Uses `soundevent.geometry.get_geometry_point`.
@ -268,9 +231,6 @@ class BBoxEncoder(ROITargetMapper):
----------
geom : soundevent.data.Geometry
Input geometry (e.g., BoundingBox).
position : Positions, optional
Overrides the default `position` configured for the encoder.
If provided, this position will be used instead of `self.position`.
Returns
-------
@ -279,42 +239,33 @@ class BBoxEncoder(ROITargetMapper):
"""
from soundevent import geometry
position = position or self.position
return geometry.get_geometry_point(geom, position=position)
geom = sound_event.geometry
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
"""Calculate the scaled [width, height] from the geometry's bounds.
if geom is None:
raise ValueError(
"Cannot encode the geometry of a sound event without geometry."
f" Sound event: {sound_event}"
)
Computes the bounding box, extracts duration and bandwidth, and applies
the configured `time_scale` and `frequency_scale`.
Parameters
----------
geom : soundevent.data.Geometry
Input geometry.
Returns
-------
np.ndarray
A 1D NumPy array: `[scaled_width, scaled_height]`.
"""
from soundevent import geometry
position = geometry.get_geometry_point(geom, position=self.anchor)
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
return np.array(
size = np.array(
[
(end_time - start_time) * self.time_scale,
(high_freq - low_freq) * self.frequency_scale,
]
)
def recover_roi(
return position, size
def decode(
self,
pos: tuple[float, float],
dims: np.ndarray,
position: Optional[Positions] = None,
position: Position,
size: Size,
) -> data.Geometry:
"""Recover a BoundingBox from a position and scaled dimensions.
@ -329,10 +280,6 @@ class BBoxEncoder(ROITargetMapper):
dims : np.ndarray
NumPy array containing the *scaled* dimensions, expected order is
[scaled_width, scaled_height].
position : Positions, optional
Overrides the default `position` configured for the encoder.
If provided, this position will be used instead of `self.position`
when reconstructing the bounding box.
Returns
-------
@ -344,28 +291,113 @@ class BBoxEncoder(ROITargetMapper):
ValueError
If `dims` does not have the expected shape (length 2).
"""
position = position or self.position
if dims.ndim != 1 or dims.shape[0] != 2:
if size.ndim != 1 or size.shape[0] != 2:
raise ValueError(
"Dimension array does not have the expected shape. "
f"({dims.shape = }) != ([2])"
f"({size.shape = }) != ([2])"
)
width, height = dims
width, height = size
return _build_bounding_box(
pos,
position,
duration=float(width) / self.time_scale,
bandwidth=float(height) / self.frequency_scale,
position=self.position,
anchor=self.anchor,
)
def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from configuration.
class PeakEnergyBBoxMapperConfig(BaseConfig):
name: Literal["peak_energy_bbox"]
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
loading_buffer: float = 0.01
time_scale: float = DEFAULT_TIME_SCALE
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
Currently creates a `BBoxEncoder` instance based on the provided
`ROIConfig`.
class PeakEnergyBBoxMapper(ROITargetMapper):
"""
Encodes the ROI using the location of the peak energy within the bounding box
as the 'position' and the distances from that point to the box edges as the 'size'.
"""
dimension_names = ["left", "bottom", "right", "top"]
def __init__(
self,
preprocessor: PreprocessorProtocol,
time_scale: float = DEFAULT_TIME_SCALE,
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
loading_buffer: float = 0.01,
):
self.preprocessor = preprocessor
self.time_scale = time_scale
self.frequency_scale = frequency_scale
self.loading_buffer = loading_buffer
def encode(
self,
sound_event: data.SoundEvent,
) -> tuple[Position, Size]:
from soundevent import geometry
geom = sound_event.geometry
if geom is None:
raise ValueError(
"Cannot encode the geometry of a sound event without geometry."
f" Sound event: {sound_event}"
)
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
time, freq = get_peak_energy_coordinates(
recording=sound_event.recording,
preprocessor=self.preprocessor,
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
loading_buffer=self.loading_buffer,
)
size = np.array(
[
(time - start_time) * self.time_scale,
(freq - low_freq) * self.frequency_scale,
(end_time - time) * self.time_scale,
(high_freq - freq) * self.frequency_scale,
]
)
return (time, freq), size
def decode(self, position: Position, size: Size) -> data.Geometry:
time, freq = position
left, bottom, right, top = size
return data.BoundingBox(
coordinates=[
time - max(0, float(left)) / self.time_scale,
freq - max(0, float(bottom)) / self.frequency_scale,
time + max(0, float(right)) / self.time_scale,
freq + max(0, float(top)) / self.frequency_scale,
]
)
ROIMapperConfig = Annotated[
Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig],
Field(discriminator="name"),
]
def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from configuration.
Parameters
----------
@ -378,10 +410,24 @@ def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
An initialized `BBoxEncoder` instance configured with the settings
from `config`.
"""
return BBoxEncoder(
position=config.position,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
if config.name == "anchor_bbox":
return AnchorBBoxMapper(
anchor=config.anchor,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
if config.name == "peak_energy_bbox":
preprocessor = build_preprocessor(config.preprocessing)
return PeakEnergyBBoxMapper(
preprocessor=preprocessor,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
loading_buffer=config.loading_buffer,
)
raise NotImplementedError(
f"No ROI mapper of name {config.name} is implemented"
)
@ -414,11 +460,11 @@ def load_roi_mapper(
If the configuration file cannot be found, parsed, validated, or if
the specified `field` is invalid.
"""
config = load_config(path=path, schema=ROIConfig, field=field)
config = load_config(path=path, schema=BBoxAnchorMapperConfig, field=field)
return build_roi_mapper(config)
VALID_POSITIONS = [
VALID_ANCHORS = [
"bottom-left",
"bottom-right",
"top-left",
@ -437,7 +483,7 @@ def _build_bounding_box(
pos: tuple[float, float],
duration: float,
bandwidth: float,
position: Positions = DEFAULT_POSITION,
anchor: Anchor = DEFAULT_ANCHOR,
) -> data.BoundingBox:
"""Construct a BoundingBox from a reference point, size, and position type.
@ -455,7 +501,7 @@ def _build_bounding_box(
bandwidth : float
The required *unscaled* frequency bandwidth (height) of the bounding
box.
position : Positions, default="bottom-left"
anchor : Anchor, default="bottom-left"
Specifies which part of the bounding box the input `pos` corresponds to.
Returns
@ -466,12 +512,12 @@ def _build_bounding_box(
Raises
------
ValueError
If `position` is not a recognized value or format.
If `anchor` is not a recognized value or format.
"""
time, freq = map(float, pos)
duration = max(0, duration)
bandwidth = max(0, bandwidth)
if position in ["center", "centroid", "point_on_surface"]:
if anchor in ["center", "centroid", "point_on_surface"]:
return data.BoundingBox(
coordinates=[
max(time - duration / 2, 0),
@ -481,13 +527,12 @@ def _build_bounding_box(
]
)
if position not in VALID_POSITIONS:
if anchor not in VALID_ANCHORS:
raise ValueError(
f"Invalid position: {position}. "
f"Valid options are: {VALID_POSITIONS}"
f"Invalid anchor: {anchor}. Valid options are: {VALID_ANCHORS}"
)
y, x = position.split("-")
y, x = anchor.split("-")
start_time = {
"left": time,
@ -509,3 +554,43 @@ def _build_bounding_box(
max(0, low_freq + bandwidth),
]
)
def get_peak_energy_coordinates(
recording: data.Recording,
preprocessor: PreprocessorProtocol,
start_time: float = 0,
end_time: Optional[float] = None,
low_freq: float = 0,
high_freq: Optional[float] = None,
loading_buffer: float = 0.05,
) -> Position:
if end_time is None:
end_time = recording.duration
end_time = min(end_time, recording.duration)
if high_freq is None:
high_freq = recording.samplerate / 2
clip_start = max(0, start_time - loading_buffer)
clip_end = min(recording.duration, end_time + loading_buffer)
clip = data.Clip(
recording=recording,
start_time=clip_start,
end_time=clip_end,
)
spec = preprocessor.preprocess_clip(clip)
low_freq = max(low_freq, preprocessor.min_freq)
high_freq = min(high_freq, preprocessor.max_freq)
selection = spec.sel(
time=slice(start_time, end_time),
frequency=slice(low_freq, high_freq),
)
index = selection.argmax(dim=["time", "frequency"])
point = selection.isel(index) # type: ignore
peak_time: float = point.time.item()
peak_freq: float = point.frequency.item()
return peak_time, peak_freq

View File

@ -4,12 +4,12 @@ from soundevent import data
from batdetect2.targets.rois import (
DEFAULT_FREQUENCY_SCALE,
DEFAULT_POSITION,
DEFAULT_ANCHOR,
DEFAULT_TIME_SCALE,
SIZE_HEIGHT,
SIZE_WIDTH,
BBoxEncoder,
ROIConfig,
AnchorBBoxMapper,
BBoxAnchorMapperConfig,
_build_bounding_box,
build_roi_mapper,
load_roi_mapper,
@ -29,36 +29,36 @@ def zero_bbox() -> data.BoundingBox:
@pytest.fixture
def default_encoder() -> BBoxEncoder:
def default_encoder() -> AnchorBBoxMapper:
"""A BBoxEncoder with default settings."""
return BBoxEncoder()
return AnchorBBoxMapper()
@pytest.fixture
def custom_encoder() -> BBoxEncoder:
def custom_encoder() -> AnchorBBoxMapper:
"""A BBoxEncoder with custom settings."""
return BBoxEncoder(position="center", time_scale=1.0, frequency_scale=10.0)
return AnchorBBoxMapper(anchor="center", time_scale=1.0, frequency_scale=10.0)
def test_roi_config_defaults():
"""Test ROIConfig default values."""
config = ROIConfig()
assert config.position == DEFAULT_POSITION
config = BBoxAnchorMapperConfig()
assert config.anchor == DEFAULT_ANCHOR
assert config.time_scale == DEFAULT_TIME_SCALE
assert config.frequency_scale == DEFAULT_FREQUENCY_SCALE
def test_roi_config_custom():
"""Test creating ROIConfig with custom values."""
config = ROIConfig(position="center", time_scale=1.0, frequency_scale=10.0)
assert config.position == "center"
config = BBoxAnchorMapperConfig(anchor="center", time_scale=1.0, frequency_scale=10.0)
assert config.anchor == "center"
assert config.time_scale == 1.0
assert config.frequency_scale == 10.0
def test_bbox_encoder_init_defaults(default_encoder):
"""Test BBoxEncoder initialization with default arguments."""
assert default_encoder.position == DEFAULT_POSITION
assert default_encoder.position == DEFAULT_ANCHOR
assert default_encoder.time_scale == DEFAULT_TIME_SCALE
assert default_encoder.frequency_scale == DEFAULT_FREQUENCY_SCALE
assert default_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
@ -92,15 +92,15 @@ def test_bbox_encoder_get_roi_position(
sample_bbox, position_type, expected_pos
):
"""Test get_roi_position for various position types."""
encoder = BBoxEncoder(position=position_type)
actual_pos = encoder.get_roi_position(sample_bbox)
encoder = AnchorBBoxMapper(anchor=position_type)
actual_pos = encoder.encode_position(sample_bbox)
assert actual_pos == pytest.approx(expected_pos)
def test_bbox_encoder_get_roi_position_zero_box(zero_bbox):
"""Test get_roi_position for a zero-sized box."""
encoder = BBoxEncoder(position="center")
assert encoder.get_roi_position(zero_bbox) == pytest.approx((15.0, 150.0))
encoder = AnchorBBoxMapper(anchor="center")
assert encoder.encode_position(zero_bbox) == pytest.approx((15.0, 150.0))
def test_bbox_encoder_get_roi_size_defaults(sample_bbox, default_encoder):
@ -160,7 +160,7 @@ def test_build_bounding_box(position_type, expected_coords):
duration = 10.0
bandwidth = 100.0
bbox = _build_bounding_box(
ref_pos, duration, bandwidth, position=position_type
ref_pos, duration, bandwidth, anchor=position_type
)
assert isinstance(bbox, data.BoundingBox)
np.testing.assert_allclose(bbox.coordinates, expected_coords)
@ -173,17 +173,17 @@ def test_build_bounding_box_invalid_position():
(0, 0),
1,
1,
position="invalid-spot", # type: ignore
anchor="invalid-spot", # type: ignore
)
@pytest.mark.parametrize("position_type, ref_pos", POSITION_TEST_CASES)
def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos):
"""Test recover_roi correctly reconstructs the original bbox."""
encoder = BBoxEncoder(position=position_type)
scaled_dims = encoder.get_roi_size(sample_bbox)
encoder = AnchorBBoxMapper(anchor=position_type)
scaled_dims = encoder.encode_size(sample_bbox)
recovered_bbox = encoder.recover_roi(ref_pos, scaled_dims)
recovered_bbox = encoder.decode(ref_pos, scaled_dims)
assert isinstance(recovered_bbox, data.BoundingBox)
np.testing.assert_allclose(
@ -227,13 +227,13 @@ def test_bbox_encoder_recover_roi_invalid_dims_shape(default_encoder):
def test_build_roi_mapper():
"""Test build_roi_mapper creates a configured BBoxEncoder."""
config = ROIConfig(
position="top-right", time_scale=2.0, frequency_scale=20.0
config = BBoxAnchorMapperConfig(
anchor="top-right", time_scale=2.0, frequency_scale=20.0
)
mapper = build_roi_mapper(config)
assert isinstance(mapper, BBoxEncoder)
assert mapper.position == config.position
assert isinstance(mapper, AnchorBBoxMapper)
assert mapper.anchor == config.anchor
assert mapper.time_scale == config.time_scale
assert mapper.frequency_scale == config.frequency_scale
@ -270,8 +270,8 @@ def test_load_roi_mapper_simple(tmp_path, sample_config_yaml_content):
mapper = load_roi_mapper(config_path)
assert isinstance(mapper, BBoxEncoder)
assert mapper.position == "center"
assert isinstance(mapper, AnchorBBoxMapper)
assert mapper.anchor == "center"
assert mapper.time_scale == 500.0
assert mapper.frequency_scale == pytest.approx(1 / 1000.0)
@ -283,8 +283,8 @@ def test_load_roi_mapper_nested(tmp_path, nested_config_yaml_content):
mapper = load_roi_mapper(config_path, field="model_settings.roi_mapping")
assert isinstance(mapper, BBoxEncoder)
assert mapper.position == "bottom-right"
assert isinstance(mapper, AnchorBBoxMapper)
assert mapper.anchor == "bottom-right"
assert mapper.time_scale == DEFAULT_TIME_SCALE
assert mapper.frequency_scale == 0.01

View File

@ -5,7 +5,7 @@ import xarray as xr
from soundevent import data
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.targets.rois import ROIConfig
from batdetect2.targets.rois import BBoxAnchorMapperConfig
from batdetect2.targets.terms import TagInfo, TermRegistry
from batdetect2.train.labels import generate_heatmaps
@ -85,7 +85,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
):
config = sample_target_config.model_copy(
update=dict(
roi=ROIConfig(
roi=BBoxAnchorMapperConfig(
time_scale=1,
frequency_scale=1,
)