batdetect2/batdetect2/targets/rois.py

597 lines
18 KiB
Python

"""Handles mapping between geometric ROIs and target representations.
This module defines the interface and provides implementation for converting
a sound event's Region of Interest (ROI), typically represented by a
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
suitable for use as a machine learning target. This usually involves:
1. Extracting a single reference point (time, frequency) from the geometry.
2. Calculating relevant size dimensions (e.g., duration/width,
bandwidth/height) and applying scaling factors.
It also provides the inverse operation: recovering an approximate geometric ROI
(like a `BoundingBox`) from a predicted reference point and predicted size
dimensions.
This logic is encapsulated within components adhering to the `ROITargetMapper`
protocol. Configuration for this mapping (e.g., which reference point to use,
scaling factors) is managed by the `ROIConfig`. This module separates the
*geometric* aspect of target definition from the *semantic* classification
handled in `batdetect2.targets.classes`.
"""
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
import numpy as np
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"ROITargetMapper",
"BBoxAnchorMapperConfig",
"AnchorBBoxMapper",
"build_roi_mapper",
"load_roi_mapper",
"DEFAULT_ANCHOR",
"SIZE_WIDTH",
"SIZE_HEIGHT",
"SIZE_ORDER",
"DEFAULT_TIME_SCALE",
"DEFAULT_FREQUENCY_SCALE",
]
Anchor = Literal[
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
SIZE_WIDTH = "width"
"""Standard name for the width/time dimension component ('width')."""
SIZE_HEIGHT = "height"
"""Standard name for the height/frequency dimension component ('height')."""
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
"""Standard order of dimensions for size arrays ([width, height])."""
DEFAULT_TIME_SCALE = 1000.0
"""Default scaling factor for time duration."""
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
"""Default scaling factor for frequency bandwidth."""
DEFAULT_ANCHOR = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner)."""
Position = tuple[float, float]
Size = np.ndarray
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the methods required for converting a geometric region of interest
(`soundevent.data.Geometry`) into a target representation (reference point
and scaled dimensions) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions returned by
`get_roi_size` and expected by `recover_roi`
(e.g., ['width', 'height']).
"""
dimension_names: List[str]
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Extract the reference position from a geometry.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry (e.g., BoundingBox, Polygon).
Returns
-------
Tuple[float, float]
The calculated reference position as (time, frequency) coordinates,
based on the implementing class's configuration (e.g., "center",
"bottom-left").
Raises
------
ValueError
If the position cannot be calculated for the given geometry type
or configured reference point.
"""
...
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Recover an approximate ROI from a position and target dimensions.
Performs the inverse mapping: takes a reference position and the
predicted dimensions and reconstructs a geometric representation.
Parameters
----------
position : Tuple[float, float]
The reference position (time, frequency).
size : np.ndarray
NumPy array containing the dimensions, matching the order
specified by `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry.
Raises
------
ValueError
If the number of provided dimensions `dims` does not match
`dimension_names` or if reconstruction fails.
"""
...
class BBoxAnchorMapperConfig(BaseConfig):
"""Configuration for mapping Regions of Interest (ROIs).
Defines parameters controlling how geometric ROIs are converted into
target representations (reference points and scaled sizes).
Attributes
----------
anchor : Anchor, default="bottom-left"
Specifies the reference point within the geometry (e.g., bounding box)
to use as the target location (e.g., "center", "bottom-left").
time_scale : float, default=1000.0
Scaling factor applied to the time duration (width) of the ROI
when calculating the target size representation. Must match model
expectations.
frequency_scale : float, default=1/859.375
Scaling factor applied to the frequency bandwidth (height) of the ROI
when calculating the target size representation. Must match model
expectations.
"""
name: Literal["anchor_bbox"] = "anchor_bbox"
anchor: Anchor = DEFAULT_ANCHOR
time_scale: float = DEFAULT_TIME_SCALE
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
class AnchorBBoxMapper(ROITargetMapper):
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
This class implements the ROI mapping protocol primarily for
`soundevent.data.BoundingBox` geometry. It extracts reference points,
calculates scaled width/height, and recovers bounding boxes based on
configured position and scaling factors.
Attributes
----------
dimension_names : List[str]
Specifies the output dimension names as ['width', 'height'].
anchor : Anchor
The configured reference point type (e.g., "center", "bottom-left").
time_scale : float
The configured scaling factor for the time dimension (width).
frequency_scale : float
The configured scaling factor for the frequency dimension (height).
"""
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
def __init__(
self,
anchor: Anchor = DEFAULT_ANCHOR,
time_scale: float = DEFAULT_TIME_SCALE,
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
):
"""Initialize the BBoxEncoder.
Parameters
----------
anchor : Anchor, default="bottom-left"
Reference point type within the bounding box.
time_scale : float, default=1000.0
Scaling factor for time duration (width).
frequency_scale : float, default=1/859.375
Scaling factor for frequency bandwidth (height).
"""
self.anchor: Anchor = anchor
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
"""Extract the configured reference position from the geometry.
Uses `soundevent.geometry.get_geometry_point`.
Parameters
----------
geom : soundevent.data.Geometry
Input geometry (e.g., BoundingBox).
Returns
-------
Tuple[float, float]
Reference position (time, frequency).
"""
from soundevent import geometry
geom = sound_event.geometry
if geom is None:
raise ValueError(
"Cannot encode the geometry of a sound event without geometry."
f" Sound event: {sound_event}"
)
position = geometry.get_geometry_point(geom, position=self.anchor)
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
size = np.array(
[
(end_time - start_time) * self.time_scale,
(high_freq - low_freq) * self.frequency_scale,
]
)
return position, size
def decode(
self,
position: Position,
size: Size,
) -> data.Geometry:
"""Recover a BoundingBox from a position and scaled dimensions.
Un-scales the input dimensions using the configured factors and
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
the given reference `pos` according to the configured `position` type.
Parameters
----------
pos : Tuple[float, float]
Reference position (time, frequency).
dims : np.ndarray
NumPy array containing the *scaled* dimensions, expected order is
[scaled_width, scaled_height].
Returns
-------
soundevent.data.BoundingBox
The reconstructed bounding box.
Raises
------
ValueError
If `dims` does not have the expected shape (length 2).
"""
if size.ndim != 1 or size.shape[0] != 2:
raise ValueError(
"Dimension array does not have the expected shape. "
f"({size.shape = }) != ([2])"
)
width, height = size
return _build_bounding_box(
position,
duration=float(width) / self.time_scale,
bandwidth=float(height) / self.frequency_scale,
anchor=self.anchor,
)
class PeakEnergyBBoxMapperConfig(BaseConfig):
name: Literal["peak_energy_bbox"]
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
loading_buffer: float = 0.01
time_scale: float = DEFAULT_TIME_SCALE
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
class PeakEnergyBBoxMapper(ROITargetMapper):
"""
Encodes the ROI using the location of the peak energy within the bounding box
as the 'position' and the distances from that point to the box edges as the 'size'.
"""
dimension_names = ["left", "bottom", "right", "top"]
def __init__(
self,
preprocessor: PreprocessorProtocol,
time_scale: float = DEFAULT_TIME_SCALE,
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
loading_buffer: float = 0.01,
):
self.preprocessor = preprocessor
self.time_scale = time_scale
self.frequency_scale = frequency_scale
self.loading_buffer = loading_buffer
def encode(
self,
sound_event: data.SoundEvent,
) -> tuple[Position, Size]:
from soundevent import geometry
geom = sound_event.geometry
if geom is None:
raise ValueError(
"Cannot encode the geometry of a sound event without geometry."
f" Sound event: {sound_event}"
)
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
time, freq = get_peak_energy_coordinates(
recording=sound_event.recording,
preprocessor=self.preprocessor,
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
loading_buffer=self.loading_buffer,
)
size = np.array(
[
(time - start_time) * self.time_scale,
(freq - low_freq) * self.frequency_scale,
(end_time - time) * self.time_scale,
(high_freq - freq) * self.frequency_scale,
]
)
return (time, freq), size
def decode(self, position: Position, size: Size) -> data.Geometry:
time, freq = position
left, bottom, right, top = size
return data.BoundingBox(
coordinates=[
time - max(0, float(left)) / self.time_scale,
freq - max(0, float(bottom)) / self.frequency_scale,
time + max(0, float(right)) / self.time_scale,
freq + max(0, float(top)) / self.frequency_scale,
]
)
ROIMapperConfig = Annotated[
Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig],
Field(discriminator="name"),
]
def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from configuration.
Parameters
----------
config : ROIConfig
Configuration object specifying ROI mapping parameters.
Returns
-------
ROITargetMapper
An initialized `BBoxEncoder` instance configured with the settings
from `config`.
"""
if config.name == "anchor_bbox":
return AnchorBBoxMapper(
anchor=config.anchor,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
if config.name == "peak_energy_bbox":
preprocessor = build_preprocessor(config.preprocessing)
return PeakEnergyBBoxMapper(
preprocessor=preprocessor,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
loading_buffer=config.loading_buffer,
)
raise NotImplementedError(
f"No ROI mapper of name {config.name} is implemented"
)
def load_roi_mapper(
path: data.PathLike, field: Optional[str] = None
) -> ROITargetMapper:
"""Load ROI mapping configuration from a file and build the mapper.
Convenience function that loads an `ROIConfig` from the specified file
(and optional field) and then uses `build_roi_mapper` to create the
corresponding `ROITargetMapper` instance.
Parameters
----------
path : PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing the
ROI configuration. If None, the entire file content is used.
Returns
-------
ROITargetMapper
An initialized ROI mapper instance based on the configuration file.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
If the configuration file cannot be found, parsed, validated, or if
the specified `field` is invalid.
"""
config = load_config(path=path, schema=BBoxAnchorMapperConfig, field=field)
return build_roi_mapper(config)
VALID_ANCHORS = [
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
def _build_bounding_box(
pos: tuple[float, float],
duration: float,
bandwidth: float,
anchor: Anchor = DEFAULT_ANCHOR,
) -> data.BoundingBox:
"""Construct a BoundingBox from a reference point, size, and position type.
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
coordinates [start_time, low_freq, end_time, high_freq] based on where
the input `pos` (time, freq) is located relative to the box (e.g.,
center, corner).
Parameters
----------
pos : Tuple[float, float]
Reference position (time, frequency).
duration : float
The required *unscaled* duration (width) of the bounding box.
bandwidth : float
The required *unscaled* frequency bandwidth (height) of the bounding
box.
anchor : Anchor, default="bottom-left"
Specifies which part of the bounding box the input `pos` corresponds to.
Returns
-------
data.BoundingBox
The constructed bounding box object.
Raises
------
ValueError
If `anchor` is not a recognized value or format.
"""
time, freq = map(float, pos)
duration = max(0, duration)
bandwidth = max(0, bandwidth)
if anchor in ["center", "centroid", "point_on_surface"]:
return data.BoundingBox(
coordinates=[
max(time - duration / 2, 0),
max(freq - bandwidth / 2, 0),
max(time + duration / 2, 0),
max(freq + bandwidth / 2, 0),
]
)
if anchor not in VALID_ANCHORS:
raise ValueError(
f"Invalid anchor: {anchor}. Valid options are: {VALID_ANCHORS}"
)
y, x = anchor.split("-")
start_time = {
"left": time,
"center": time - duration / 2,
"right": time - duration,
}[x]
low_freq = {
"bottom": freq,
"center": freq - bandwidth / 2,
"top": freq - bandwidth,
}[y]
return data.BoundingBox(
coordinates=[
max(0, start_time),
max(0, low_freq),
max(0, start_time + duration),
max(0, low_freq + bandwidth),
]
)
def get_peak_energy_coordinates(
recording: data.Recording,
preprocessor: PreprocessorProtocol,
start_time: float = 0,
end_time: Optional[float] = None,
low_freq: float = 0,
high_freq: Optional[float] = None,
loading_buffer: float = 0.05,
) -> Position:
if end_time is None:
end_time = recording.duration
end_time = min(end_time, recording.duration)
if high_freq is None:
high_freq = recording.samplerate / 2
clip_start = max(0, start_time - loading_buffer)
clip_end = min(recording.duration, end_time + loading_buffer)
clip = data.Clip(
recording=recording,
start_time=clip_start,
end_time=clip_end,
)
spec = preprocessor.preprocess_clip(clip)
low_freq = max(low_freq, preprocessor.min_freq)
high_freq = min(high_freq, preprocessor.max_freq)
selection = spec.sel(
time=slice(start_time, end_time),
frequency=slice(low_freq, high_freq),
)
index = selection.argmax(dim=["time", "frequency"])
point = selection.isel(index) # type: ignore
peak_time: float = point.time.item()
peak_freq: float = point.frequency.item()
return peak_time, peak_freq