mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Changed ROIMapper protocol to only have encoder/decoder methods
This commit is contained in:
parent
ebad489cb1
commit
c559bcc682
@ -50,7 +50,7 @@ from batdetect2.targets.filtering import (
|
||||
load_filter_from_config,
|
||||
)
|
||||
from batdetect2.targets.rois import (
|
||||
ROIConfig,
|
||||
BBoxAnchorMapperConfig,
|
||||
ROITargetMapper,
|
||||
build_roi_mapper,
|
||||
)
|
||||
@ -88,7 +88,7 @@ __all__ = [
|
||||
"FilterConfig",
|
||||
"FilterRule",
|
||||
"MapValueRule",
|
||||
"ROIConfig",
|
||||
"BBoxAnchorMapperConfig",
|
||||
"ROITargetMapper",
|
||||
"ReplaceRule",
|
||||
"SoundEventDecoder",
|
||||
@ -156,12 +156,12 @@ class TargetConfig(BaseConfig):
|
||||
omitted, default ROI mapping settings are used.
|
||||
"""
|
||||
|
||||
filtering: Optional[FilterConfig] = None
|
||||
transforms: Optional[TransformConfig] = None
|
||||
filtering: FilterConfig = Field(default_factory=FilterConfig)
|
||||
transforms: TransformConfig = Field(default_factory=TransformConfig)
|
||||
classes: ClassesConfig = Field(
|
||||
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
||||
)
|
||||
roi: Optional[ROIConfig] = None
|
||||
roi: Optional[BBoxAnchorMapperConfig] = None
|
||||
|
||||
|
||||
def load_target_config(
|
||||
@ -374,14 +374,7 @@ class Targets(TargetProtocol):
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
geom = sound_event.sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Sound event has no geometry, cannot get its position."
|
||||
)
|
||||
|
||||
return self._roi_mapper.get_roi_position(geom)
|
||||
return self._roi_mapper.encode_position(sound_event.sound_event)
|
||||
|
||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||
"""Calculate the target size dimensions from the annotation's geometry.
|
||||
@ -405,14 +398,7 @@ class Targets(TargetProtocol):
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
geom = sound_event.sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Sound event has no geometry, cannot get its size."
|
||||
)
|
||||
|
||||
return self._roi_mapper.get_roi_size(geom)
|
||||
return self._roi_mapper.encode_size(sound_event.sound_event)
|
||||
|
||||
def recover_roi(
|
||||
self,
|
||||
@ -438,7 +424,7 @@ class Targets(TargetProtocol):
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
return self._roi_mapper.recover_roi(pos, dims)
|
||||
return self._roi_mapper.decode(pos, dims)
|
||||
|
||||
|
||||
DEFAULT_CLASSES = [
|
||||
@ -606,7 +592,7 @@ def build_targets(
|
||||
if config.transforms
|
||||
else None
|
||||
)
|
||||
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
|
||||
roi_mapper = build_roi_mapper(config.roi or BBoxAnchorMapperConfig())
|
||||
class_names = get_class_names_from_config(config.classes)
|
||||
generic_class_tags = build_generic_class_tags(
|
||||
config.classes,
|
||||
|
@ -27,8 +27,29 @@ __all__ = [
|
||||
"build_generic_class_tags",
|
||||
"get_class_names_from_config",
|
||||
"DEFAULT_SPECIES_LIST",
|
||||
"PositionMethod",
|
||||
"CornerPosition",
|
||||
"SizeMethod",
|
||||
"BoundingBoxSize",
|
||||
]
|
||||
|
||||
class PositionMethod(BaseConfig):
|
||||
"""Base class for defining how to select a position from a geometry."""
|
||||
method_type: str
|
||||
|
||||
class CornerPosition(PositionMethod):
|
||||
"""Selects a position based on a corner or center of the bounding box."""
|
||||
method_type: Literal["corner"] = "corner"
|
||||
corner: Literal["upper_left", "lower_left", "center"] = "lower_left"
|
||||
|
||||
class SizeMethod(BaseConfig):
|
||||
"""Base class for defining how to select a size from a geometry."""
|
||||
method_type: str
|
||||
|
||||
class BoundingBoxSize(SizeMethod):
|
||||
"""Uses the width and height of the bounding box as the size."""
|
||||
method_type: Literal["bounding_box"] = "bounding_box"
|
||||
|
||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
"""Type alias for a sound event class encoder function.
|
||||
|
||||
@ -113,6 +134,8 @@ class TargetClass(BaseConfig):
|
||||
tags: List[TagInfo] = Field(min_length=1)
|
||||
match_type: Literal["all", "any"] = Field(default="all")
|
||||
output_tags: Optional[List[TagInfo]] = None
|
||||
position_method: PositionMethod = Field(default_factory=lambda: CornerPosition(corner="lower_left"))
|
||||
size_method: SizeMethod = Field(default_factory=BoundingBoxSize)
|
||||
|
||||
|
||||
def _get_default_classes() -> List[TargetClass]:
|
||||
|
@ -20,14 +20,31 @@ scaling factors) is managed by the `ROIConfig`. This module separates the
|
||||
handled in `batdetect2.targets.classes`.
|
||||
"""
|
||||
|
||||
from typing import List, Literal, Optional, Protocol, Tuple
|
||||
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
Positions = Literal[
|
||||
__all__ = [
|
||||
"ROITargetMapper",
|
||||
"BBoxAnchorMapperConfig",
|
||||
"AnchorBBoxMapper",
|
||||
"build_roi_mapper",
|
||||
"load_roi_mapper",
|
||||
"DEFAULT_ANCHOR",
|
||||
"SIZE_WIDTH",
|
||||
"SIZE_HEIGHT",
|
||||
"SIZE_ORDER",
|
||||
"DEFAULT_TIME_SCALE",
|
||||
"DEFAULT_FREQUENCY_SCALE",
|
||||
]
|
||||
|
||||
Anchor = Literal[
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
"top-left",
|
||||
@ -41,20 +58,6 @@ Positions = Literal[
|
||||
"point_on_surface",
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"ROITargetMapper",
|
||||
"ROIConfig",
|
||||
"BBoxEncoder",
|
||||
"build_roi_mapper",
|
||||
"load_roi_mapper",
|
||||
"DEFAULT_POSITION",
|
||||
"SIZE_WIDTH",
|
||||
"SIZE_HEIGHT",
|
||||
"SIZE_ORDER",
|
||||
"DEFAULT_TIME_SCALE",
|
||||
"DEFAULT_FREQUENCY_SCALE",
|
||||
]
|
||||
|
||||
SIZE_WIDTH = "width"
|
||||
"""Standard name for the width/time dimension component ('width')."""
|
||||
|
||||
@ -71,10 +74,15 @@ DEFAULT_FREQUENCY_SCALE = 1 / 859.375
|
||||
"""Default scaling factor for frequency bandwidth."""
|
||||
|
||||
|
||||
DEFAULT_POSITION = "bottom-left"
|
||||
DEFAULT_ANCHOR = "bottom-left"
|
||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||
|
||||
|
||||
Position = tuple[float, float]
|
||||
|
||||
Size = np.ndarray
|
||||
|
||||
|
||||
class ROITargetMapper(Protocol):
|
||||
"""Protocol defining the interface for ROI-to-target mapping.
|
||||
|
||||
@ -93,21 +101,13 @@ class ROITargetMapper(Protocol):
|
||||
|
||||
dimension_names: List[str]
|
||||
|
||||
def get_roi_position(
|
||||
self,
|
||||
geom: data.Geometry,
|
||||
position: Optional[Positions] = None,
|
||||
) -> tuple[float, float]:
|
||||
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||
"""Extract the reference position from a geometry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
The input geometry (e.g., BoundingBox, Polygon).
|
||||
position : Positions, optional
|
||||
Overrides the default `position` configured for the mapper.
|
||||
If provided, this position will be used instead of the mapper's
|
||||
internal default.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -124,36 +124,7 @@ class ROITargetMapper(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||
"""Calculate the scaled target dimensions from a geometry.
|
||||
|
||||
Computes the relevant size measures.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
The input geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
A NumPy array containing the scaled dimensions corresponding to
|
||||
`dimension_names`. For bounding boxes, typically contains
|
||||
`[scaled_width, scaled_height]`.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError, ValueError
|
||||
If the size cannot be computed for the given geometry type.
|
||||
"""
|
||||
...
|
||||
|
||||
def recover_roi(
|
||||
self,
|
||||
pos: tuple[float, float],
|
||||
dims: np.ndarray,
|
||||
position: Optional[Positions] = None,
|
||||
) -> data.Geometry:
|
||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||
"""Recover an approximate ROI from a position and target dimensions.
|
||||
|
||||
Performs the inverse mapping: takes a reference position and the
|
||||
@ -161,15 +132,11 @@ class ROITargetMapper(Protocol):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
position : Tuple[float, float]
|
||||
The reference position (time, frequency).
|
||||
dims : np.ndarray
|
||||
size : np.ndarray
|
||||
NumPy array containing the dimensions, matching the order
|
||||
specified by `dimension_names`.
|
||||
position : Positions, optional
|
||||
Overrides the default `position` configured for the mapper.
|
||||
If provided, this position will be used instead of the mapper's
|
||||
internal default when reconstructing the roi geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -185,7 +152,7 @@ class ROITargetMapper(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class ROIConfig(BaseConfig):
|
||||
class BBoxAnchorMapperConfig(BaseConfig):
|
||||
"""Configuration for mapping Regions of Interest (ROIs).
|
||||
|
||||
Defines parameters controlling how geometric ROIs are converted into
|
||||
@ -193,10 +160,9 @@ class ROIConfig(BaseConfig):
|
||||
|
||||
Attributes
|
||||
----------
|
||||
position : Positions, default="bottom-left"
|
||||
anchor : Anchor, default="bottom-left"
|
||||
Specifies the reference point within the geometry (e.g., bounding box)
|
||||
to use as the target location (e.g., "center", "bottom-left").
|
||||
See `soundevent.geometry.operations.Positions`.
|
||||
time_scale : float, default=1000.0
|
||||
Scaling factor applied to the time duration (width) of the ROI
|
||||
when calculating the target size representation. Must match model
|
||||
@ -207,12 +173,13 @@ class ROIConfig(BaseConfig):
|
||||
expectations.
|
||||
"""
|
||||
|
||||
position: Positions = DEFAULT_POSITION
|
||||
name: Literal["anchor_bbox"] = "anchor_bbox"
|
||||
anchor: Anchor = DEFAULT_ANCHOR
|
||||
time_scale: float = DEFAULT_TIME_SCALE
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
||||
|
||||
|
||||
class BBoxEncoder(ROITargetMapper):
|
||||
class AnchorBBoxMapper(ROITargetMapper):
|
||||
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
|
||||
|
||||
This class implements the ROI mapping protocol primarily for
|
||||
@ -224,7 +191,7 @@ class BBoxEncoder(ROITargetMapper):
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
Specifies the output dimension names as ['width', 'height'].
|
||||
position : Positions
|
||||
anchor : Anchor
|
||||
The configured reference point type (e.g., "center", "bottom-left").
|
||||
time_scale : float
|
||||
The configured scaling factor for the time dimension (width).
|
||||
@ -236,7 +203,7 @@ class BBoxEncoder(ROITargetMapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
position: Positions = DEFAULT_POSITION,
|
||||
anchor: Anchor = DEFAULT_ANCHOR,
|
||||
time_scale: float = DEFAULT_TIME_SCALE,
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||
):
|
||||
@ -244,22 +211,18 @@ class BBoxEncoder(ROITargetMapper):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Positions, default="bottom-left"
|
||||
anchor : Anchor, default="bottom-left"
|
||||
Reference point type within the bounding box.
|
||||
time_scale : float, default=1000.0
|
||||
Scaling factor for time duration (width).
|
||||
frequency_scale : float, default=1/859.375
|
||||
Scaling factor for frequency bandwidth (height).
|
||||
"""
|
||||
self.position: Positions = position
|
||||
self.anchor: Anchor = anchor
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
|
||||
def get_roi_position(
|
||||
self,
|
||||
geom: data.Geometry,
|
||||
position: Optional[Positions] = None,
|
||||
) -> Tuple[float, float]:
|
||||
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
|
||||
"""Extract the configured reference position from the geometry.
|
||||
|
||||
Uses `soundevent.geometry.get_geometry_point`.
|
||||
@ -268,9 +231,6 @@ class BBoxEncoder(ROITargetMapper):
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
Input geometry (e.g., BoundingBox).
|
||||
position : Positions, optional
|
||||
Overrides the default `position` configured for the encoder.
|
||||
If provided, this position will be used instead of `self.position`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -279,42 +239,33 @@ class BBoxEncoder(ROITargetMapper):
|
||||
"""
|
||||
from soundevent import geometry
|
||||
|
||||
position = position or self.position
|
||||
return geometry.get_geometry_point(geom, position=position)
|
||||
geom = sound_event.geometry
|
||||
|
||||
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||
"""Calculate the scaled [width, height] from the geometry's bounds.
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Cannot encode the geometry of a sound event without geometry."
|
||||
f" Sound event: {sound_event}"
|
||||
)
|
||||
|
||||
Computes the bounding box, extracts duration and bandwidth, and applies
|
||||
the configured `time_scale` and `frequency_scale`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
Input geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
A 1D NumPy array: `[scaled_width, scaled_height]`.
|
||||
"""
|
||||
from soundevent import geometry
|
||||
position = geometry.get_geometry_point(geom, position=self.anchor)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
return np.array(
|
||||
|
||||
size = np.array(
|
||||
[
|
||||
(end_time - start_time) * self.time_scale,
|
||||
(high_freq - low_freq) * self.frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
def recover_roi(
|
||||
return position, size
|
||||
|
||||
def decode(
|
||||
self,
|
||||
pos: tuple[float, float],
|
||||
dims: np.ndarray,
|
||||
position: Optional[Positions] = None,
|
||||
position: Position,
|
||||
size: Size,
|
||||
) -> data.Geometry:
|
||||
"""Recover a BoundingBox from a position and scaled dimensions.
|
||||
|
||||
@ -329,10 +280,6 @@ class BBoxEncoder(ROITargetMapper):
|
||||
dims : np.ndarray
|
||||
NumPy array containing the *scaled* dimensions, expected order is
|
||||
[scaled_width, scaled_height].
|
||||
position : Positions, optional
|
||||
Overrides the default `position` configured for the encoder.
|
||||
If provided, this position will be used instead of `self.position`
|
||||
when reconstructing the bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -344,28 +291,113 @@ class BBoxEncoder(ROITargetMapper):
|
||||
ValueError
|
||||
If `dims` does not have the expected shape (length 2).
|
||||
"""
|
||||
position = position or self.position
|
||||
|
||||
if dims.ndim != 1 or dims.shape[0] != 2:
|
||||
if size.ndim != 1 or size.shape[0] != 2:
|
||||
raise ValueError(
|
||||
"Dimension array does not have the expected shape. "
|
||||
f"({dims.shape = }) != ([2])"
|
||||
f"({size.shape = }) != ([2])"
|
||||
)
|
||||
|
||||
width, height = dims
|
||||
width, height = size
|
||||
return _build_bounding_box(
|
||||
pos,
|
||||
position,
|
||||
duration=float(width) / self.time_scale,
|
||||
bandwidth=float(height) / self.frequency_scale,
|
||||
position=self.position,
|
||||
anchor=self.anchor,
|
||||
)
|
||||
|
||||
|
||||
def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
|
||||
"""Factory function to create an ROITargetMapper from configuration.
|
||||
class PeakEnergyBBoxMapperConfig(BaseConfig):
|
||||
name: Literal["peak_energy_bbox"]
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
loading_buffer: float = 0.01
|
||||
time_scale: float = DEFAULT_TIME_SCALE
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
||||
|
||||
Currently creates a `BBoxEncoder` instance based on the provided
|
||||
`ROIConfig`.
|
||||
|
||||
class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
"""
|
||||
Encodes the ROI using the location of the peak energy within the bounding box
|
||||
as the 'position' and the distances from that point to the box edges as the 'size'.
|
||||
"""
|
||||
|
||||
dimension_names = ["left", "bottom", "right", "top"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
time_scale: float = DEFAULT_TIME_SCALE,
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||
loading_buffer: float = 0.01,
|
||||
):
|
||||
self.preprocessor = preprocessor
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
self.loading_buffer = loading_buffer
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sound_event: data.SoundEvent,
|
||||
) -> tuple[Position, Size]:
|
||||
from soundevent import geometry
|
||||
|
||||
geom = sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Cannot encode the geometry of a sound event without geometry."
|
||||
f" Sound event: {sound_event}"
|
||||
)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
|
||||
time, freq = get_peak_energy_coordinates(
|
||||
recording=sound_event.recording,
|
||||
preprocessor=self.preprocessor,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
loading_buffer=self.loading_buffer,
|
||||
)
|
||||
|
||||
size = np.array(
|
||||
[
|
||||
(time - start_time) * self.time_scale,
|
||||
(freq - low_freq) * self.frequency_scale,
|
||||
(end_time - time) * self.time_scale,
|
||||
(high_freq - freq) * self.frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
return (time, freq), size
|
||||
|
||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||
time, freq = position
|
||||
left, bottom, right, top = size
|
||||
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
time - max(0, float(left)) / self.time_scale,
|
||||
freq - max(0, float(bottom)) / self.frequency_scale,
|
||||
time + max(0, float(right)) / self.time_scale,
|
||||
freq + max(0, float(top)) / self.frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
ROIMapperConfig = Annotated[
|
||||
Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
|
||||
"""Factory function to create an ROITargetMapper from configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -378,10 +410,24 @@ def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
|
||||
An initialized `BBoxEncoder` instance configured with the settings
|
||||
from `config`.
|
||||
"""
|
||||
return BBoxEncoder(
|
||||
position=config.position,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
if config.name == "anchor_bbox":
|
||||
return AnchorBBoxMapper(
|
||||
anchor=config.anchor,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
if config.name == "peak_energy_bbox":
|
||||
preprocessor = build_preprocessor(config.preprocessing)
|
||||
return PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
loading_buffer=config.loading_buffer,
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"No ROI mapper of name {config.name} is implemented"
|
||||
)
|
||||
|
||||
|
||||
@ -414,11 +460,11 @@ def load_roi_mapper(
|
||||
If the configuration file cannot be found, parsed, validated, or if
|
||||
the specified `field` is invalid.
|
||||
"""
|
||||
config = load_config(path=path, schema=ROIConfig, field=field)
|
||||
config = load_config(path=path, schema=BBoxAnchorMapperConfig, field=field)
|
||||
return build_roi_mapper(config)
|
||||
|
||||
|
||||
VALID_POSITIONS = [
|
||||
VALID_ANCHORS = [
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
"top-left",
|
||||
@ -437,7 +483,7 @@ def _build_bounding_box(
|
||||
pos: tuple[float, float],
|
||||
duration: float,
|
||||
bandwidth: float,
|
||||
position: Positions = DEFAULT_POSITION,
|
||||
anchor: Anchor = DEFAULT_ANCHOR,
|
||||
) -> data.BoundingBox:
|
||||
"""Construct a BoundingBox from a reference point, size, and position type.
|
||||
|
||||
@ -455,7 +501,7 @@ def _build_bounding_box(
|
||||
bandwidth : float
|
||||
The required *unscaled* frequency bandwidth (height) of the bounding
|
||||
box.
|
||||
position : Positions, default="bottom-left"
|
||||
anchor : Anchor, default="bottom-left"
|
||||
Specifies which part of the bounding box the input `pos` corresponds to.
|
||||
|
||||
Returns
|
||||
@ -466,12 +512,12 @@ def _build_bounding_box(
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `position` is not a recognized value or format.
|
||||
If `anchor` is not a recognized value or format.
|
||||
"""
|
||||
time, freq = map(float, pos)
|
||||
duration = max(0, duration)
|
||||
bandwidth = max(0, bandwidth)
|
||||
if position in ["center", "centroid", "point_on_surface"]:
|
||||
if anchor in ["center", "centroid", "point_on_surface"]:
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
max(time - duration / 2, 0),
|
||||
@ -481,13 +527,12 @@ def _build_bounding_box(
|
||||
]
|
||||
)
|
||||
|
||||
if position not in VALID_POSITIONS:
|
||||
if anchor not in VALID_ANCHORS:
|
||||
raise ValueError(
|
||||
f"Invalid position: {position}. "
|
||||
f"Valid options are: {VALID_POSITIONS}"
|
||||
f"Invalid anchor: {anchor}. Valid options are: {VALID_ANCHORS}"
|
||||
)
|
||||
|
||||
y, x = position.split("-")
|
||||
y, x = anchor.split("-")
|
||||
|
||||
start_time = {
|
||||
"left": time,
|
||||
@ -509,3 +554,43 @@ def _build_bounding_box(
|
||||
max(0, low_freq + bandwidth),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_peak_energy_coordinates(
|
||||
recording: data.Recording,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
start_time: float = 0,
|
||||
end_time: Optional[float] = None,
|
||||
low_freq: float = 0,
|
||||
high_freq: Optional[float] = None,
|
||||
loading_buffer: float = 0.05,
|
||||
) -> Position:
|
||||
if end_time is None:
|
||||
end_time = recording.duration
|
||||
end_time = min(end_time, recording.duration)
|
||||
|
||||
if high_freq is None:
|
||||
high_freq = recording.samplerate / 2
|
||||
|
||||
clip_start = max(0, start_time - loading_buffer)
|
||||
clip_end = min(recording.duration, end_time + loading_buffer)
|
||||
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=clip_start,
|
||||
end_time=clip_end,
|
||||
)
|
||||
|
||||
spec = preprocessor.preprocess_clip(clip)
|
||||
low_freq = max(low_freq, preprocessor.min_freq)
|
||||
high_freq = min(high_freq, preprocessor.max_freq)
|
||||
selection = spec.sel(
|
||||
time=slice(start_time, end_time),
|
||||
frequency=slice(low_freq, high_freq),
|
||||
)
|
||||
|
||||
index = selection.argmax(dim=["time", "frequency"])
|
||||
point = selection.isel(index) # type: ignore
|
||||
peak_time: float = point.time.item()
|
||||
peak_freq: float = point.frequency.item()
|
||||
return peak_time, peak_freq
|
||||
|
@ -4,12 +4,12 @@ from soundevent import data
|
||||
|
||||
from batdetect2.targets.rois import (
|
||||
DEFAULT_FREQUENCY_SCALE,
|
||||
DEFAULT_POSITION,
|
||||
DEFAULT_ANCHOR,
|
||||
DEFAULT_TIME_SCALE,
|
||||
SIZE_HEIGHT,
|
||||
SIZE_WIDTH,
|
||||
BBoxEncoder,
|
||||
ROIConfig,
|
||||
AnchorBBoxMapper,
|
||||
BBoxAnchorMapperConfig,
|
||||
_build_bounding_box,
|
||||
build_roi_mapper,
|
||||
load_roi_mapper,
|
||||
@ -29,36 +29,36 @@ def zero_bbox() -> data.BoundingBox:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_encoder() -> BBoxEncoder:
|
||||
def default_encoder() -> AnchorBBoxMapper:
|
||||
"""A BBoxEncoder with default settings."""
|
||||
return BBoxEncoder()
|
||||
return AnchorBBoxMapper()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_encoder() -> BBoxEncoder:
|
||||
def custom_encoder() -> AnchorBBoxMapper:
|
||||
"""A BBoxEncoder with custom settings."""
|
||||
return BBoxEncoder(position="center", time_scale=1.0, frequency_scale=10.0)
|
||||
return AnchorBBoxMapper(anchor="center", time_scale=1.0, frequency_scale=10.0)
|
||||
|
||||
|
||||
def test_roi_config_defaults():
|
||||
"""Test ROIConfig default values."""
|
||||
config = ROIConfig()
|
||||
assert config.position == DEFAULT_POSITION
|
||||
config = BBoxAnchorMapperConfig()
|
||||
assert config.anchor == DEFAULT_ANCHOR
|
||||
assert config.time_scale == DEFAULT_TIME_SCALE
|
||||
assert config.frequency_scale == DEFAULT_FREQUENCY_SCALE
|
||||
|
||||
|
||||
def test_roi_config_custom():
|
||||
"""Test creating ROIConfig with custom values."""
|
||||
config = ROIConfig(position="center", time_scale=1.0, frequency_scale=10.0)
|
||||
assert config.position == "center"
|
||||
config = BBoxAnchorMapperConfig(anchor="center", time_scale=1.0, frequency_scale=10.0)
|
||||
assert config.anchor == "center"
|
||||
assert config.time_scale == 1.0
|
||||
assert config.frequency_scale == 10.0
|
||||
|
||||
|
||||
def test_bbox_encoder_init_defaults(default_encoder):
|
||||
"""Test BBoxEncoder initialization with default arguments."""
|
||||
assert default_encoder.position == DEFAULT_POSITION
|
||||
assert default_encoder.position == DEFAULT_ANCHOR
|
||||
assert default_encoder.time_scale == DEFAULT_TIME_SCALE
|
||||
assert default_encoder.frequency_scale == DEFAULT_FREQUENCY_SCALE
|
||||
assert default_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
|
||||
@ -92,15 +92,15 @@ def test_bbox_encoder_get_roi_position(
|
||||
sample_bbox, position_type, expected_pos
|
||||
):
|
||||
"""Test get_roi_position for various position types."""
|
||||
encoder = BBoxEncoder(position=position_type)
|
||||
actual_pos = encoder.get_roi_position(sample_bbox)
|
||||
encoder = AnchorBBoxMapper(anchor=position_type)
|
||||
actual_pos = encoder.encode_position(sample_bbox)
|
||||
assert actual_pos == pytest.approx(expected_pos)
|
||||
|
||||
|
||||
def test_bbox_encoder_get_roi_position_zero_box(zero_bbox):
|
||||
"""Test get_roi_position for a zero-sized box."""
|
||||
encoder = BBoxEncoder(position="center")
|
||||
assert encoder.get_roi_position(zero_bbox) == pytest.approx((15.0, 150.0))
|
||||
encoder = AnchorBBoxMapper(anchor="center")
|
||||
assert encoder.encode_position(zero_bbox) == pytest.approx((15.0, 150.0))
|
||||
|
||||
|
||||
def test_bbox_encoder_get_roi_size_defaults(sample_bbox, default_encoder):
|
||||
@ -160,7 +160,7 @@ def test_build_bounding_box(position_type, expected_coords):
|
||||
duration = 10.0
|
||||
bandwidth = 100.0
|
||||
bbox = _build_bounding_box(
|
||||
ref_pos, duration, bandwidth, position=position_type
|
||||
ref_pos, duration, bandwidth, anchor=position_type
|
||||
)
|
||||
assert isinstance(bbox, data.BoundingBox)
|
||||
np.testing.assert_allclose(bbox.coordinates, expected_coords)
|
||||
@ -173,17 +173,17 @@ def test_build_bounding_box_invalid_position():
|
||||
(0, 0),
|
||||
1,
|
||||
1,
|
||||
position="invalid-spot", # type: ignore
|
||||
anchor="invalid-spot", # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("position_type, ref_pos", POSITION_TEST_CASES)
|
||||
def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos):
|
||||
"""Test recover_roi correctly reconstructs the original bbox."""
|
||||
encoder = BBoxEncoder(position=position_type)
|
||||
scaled_dims = encoder.get_roi_size(sample_bbox)
|
||||
encoder = AnchorBBoxMapper(anchor=position_type)
|
||||
scaled_dims = encoder.encode_size(sample_bbox)
|
||||
|
||||
recovered_bbox = encoder.recover_roi(ref_pos, scaled_dims)
|
||||
recovered_bbox = encoder.decode(ref_pos, scaled_dims)
|
||||
|
||||
assert isinstance(recovered_bbox, data.BoundingBox)
|
||||
np.testing.assert_allclose(
|
||||
@ -227,13 +227,13 @@ def test_bbox_encoder_recover_roi_invalid_dims_shape(default_encoder):
|
||||
|
||||
def test_build_roi_mapper():
|
||||
"""Test build_roi_mapper creates a configured BBoxEncoder."""
|
||||
config = ROIConfig(
|
||||
position="top-right", time_scale=2.0, frequency_scale=20.0
|
||||
config = BBoxAnchorMapperConfig(
|
||||
anchor="top-right", time_scale=2.0, frequency_scale=20.0
|
||||
)
|
||||
mapper = build_roi_mapper(config)
|
||||
|
||||
assert isinstance(mapper, BBoxEncoder)
|
||||
assert mapper.position == config.position
|
||||
assert isinstance(mapper, AnchorBBoxMapper)
|
||||
assert mapper.anchor == config.anchor
|
||||
assert mapper.time_scale == config.time_scale
|
||||
assert mapper.frequency_scale == config.frequency_scale
|
||||
|
||||
@ -270,8 +270,8 @@ def test_load_roi_mapper_simple(tmp_path, sample_config_yaml_content):
|
||||
|
||||
mapper = load_roi_mapper(config_path)
|
||||
|
||||
assert isinstance(mapper, BBoxEncoder)
|
||||
assert mapper.position == "center"
|
||||
assert isinstance(mapper, AnchorBBoxMapper)
|
||||
assert mapper.anchor == "center"
|
||||
assert mapper.time_scale == 500.0
|
||||
assert mapper.frequency_scale == pytest.approx(1 / 1000.0)
|
||||
|
||||
@ -283,8 +283,8 @@ def test_load_roi_mapper_nested(tmp_path, nested_config_yaml_content):
|
||||
|
||||
mapper = load_roi_mapper(config_path, field="model_settings.roi_mapping")
|
||||
|
||||
assert isinstance(mapper, BBoxEncoder)
|
||||
assert mapper.position == "bottom-right"
|
||||
assert isinstance(mapper, AnchorBBoxMapper)
|
||||
assert mapper.anchor == "bottom-right"
|
||||
assert mapper.time_scale == DEFAULT_TIME_SCALE
|
||||
assert mapper.frequency_scale == 0.01
|
||||
|
||||
|
@ -5,7 +5,7 @@ import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||
from batdetect2.targets.rois import ROIConfig
|
||||
from batdetect2.targets.rois import BBoxAnchorMapperConfig
|
||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||
from batdetect2.train.labels import generate_heatmaps
|
||||
|
||||
@ -85,7 +85,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
):
|
||||
config = sample_target_config.model_copy(
|
||||
update=dict(
|
||||
roi=ROIConfig(
|
||||
roi=BBoxAnchorMapperConfig(
|
||||
time_scale=1,
|
||||
frequency_scale=1,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user