mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Create targets.rois module
This commit is contained in:
parent
07f065cf93
commit
9410112e41
462
batdetect2/targets/rois.py
Normal file
462
batdetect2/targets/rois.py
Normal file
@ -0,0 +1,462 @@
|
|||||||
|
"""Handles mapping between geometric ROIs and target representations.
|
||||||
|
|
||||||
|
This module defines the interface and provides implementation for converting
|
||||||
|
a sound event's Region of Interest (ROI), typically represented by a
|
||||||
|
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
|
||||||
|
suitable for use as a machine learning target. This usually involves:
|
||||||
|
|
||||||
|
1. Extracting a single reference point (time, frequency) from the geometry.
|
||||||
|
2. Calculating relevant size dimensions (e.g., duration/width,
|
||||||
|
bandwidth/height) and applying scaling factors.
|
||||||
|
|
||||||
|
It also provides the inverse operation: recovering an approximate geometric ROI
|
||||||
|
(like a `BoundingBox`) from a predicted reference point and predicted size
|
||||||
|
dimensions.
|
||||||
|
|
||||||
|
This logic is encapsulated within components adhering to the `ROITargetMapper`
|
||||||
|
protocol. Configuration for this mapping (e.g., which reference point to use,
|
||||||
|
scaling factors) is managed by the `ROIConfig`. This module separates the
|
||||||
|
*geometric* aspect of target definition from the *semantic* classification
|
||||||
|
handled in `batdetect2.targets.classes`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from soundevent import data, geometry
|
||||||
|
from soundevent.geometry.operations import Positions
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ROITargetMapper",
|
||||||
|
"ROIConfig",
|
||||||
|
"BBoxEncoder",
|
||||||
|
"build_roi_mapper",
|
||||||
|
"load_roi_mapper",
|
||||||
|
"DEFAULT_POSITION",
|
||||||
|
"SIZE_WIDTH",
|
||||||
|
"SIZE_HEIGHT",
|
||||||
|
"SIZE_ORDER",
|
||||||
|
"DEFAULT_TIME_SCALE",
|
||||||
|
"DEFAULT_FREQUENCY_SCALE",
|
||||||
|
]
|
||||||
|
|
||||||
|
SIZE_WIDTH = "width"
|
||||||
|
"""Standard name for the width/time dimension component ('width')."""
|
||||||
|
|
||||||
|
SIZE_HEIGHT = "height"
|
||||||
|
"""Standard name for the height/frequency dimension component ('height')."""
|
||||||
|
|
||||||
|
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
|
||||||
|
"""Standard order of dimensions for size arrays ([width, height])."""
|
||||||
|
|
||||||
|
DEFAULT_TIME_SCALE = 1000.0
|
||||||
|
"""Default scaling factor for time duration."""
|
||||||
|
|
||||||
|
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
|
||||||
|
"""Default scaling factor for frequency bandwidth."""
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_POSITION = "bottom-left"
|
||||||
|
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||||
|
|
||||||
|
|
||||||
|
class ROITargetMapper(Protocol):
|
||||||
|
"""Protocol defining the interface for ROI-to-target mapping.
|
||||||
|
|
||||||
|
Specifies the methods required for converting a geometric region of interest
|
||||||
|
(`soundevent.data.Geometry`) into a target representation (reference point
|
||||||
|
and scaled dimensions) and for recovering an approximate ROI from that
|
||||||
|
representation.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
dimension_names : List[str]
|
||||||
|
A list containing the names of the dimensions returned by
|
||||||
|
`get_roi_size` and expected by `recover_roi`
|
||||||
|
(e.g., ['width', 'height']).
|
||||||
|
"""
|
||||||
|
|
||||||
|
dimension_names: List[str]
|
||||||
|
|
||||||
|
def get_roi_position(self, geom: data.Geometry) -> tuple[float, float]:
|
||||||
|
"""Extract the reference position from a geometry.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
geom : soundevent.data.Geometry
|
||||||
|
The input geometry (e.g., BoundingBox, Polygon).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[float, float]
|
||||||
|
The calculated reference position as (time, frequency) coordinates,
|
||||||
|
based on the implementing class's configuration (e.g., "center",
|
||||||
|
"bottom-left").
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the position cannot be calculated for the given geometry type
|
||||||
|
or configured reference point.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||||
|
"""Calculate the scaled target dimensions from a geometry.
|
||||||
|
|
||||||
|
Computes the relevant size measures.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
geom : soundevent.data.Geometry
|
||||||
|
The input geometry.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
A NumPy array containing the scaled dimensions corresponding to
|
||||||
|
`dimension_names`. For bounding boxes, typically contains
|
||||||
|
`[scaled_width, scaled_height]`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
TypeError, ValueError
|
||||||
|
If the size cannot be computed for the given geometry type.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def recover_roi(
|
||||||
|
self, pos: tuple[float, float], dims: np.ndarray
|
||||||
|
) -> data.Geometry:
|
||||||
|
"""Recover an approximate ROI from a position and target dimensions.
|
||||||
|
|
||||||
|
Performs the inverse mapping: takes a reference position and the
|
||||||
|
predicted dimensions and reconstructs a geometric representation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pos : Tuple[float, float]
|
||||||
|
The reference position (time, frequency).
|
||||||
|
dims : np.ndarray
|
||||||
|
The NumPy array containing the dimensions, matching the order
|
||||||
|
specified by `dimension_names`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.Geometry
|
||||||
|
The reconstructed geometry.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the number of provided dimensions `dims` does not match
|
||||||
|
`dimension_names` or if reconstruction fails.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ROIConfig(BaseConfig):
|
||||||
|
"""Configuration for mapping Regions of Interest (ROIs).
|
||||||
|
|
||||||
|
Defines parameters controlling how geometric ROIs are converted into
|
||||||
|
target representations (reference points and scaled sizes).
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
position : Positions, default="bottom-left"
|
||||||
|
Specifies the reference point within the geometry (e.g., bounding box)
|
||||||
|
to use as the target location (e.g., "center", "bottom-left").
|
||||||
|
See `soundevent.geometry.operations.Positions`.
|
||||||
|
time_scale : float, default=1000.0
|
||||||
|
Scaling factor applied to the time duration (width) of the ROI
|
||||||
|
when calculating the target size representation. Must match model
|
||||||
|
expectations.
|
||||||
|
frequency_scale : float, default=1/859.375
|
||||||
|
Scaling factor applied to the frequency bandwidth (height) of the ROI
|
||||||
|
when calculating the target size representation. Must match model
|
||||||
|
expectations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
position: Positions = DEFAULT_POSITION
|
||||||
|
time_scale: float = DEFAULT_TIME_SCALE
|
||||||
|
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
||||||
|
|
||||||
|
|
||||||
|
class BBoxEncoder(ROITargetMapper):
|
||||||
|
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
|
||||||
|
|
||||||
|
This class implements the ROI mapping protocol primarily for
|
||||||
|
`soundevent.data.BoundingBox` geometry. It extracts reference points,
|
||||||
|
calculates scaled width/height, and recovers bounding boxes based on
|
||||||
|
configured position and scaling factors.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
dimension_names : List[str]
|
||||||
|
Specifies the output dimension names as ['width', 'height'].
|
||||||
|
position : Positions
|
||||||
|
The configured reference point type (e.g., "center", "bottom-left").
|
||||||
|
time_scale : float
|
||||||
|
The configured scaling factor for the time dimension (width).
|
||||||
|
frequency_scale : float
|
||||||
|
The configured scaling factor for the frequency dimension (height).
|
||||||
|
"""
|
||||||
|
|
||||||
|
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
position: Positions = DEFAULT_POSITION,
|
||||||
|
time_scale: float = DEFAULT_TIME_SCALE,
|
||||||
|
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||||
|
):
|
||||||
|
"""Initialize the BBoxEncoder.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
position : Positions, default="bottom-left"
|
||||||
|
Reference point type within the bounding box.
|
||||||
|
time_scale : float, default=1000.0
|
||||||
|
Scaling factor for time duration (width).
|
||||||
|
frequency_scale : float, default=1/859.375
|
||||||
|
Scaling factor for frequency bandwidth (height).
|
||||||
|
"""
|
||||||
|
self.position: Positions = position
|
||||||
|
self.time_scale = time_scale
|
||||||
|
self.frequency_scale = frequency_scale
|
||||||
|
|
||||||
|
def get_roi_position(self, geom: data.Geometry) -> Tuple[float, float]:
|
||||||
|
"""Extract the configured reference position from the geometry.
|
||||||
|
|
||||||
|
Uses `soundevent.geometry.get_geometry_point`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
geom : soundevent.data.Geometry
|
||||||
|
Input geometry (e.g., BoundingBox).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[float, float]
|
||||||
|
Reference position (time, frequency).
|
||||||
|
"""
|
||||||
|
return geometry.get_geometry_point(geom, position=self.position)
|
||||||
|
|
||||||
|
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||||
|
"""Calculate the scaled [width, height] from the geometry's bounds.
|
||||||
|
|
||||||
|
Computes the bounding box, extracts duration and bandwidth, and applies
|
||||||
|
the configured `time_scale` and `frequency_scale`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
geom : soundevent.data.Geometry
|
||||||
|
Input geometry.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
A 1D NumPy array: `[scaled_width, scaled_height]`.
|
||||||
|
"""
|
||||||
|
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||||
|
geom
|
||||||
|
)
|
||||||
|
return np.array(
|
||||||
|
[
|
||||||
|
(end_time - start_time) * self.time_scale,
|
||||||
|
(high_freq - low_freq) * self.frequency_scale,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def recover_roi(
|
||||||
|
self,
|
||||||
|
pos: tuple[float, float],
|
||||||
|
dims: np.ndarray,
|
||||||
|
) -> data.Geometry:
|
||||||
|
"""Recover a BoundingBox from a position and scaled dimensions.
|
||||||
|
|
||||||
|
Un-scales the input dimensions using the configured factors and
|
||||||
|
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
|
||||||
|
the given reference `pos` according to the configured `position` type.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pos : Tuple[float, float]
|
||||||
|
Reference position (time, frequency).
|
||||||
|
dims : np.ndarray
|
||||||
|
NumPy array containing the *scaled* dimensions, expected order is
|
||||||
|
[scaled_width, scaled_height].
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.BoundingBox
|
||||||
|
The reconstructed bounding box.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `dims` does not have the expected shape (length 2).
|
||||||
|
"""
|
||||||
|
if dims.ndim != 1 or dims.shape[0] != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Dimension array does not have the expected shape. "
|
||||||
|
f"({dims.shape = }) != ([2])"
|
||||||
|
)
|
||||||
|
|
||||||
|
width, height = dims
|
||||||
|
return _build_bounding_box(
|
||||||
|
pos,
|
||||||
|
duration=width / self.time_scale,
|
||||||
|
bandwidth=height / self.frequency_scale,
|
||||||
|
position=self.position,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
|
||||||
|
"""Factory function to create an ROITargetMapper from configuration.
|
||||||
|
|
||||||
|
Currently creates a `BBoxEncoder` instance based on the provided
|
||||||
|
`ROIConfig`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : ROIConfig
|
||||||
|
Configuration object specifying ROI mapping parameters.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ROITargetMapper
|
||||||
|
An initialized `BBoxEncoder` instance configured with the settings
|
||||||
|
from `config`.
|
||||||
|
"""
|
||||||
|
return BBoxEncoder(
|
||||||
|
position=config.position,
|
||||||
|
time_scale=config.time_scale,
|
||||||
|
frequency_scale=config.frequency_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_roi_mapper(
|
||||||
|
path: data.PathLike, field: Optional[str] = None
|
||||||
|
) -> ROITargetMapper:
|
||||||
|
"""Load ROI mapping configuration from a file and build the mapper.
|
||||||
|
|
||||||
|
Convenience function that loads an `ROIConfig` from the specified file
|
||||||
|
(and optional field) and then uses `build_roi_mapper` to create the
|
||||||
|
corresponding `ROITargetMapper` instance.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file (e.g., YAML).
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
ROI configuration. If None, the entire file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ROITargetMapper
|
||||||
|
An initialized ROI mapper instance based on the configuration file.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||||
|
TypeError
|
||||||
|
If the configuration file cannot be found, parsed, validated, or if
|
||||||
|
the specified `field` is invalid.
|
||||||
|
"""
|
||||||
|
config = load_config(path=path, schema=ROIConfig, field=field)
|
||||||
|
return build_roi_mapper(config)
|
||||||
|
|
||||||
|
|
||||||
|
VALID_POSITIONS = [
|
||||||
|
"bottom-left",
|
||||||
|
"bottom-right",
|
||||||
|
"top-left",
|
||||||
|
"top-right",
|
||||||
|
"center-left",
|
||||||
|
"center-right",
|
||||||
|
"top-center",
|
||||||
|
"bottom-center",
|
||||||
|
"center",
|
||||||
|
"centroid",
|
||||||
|
"point_on_surface",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bounding_box(
|
||||||
|
pos: tuple[float, float],
|
||||||
|
duration: float,
|
||||||
|
bandwidth: float,
|
||||||
|
position: Positions = DEFAULT_POSITION,
|
||||||
|
) -> data.BoundingBox:
|
||||||
|
"""Construct a BoundingBox from a reference point, size, and position type.
|
||||||
|
|
||||||
|
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
|
||||||
|
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
||||||
|
the input `pos` (time, freq) is located relative to the box (e.g.,
|
||||||
|
center, corner).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pos : Tuple[float, float]
|
||||||
|
Reference position (time, frequency).
|
||||||
|
duration : float
|
||||||
|
The required *unscaled* duration (width) of the bounding box.
|
||||||
|
bandwidth : float
|
||||||
|
The required *unscaled* frequency bandwidth (height) of the bounding
|
||||||
|
box.
|
||||||
|
position : Positions, default="bottom-left"
|
||||||
|
Specifies which part of the bounding box the input `pos` corresponds to.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.BoundingBox
|
||||||
|
The constructed bounding box object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `position` is not a recognized value or format.
|
||||||
|
"""
|
||||||
|
time, freq = pos
|
||||||
|
if position in ["center", "centroid", "point_on_surface"]:
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
time - duration / 2,
|
||||||
|
freq - bandwidth / 2,
|
||||||
|
time + duration / 2,
|
||||||
|
freq + bandwidth / 2,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if position not in VALID_POSITIONS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid position: {position}. "
|
||||||
|
f"Valid options are: {VALID_POSITIONS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
y, x = position.split("-")
|
||||||
|
|
||||||
|
start_time = {
|
||||||
|
"left": time,
|
||||||
|
"center": time - duration / 2,
|
||||||
|
"right": time - duration,
|
||||||
|
}[x]
|
||||||
|
|
||||||
|
low_freq = {
|
||||||
|
"bottom": freq,
|
||||||
|
"center": freq - bandwidth / 2,
|
||||||
|
"top": freq - bandwidth,
|
||||||
|
}[y]
|
||||||
|
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
start_time,
|
||||||
|
low_freq,
|
||||||
|
start_time + duration,
|
||||||
|
low_freq + bandwidth,
|
||||||
|
]
|
||||||
|
)
|
304
tests/test_targets/test_rois.py
Normal file
304
tests/test_targets/test_rois.py
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets.rois import (
|
||||||
|
DEFAULT_FREQUENCY_SCALE,
|
||||||
|
DEFAULT_POSITION,
|
||||||
|
DEFAULT_TIME_SCALE,
|
||||||
|
SIZE_HEIGHT,
|
||||||
|
SIZE_WIDTH,
|
||||||
|
BBoxEncoder,
|
||||||
|
ROIConfig,
|
||||||
|
_build_bounding_box,
|
||||||
|
build_roi_mapper,
|
||||||
|
load_roi_mapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_bbox() -> data.BoundingBox:
|
||||||
|
"""A standard bounding box for testing."""
|
||||||
|
return data.BoundingBox(coordinates=[10.0, 100.0, 20.0, 200.0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def zero_bbox() -> data.BoundingBox:
|
||||||
|
"""A bounding box with zero duration and bandwidth."""
|
||||||
|
return data.BoundingBox(coordinates=[15.0, 150.0, 15.0, 150.0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def default_encoder() -> BBoxEncoder:
|
||||||
|
"""A BBoxEncoder with default settings."""
|
||||||
|
return BBoxEncoder()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def custom_encoder() -> BBoxEncoder:
|
||||||
|
"""A BBoxEncoder with custom settings."""
|
||||||
|
return BBoxEncoder(position="center", time_scale=1.0, frequency_scale=10.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_roi_config_defaults():
|
||||||
|
"""Test ROIConfig default values."""
|
||||||
|
config = ROIConfig()
|
||||||
|
assert config.position == DEFAULT_POSITION
|
||||||
|
assert config.time_scale == DEFAULT_TIME_SCALE
|
||||||
|
assert config.frequency_scale == DEFAULT_FREQUENCY_SCALE
|
||||||
|
|
||||||
|
|
||||||
|
def test_roi_config_custom():
|
||||||
|
"""Test creating ROIConfig with custom values."""
|
||||||
|
config = ROIConfig(position="center", time_scale=1.0, frequency_scale=10.0)
|
||||||
|
assert config.position == "center"
|
||||||
|
assert config.time_scale == 1.0
|
||||||
|
assert config.frequency_scale == 10.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_bbox_encoder_init_defaults(default_encoder):
|
||||||
|
"""Test BBoxEncoder initialization with default arguments."""
|
||||||
|
assert default_encoder.position == DEFAULT_POSITION
|
||||||
|
assert default_encoder.time_scale == DEFAULT_TIME_SCALE
|
||||||
|
assert default_encoder.frequency_scale == DEFAULT_FREQUENCY_SCALE
|
||||||
|
assert default_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
|
||||||
|
|
||||||
|
|
||||||
|
def test_bbox_encoder_init_custom(custom_encoder):
|
||||||
|
"""Test BBoxEncoder initialization with custom arguments."""
|
||||||
|
assert custom_encoder.position == "center"
|
||||||
|
assert custom_encoder.time_scale == 1.0
|
||||||
|
assert custom_encoder.frequency_scale == 10.0
|
||||||
|
assert custom_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
|
||||||
|
|
||||||
|
|
||||||
|
POSITION_TEST_CASES = [
|
||||||
|
("bottom-left", (10.0, 100.0)),
|
||||||
|
("bottom-right", (20.0, 100.0)),
|
||||||
|
("top-left", (10.0, 200.0)),
|
||||||
|
("top-right", (20.0, 200.0)),
|
||||||
|
("center-left", (10.0, 150.0)),
|
||||||
|
("center-right", (20.0, 150.0)),
|
||||||
|
("top-center", (15.0, 200.0)),
|
||||||
|
("bottom-center", (15.0, 100.0)),
|
||||||
|
("center", (15.0, 150.0)),
|
||||||
|
("centroid", (15.0, 150.0)),
|
||||||
|
("point_on_surface", (15.0, 150.0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("position_type, expected_pos", POSITION_TEST_CASES)
|
||||||
|
def test_bbox_encoder_get_roi_position(
|
||||||
|
sample_bbox, position_type, expected_pos
|
||||||
|
):
|
||||||
|
"""Test get_roi_position for various position types."""
|
||||||
|
encoder = BBoxEncoder(position=position_type)
|
||||||
|
actual_pos = encoder.get_roi_position(sample_bbox)
|
||||||
|
assert actual_pos == pytest.approx(expected_pos)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bbox_encoder_get_roi_position_zero_box(zero_bbox):
|
||||||
|
"""Test get_roi_position for a zero-sized box."""
|
||||||
|
encoder = BBoxEncoder(position="center")
|
||||||
|
assert encoder.get_roi_position(zero_bbox) == pytest.approx((15.0, 150.0))
|
||||||
|
|
||||||
|
|
||||||
|
def test_bbox_encoder_get_roi_size_defaults(sample_bbox, default_encoder):
|
||||||
|
"""Test get_roi_size with default scaling."""
|
||||||
|
expected_size = np.array(
|
||||||
|
[
|
||||||
|
10.0 * DEFAULT_TIME_SCALE,
|
||||||
|
100.0 * DEFAULT_FREQUENCY_SCALE,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
actual_size = default_encoder.get_roi_size(sample_bbox)
|
||||||
|
np.testing.assert_allclose(actual_size, expected_size)
|
||||||
|
assert actual_size.shape == (2,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bbox_encoder_get_roi_size_custom(sample_bbox, custom_encoder):
|
||||||
|
"""Test get_roi_size with custom scaling."""
|
||||||
|
expected_size = np.array(
|
||||||
|
[
|
||||||
|
10.0 * 1.0,
|
||||||
|
100.0 * 10.0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
actual_size = custom_encoder.get_roi_size(sample_bbox)
|
||||||
|
np.testing.assert_allclose(actual_size, expected_size)
|
||||||
|
assert actual_size.shape == (2,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bbox_encoder_get_roi_size_zero_box(zero_bbox, default_encoder):
|
||||||
|
"""Test get_roi_size for a zero-sized box."""
|
||||||
|
expected_size = np.array([0.0, 0.0])
|
||||||
|
actual_size = default_encoder.get_roi_size(zero_bbox)
|
||||||
|
np.testing.assert_allclose(actual_size, expected_size)
|
||||||
|
|
||||||
|
|
||||||
|
BUILD_BOX_TEST_CASES = [
|
||||||
|
("bottom-left", [50.0, 500.0, 60.0, 600.0]),
|
||||||
|
("bottom-right", [40.0, 500.0, 50.0, 600.0]),
|
||||||
|
("top-left", [50.0, 400.0, 60.0, 500.0]),
|
||||||
|
("top-right", [40.0, 400.0, 50.0, 500.0]),
|
||||||
|
("center-left", [50.0, 450.0, 60.0, 550.0]),
|
||||||
|
("center-right", [40.0, 450.0, 50.0, 550.0]),
|
||||||
|
("top-center", [45.0, 400.0, 55.0, 500.0]),
|
||||||
|
("bottom-center", [45.0, 500.0, 55.0, 600.0]),
|
||||||
|
("center", [45.0, 450.0, 55.0, 550.0]),
|
||||||
|
("centroid", [45.0, 450.0, 55.0, 550.0]),
|
||||||
|
("point_on_surface", [45.0, 450.0, 55.0, 550.0]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"position_type, expected_coords", BUILD_BOX_TEST_CASES
|
||||||
|
)
|
||||||
|
def test_build_bounding_box(position_type, expected_coords):
|
||||||
|
"""Test _build_bounding_box for various position types."""
|
||||||
|
ref_pos = (50.0, 500.0)
|
||||||
|
duration = 10.0
|
||||||
|
bandwidth = 100.0
|
||||||
|
bbox = _build_bounding_box(
|
||||||
|
ref_pos, duration, bandwidth, position=position_type
|
||||||
|
)
|
||||||
|
assert isinstance(bbox, data.BoundingBox)
|
||||||
|
np.testing.assert_allclose(bbox.coordinates, expected_coords)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_bounding_box_invalid_position():
|
||||||
|
"""Test _build_bounding_box raises error for invalid position."""
|
||||||
|
with pytest.raises(ValueError, match="Invalid position"):
|
||||||
|
_build_bounding_box(
|
||||||
|
(0, 0),
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
position="invalid-spot", # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("position_type, ref_pos", POSITION_TEST_CASES)
|
||||||
|
def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos):
|
||||||
|
"""Test recover_roi correctly reconstructs the original bbox."""
|
||||||
|
encoder = BBoxEncoder(position=position_type)
|
||||||
|
scaled_dims = encoder.get_roi_size(sample_bbox)
|
||||||
|
|
||||||
|
recovered_bbox = encoder.recover_roi(ref_pos, scaled_dims)
|
||||||
|
|
||||||
|
assert isinstance(recovered_bbox, data.BoundingBox)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
recovered_bbox.coordinates, sample_bbox.coordinates, atol=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bbox_encoder_recover_roi_custom_scale(sample_bbox, custom_encoder):
|
||||||
|
"""Test recover_roi with custom scaling factors."""
|
||||||
|
ref_pos = custom_encoder.get_roi_position(sample_bbox)
|
||||||
|
scaled_dims = custom_encoder.get_roi_size(sample_bbox)
|
||||||
|
|
||||||
|
recovered_bbox = custom_encoder.recover_roi(ref_pos, scaled_dims)
|
||||||
|
|
||||||
|
assert isinstance(recovered_bbox, data.BoundingBox)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
recovered_bbox.coordinates, sample_bbox.coordinates, atol=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bbox_encoder_recover_roi_zero_box(zero_bbox, default_encoder):
|
||||||
|
"""Test recover_roi for a zero-sized box."""
|
||||||
|
ref_pos = default_encoder.get_roi_position(zero_bbox)
|
||||||
|
scaled_dims = default_encoder.get_roi_size(zero_bbox)
|
||||||
|
recovered_bbox = default_encoder.recover_roi(ref_pos, scaled_dims)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
recovered_bbox.coordinates, zero_bbox.coordinates, atol=1e-6
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bbox_encoder_recover_roi_invalid_dims_shape(default_encoder):
|
||||||
|
"""Test recover_roi raises ValueError for incorrect dims shape."""
|
||||||
|
ref_pos = (10, 100)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
default_encoder.recover_roi(ref_pos, np.array([1.0]))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
default_encoder.recover_roi(ref_pos, np.array([1.0, 2.0, 3.0]))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
default_encoder.recover_roi(ref_pos, np.array([[1.0], [2.0]]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_roi_mapper():
|
||||||
|
"""Test build_roi_mapper creates a configured BBoxEncoder."""
|
||||||
|
config = ROIConfig(
|
||||||
|
position="top-right", time_scale=2.0, frequency_scale=20.0
|
||||||
|
)
|
||||||
|
mapper = build_roi_mapper(config)
|
||||||
|
|
||||||
|
assert isinstance(mapper, BBoxEncoder)
|
||||||
|
assert mapper.position == config.position
|
||||||
|
assert mapper.time_scale == config.time_scale
|
||||||
|
assert mapper.frequency_scale == config.frequency_scale
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_config_yaml_content() -> str:
|
||||||
|
"""YAML content for a sample ROIConfig."""
|
||||||
|
return f"""
|
||||||
|
position: center
|
||||||
|
time_scale: 500.0
|
||||||
|
frequency_scale: {1 / 1000.0}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def nested_config_yaml_content() -> str:
|
||||||
|
"""YAML content with ROIConfig nested under a field."""
|
||||||
|
return f"""
|
||||||
|
model_settings:
|
||||||
|
preprocessing:
|
||||||
|
whatever: true
|
||||||
|
roi_mapping:
|
||||||
|
position: bottom-right
|
||||||
|
time_scale: {DEFAULT_TIME_SCALE}
|
||||||
|
frequency_scale: 0.01
|
||||||
|
other_stuff: 123
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_roi_mapper_simple(tmp_path, sample_config_yaml_content):
|
||||||
|
"""Test loading a simple ROIConfig from YAML."""
|
||||||
|
config_path = tmp_path / "config.yaml"
|
||||||
|
config_path.write_text(sample_config_yaml_content)
|
||||||
|
|
||||||
|
mapper = load_roi_mapper(config_path)
|
||||||
|
|
||||||
|
assert isinstance(mapper, BBoxEncoder)
|
||||||
|
assert mapper.position == "center"
|
||||||
|
assert mapper.time_scale == 500.0
|
||||||
|
assert mapper.frequency_scale == pytest.approx(1 / 1000.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_roi_mapper_nested(tmp_path, nested_config_yaml_content):
|
||||||
|
"""Test loading a nested ROIConfig from YAML using 'field'."""
|
||||||
|
config_path = tmp_path / "nested_config.yaml"
|
||||||
|
config_path.write_text(nested_config_yaml_content)
|
||||||
|
|
||||||
|
mapper = load_roi_mapper(config_path, field="model_settings.roi_mapping")
|
||||||
|
|
||||||
|
assert isinstance(mapper, BBoxEncoder)
|
||||||
|
assert mapper.position == "bottom-right"
|
||||||
|
assert mapper.time_scale == DEFAULT_TIME_SCALE
|
||||||
|
assert mapper.frequency_scale == 0.01
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_roi_mapper_file_not_found(tmp_path):
|
||||||
|
"""Test load_roi_mapper raises error if file doesn't exist."""
|
||||||
|
non_existent_path = tmp_path / "not_real.yaml"
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
load_roi_mapper(non_existent_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_roi_mapper_invalid_field(tmp_path, sample_config_yaml_content):
|
||||||
|
"""Test load_roi_mapper raises error for invalid field."""
|
||||||
|
config_path = tmp_path / "config.yaml"
|
||||||
|
config_path.write_text(sample_config_yaml_content)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
load_roi_mapper(config_path, field="invalid.path")
|
Loading…
Reference in New Issue
Block a user