Fixed Target object after changes to roi

This commit is contained in:
mbsantiago 2025-06-21 13:47:04 +01:00
parent c559bcc682
commit e352dc40bd
15 changed files with 426 additions and 411 deletions

View File

@ -72,7 +72,7 @@ def iterate_over_sound_events(
sound_event_annotation sound_event_annotation
) )
class_name = targets.encode(sound_event_annotation) class_name = targets.encode_class(sound_event_annotation)
if class_name is None and exclude_generic: if class_name is None and exclude_generic:
continue continue

View File

@ -40,7 +40,7 @@ def match_sound_events_and_raw_predictions(
gt_uuid = target.uuid if target is not None else None gt_uuid = target.uuid if target is not None else None
gt_det = target is not None gt_det = target is not None
gt_class = targets.encode(target) if target is not None else None gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0 pred_score = float(prediction.detection_score) if prediction else 0

View File

@ -526,7 +526,7 @@ class Postprocessor(PostprocessorProtocol):
return [ return [
convert_xr_dataset_to_raw_prediction( convert_xr_dataset_to_raw_prediction(
dataset, dataset,
self.targets.recover_roi, self.targets.decode_roi,
) )
for dataset in detection_datasets for dataset in detection_datasets
] ]
@ -558,7 +558,7 @@ class Postprocessor(PostprocessorProtocol):
convert_raw_predictions_to_clip_prediction( convert_raw_predictions_to_clip_prediction(
prediction, prediction,
clip, clip,
sound_event_decoder=self.targets.decode, sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags, generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold, classification_threshold=self.config.classification_threshold,
) )

View File

@ -23,7 +23,6 @@ object is via the `build_targets` or `load_targets` functions.
from typing import List, Optional from typing import List, Optional
import numpy as np
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -63,7 +62,7 @@ from batdetect2.targets.terms import (
get_term_from_key, get_term_from_key,
individual, individual,
register_term, register_term,
term_registry, default_term_registry,
) )
from batdetect2.targets.transform import ( from batdetect2.targets.transform import (
DerivationRegistry, DerivationRegistry,
@ -73,13 +72,13 @@ from batdetect2.targets.transform import (
SoundEventTransformation, SoundEventTransformation,
TransformConfig, TransformConfig,
build_transformation_from_config, build_transformation_from_config,
derivation_registry, default_derivation_registry,
get_derivation, get_derivation,
load_transformation_config, load_transformation_config,
load_transformation_from_config, load_transformation_from_config,
register_derivation, register_derivation,
) )
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import Position, Size, TargetProtocol
__all__ = [ __all__ = [
"ClassesConfig", "ClassesConfig",
@ -291,7 +290,9 @@ class Targets(TargetProtocol):
return True return True
return self._filter_fn(sound_event) return self._filter_fn(sound_event)
def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]: def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
"""Encode a sound event annotation to its target class name. """Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority) Applies the configured class definition rules (including priority)
@ -312,7 +313,7 @@ class Targets(TargetProtocol):
""" """
return self._encode_fn(sound_event) return self._encode_fn(sound_event)
def decode(self, class_label: str) -> List[data.Tag]: def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags. """Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or Uses the configured mapping (based on `TargetClass.output_tags` or
@ -352,9 +353,9 @@ class Targets(TargetProtocol):
return self._transform_fn(sound_event) return self._transform_fn(sound_event)
return sound_event return sound_event
def get_position( def encode_roi(
self, sound_event: data.SoundEventAnnotation self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]: ) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's roi. """Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method. Delegates to the internal ROI mapper's `get_roi_position` method.
@ -374,37 +375,9 @@ class Targets(TargetProtocol):
ValueError ValueError
If the annotation lacks geometry. If the annotation lacks geometry.
""" """
return self._roi_mapper.encode_position(sound_event.sound_event) return self._roi_mapper.encode(sound_event.sound_event)
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray: def decode_roi(self, position: Position, size: Size) -> data.Geometry:
"""Calculate the target size dimensions from the annotation's geometry.
Delegates to the internal ROI mapper's `get_roi_size` method, which
applies configured scaling factors.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
np.ndarray
NumPy array containing the size dimensions, matching the
order in `self.dimension_names` (e.g., `[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry.
"""
return self._roi_mapper.encode_size(sound_event.sound_event)
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions. """Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which Delegates to the internal ROI mapper's `recover_roi` method, which
@ -424,7 +397,7 @@ class Targets(TargetProtocol):
data.Geometry data.Geometry
The reconstructed geometry (typically `BoundingBox`). The reconstructed geometry (typically `BoundingBox`).
""" """
return self._roi_mapper.decode(pos, dims) return self._roi_mapper.decode(position, size)
DEFAULT_CLASSES = [ DEFAULT_CLASSES = [
@ -528,8 +501,8 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
def build_targets( def build_targets(
config: Optional[TargetConfig] = None, config: Optional[TargetConfig] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets: ) -> Targets:
"""Build a Targets object from a loaded TargetConfig. """Build a Targets object from a loaded TargetConfig.
@ -613,8 +586,8 @@ def build_targets(
def load_targets( def load_targets(
config_path: data.PathLike, config_path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets: ) -> Targets:
"""Load a Targets object directly from a configuration file. """Load a Targets object directly from a configuration file.

View File

@ -11,7 +11,7 @@ from batdetect2.targets.terms import (
TagInfo, TagInfo,
TermRegistry, TermRegistry,
get_tag_from_info, get_tag_from_info,
term_registry, default_term_registry,
) )
__all__ = [ __all__ = [
@ -339,7 +339,7 @@ def _encode_with_multiple_classifiers(
def build_sound_event_encoder( def build_sound_event_encoder(
config: ClassesConfig, config: ClassesConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> SoundEventEncoder: ) -> SoundEventEncoder:
"""Build a sound event encoder function from the classes configuration. """Build a sound event encoder function from the classes configuration.
@ -433,7 +433,7 @@ def _decode_class(
def build_sound_event_decoder( def build_sound_event_decoder(
config: ClassesConfig, config: ClassesConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
raise_on_unmapped: bool = False, raise_on_unmapped: bool = False,
) -> SoundEventDecoder: ) -> SoundEventDecoder:
"""Build a sound event decoder function from the classes configuration. """Build a sound event decoder function from the classes configuration.
@ -488,7 +488,7 @@ def build_sound_event_decoder(
def build_generic_class_tags( def build_generic_class_tags(
config: ClassesConfig, config: ClassesConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> List[data.Tag]: ) -> List[data.Tag]:
"""Extract and build the list of tags for the generic class from config. """Extract and build the list of tags for the generic class from config.
@ -553,7 +553,7 @@ def load_classes_config(
def load_encoder_from_config( def load_encoder_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> SoundEventEncoder: ) -> SoundEventEncoder:
"""Load a class encoder function directly from a configuration file. """Load a class encoder function directly from a configuration file.
@ -594,7 +594,7 @@ def load_encoder_from_config(
def load_decoder_from_config( def load_decoder_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
raise_on_unmapped: bool = False, raise_on_unmapped: bool = False,
) -> SoundEventDecoder: ) -> SoundEventDecoder:
"""Load a class decoder function directly from a configuration file. """Load a class decoder function directly from a configuration file.

View File

@ -10,7 +10,7 @@ from batdetect2.targets.terms import (
TagInfo, TagInfo,
TermRegistry, TermRegistry,
get_tag_from_info, get_tag_from_info,
term_registry, default_term_registry,
) )
__all__ = [ __all__ = [
@ -156,7 +156,7 @@ def equal_tags(
def build_filter_from_rule( def build_filter_from_rule(
rule: FilterRule, rule: FilterRule,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter: ) -> SoundEventFilter:
"""Creates a callable filter function from a single FilterRule. """Creates a callable filter function from a single FilterRule.
@ -243,7 +243,7 @@ class FilterConfig(BaseConfig):
def build_sound_event_filter( def build_sound_event_filter(
config: FilterConfig, config: FilterConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter: ) -> SoundEventFilter:
"""Builds a merged filter function from a FilterConfig object. """Builds a merged filter function from a FilterConfig object.
@ -291,7 +291,7 @@ def load_filter_config(
def load_filter_from_config( def load_filter_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter: ) -> SoundEventFilter:
"""Loads filter configuration from a file and builds the filter function. """Loads filter configuration from a file and builds the filter function.

View File

@ -1,23 +1,23 @@
"""Handles mapping between geometric ROIs and target representations. """Handles mapping between geometric ROIs and target representations.
This module defines the interface and provides implementation for converting This module defines a standardized interface (`ROITargetMapper`) for converting
a sound event's Region of Interest (ROI), typically represented by a a sound event's Region of Interest (ROI) into a target representation suitable
`soundevent.data.Geometry` object like a `BoundingBox`, into a format for machine learning models, and for decoding model outputs back into geometric
suitable for use as a machine learning target. This usually involves: ROIs.
1. Extracting a single reference point (time, frequency) from the geometry. The core operations are:
2. Calculating relevant size dimensions (e.g., duration/width, 1. **Encoding**: A `soundevent.data.SoundEvent` is mapped to a reference
bandwidth/height) and applying scaling factors. `Position` (time, frequency) and a `Size` array. The method for
determining the position and size varies by the mapper implementation
(e.g., using a bounding box anchor or the point of peak energy).
2. **Decoding**: A `Position` and `Size` array are mapped back to an
approximate `soundevent.data.Geometry` (typically a `BoundingBox`).
It also provides the inverse operation: recovering an approximate geometric ROI This logic is encapsulated within specific mapper classes. Configuration for
(like a `BoundingBox`) from a predicted reference point and predicted size each mapper (e.g., anchor point, scaling factors) is managed by a corresponding
dimensions. Pydantic config object. The `ROIMapperConfig` type allows for flexibly
selecting and configuring the desired mapper. This module separates the
This logic is encapsulated within components adhering to the `ROITargetMapper` *geometric* aspect of target definition from *semantic* classification.
protocol. Configuration for this mapping (e.g., which reference point to use,
scaling factors) is managed by the `ROIConfig`. This module separates the
*geometric* aspect of target definition from the *semantic* classification
handled in `batdetect2.targets.classes`.
""" """
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
@ -26,22 +26,26 @@ import numpy as np
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import Position, Size
__all__ = [ __all__ = [
"ROITargetMapper", "Anchor",
"BBoxAnchorMapperConfig",
"AnchorBBoxMapper", "AnchorBBoxMapper",
"build_roi_mapper", "BBoxAnchorMapperConfig",
"load_roi_mapper",
"DEFAULT_ANCHOR", "DEFAULT_ANCHOR",
"SIZE_WIDTH", "DEFAULT_FREQUENCY_SCALE",
"DEFAULT_TIME_SCALE",
"PeakEnergyBBoxMapper",
"PeakEnergyBBoxMapperConfig",
"ROIMapperConfig",
"ROITargetMapper",
"SIZE_HEIGHT", "SIZE_HEIGHT",
"SIZE_ORDER", "SIZE_ORDER",
"DEFAULT_TIME_SCALE", "SIZE_WIDTH",
"DEFAULT_FREQUENCY_SCALE", "build_roi_mapper",
] ]
Anchor = Literal[ Anchor = Literal[
@ -73,104 +77,94 @@ DEFAULT_TIME_SCALE = 1000.0
DEFAULT_FREQUENCY_SCALE = 1 / 859.375 DEFAULT_FREQUENCY_SCALE = 1 / 859.375
"""Default scaling factor for frequency bandwidth.""" """Default scaling factor for frequency bandwidth."""
DEFAULT_ANCHOR = "bottom-left" DEFAULT_ANCHOR = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner).""" """Default reference position within the geometry ('bottom-left' corner)."""
Position = tuple[float, float]
Size = np.ndarray
class ROITargetMapper(Protocol): class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping. """Protocol defining the interface for ROI-to-target mapping.
Specifies the methods required for converting a geometric region of interest Specifies the `encode` and `decode` methods required for converting a
(`soundevent.data.Geometry`) into a target representation (reference point `soundevent.data.SoundEvent` into a target representation (a reference
and scaled dimensions) and for recovering an approximate ROI from that position and a size vector) and for recovering an approximate ROI from that
representation. representation.
Attributes Attributes
---------- ----------
dimension_names : List[str] dimension_names : List[str]
A list containing the names of the dimensions returned by A list containing the names of the dimensions in the `Size` array
`get_roi_size` and expected by `recover_roi` returned by `encode` and expected by `decode`.
(e.g., ['width', 'height']).
""" """
dimension_names: List[str] dimension_names: List[str]
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]: def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Extract the reference position from a geometry. """Encode a SoundEvent's geometry into a position and size.
Parameters Parameters
---------- ----------
geom : soundevent.data.Geometry sound_event : data.SoundEvent
The input geometry (e.g., BoundingBox, Polygon). The input sound event, which must have a geometry attribute.
Returns Returns
------- -------
Tuple[float, float] Tuple[Position, Size]
The calculated reference position as (time, frequency) coordinates, A tuple containing:
based on the implementing class's configuration (e.g., "center", - The reference position as (time, frequency) coordinates.
"bottom-left"). - A NumPy array with the calculated size dimensions.
Raises Raises
------ ------
ValueError ValueError
If the position cannot be calculated for the given geometry type If the sound event does not have a geometry.
or configured reference point.
""" """
... ...
def decode(self, position: Position, size: Size) -> data.Geometry: def decode(self, position: Position, size: Size) -> data.Geometry:
"""Recover an approximate ROI from a position and target dimensions. """Decode a position and size back into a geometric ROI.
Performs the inverse mapping: takes a reference position and the Performs the inverse mapping: takes a reference position and size
predicted dimensions and reconstructs a geometric representation. dimensions and reconstructs a geometric representation.
Parameters Parameters
---------- ----------
position : Tuple[float, float] position : Position
The reference position (time, frequency). The reference position (time, frequency).
size : np.ndarray size : Size
NumPy array containing the dimensions, matching the order NumPy array containing the size dimensions, matching the order
specified by `dimension_names`. and meaning specified by `dimension_names`.
Returns Returns
------- -------
soundevent.data.Geometry soundevent.data.Geometry
The reconstructed geometry. The reconstructed geometry, typically a `BoundingBox`.
Raises Raises
------ ------
ValueError ValueError
If the number of provided dimensions `dims` does not match If the `size` array has an unexpected shape or if reconstruction
`dimension_names` or if reconstruction fails. fails.
""" """
... ...
class BBoxAnchorMapperConfig(BaseConfig): class BBoxAnchorMapperConfig(BaseConfig):
"""Configuration for mapping Regions of Interest (ROIs). """Configuration for `AnchorBBoxMapper`.
Defines parameters controlling how geometric ROIs are converted into Defines parameters for converting ROIs into targets using a fixed anchor
target representations (reference points and scaled sizes). point on the bounding box.
Attributes Attributes
---------- ----------
anchor : Anchor, default="bottom-left" name : Literal["anchor_bbox"]
Specifies the reference point within the geometry (e.g., bounding box) The unique identifier for this mapper type.
to use as the target location (e.g., "center", "bottom-left"). anchor : Anchor
time_scale : float, default=1000.0 Specifies the anchor point within the bounding box to use as the
Scaling factor applied to the time duration (width) of the ROI target's reference position (e.g., "center", "bottom-left").
when calculating the target size representation. Must match model time_scale : float
expectations. Scaling factor applied to the time duration (width) of the ROI.
frequency_scale : float, default=1/859.375 frequency_scale : float
Scaling factor applied to the frequency bandwidth (height) of the ROI Scaling factor applied to the frequency bandwidth (height) of the ROI.
when calculating the target size representation. Must match model
expectations.
""" """
name: Literal["anchor_bbox"] = "anchor_bbox" name: Literal["anchor_bbox"] = "anchor_bbox"
@ -180,23 +174,28 @@ class BBoxAnchorMapperConfig(BaseConfig):
class AnchorBBoxMapper(ROITargetMapper): class AnchorBBoxMapper(ROITargetMapper):
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes. """Maps ROIs using a bounding box anchor point and width/height.
This class implements the ROI mapping protocol primarily for This class implements the `ROITargetMapper` protocol for `BoundingBox`
`soundevent.data.BoundingBox` geometry. It extracts reference points, geometries.
calculates scaled width/height, and recovers bounding boxes based on
configured position and scaling factors. **Encoding**: The `position` is a fixed anchor point on the bounding box
(e.g., "bottom-left"). The `size` is a 2-element array containing the
scaled width and height of the box.
**Decoding**: Reconstructs a `BoundingBox` from an anchor point and
scaled width/height.
Attributes Attributes
---------- ----------
dimension_names : List[str] dimension_names : List[str]
Specifies the output dimension names as ['width', 'height']. The output dimension names: `['width', 'height']`.
anchor : Anchor anchor : Anchor
The configured reference point type (e.g., "center", "bottom-left"). The configured anchor point type (e.g., "center", "bottom-left").
time_scale : float time_scale : float
The configured scaling factor for the time dimension (width). The scaling factor for the time dimension (width).
frequency_scale : float frequency_scale : float
The configured scaling factor for the frequency dimension (height). The scaling factor for the frequency dimension (height).
""" """
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT] dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
@ -211,11 +210,11 @@ class AnchorBBoxMapper(ROITargetMapper):
Parameters Parameters
---------- ----------
anchor : Anchor, default="bottom-left" anchor : Anchor
Reference point type within the bounding box. Reference point type within the bounding box.
time_scale : float, default=1000.0 time_scale : float
Scaling factor for time duration (width). Scaling factor for time duration (width).
frequency_scale : float, default=1/859.375 frequency_scale : float
Scaling factor for frequency bandwidth (height). Scaling factor for frequency bandwidth (height).
""" """
self.anchor: Anchor = anchor self.anchor: Anchor = anchor
@ -223,19 +222,20 @@ class AnchorBBoxMapper(ROITargetMapper):
self.frequency_scale = frequency_scale self.frequency_scale = frequency_scale
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]: def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
"""Extract the configured reference position from the geometry. """Encode a SoundEvent into an anchor position and scaled box size.
Uses `soundevent.geometry.get_geometry_point`. The position is determined by the configured anchor on the sound
event's bounding box. The size is the scaled width and height.
Parameters Parameters
---------- ----------
geom : soundevent.data.Geometry sound_event : data.SoundEvent
Input geometry (e.g., BoundingBox). The input sound event with a geometry.
Returns Returns
------- -------
Tuple[float, float] Tuple[Position, Size]
Reference position (time, frequency). A tuple of (anchor_position, [scaled_width, scaled_height]).
""" """
from soundevent import geometry from soundevent import geometry
@ -267,29 +267,27 @@ class AnchorBBoxMapper(ROITargetMapper):
position: Position, position: Position,
size: Size, size: Size,
) -> data.Geometry: ) -> data.Geometry:
"""Recover a BoundingBox from a position and scaled dimensions. """Recover a BoundingBox from an anchor position and scaled size.
Un-scales the input dimensions using the configured factors and Un-scales the input dimensions and reconstructs a
reconstructs a `soundevent.data.BoundingBox` centered or anchored at `soundevent.data.BoundingBox` relative to the given anchor position.
the given reference `pos` according to the configured `position` type.
Parameters Parameters
---------- ----------
pos : Tuple[float, float] position : Position
Reference position (time, frequency). Reference anchor position (time, frequency).
dims : np.ndarray size : Size
NumPy array containing the *scaled* dimensions, expected order is NumPy array containing the scaled [width, height].
[scaled_width, scaled_height].
Returns Returns
------- -------
soundevent.data.BoundingBox data.BoundingBox
The reconstructed bounding box. The reconstructed bounding box.
Raises Raises
------ ------
ValueError ValueError
If `dims` does not have the expected shape (length 2). If `size` does not have the expected shape (length 2).
""" """
if size.ndim != 1 or size.shape[0] != 2: if size.ndim != 1 or size.shape[0] != 2:
@ -308,6 +306,24 @@ class AnchorBBoxMapper(ROITargetMapper):
class PeakEnergyBBoxMapperConfig(BaseConfig): class PeakEnergyBBoxMapperConfig(BaseConfig):
"""Configuration for `PeakEnergyBBoxMapper`.
Attributes
----------
name : Literal["peak_energy_bbox"]
The unique identifier for this mapper type.
preprocessing : PreprocessingConfig
Configuration for the spectrogram preprocessor needed to find the
peak energy.
loading_buffer : float
Seconds to add to each side of the ROI when loading audio to ensure
the peak is captured accurately, avoiding boundary effects.
time_scale : float
Scaling factor applied to the time dimensions.
frequency_scale : float
Scaling factor applied to the frequency dimensions.
"""
name: Literal["peak_energy_bbox"] name: Literal["peak_energy_bbox"]
preprocessing: PreprocessingConfig = Field( preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig default_factory=PreprocessingConfig
@ -318,9 +334,30 @@ class PeakEnergyBBoxMapperConfig(BaseConfig):
class PeakEnergyBBoxMapper(ROITargetMapper): class PeakEnergyBBoxMapper(ROITargetMapper):
""" """Maps ROIs using the peak energy point and distances to edges.
Encodes the ROI using the location of the peak energy within the bounding box
as the 'position' and the distances from that point to the box edges as the 'size'. This class implements the `ROITargetMapper` protocol.
**Encoding**: The `position` is the (time, frequency) coordinate of the
point with the highest energy within the sound event's bounding box. The
`size` is a 4-element array representing the scaled distances from this
peak energy point to the left, bottom, right, and top edges of the box.
**Decoding**: Reconstructs a `BoundingBox` by adding/subtracting the
un-scaled distances from the peak energy point.
Attributes
----------
dimension_names : List[str]
The output dimension names: `['left', 'bottom', 'right', 'top']`.
preprocessor : PreprocessorProtocol
The spectrogram preprocessor instance.
time_scale : float
The scaling factor for time-based distances.
frequency_scale : float
The scaling factor for frequency-based distances.
loading_buffer : float
The buffer used for loading audio around the ROI.
""" """
dimension_names = ["left", "bottom", "right", "top"] dimension_names = ["left", "bottom", "right", "top"]
@ -332,6 +369,19 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
frequency_scale: float = DEFAULT_FREQUENCY_SCALE, frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
loading_buffer: float = 0.01, loading_buffer: float = 0.01,
): ):
"""Initialize the PeakEnergyBBoxMapper.
Parameters
----------
preprocessor : PreprocessorProtocol
An initialized preprocessor for generating spectrograms.
time_scale : float
Scaling factor for time dimensions (left, right distances).
frequency_scale : float
Scaling factor for frequency dimensions (bottom, top distances).
loading_buffer : float
Buffer in seconds to add when loading audio clips.
"""
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.time_scale = time_scale self.time_scale = time_scale
self.frequency_scale = frequency_scale self.frequency_scale = frequency_scale
@ -341,6 +391,21 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
self, self,
sound_event: data.SoundEvent, sound_event: data.SoundEvent,
) -> tuple[Position, Size]: ) -> tuple[Position, Size]:
"""Encode a SoundEvent into a peak energy position and edge distances.
Finds the peak energy coordinates within the event's bounding box
and calculates the scaled distances from this point to the box edges.
Parameters
----------
sound_event : data.SoundEvent
The input sound event with a geometry and associated recording.
Returns
-------
Tuple[Position, Size]
A tuple of (peak_position, [l, b, r, t] distances).
"""
from soundevent import geometry from soundevent import geometry
geom = sound_event.geometry geom = sound_event.geometry
@ -377,6 +442,20 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
return (time, freq), size return (time, freq), size
def decode(self, position: Position, size: Size) -> data.Geometry: def decode(self, position: Position, size: Size) -> data.Geometry:
"""Recover a BoundingBox from a peak position and edge distances.
Parameters
----------
position : Position
The reference peak energy position (time, frequency).
size : Size
NumPy array with scaled distances [left, bottom, right, top].
Returns
-------
data.BoundingBox
The reconstructed bounding box.
"""
time, freq = position time, freq = position
left, bottom, right, top = size left, bottom, right, top = size
@ -394,21 +473,30 @@ ROIMapperConfig = Annotated[
Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig], Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig],
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""A discriminated union of all supported ROI mapper configurations.
This type allows for selecting and configuring different `ROITargetMapper`
implementations by using the `name` field as a discriminator.
"""
def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper: def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from configuration. """Factory function to create an ROITargetMapper from a config object.
Parameters Parameters
---------- ----------
config : ROIConfig config : ROIMapperConfig
Configuration object specifying ROI mapping parameters. A configuration object specifying the mapper type and its parameters.
Returns Returns
------- -------
ROITargetMapper ROITargetMapper
An initialized `BBoxEncoder` instance configured with the settings An initialized ROI mapper instance.
from `config`.
Raises
------
NotImplementedError
If the `name` in the config does not correspond to a known mapper.
""" """
if config.name == "anchor_bbox": if config.name == "anchor_bbox":
return AnchorBBoxMapper( return AnchorBBoxMapper(
@ -431,39 +519,6 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
) )
def load_roi_mapper(
path: data.PathLike, field: Optional[str] = None
) -> ROITargetMapper:
"""Load ROI mapping configuration from a file and build the mapper.
Convenience function that loads an `ROIConfig` from the specified file
(and optional field) and then uses `build_roi_mapper` to create the
corresponding `ROITargetMapper` instance.
Parameters
----------
path : PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing the
ROI configuration. If None, the entire file content is used.
Returns
-------
ROITargetMapper
An initialized ROI mapper instance based on the configuration file.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
If the configuration file cannot be found, parsed, validated, or if
the specified `field` is invalid.
"""
config = load_config(path=path, schema=BBoxAnchorMapperConfig, field=field)
return build_roi_mapper(config)
VALID_ANCHORS = [ VALID_ANCHORS = [
"bottom-left", "bottom-left",
"bottom-right", "bottom-right",
@ -501,7 +556,7 @@ def _build_bounding_box(
bandwidth : float bandwidth : float
The required *unscaled* frequency bandwidth (height) of the bounding The required *unscaled* frequency bandwidth (height) of the bounding
box. box.
anchor : Anchor, default="bottom-left" anchor : Anchor
Specifies which part of the bounding box the input `pos` corresponds to. Specifies which part of the bounding box the input `pos` corresponds to.
Returns Returns
@ -565,6 +620,35 @@ def get_peak_energy_coordinates(
high_freq: Optional[float] = None, high_freq: Optional[float] = None,
loading_buffer: float = 0.05, loading_buffer: float = 0.05,
) -> Position: ) -> Position:
"""Find the coordinates of the highest energy point in a spectrogram.
Generates a spectrogram for a specified time-frequency region of a
recording and returns the (time, frequency) coordinates of the pixel with
the maximum value.
Parameters
----------
recording : data.Recording
The recording to analyze.
preprocessor : PreprocessorProtocol
The processor to convert audio to a spectrogram.
start_time : float, default=0
The start time of the region of interest.
end_time : float, optional
The end time of the region of interest. Defaults to recording duration.
low_freq : float, default=0
The low frequency of the region of interest.
high_freq : float, optional
The high frequency of the region of interest. Defaults to Nyquist.
loading_buffer : float, default=0.05
Buffer in seconds to add around the time range when loading the clip
to mitigate border effects from transformations like STFT.
Returns
-------
Position
A (time, frequency) tuple for the peak energy location.
"""
if end_time is None: if end_time is None:
end_time = recording.duration end_time = recording.duration
end_time = min(end_time, recording.duration) end_time = min(end_time, recording.duration)

View File

@ -230,7 +230,7 @@ class TermRegistry(Mapping[str, data.Term]):
del self._terms[key] del self._terms[key]
term_registry = TermRegistry( default_term_registry = TermRegistry(
terms=dict( terms=dict(
[ [
*getmembers(terms, lambda x: isinstance(x, data.Term)), *getmembers(terms, lambda x: isinstance(x, data.Term)),
@ -252,7 +252,7 @@ is explicitly passed.
def get_term_from_key( def get_term_from_key(
key: str, key: str,
term_registry: TermRegistry = term_registry, term_registry: Optional[TermRegistry] = None,
) -> data.Term: ) -> data.Term:
"""Convenience function to retrieve a term by key from a registry. """Convenience function to retrieve a term by key from a registry.
@ -277,10 +277,13 @@ def get_term_from_key(
KeyError KeyError
If the key is not found in the specified registry. If the key is not found in the specified registry.
""" """
term_registry = term_registry or default_term_registry
return term_registry.get_term(key) return term_registry.get_term(key)
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]: def get_term_keys(
term_registry: TermRegistry = default_term_registry,
) -> List[str]:
"""Convenience function to get all registered keys from a registry. """Convenience function to get all registered keys from a registry.
Uses the global default registry unless a specific `term_registry` Uses the global default registry unless a specific `term_registry`
@ -299,7 +302,9 @@ def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
return term_registry.get_keys() return term_registry.get_keys()
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]: def get_terms(
term_registry: TermRegistry = default_term_registry,
) -> List[data.Term]:
"""Convenience function to get all registered terms from a registry. """Convenience function to get all registered terms from a registry.
Uses the global default registry unless a specific `term_registry` Uses the global default registry unless a specific `term_registry`
@ -342,7 +347,7 @@ class TagInfo(BaseModel):
def get_tag_from_info( def get_tag_from_info(
tag_info: TagInfo, tag_info: TagInfo,
term_registry: TermRegistry = term_registry, term_registry: Optional[TermRegistry] = None,
) -> data.Tag: ) -> data.Tag:
"""Creates a soundevent.data.Tag object from TagInfo data. """Creates a soundevent.data.Tag object from TagInfo data.
@ -368,6 +373,7 @@ def get_tag_from_info(
If the term key specified in `tag_info.key` is not found If the term key specified in `tag_info.key` is not found
in the registry. in the registry.
""" """
term_registry = term_registry or default_term_registry
term = get_term_from_key(tag_info.key, term_registry=term_registry) term = get_term_from_key(tag_info.key, term_registry=term_registry)
return data.Tag(term=term, value=tag_info.value) return data.Tag(term=term, value=tag_info.value)
@ -439,7 +445,7 @@ class TermConfig(BaseModel):
def load_terms_from_config( def load_terms_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> Dict[str, data.Term]: ) -> Dict[str, data.Term]:
"""Loads term definitions from a configuration file and registers them. """Loads term definitions from a configuration file and registers them.
@ -490,6 +496,6 @@ def load_terms_from_config(
def register_term( def register_term(
key: str, term: data.Term, registry: TermRegistry = term_registry key: str, term: data.Term, registry: TermRegistry = default_term_registry
) -> None: ) -> None:
registry.add_term(key, term) registry.add_term(key, term)

View File

@ -21,9 +21,6 @@ from batdetect2.targets.terms import (
get_tag_from_info, get_tag_from_info,
get_term_from_key, get_term_from_key,
) )
from batdetect2.targets.terms import (
term_registry as default_term_registry,
)
__all__ = [ __all__ = [
"DerivationRegistry", "DerivationRegistry",
@ -34,7 +31,7 @@ __all__ = [
"TransformConfig", "TransformConfig",
"build_transform_from_rule", "build_transform_from_rule",
"build_transformation_from_config", "build_transformation_from_config",
"derivation_registry", "default_derivation_registry",
"get_derivation", "get_derivation",
"load_transformation_config", "load_transformation_config",
"load_transformation_from_config", "load_transformation_from_config",
@ -398,7 +395,7 @@ class DerivationRegistry(Mapping[str, Derivation]):
return list(self._derivations.values()) return list(self._derivations.values())
derivation_registry = DerivationRegistry() default_derivation_registry = DerivationRegistry()
"""Global instance of the DerivationRegistry. """Global instance of the DerivationRegistry.
Register custom derivation functions here to make them available by key Register custom derivation functions here to make them available by key
@ -409,7 +406,7 @@ in `DeriveTagRule` configuration.
def get_derivation( def get_derivation(
key: str, key: str,
import_derivation: bool = False, import_derivation: bool = False,
registry: DerivationRegistry = derivation_registry, registry: Optional[DerivationRegistry] = None,
): ):
"""Retrieve a derivation function by key, optionally importing it. """Retrieve a derivation function by key, optionally importing it.
@ -443,6 +440,8 @@ def get_derivation(
AttributeError AttributeError
If dynamic import fails because the function name isn't in the module. If dynamic import fails because the function name isn't in the module.
""" """
registry = registry or default_derivation_registry
if not import_derivation or key in registry: if not import_derivation or key in registry:
return registry.get_derivation(key) return registry.get_derivation(key)
@ -458,10 +457,16 @@ def get_derivation(
) from err ) from err
TranformationRule = Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
def build_transform_from_rule( def build_transform_from_rule(
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule], rule: TranformationRule,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: Optional[DerivationRegistry] = None,
term_registry: TermRegistry = default_term_registry, term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation: ) -> SoundEventTransformation:
"""Build a specific SoundEventTransformation function from a rule config. """Build a specific SoundEventTransformation function from a rule config.
@ -559,8 +564,8 @@ def build_transform_from_rule(
def build_transformation_from_config( def build_transformation_from_config(
config: TransformConfig, config: TransformConfig,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: Optional[DerivationRegistry] = None,
term_registry: TermRegistry = default_term_registry, term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation: ) -> SoundEventTransformation:
"""Build a composite transformation function from a TransformConfig. """Build a composite transformation function from a TransformConfig.
@ -581,6 +586,7 @@ def build_transformation_from_config(
SoundEventTransformation SoundEventTransformation
A single function that applies all configured transformations in order. A single function that applies all configured transformations in order.
""" """
transforms = [ transforms = [
build_transform_from_rule( build_transform_from_rule(
rule, rule,
@ -590,14 +596,16 @@ def build_transformation_from_config(
for rule in config.rules for rule in config.rules
] ]
def transformation( return partial(apply_sequence_of_transforms, transforms=transforms)
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
for transform in transforms:
sound_event_annotation = transform(sound_event_annotation)
return sound_event_annotation
return transformation
def apply_sequence_of_transforms(
sound_event_annotation: data.SoundEventAnnotation,
transforms: list[SoundEventTransformation],
) -> data.SoundEventAnnotation:
for transform in transforms:
sound_event_annotation = transform(sound_event_annotation)
return sound_event_annotation
def load_transformation_config( def load_transformation_config(
@ -631,8 +639,8 @@ def load_transformation_config(
def load_transformation_from_config( def load_transformation_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: Optional[DerivationRegistry] = None,
term_registry: TermRegistry = default_term_registry, term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation: ) -> SoundEventTransformation:
"""Load transformation config from a file and build the final function. """Load transformation config from a file and build the final function.
@ -677,7 +685,7 @@ def load_transformation_from_config(
def register_derivation( def register_derivation(
key: str, key: str,
derivation: Derivation, derivation: Derivation,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: Optional[DerivationRegistry] = None,
) -> None: ) -> None:
"""Register a new derivation function in the global registry. """Register a new derivation function in the global registry.
@ -696,4 +704,5 @@ def register_derivation(
KeyError KeyError
If a derivation function with the same key is already registered. If a derivation function with the same key is already registered.
""" """
derivation_registry = derivation_registry or default_derivation_registry
derivation_registry.register(key, derivation) derivation_registry.register(key, derivation)

View File

@ -19,8 +19,16 @@ from soundevent import data
__all__ = [ __all__ = [
"TargetProtocol", "TargetProtocol",
"Position",
"Size",
] ]
Position = tuple[float, float]
"""A tuple representing (time, frequency) coordinates."""
Size = np.ndarray
"""A NumPy array representing the size dimensions of a target."""
class TargetProtocol(Protocol): class TargetProtocol(Protocol):
"""Protocol defining the interface for the target definition pipeline. """Protocol defining the interface for the target definition pipeline.
@ -102,7 +110,7 @@ class TargetProtocol(Protocol):
""" """
... ...
def encode( def encode_class(
self, self,
sound_event: data.SoundEventAnnotation, sound_event: data.SoundEventAnnotation,
) -> Optional[str]: ) -> Optional[str]:
@ -123,7 +131,7 @@ class TargetProtocol(Protocol):
""" """
... ...
def decode(self, class_label: str) -> List[data.Tag]: def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags. """Decode a predicted class name back into representative tags.
Parameters Parameters
@ -147,9 +155,9 @@ class TargetProtocol(Protocol):
""" """
... ...
def get_position( def encode_roi(
self, sound_event: data.SoundEventAnnotation self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]: ) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's geometry. """Extract the target reference position from the annotation's geometry.
Calculates the `(time, frequency)` coordinate representing the primary Calculates the `(time, frequency)` coordinate representing the primary
@ -173,37 +181,7 @@ class TargetProtocol(Protocol):
""" """
... ...
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray: def decode_roi(self, position: Position, size: Size) -> data.Geometry:
"""Calculate the target size dimensions from the annotation's geometry.
Computes the relevant physical size (e.g., duration/width,
bandwidth/height from a bounding box) to produce
the numerical target values expected by the model.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
np.ndarray
A NumPy array containing the size dimensions, matching the
order specified by the `dimension_names` attribute (e.g.,
`[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry or if the size cannot be computed.
TypeError
If geometry type is unsupported.
"""
...
def recover_roi(
self, pos: tuple[float, float], dims: np.ndarray
) -> data.Geometry:
"""Recover the ROI geometry from a position and dimensions. """Recover the ROI geometry from a position and dimensions.
Performs the inverse mapping of `get_position` and `get_size`. It takes Performs the inverse mapping of `get_position` and `get_size`. It takes

View File

@ -97,7 +97,7 @@ def _is_in_subclip(
start_time: float, start_time: float,
end_time: float, end_time: float,
) -> bool: ) -> bool:
time, _ = targets.get_position(sound_event_annotation) time, _ = targets.encode_roi(sound_event_annotation)
return start_time <= time <= end_time return start_time <= time <= end_time

View File

@ -138,7 +138,7 @@ def generate_clip_label(
logger.debug( logger.debug(
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events", "Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
uuid=clip_annotation.uuid, uuid=clip_annotation.uuid,
num=len(clip_annotation.sound_events) num=len(clip_annotation.sound_events),
) )
sound_events = [] sound_events = []
@ -260,7 +260,7 @@ def generate_heatmaps(
continue continue
# Get the position of the sound event # Get the position of the sound event
time, frequency = targets.get_position(sound_event_annotation) (time, frequency), size = targets.encode_roi(sound_event_annotation)
# Set 1.0 at the position of the sound event in the detection heatmap # Set 1.0 at the position of the sound event in the detection heatmap
try: try:
@ -280,8 +280,6 @@ def generate_heatmaps(
) )
continue continue
size = targets.get_size(sound_event_annotation)
size_heatmap = arrays.set_value_at_pos( size_heatmap = arrays.set_value_at_pos(
size_heatmap, size_heatmap,
size, size,
@ -291,7 +289,7 @@ def generate_heatmaps(
# Get the class name of the sound event # Get the class name of the sound event
try: try:
class_name = targets.encode(sound_event_annotation) class_name = targets.encode_class(sound_event_annotation)
except ValueError as e: except ValueError as e:
logger.warning( logger.warning(
"Skipping annotation %s: Unexpected error while encoding " "Skipping annotation %s: Unexpected error while encoding "

View File

@ -0,0 +1,19 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
@app.cell
def _():
from batdetect2.preprocess import build_preprocessor
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -62,6 +62,9 @@ keywords = [
requires = ["hatchling"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["batdetect2/"]
[project.scripts] [project.scripts]
batdetect2 = "batdetect2.cli:cli" batdetect2 = "batdetect2.cli:cli"

View File

@ -3,8 +3,8 @@ import pytest
from soundevent import data from soundevent import data
from batdetect2.targets.rois import ( from batdetect2.targets.rois import (
DEFAULT_FREQUENCY_SCALE,
DEFAULT_ANCHOR, DEFAULT_ANCHOR,
DEFAULT_FREQUENCY_SCALE,
DEFAULT_TIME_SCALE, DEFAULT_TIME_SCALE,
SIZE_HEIGHT, SIZE_HEIGHT,
SIZE_WIDTH, SIZE_WIDTH,
@ -12,7 +12,6 @@ from batdetect2.targets.rois import (
BBoxAnchorMapperConfig, BBoxAnchorMapperConfig,
_build_bounding_box, _build_bounding_box,
build_roi_mapper, build_roi_mapper,
load_roi_mapper,
) )
@ -22,6 +21,16 @@ def sample_bbox() -> data.BoundingBox:
return data.BoundingBox(coordinates=[10.0, 100.0, 20.0, 200.0]) return data.BoundingBox(coordinates=[10.0, 100.0, 20.0, 200.0])
@pytest.fixture
def sample_recording(create_recording) -> data.Recording:
return create_recording(duration=30, samplerate=4_000)
@pytest.fixture
def sample_sound_event(sample_bbox, sample_recording) -> data.SoundEvent:
return data.SoundEvent(geometry=sample_bbox, recording=sample_recording)
@pytest.fixture @pytest.fixture
def zero_bbox() -> data.BoundingBox: def zero_bbox() -> data.BoundingBox:
"""A bounding box with zero duration and bandwidth.""" """A bounding box with zero duration and bandwidth."""
@ -29,7 +38,13 @@ def zero_bbox() -> data.BoundingBox:
@pytest.fixture @pytest.fixture
def default_encoder() -> AnchorBBoxMapper: def zero_sound_event(zero_bbox, sample_recording) -> data.SoundEvent:
"""A sample sound event with a zero-sized bounding box."""
return data.SoundEvent(geometry=zero_bbox, recording=sample_recording)
@pytest.fixture
def default_mapper() -> AnchorBBoxMapper:
"""A BBoxEncoder with default settings.""" """A BBoxEncoder with default settings."""
return AnchorBBoxMapper() return AnchorBBoxMapper()
@ -37,36 +52,30 @@ def default_encoder() -> AnchorBBoxMapper:
@pytest.fixture @pytest.fixture
def custom_encoder() -> AnchorBBoxMapper: def custom_encoder() -> AnchorBBoxMapper:
"""A BBoxEncoder with custom settings.""" """A BBoxEncoder with custom settings."""
return AnchorBBoxMapper(anchor="center", time_scale=1.0, frequency_scale=10.0) return AnchorBBoxMapper(
anchor="center", time_scale=1.0, frequency_scale=10.0
)
def test_roi_config_defaults(): @pytest.fixture
"""Test ROIConfig default values.""" def custom_mapper() -> AnchorBBoxMapper:
config = BBoxAnchorMapperConfig() """An AnchorBBoxMapper with custom settings."""
assert config.anchor == DEFAULT_ANCHOR return AnchorBBoxMapper(
assert config.time_scale == DEFAULT_TIME_SCALE anchor="center", time_scale=1.0, frequency_scale=10.0
assert config.frequency_scale == DEFAULT_FREQUENCY_SCALE )
def test_roi_config_custom(): def test_bbox_encoder_init_defaults(default_mapper):
"""Test creating ROIConfig with custom values."""
config = BBoxAnchorMapperConfig(anchor="center", time_scale=1.0, frequency_scale=10.0)
assert config.anchor == "center"
assert config.time_scale == 1.0
assert config.frequency_scale == 10.0
def test_bbox_encoder_init_defaults(default_encoder):
"""Test BBoxEncoder initialization with default arguments.""" """Test BBoxEncoder initialization with default arguments."""
assert default_encoder.position == DEFAULT_ANCHOR assert default_mapper.anchor == DEFAULT_ANCHOR
assert default_encoder.time_scale == DEFAULT_TIME_SCALE assert default_mapper.time_scale == DEFAULT_TIME_SCALE
assert default_encoder.frequency_scale == DEFAULT_FREQUENCY_SCALE assert default_mapper.frequency_scale == DEFAULT_FREQUENCY_SCALE
assert default_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT] assert default_mapper.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
def test_bbox_encoder_init_custom(custom_encoder): def test_bbox_encoder_init_custom(custom_encoder):
"""Test BBoxEncoder initialization with custom arguments.""" """Test BBoxEncoder initialization with custom arguments."""
assert custom_encoder.position == "center" assert custom_encoder.anchor == "center"
assert custom_encoder.time_scale == 1.0 assert custom_encoder.time_scale == 1.0
assert custom_encoder.frequency_scale == 10.0 assert custom_encoder.frequency_scale == 10.0
assert custom_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT] assert custom_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
@ -87,52 +96,50 @@ POSITION_TEST_CASES = [
] ]
@pytest.mark.parametrize("position_type, expected_pos", POSITION_TEST_CASES) @pytest.mark.parametrize("anchor, expected_pos", POSITION_TEST_CASES)
def test_bbox_encoder_get_roi_position( def test_anchor_bbox_mapper_encode_position(
sample_bbox, position_type, expected_pos sample_sound_event, anchor, expected_pos
): ):
"""Test get_roi_position for various position types.""" """Test encode returns the correct position for various anchors."""
encoder = AnchorBBoxMapper(anchor=position_type) encoder = AnchorBBoxMapper(anchor=anchor)
actual_pos = encoder.encode_position(sample_bbox) actual_pos, _ = encoder.encode(sample_sound_event)
assert actual_pos == pytest.approx(expected_pos) assert actual_pos == pytest.approx(expected_pos)
def test_bbox_encoder_get_roi_position_zero_box(zero_bbox): def test_anchor_bbox_mapper_encode_defaults(
"""Test get_roi_position for a zero-sized box.""" sample_sound_event, default_mapper
encoder = AnchorBBoxMapper(anchor="center") ):
assert encoder.encode_position(zero_bbox) == pytest.approx((15.0, 150.0)) """Test encode with default settings returns correct position and size."""
expected_pos = (10.0, 100.0) # bottom-left
def test_bbox_encoder_get_roi_size_defaults(sample_bbox, default_encoder):
"""Test get_roi_size with default scaling."""
expected_size = np.array( expected_size = np.array(
[ [
10.0 * DEFAULT_TIME_SCALE, 10.0 * DEFAULT_TIME_SCALE,
100.0 * DEFAULT_FREQUENCY_SCALE, 100.0 * DEFAULT_FREQUENCY_SCALE,
] ]
) )
actual_size = default_encoder.get_roi_size(sample_bbox) actual_pos, actual_size = default_mapper.encode(sample_sound_event)
assert actual_pos == pytest.approx(expected_pos)
np.testing.assert_allclose(actual_size, expected_size) np.testing.assert_allclose(actual_size, expected_size)
assert actual_size.shape == (2,) assert actual_size.shape == (2,)
def test_bbox_encoder_get_roi_size_custom(sample_bbox, custom_encoder): def test_anchor_bbox_mapper_encode_custom(sample_sound_event, custom_mapper):
"""Test get_roi_size with custom scaling.""" """Test encode with custom settings returns correct position and size."""
expected_size = np.array( expected_pos = (15.0, 150.0) # center
[ expected_size = np.array([10.0 * 1.0, 100.0 * 10.0])
10.0 * 1.0,
100.0 * 10.0, actual_pos, actual_size = custom_mapper.encode(sample_sound_event)
] assert actual_pos == pytest.approx(expected_pos)
)
actual_size = custom_encoder.get_roi_size(sample_bbox)
np.testing.assert_allclose(actual_size, expected_size) np.testing.assert_allclose(actual_size, expected_size)
assert actual_size.shape == (2,) assert actual_size.shape == (2,)
def test_bbox_encoder_get_roi_size_zero_box(zero_bbox, default_encoder): def test_anchor_bbox_mapper_encode_zero_box(zero_sound_event, default_mapper):
"""Test get_roi_size for a zero-sized box.""" """Test encode for a zero-sized box."""
expected_pos = (15.0, 150.0)
expected_size = np.array([0.0, 0.0]) expected_size = np.array([0.0, 0.0])
actual_size = default_encoder.get_roi_size(zero_bbox) actual_pos, actual_size = default_mapper.encode(zero_sound_event)
assert actual_pos == pytest.approx(expected_pos)
np.testing.assert_allclose(actual_size, expected_size) np.testing.assert_allclose(actual_size, expected_size)
@ -166,9 +173,9 @@ def test_build_bounding_box(position_type, expected_coords):
np.testing.assert_allclose(bbox.coordinates, expected_coords) np.testing.assert_allclose(bbox.coordinates, expected_coords)
def test_build_bounding_box_invalid_position(): def test_build_bounding_box_invalid_anchor():
"""Test _build_bounding_box raises error for invalid position.""" """Test _build_bounding_box raises error for invalid position."""
with pytest.raises(ValueError, match="Invalid position"): with pytest.raises(ValueError, match="Invalid anchor"):
_build_bounding_box( _build_bounding_box(
(0, 0), (0, 0),
1, 1,
@ -177,13 +184,16 @@ def test_build_bounding_box_invalid_position():
) )
@pytest.mark.parametrize("position_type, ref_pos", POSITION_TEST_CASES) @pytest.mark.parametrize(
def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos): "anchor", [anchor for anchor, _ in POSITION_TEST_CASES]
"""Test recover_roi correctly reconstructs the original bbox.""" )
encoder = AnchorBBoxMapper(anchor=position_type) def test_anchor_bbox_mapper_encode_decode_roundtrip(
scaled_dims = encoder.encode_size(sample_bbox) sample_sound_event, sample_bbox, anchor
):
recovered_bbox = encoder.decode(ref_pos, scaled_dims) """Test encode-decode roundtrip reconstructs the original bbox."""
mapper = AnchorBBoxMapper(anchor=anchor)
position, size = mapper.encode(sample_sound_event)
recovered_bbox = mapper.decode(position, size)
assert isinstance(recovered_bbox, data.BoundingBox) assert isinstance(recovered_bbox, data.BoundingBox)
np.testing.assert_allclose( np.testing.assert_allclose(
@ -191,12 +201,12 @@ def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos):
) )
def test_bbox_encoder_recover_roi_custom_scale(sample_bbox, custom_encoder): def test_anchor_bbox_mapper_roundtrip_custom_scale(
"""Test recover_roi with custom scaling factors.""" sample_sound_event, sample_bbox, custom_mapper
ref_pos = custom_encoder.get_roi_position(sample_bbox) ):
scaled_dims = custom_encoder.get_roi_size(sample_bbox) """Test encode-decode roundtrip with custom scaling factors."""
position, size = custom_mapper.encode(sample_sound_event)
recovered_bbox = custom_encoder.recover_roi(ref_pos, scaled_dims) recovered_bbox = custom_mapper.decode(position, size)
assert isinstance(recovered_bbox, data.BoundingBox) assert isinstance(recovered_bbox, data.BoundingBox)
np.testing.assert_allclose( np.testing.assert_allclose(
@ -204,25 +214,26 @@ def test_bbox_encoder_recover_roi_custom_scale(sample_bbox, custom_encoder):
) )
def test_bbox_encoder_recover_roi_zero_box(zero_bbox, default_encoder): def test_anchor_bbox_mapper_roundtrip_zero_box(
"""Test recover_roi for a zero-sized box.""" zero_sound_event, zero_bbox, default_mapper
ref_pos = default_encoder.get_roi_position(zero_bbox) ):
scaled_dims = default_encoder.get_roi_size(zero_bbox) """Test encode-decode roundtrip for a zero-sized box."""
recovered_bbox = default_encoder.recover_roi(ref_pos, scaled_dims) position, size = default_mapper.encode(zero_sound_event)
recovered_bbox = default_mapper.decode(position, size)
np.testing.assert_allclose( np.testing.assert_allclose(
recovered_bbox.coordinates, zero_bbox.coordinates, atol=1e-6 recovered_bbox.coordinates, zero_bbox.coordinates, atol=1e-6
) )
def test_bbox_encoder_recover_roi_invalid_dims_shape(default_encoder): def test_anchor_bbox_mapper_decode_invalid_size_shape(default_mapper):
"""Test recover_roi raises ValueError for incorrect dims shape.""" """Test decode raises ValueError for incorrect size shape."""
ref_pos = (10, 100) ref_pos = (10, 100)
with pytest.raises(ValueError): with pytest.raises(ValueError, match="does not have the expected shape"):
default_encoder.recover_roi(ref_pos, np.array([1.0])) default_mapper.decode(ref_pos, np.array([1.0]))
with pytest.raises(ValueError): with pytest.raises(ValueError, match="does not have the expected shape"):
default_encoder.recover_roi(ref_pos, np.array([1.0, 2.0, 3.0])) default_mapper.decode(ref_pos, np.array([1.0, 2.0, 3.0]))
with pytest.raises(ValueError): with pytest.raises(ValueError, match="does not have the expected shape"):
default_encoder.recover_roi(ref_pos, np.array([[1.0], [2.0]])) default_mapper.decode(ref_pos, np.array([[1.0], [2.0]]))
def test_build_roi_mapper(): def test_build_roi_mapper():
@ -236,69 +247,3 @@ def test_build_roi_mapper():
assert mapper.anchor == config.anchor assert mapper.anchor == config.anchor
assert mapper.time_scale == config.time_scale assert mapper.time_scale == config.time_scale
assert mapper.frequency_scale == config.frequency_scale assert mapper.frequency_scale == config.frequency_scale
@pytest.fixture
def sample_config_yaml_content() -> str:
"""YAML content for a sample ROIConfig."""
return f"""
position: center
time_scale: 500.0
frequency_scale: {1 / 1000.0}
"""
@pytest.fixture
def nested_config_yaml_content() -> str:
"""YAML content with ROIConfig nested under a field."""
return f"""
model_settings:
preprocessing:
whatever: true
roi_mapping:
position: bottom-right
time_scale: {DEFAULT_TIME_SCALE}
frequency_scale: 0.01
other_stuff: 123
"""
def test_load_roi_mapper_simple(tmp_path, sample_config_yaml_content):
"""Test loading a simple ROIConfig from YAML."""
config_path = tmp_path / "config.yaml"
config_path.write_text(sample_config_yaml_content)
mapper = load_roi_mapper(config_path)
assert isinstance(mapper, AnchorBBoxMapper)
assert mapper.anchor == "center"
assert mapper.time_scale == 500.0
assert mapper.frequency_scale == pytest.approx(1 / 1000.0)
def test_load_roi_mapper_nested(tmp_path, nested_config_yaml_content):
"""Test loading a nested ROIConfig from YAML using 'field'."""
config_path = tmp_path / "nested_config.yaml"
config_path.write_text(nested_config_yaml_content)
mapper = load_roi_mapper(config_path, field="model_settings.roi_mapping")
assert isinstance(mapper, AnchorBBoxMapper)
assert mapper.anchor == "bottom-right"
assert mapper.time_scale == DEFAULT_TIME_SCALE
assert mapper.frequency_scale == 0.01
def test_load_roi_mapper_file_not_found(tmp_path):
"""Test load_roi_mapper raises error if file doesn't exist."""
non_existent_path = tmp_path / "not_real.yaml"
with pytest.raises(FileNotFoundError):
load_roi_mapper(non_existent_path)
def test_load_roi_mapper_invalid_field(tmp_path, sample_config_yaml_content):
"""Test load_roi_mapper raises error for invalid field."""
config_path = tmp_path / "config.yaml"
config_path.write_text(sample_config_yaml_content)
with pytest.raises(KeyError):
load_roi_mapper(config_path, field="invalid.path")