From e352dc40bdc37fee6ea3083c75ce7589369ab6d1 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 21 Jun 2025 13:47:04 +0100 Subject: [PATCH] Fixed Target object after changes to roi --- batdetect2/data/iterators.py | 2 +- batdetect2/evaluate/match.py | 2 +- batdetect2/postprocess/__init__.py | 4 +- batdetect2/targets/__init__.py | 59 ++--- batdetect2/targets/classes.py | 12 +- batdetect2/targets/filtering.py | 8 +- batdetect2/targets/rois.py | 364 ++++++++++++++++++----------- batdetect2/targets/terms.py | 20 +- batdetect2/targets/transform.py | 51 ++-- batdetect2/targets/types.py | 48 ++-- batdetect2/train/callbacks.py | 2 +- batdetect2/train/labels.py | 8 +- notebooks/signal_generation.py | 19 ++ pyproject.toml | 3 + tests/test_targets/test_rois.py | 235 +++++++------------ 15 files changed, 426 insertions(+), 411 deletions(-) create mode 100644 notebooks/signal_generation.py diff --git a/batdetect2/data/iterators.py b/batdetect2/data/iterators.py index 4a3e11d..289f7ce 100644 --- a/batdetect2/data/iterators.py +++ b/batdetect2/data/iterators.py @@ -72,7 +72,7 @@ def iterate_over_sound_events( sound_event_annotation ) - class_name = targets.encode(sound_event_annotation) + class_name = targets.encode_class(sound_event_annotation) if class_name is None and exclude_generic: continue diff --git a/batdetect2/evaluate/match.py b/batdetect2/evaluate/match.py index ccae73f..b20b361 100644 --- a/batdetect2/evaluate/match.py +++ b/batdetect2/evaluate/match.py @@ -40,7 +40,7 @@ def match_sound_events_and_raw_predictions( gt_uuid = target.uuid if target is not None else None gt_det = target is not None - gt_class = targets.encode(target) if target is not None else None + gt_class = targets.encode_class(target) if target is not None else None pred_score = float(prediction.detection_score) if prediction else 0 diff --git a/batdetect2/postprocess/__init__.py b/batdetect2/postprocess/__init__.py index cc9295a..7a79289 100644 --- a/batdetect2/postprocess/__init__.py +++ b/batdetect2/postprocess/__init__.py @@ -526,7 +526,7 @@ class Postprocessor(PostprocessorProtocol): return [ convert_xr_dataset_to_raw_prediction( dataset, - self.targets.recover_roi, + self.targets.decode_roi, ) for dataset in detection_datasets ] @@ -558,7 +558,7 @@ class Postprocessor(PostprocessorProtocol): convert_raw_predictions_to_clip_prediction( prediction, clip, - sound_event_decoder=self.targets.decode, + sound_event_decoder=self.targets.decode_class, generic_class_tags=self.targets.generic_class_tags, classification_threshold=self.config.classification_threshold, ) diff --git a/batdetect2/targets/__init__.py b/batdetect2/targets/__init__.py index 7d1ab3e..68a8095 100644 --- a/batdetect2/targets/__init__.py +++ b/batdetect2/targets/__init__.py @@ -23,7 +23,6 @@ object is via the `build_targets` or `load_targets` functions. from typing import List, Optional -import numpy as np from pydantic import Field from soundevent import data @@ -63,7 +62,7 @@ from batdetect2.targets.terms import ( get_term_from_key, individual, register_term, - term_registry, + default_term_registry, ) from batdetect2.targets.transform import ( DerivationRegistry, @@ -73,13 +72,13 @@ from batdetect2.targets.transform import ( SoundEventTransformation, TransformConfig, build_transformation_from_config, - derivation_registry, + default_derivation_registry, get_derivation, load_transformation_config, load_transformation_from_config, register_derivation, ) -from batdetect2.targets.types import TargetProtocol +from batdetect2.targets.types import Position, Size, TargetProtocol __all__ = [ "ClassesConfig", @@ -291,7 +290,9 @@ class Targets(TargetProtocol): return True return self._filter_fn(sound_event) - def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]: + def encode_class( + self, sound_event: data.SoundEventAnnotation + ) -> Optional[str]: """Encode a sound event annotation to its target class name. Applies the configured class definition rules (including priority) @@ -312,7 +313,7 @@ class Targets(TargetProtocol): """ return self._encode_fn(sound_event) - def decode(self, class_label: str) -> List[data.Tag]: + def decode_class(self, class_label: str) -> List[data.Tag]: """Decode a predicted class name back into representative tags. Uses the configured mapping (based on `TargetClass.output_tags` or @@ -352,9 +353,9 @@ class Targets(TargetProtocol): return self._transform_fn(sound_event) return sound_event - def get_position( + def encode_roi( self, sound_event: data.SoundEventAnnotation - ) -> tuple[float, float]: + ) -> tuple[Position, Size]: """Extract the target reference position from the annotation's roi. Delegates to the internal ROI mapper's `get_roi_position` method. @@ -374,37 +375,9 @@ class Targets(TargetProtocol): ValueError If the annotation lacks geometry. """ - return self._roi_mapper.encode_position(sound_event.sound_event) + return self._roi_mapper.encode(sound_event.sound_event) - def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray: - """Calculate the target size dimensions from the annotation's geometry. - - Delegates to the internal ROI mapper's `get_roi_size` method, which - applies configured scaling factors. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The annotation containing the geometry (ROI). - - Returns - ------- - np.ndarray - NumPy array containing the size dimensions, matching the - order in `self.dimension_names` (e.g., `[width, height]`). - - Raises - ------ - ValueError - If the annotation lacks geometry. - """ - return self._roi_mapper.encode_size(sound_event.sound_event) - - def recover_roi( - self, - pos: tuple[float, float], - dims: np.ndarray, - ) -> data.Geometry: + def decode_roi(self, position: Position, size: Size) -> data.Geometry: """Recover an approximate geometric ROI from a position and dimensions. Delegates to the internal ROI mapper's `recover_roi` method, which @@ -424,7 +397,7 @@ class Targets(TargetProtocol): data.Geometry The reconstructed geometry (typically `BoundingBox`). """ - return self._roi_mapper.decode(pos, dims) + return self._roi_mapper.decode(position, size) DEFAULT_CLASSES = [ @@ -528,8 +501,8 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( def build_targets( config: Optional[TargetConfig] = None, - term_registry: TermRegistry = term_registry, - derivation_registry: DerivationRegistry = derivation_registry, + term_registry: TermRegistry = default_term_registry, + derivation_registry: DerivationRegistry = default_derivation_registry, ) -> Targets: """Build a Targets object from a loaded TargetConfig. @@ -613,8 +586,8 @@ def build_targets( def load_targets( config_path: data.PathLike, field: Optional[str] = None, - term_registry: TermRegistry = term_registry, - derivation_registry: DerivationRegistry = derivation_registry, + term_registry: TermRegistry = default_term_registry, + derivation_registry: DerivationRegistry = default_derivation_registry, ) -> Targets: """Load a Targets object directly from a configuration file. diff --git a/batdetect2/targets/classes.py b/batdetect2/targets/classes.py index 0f223cb..7b947bd 100644 --- a/batdetect2/targets/classes.py +++ b/batdetect2/targets/classes.py @@ -11,7 +11,7 @@ from batdetect2.targets.terms import ( TagInfo, TermRegistry, get_tag_from_info, - term_registry, + default_term_registry, ) __all__ = [ @@ -339,7 +339,7 @@ def _encode_with_multiple_classifiers( def build_sound_event_encoder( config: ClassesConfig, - term_registry: TermRegistry = term_registry, + term_registry: TermRegistry = default_term_registry, ) -> SoundEventEncoder: """Build a sound event encoder function from the classes configuration. @@ -433,7 +433,7 @@ def _decode_class( def build_sound_event_decoder( config: ClassesConfig, - term_registry: TermRegistry = term_registry, + term_registry: TermRegistry = default_term_registry, raise_on_unmapped: bool = False, ) -> SoundEventDecoder: """Build a sound event decoder function from the classes configuration. @@ -488,7 +488,7 @@ def build_sound_event_decoder( def build_generic_class_tags( config: ClassesConfig, - term_registry: TermRegistry = term_registry, + term_registry: TermRegistry = default_term_registry, ) -> List[data.Tag]: """Extract and build the list of tags for the generic class from config. @@ -553,7 +553,7 @@ def load_classes_config( def load_encoder_from_config( path: data.PathLike, field: Optional[str] = None, - term_registry: TermRegistry = term_registry, + term_registry: TermRegistry = default_term_registry, ) -> SoundEventEncoder: """Load a class encoder function directly from a configuration file. @@ -594,7 +594,7 @@ def load_encoder_from_config( def load_decoder_from_config( path: data.PathLike, field: Optional[str] = None, - term_registry: TermRegistry = term_registry, + term_registry: TermRegistry = default_term_registry, raise_on_unmapped: bool = False, ) -> SoundEventDecoder: """Load a class decoder function directly from a configuration file. diff --git a/batdetect2/targets/filtering.py b/batdetect2/targets/filtering.py index a1172d0..4f30dc8 100644 --- a/batdetect2/targets/filtering.py +++ b/batdetect2/targets/filtering.py @@ -10,7 +10,7 @@ from batdetect2.targets.terms import ( TagInfo, TermRegistry, get_tag_from_info, - term_registry, + default_term_registry, ) __all__ = [ @@ -156,7 +156,7 @@ def equal_tags( def build_filter_from_rule( rule: FilterRule, - term_registry: TermRegistry = term_registry, + term_registry: TermRegistry = default_term_registry, ) -> SoundEventFilter: """Creates a callable filter function from a single FilterRule. @@ -243,7 +243,7 @@ class FilterConfig(BaseConfig): def build_sound_event_filter( config: FilterConfig, - term_registry: TermRegistry = term_registry, + term_registry: TermRegistry = default_term_registry, ) -> SoundEventFilter: """Builds a merged filter function from a FilterConfig object. @@ -291,7 +291,7 @@ def load_filter_config( def load_filter_from_config( path: data.PathLike, field: Optional[str] = None, - term_registry: TermRegistry = term_registry, + term_registry: TermRegistry = default_term_registry, ) -> SoundEventFilter: """Loads filter configuration from a file and builds the filter function. diff --git a/batdetect2/targets/rois.py b/batdetect2/targets/rois.py index 484506a..8ecf886 100644 --- a/batdetect2/targets/rois.py +++ b/batdetect2/targets/rois.py @@ -1,23 +1,23 @@ """Handles mapping between geometric ROIs and target representations. -This module defines the interface and provides implementation for converting -a sound event's Region of Interest (ROI), typically represented by a -`soundevent.data.Geometry` object like a `BoundingBox`, into a format -suitable for use as a machine learning target. This usually involves: +This module defines a standardized interface (`ROITargetMapper`) for converting +a sound event's Region of Interest (ROI) into a target representation suitable +for machine learning models, and for decoding model outputs back into geometric +ROIs. -1. Extracting a single reference point (time, frequency) from the geometry. -2. Calculating relevant size dimensions (e.g., duration/width, - bandwidth/height) and applying scaling factors. +The core operations are: +1. **Encoding**: A `soundevent.data.SoundEvent` is mapped to a reference + `Position` (time, frequency) and a `Size` array. The method for + determining the position and size varies by the mapper implementation + (e.g., using a bounding box anchor or the point of peak energy). +2. **Decoding**: A `Position` and `Size` array are mapped back to an + approximate `soundevent.data.Geometry` (typically a `BoundingBox`). -It also provides the inverse operation: recovering an approximate geometric ROI -(like a `BoundingBox`) from a predicted reference point and predicted size -dimensions. - -This logic is encapsulated within components adhering to the `ROITargetMapper` -protocol. Configuration for this mapping (e.g., which reference point to use, -scaling factors) is managed by the `ROIConfig`. This module separates the -*geometric* aspect of target definition from the *semantic* classification -handled in `batdetect2.targets.classes`. +This logic is encapsulated within specific mapper classes. Configuration for +each mapper (e.g., anchor point, scaling factors) is managed by a corresponding +Pydantic config object. The `ROIMapperConfig` type allows for flexibly +selecting and configuring the desired mapper. This module separates the +*geometric* aspect of target definition from *semantic* classification. """ from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union @@ -26,22 +26,26 @@ import numpy as np from pydantic import Field from soundevent import data -from batdetect2.configs import BaseConfig, load_config +from batdetect2.configs import BaseConfig from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets.types import Position, Size __all__ = [ - "ROITargetMapper", - "BBoxAnchorMapperConfig", + "Anchor", "AnchorBBoxMapper", - "build_roi_mapper", - "load_roi_mapper", + "BBoxAnchorMapperConfig", "DEFAULT_ANCHOR", - "SIZE_WIDTH", + "DEFAULT_FREQUENCY_SCALE", + "DEFAULT_TIME_SCALE", + "PeakEnergyBBoxMapper", + "PeakEnergyBBoxMapperConfig", + "ROIMapperConfig", + "ROITargetMapper", "SIZE_HEIGHT", "SIZE_ORDER", - "DEFAULT_TIME_SCALE", - "DEFAULT_FREQUENCY_SCALE", + "SIZE_WIDTH", + "build_roi_mapper", ] Anchor = Literal[ @@ -73,104 +77,94 @@ DEFAULT_TIME_SCALE = 1000.0 DEFAULT_FREQUENCY_SCALE = 1 / 859.375 """Default scaling factor for frequency bandwidth.""" - DEFAULT_ANCHOR = "bottom-left" """Default reference position within the geometry ('bottom-left' corner).""" -Position = tuple[float, float] - -Size = np.ndarray - - class ROITargetMapper(Protocol): """Protocol defining the interface for ROI-to-target mapping. - Specifies the methods required for converting a geometric region of interest - (`soundevent.data.Geometry`) into a target representation (reference point - and scaled dimensions) and for recovering an approximate ROI from that + Specifies the `encode` and `decode` methods required for converting a + `soundevent.data.SoundEvent` into a target representation (a reference + position and a size vector) and for recovering an approximate ROI from that representation. Attributes ---------- dimension_names : List[str] - A list containing the names of the dimensions returned by - `get_roi_size` and expected by `recover_roi` - (e.g., ['width', 'height']). + A list containing the names of the dimensions in the `Size` array + returned by `encode` and expected by `decode`. """ dimension_names: List[str] def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]: - """Extract the reference position from a geometry. + """Encode a SoundEvent's geometry into a position and size. Parameters ---------- - geom : soundevent.data.Geometry - The input geometry (e.g., BoundingBox, Polygon). + sound_event : data.SoundEvent + The input sound event, which must have a geometry attribute. Returns ------- - Tuple[float, float] - The calculated reference position as (time, frequency) coordinates, - based on the implementing class's configuration (e.g., "center", - "bottom-left"). + Tuple[Position, Size] + A tuple containing: + - The reference position as (time, frequency) coordinates. + - A NumPy array with the calculated size dimensions. Raises ------ ValueError - If the position cannot be calculated for the given geometry type - or configured reference point. + If the sound event does not have a geometry. """ ... def decode(self, position: Position, size: Size) -> data.Geometry: - """Recover an approximate ROI from a position and target dimensions. + """Decode a position and size back into a geometric ROI. - Performs the inverse mapping: takes a reference position and the - predicted dimensions and reconstructs a geometric representation. + Performs the inverse mapping: takes a reference position and size + dimensions and reconstructs a geometric representation. Parameters ---------- - position : Tuple[float, float] + position : Position The reference position (time, frequency). - size : np.ndarray - NumPy array containing the dimensions, matching the order - specified by `dimension_names`. + size : Size + NumPy array containing the size dimensions, matching the order + and meaning specified by `dimension_names`. Returns ------- soundevent.data.Geometry - The reconstructed geometry. + The reconstructed geometry, typically a `BoundingBox`. Raises ------ ValueError - If the number of provided dimensions `dims` does not match - `dimension_names` or if reconstruction fails. + If the `size` array has an unexpected shape or if reconstruction + fails. """ ... class BBoxAnchorMapperConfig(BaseConfig): - """Configuration for mapping Regions of Interest (ROIs). + """Configuration for `AnchorBBoxMapper`. - Defines parameters controlling how geometric ROIs are converted into - target representations (reference points and scaled sizes). + Defines parameters for converting ROIs into targets using a fixed anchor + point on the bounding box. Attributes ---------- - anchor : Anchor, default="bottom-left" - Specifies the reference point within the geometry (e.g., bounding box) - to use as the target location (e.g., "center", "bottom-left"). - time_scale : float, default=1000.0 - Scaling factor applied to the time duration (width) of the ROI - when calculating the target size representation. Must match model - expectations. - frequency_scale : float, default=1/859.375 - Scaling factor applied to the frequency bandwidth (height) of the ROI - when calculating the target size representation. Must match model - expectations. + name : Literal["anchor_bbox"] + The unique identifier for this mapper type. + anchor : Anchor + Specifies the anchor point within the bounding box to use as the + target's reference position (e.g., "center", "bottom-left"). + time_scale : float + Scaling factor applied to the time duration (width) of the ROI. + frequency_scale : float + Scaling factor applied to the frequency bandwidth (height) of the ROI. """ name: Literal["anchor_bbox"] = "anchor_bbox" @@ -180,23 +174,28 @@ class BBoxAnchorMapperConfig(BaseConfig): class AnchorBBoxMapper(ROITargetMapper): - """Concrete implementation of `ROITargetMapper` focused on Bounding Boxes. + """Maps ROIs using a bounding box anchor point and width/height. - This class implements the ROI mapping protocol primarily for - `soundevent.data.BoundingBox` geometry. It extracts reference points, - calculates scaled width/height, and recovers bounding boxes based on - configured position and scaling factors. + This class implements the `ROITargetMapper` protocol for `BoundingBox` + geometries. + + **Encoding**: The `position` is a fixed anchor point on the bounding box + (e.g., "bottom-left"). The `size` is a 2-element array containing the + scaled width and height of the box. + + **Decoding**: Reconstructs a `BoundingBox` from an anchor point and + scaled width/height. Attributes ---------- dimension_names : List[str] - Specifies the output dimension names as ['width', 'height']. + The output dimension names: `['width', 'height']`. anchor : Anchor - The configured reference point type (e.g., "center", "bottom-left"). + The configured anchor point type (e.g., "center", "bottom-left"). time_scale : float - The configured scaling factor for the time dimension (width). + The scaling factor for the time dimension (width). frequency_scale : float - The configured scaling factor for the frequency dimension (height). + The scaling factor for the frequency dimension (height). """ dimension_names = [SIZE_WIDTH, SIZE_HEIGHT] @@ -211,11 +210,11 @@ class AnchorBBoxMapper(ROITargetMapper): Parameters ---------- - anchor : Anchor, default="bottom-left" + anchor : Anchor Reference point type within the bounding box. - time_scale : float, default=1000.0 + time_scale : float Scaling factor for time duration (width). - frequency_scale : float, default=1/859.375 + frequency_scale : float Scaling factor for frequency bandwidth (height). """ self.anchor: Anchor = anchor @@ -223,19 +222,20 @@ class AnchorBBoxMapper(ROITargetMapper): self.frequency_scale = frequency_scale def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]: - """Extract the configured reference position from the geometry. + """Encode a SoundEvent into an anchor position and scaled box size. - Uses `soundevent.geometry.get_geometry_point`. + The position is determined by the configured anchor on the sound + event's bounding box. The size is the scaled width and height. Parameters ---------- - geom : soundevent.data.Geometry - Input geometry (e.g., BoundingBox). + sound_event : data.SoundEvent + The input sound event with a geometry. Returns ------- - Tuple[float, float] - Reference position (time, frequency). + Tuple[Position, Size] + A tuple of (anchor_position, [scaled_width, scaled_height]). """ from soundevent import geometry @@ -267,29 +267,27 @@ class AnchorBBoxMapper(ROITargetMapper): position: Position, size: Size, ) -> data.Geometry: - """Recover a BoundingBox from a position and scaled dimensions. + """Recover a BoundingBox from an anchor position and scaled size. - Un-scales the input dimensions using the configured factors and - reconstructs a `soundevent.data.BoundingBox` centered or anchored at - the given reference `pos` according to the configured `position` type. + Un-scales the input dimensions and reconstructs a + `soundevent.data.BoundingBox` relative to the given anchor position. Parameters ---------- - pos : Tuple[float, float] - Reference position (time, frequency). - dims : np.ndarray - NumPy array containing the *scaled* dimensions, expected order is - [scaled_width, scaled_height]. + position : Position + Reference anchor position (time, frequency). + size : Size + NumPy array containing the scaled [width, height]. Returns ------- - soundevent.data.BoundingBox + data.BoundingBox The reconstructed bounding box. Raises ------ ValueError - If `dims` does not have the expected shape (length 2). + If `size` does not have the expected shape (length 2). """ if size.ndim != 1 or size.shape[0] != 2: @@ -308,6 +306,24 @@ class AnchorBBoxMapper(ROITargetMapper): class PeakEnergyBBoxMapperConfig(BaseConfig): + """Configuration for `PeakEnergyBBoxMapper`. + + Attributes + ---------- + name : Literal["peak_energy_bbox"] + The unique identifier for this mapper type. + preprocessing : PreprocessingConfig + Configuration for the spectrogram preprocessor needed to find the + peak energy. + loading_buffer : float + Seconds to add to each side of the ROI when loading audio to ensure + the peak is captured accurately, avoiding boundary effects. + time_scale : float + Scaling factor applied to the time dimensions. + frequency_scale : float + Scaling factor applied to the frequency dimensions. + """ + name: Literal["peak_energy_bbox"] preprocessing: PreprocessingConfig = Field( default_factory=PreprocessingConfig @@ -318,9 +334,30 @@ class PeakEnergyBBoxMapperConfig(BaseConfig): class PeakEnergyBBoxMapper(ROITargetMapper): - """ - Encodes the ROI using the location of the peak energy within the bounding box - as the 'position' and the distances from that point to the box edges as the 'size'. + """Maps ROIs using the peak energy point and distances to edges. + + This class implements the `ROITargetMapper` protocol. + + **Encoding**: The `position` is the (time, frequency) coordinate of the + point with the highest energy within the sound event's bounding box. The + `size` is a 4-element array representing the scaled distances from this + peak energy point to the left, bottom, right, and top edges of the box. + + **Decoding**: Reconstructs a `BoundingBox` by adding/subtracting the + un-scaled distances from the peak energy point. + + Attributes + ---------- + dimension_names : List[str] + The output dimension names: `['left', 'bottom', 'right', 'top']`. + preprocessor : PreprocessorProtocol + The spectrogram preprocessor instance. + time_scale : float + The scaling factor for time-based distances. + frequency_scale : float + The scaling factor for frequency-based distances. + loading_buffer : float + The buffer used for loading audio around the ROI. """ dimension_names = ["left", "bottom", "right", "top"] @@ -332,6 +369,19 @@ class PeakEnergyBBoxMapper(ROITargetMapper): frequency_scale: float = DEFAULT_FREQUENCY_SCALE, loading_buffer: float = 0.01, ): + """Initialize the PeakEnergyBBoxMapper. + + Parameters + ---------- + preprocessor : PreprocessorProtocol + An initialized preprocessor for generating spectrograms. + time_scale : float + Scaling factor for time dimensions (left, right distances). + frequency_scale : float + Scaling factor for frequency dimensions (bottom, top distances). + loading_buffer : float + Buffer in seconds to add when loading audio clips. + """ self.preprocessor = preprocessor self.time_scale = time_scale self.frequency_scale = frequency_scale @@ -341,6 +391,21 @@ class PeakEnergyBBoxMapper(ROITargetMapper): self, sound_event: data.SoundEvent, ) -> tuple[Position, Size]: + """Encode a SoundEvent into a peak energy position and edge distances. + + Finds the peak energy coordinates within the event's bounding box + and calculates the scaled distances from this point to the box edges. + + Parameters + ---------- + sound_event : data.SoundEvent + The input sound event with a geometry and associated recording. + + Returns + ------- + Tuple[Position, Size] + A tuple of (peak_position, [l, b, r, t] distances). + """ from soundevent import geometry geom = sound_event.geometry @@ -377,6 +442,20 @@ class PeakEnergyBBoxMapper(ROITargetMapper): return (time, freq), size def decode(self, position: Position, size: Size) -> data.Geometry: + """Recover a BoundingBox from a peak position and edge distances. + + Parameters + ---------- + position : Position + The reference peak energy position (time, frequency). + size : Size + NumPy array with scaled distances [left, bottom, right, top]. + + Returns + ------- + data.BoundingBox + The reconstructed bounding box. + """ time, freq = position left, bottom, right, top = size @@ -394,21 +473,30 @@ ROIMapperConfig = Annotated[ Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig], Field(discriminator="name"), ] +"""A discriminated union of all supported ROI mapper configurations. + +This type allows for selecting and configuring different `ROITargetMapper` +implementations by using the `name` field as a discriminator. +""" def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper: - """Factory function to create an ROITargetMapper from configuration. + """Factory function to create an ROITargetMapper from a config object. Parameters ---------- - config : ROIConfig - Configuration object specifying ROI mapping parameters. + config : ROIMapperConfig + A configuration object specifying the mapper type and its parameters. Returns ------- ROITargetMapper - An initialized `BBoxEncoder` instance configured with the settings - from `config`. + An initialized ROI mapper instance. + + Raises + ------ + NotImplementedError + If the `name` in the config does not correspond to a known mapper. """ if config.name == "anchor_bbox": return AnchorBBoxMapper( @@ -431,39 +519,6 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper: ) -def load_roi_mapper( - path: data.PathLike, field: Optional[str] = None -) -> ROITargetMapper: - """Load ROI mapping configuration from a file and build the mapper. - - Convenience function that loads an `ROIConfig` from the specified file - (and optional field) and then uses `build_roi_mapper` to create the - corresponding `ROITargetMapper` instance. - - Parameters - ---------- - path : PathLike - Path to the configuration file (e.g., YAML). - field : str, optional - Dot-separated path to a nested section within the file containing the - ROI configuration. If None, the entire file content is used. - - Returns - ------- - ROITargetMapper - An initialized ROI mapper instance based on the configuration file. - - Raises - ------ - FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError, - TypeError - If the configuration file cannot be found, parsed, validated, or if - the specified `field` is invalid. - """ - config = load_config(path=path, schema=BBoxAnchorMapperConfig, field=field) - return build_roi_mapper(config) - - VALID_ANCHORS = [ "bottom-left", "bottom-right", @@ -501,7 +556,7 @@ def _build_bounding_box( bandwidth : float The required *unscaled* frequency bandwidth (height) of the bounding box. - anchor : Anchor, default="bottom-left" + anchor : Anchor Specifies which part of the bounding box the input `pos` corresponds to. Returns @@ -565,6 +620,35 @@ def get_peak_energy_coordinates( high_freq: Optional[float] = None, loading_buffer: float = 0.05, ) -> Position: + """Find the coordinates of the highest energy point in a spectrogram. + + Generates a spectrogram for a specified time-frequency region of a + recording and returns the (time, frequency) coordinates of the pixel with + the maximum value. + + Parameters + ---------- + recording : data.Recording + The recording to analyze. + preprocessor : PreprocessorProtocol + The processor to convert audio to a spectrogram. + start_time : float, default=0 + The start time of the region of interest. + end_time : float, optional + The end time of the region of interest. Defaults to recording duration. + low_freq : float, default=0 + The low frequency of the region of interest. + high_freq : float, optional + The high frequency of the region of interest. Defaults to Nyquist. + loading_buffer : float, default=0.05 + Buffer in seconds to add around the time range when loading the clip + to mitigate border effects from transformations like STFT. + + Returns + ------- + Position + A (time, frequency) tuple for the peak energy location. + """ if end_time is None: end_time = recording.duration end_time = min(end_time, recording.duration) diff --git a/batdetect2/targets/terms.py b/batdetect2/targets/terms.py index 39367f2..d6a3814 100644 --- a/batdetect2/targets/terms.py +++ b/batdetect2/targets/terms.py @@ -230,7 +230,7 @@ class TermRegistry(Mapping[str, data.Term]): del self._terms[key] -term_registry = TermRegistry( +default_term_registry = TermRegistry( terms=dict( [ *getmembers(terms, lambda x: isinstance(x, data.Term)), @@ -252,7 +252,7 @@ is explicitly passed. def get_term_from_key( key: str, - term_registry: TermRegistry = term_registry, + term_registry: Optional[TermRegistry] = None, ) -> data.Term: """Convenience function to retrieve a term by key from a registry. @@ -277,10 +277,13 @@ def get_term_from_key( KeyError If the key is not found in the specified registry. """ + term_registry = term_registry or default_term_registry return term_registry.get_term(key) -def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]: +def get_term_keys( + term_registry: TermRegistry = default_term_registry, +) -> List[str]: """Convenience function to get all registered keys from a registry. Uses the global default registry unless a specific `term_registry` @@ -299,7 +302,9 @@ def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]: return term_registry.get_keys() -def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]: +def get_terms( + term_registry: TermRegistry = default_term_registry, +) -> List[data.Term]: """Convenience function to get all registered terms from a registry. Uses the global default registry unless a specific `term_registry` @@ -342,7 +347,7 @@ class TagInfo(BaseModel): def get_tag_from_info( tag_info: TagInfo, - term_registry: TermRegistry = term_registry, + term_registry: Optional[TermRegistry] = None, ) -> data.Tag: """Creates a soundevent.data.Tag object from TagInfo data. @@ -368,6 +373,7 @@ def get_tag_from_info( If the term key specified in `tag_info.key` is not found in the registry. """ + term_registry = term_registry or default_term_registry term = get_term_from_key(tag_info.key, term_registry=term_registry) return data.Tag(term=term, value=tag_info.value) @@ -439,7 +445,7 @@ class TermConfig(BaseModel): def load_terms_from_config( path: data.PathLike, field: Optional[str] = None, - term_registry: TermRegistry = term_registry, + term_registry: TermRegistry = default_term_registry, ) -> Dict[str, data.Term]: """Loads term definitions from a configuration file and registers them. @@ -490,6 +496,6 @@ def load_terms_from_config( def register_term( - key: str, term: data.Term, registry: TermRegistry = term_registry + key: str, term: data.Term, registry: TermRegistry = default_term_registry ) -> None: registry.add_term(key, term) diff --git a/batdetect2/targets/transform.py b/batdetect2/targets/transform.py index 7cfda06..29056a7 100644 --- a/batdetect2/targets/transform.py +++ b/batdetect2/targets/transform.py @@ -21,9 +21,6 @@ from batdetect2.targets.terms import ( get_tag_from_info, get_term_from_key, ) -from batdetect2.targets.terms import ( - term_registry as default_term_registry, -) __all__ = [ "DerivationRegistry", @@ -34,7 +31,7 @@ __all__ = [ "TransformConfig", "build_transform_from_rule", "build_transformation_from_config", - "derivation_registry", + "default_derivation_registry", "get_derivation", "load_transformation_config", "load_transformation_from_config", @@ -398,7 +395,7 @@ class DerivationRegistry(Mapping[str, Derivation]): return list(self._derivations.values()) -derivation_registry = DerivationRegistry() +default_derivation_registry = DerivationRegistry() """Global instance of the DerivationRegistry. Register custom derivation functions here to make them available by key @@ -409,7 +406,7 @@ in `DeriveTagRule` configuration. def get_derivation( key: str, import_derivation: bool = False, - registry: DerivationRegistry = derivation_registry, + registry: Optional[DerivationRegistry] = None, ): """Retrieve a derivation function by key, optionally importing it. @@ -443,6 +440,8 @@ def get_derivation( AttributeError If dynamic import fails because the function name isn't in the module. """ + registry = registry or default_derivation_registry + if not import_derivation or key in registry: return registry.get_derivation(key) @@ -458,10 +457,16 @@ def get_derivation( ) from err +TranformationRule = Annotated[ + Union[ReplaceRule, MapValueRule, DeriveTagRule], + Field(discriminator="rule_type"), +] + + def build_transform_from_rule( - rule: Union[ReplaceRule, MapValueRule, DeriveTagRule], - derivation_registry: DerivationRegistry = derivation_registry, - term_registry: TermRegistry = default_term_registry, + rule: TranformationRule, + derivation_registry: Optional[DerivationRegistry] = None, + term_registry: Optional[TermRegistry] = None, ) -> SoundEventTransformation: """Build a specific SoundEventTransformation function from a rule config. @@ -559,8 +564,8 @@ def build_transform_from_rule( def build_transformation_from_config( config: TransformConfig, - derivation_registry: DerivationRegistry = derivation_registry, - term_registry: TermRegistry = default_term_registry, + derivation_registry: Optional[DerivationRegistry] = None, + term_registry: Optional[TermRegistry] = None, ) -> SoundEventTransformation: """Build a composite transformation function from a TransformConfig. @@ -581,6 +586,7 @@ def build_transformation_from_config( SoundEventTransformation A single function that applies all configured transformations in order. """ + transforms = [ build_transform_from_rule( rule, @@ -590,14 +596,16 @@ def build_transformation_from_config( for rule in config.rules ] - def transformation( - sound_event_annotation: data.SoundEventAnnotation, - ) -> data.SoundEventAnnotation: - for transform in transforms: - sound_event_annotation = transform(sound_event_annotation) - return sound_event_annotation + return partial(apply_sequence_of_transforms, transforms=transforms) - return transformation + +def apply_sequence_of_transforms( + sound_event_annotation: data.SoundEventAnnotation, + transforms: list[SoundEventTransformation], +) -> data.SoundEventAnnotation: + for transform in transforms: + sound_event_annotation = transform(sound_event_annotation) + return sound_event_annotation def load_transformation_config( @@ -631,8 +639,8 @@ def load_transformation_config( def load_transformation_from_config( path: data.PathLike, field: Optional[str] = None, - derivation_registry: DerivationRegistry = derivation_registry, - term_registry: TermRegistry = default_term_registry, + derivation_registry: Optional[DerivationRegistry] = None, + term_registry: Optional[TermRegistry] = None, ) -> SoundEventTransformation: """Load transformation config from a file and build the final function. @@ -677,7 +685,7 @@ def load_transformation_from_config( def register_derivation( key: str, derivation: Derivation, - derivation_registry: DerivationRegistry = derivation_registry, + derivation_registry: Optional[DerivationRegistry] = None, ) -> None: """Register a new derivation function in the global registry. @@ -696,4 +704,5 @@ def register_derivation( KeyError If a derivation function with the same key is already registered. """ + derivation_registry = derivation_registry or default_derivation_registry derivation_registry.register(key, derivation) diff --git a/batdetect2/targets/types.py b/batdetect2/targets/types.py index 5bb260e..19a0ea6 100644 --- a/batdetect2/targets/types.py +++ b/batdetect2/targets/types.py @@ -19,8 +19,16 @@ from soundevent import data __all__ = [ "TargetProtocol", + "Position", + "Size", ] +Position = tuple[float, float] +"""A tuple representing (time, frequency) coordinates.""" + +Size = np.ndarray +"""A NumPy array representing the size dimensions of a target.""" + class TargetProtocol(Protocol): """Protocol defining the interface for the target definition pipeline. @@ -102,7 +110,7 @@ class TargetProtocol(Protocol): """ ... - def encode( + def encode_class( self, sound_event: data.SoundEventAnnotation, ) -> Optional[str]: @@ -123,7 +131,7 @@ class TargetProtocol(Protocol): """ ... - def decode(self, class_label: str) -> List[data.Tag]: + def decode_class(self, class_label: str) -> List[data.Tag]: """Decode a predicted class name back into representative tags. Parameters @@ -147,9 +155,9 @@ class TargetProtocol(Protocol): """ ... - def get_position( + def encode_roi( self, sound_event: data.SoundEventAnnotation - ) -> tuple[float, float]: + ) -> tuple[Position, Size]: """Extract the target reference position from the annotation's geometry. Calculates the `(time, frequency)` coordinate representing the primary @@ -173,37 +181,7 @@ class TargetProtocol(Protocol): """ ... - def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray: - """Calculate the target size dimensions from the annotation's geometry. - - Computes the relevant physical size (e.g., duration/width, - bandwidth/height from a bounding box) to produce - the numerical target values expected by the model. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The annotation containing the geometry (ROI) to process. - - Returns - ------- - np.ndarray - A NumPy array containing the size dimensions, matching the - order specified by the `dimension_names` attribute (e.g., - `[width, height]`). - - Raises - ------ - ValueError - If the annotation lacks geometry or if the size cannot be computed. - TypeError - If geometry type is unsupported. - """ - ... - - def recover_roi( - self, pos: tuple[float, float], dims: np.ndarray - ) -> data.Geometry: + def decode_roi(self, position: Position, size: Size) -> data.Geometry: """Recover the ROI geometry from a position and dimensions. Performs the inverse mapping of `get_position` and `get_size`. It takes diff --git a/batdetect2/train/callbacks.py b/batdetect2/train/callbacks.py index fe30f40..013b863 100644 --- a/batdetect2/train/callbacks.py +++ b/batdetect2/train/callbacks.py @@ -97,7 +97,7 @@ def _is_in_subclip( start_time: float, end_time: float, ) -> bool: - time, _ = targets.get_position(sound_event_annotation) + time, _ = targets.encode_roi(sound_event_annotation) return start_time <= time <= end_time diff --git a/batdetect2/train/labels.py b/batdetect2/train/labels.py index c0acf45..48a1383 100644 --- a/batdetect2/train/labels.py +++ b/batdetect2/train/labels.py @@ -138,7 +138,7 @@ def generate_clip_label( logger.debug( "Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events", uuid=clip_annotation.uuid, - num=len(clip_annotation.sound_events) + num=len(clip_annotation.sound_events), ) sound_events = [] @@ -260,7 +260,7 @@ def generate_heatmaps( continue # Get the position of the sound event - time, frequency = targets.get_position(sound_event_annotation) + (time, frequency), size = targets.encode_roi(sound_event_annotation) # Set 1.0 at the position of the sound event in the detection heatmap try: @@ -280,8 +280,6 @@ def generate_heatmaps( ) continue - size = targets.get_size(sound_event_annotation) - size_heatmap = arrays.set_value_at_pos( size_heatmap, size, @@ -291,7 +289,7 @@ def generate_heatmaps( # Get the class name of the sound event try: - class_name = targets.encode(sound_event_annotation) + class_name = targets.encode_class(sound_event_annotation) except ValueError as e: logger.warning( "Skipping annotation %s: Unexpected error while encoding " diff --git a/notebooks/signal_generation.py b/notebooks/signal_generation.py new file mode 100644 index 0000000..4dd6dbd --- /dev/null +++ b/notebooks/signal_generation.py @@ -0,0 +1,19 @@ +import marimo + +__generated_with = "0.13.15" +app = marimo.App(width="medium") + + +@app.cell +def _(): + from batdetect2.preprocess import build_preprocessor + return + + +@app.cell +def _(): + return + + +if __name__ == "__main__": + app.run() diff --git a/pyproject.toml b/pyproject.toml index 12cb0cb..bc5f805 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,9 @@ keywords = [ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.build.targets.wheel] +packages = ["batdetect2/"] + [project.scripts] batdetect2 = "batdetect2.cli:cli" diff --git a/tests/test_targets/test_rois.py b/tests/test_targets/test_rois.py index 79b5c45..c665ade 100644 --- a/tests/test_targets/test_rois.py +++ b/tests/test_targets/test_rois.py @@ -3,8 +3,8 @@ import pytest from soundevent import data from batdetect2.targets.rois import ( - DEFAULT_FREQUENCY_SCALE, DEFAULT_ANCHOR, + DEFAULT_FREQUENCY_SCALE, DEFAULT_TIME_SCALE, SIZE_HEIGHT, SIZE_WIDTH, @@ -12,7 +12,6 @@ from batdetect2.targets.rois import ( BBoxAnchorMapperConfig, _build_bounding_box, build_roi_mapper, - load_roi_mapper, ) @@ -22,6 +21,16 @@ def sample_bbox() -> data.BoundingBox: return data.BoundingBox(coordinates=[10.0, 100.0, 20.0, 200.0]) +@pytest.fixture +def sample_recording(create_recording) -> data.Recording: + return create_recording(duration=30, samplerate=4_000) + + +@pytest.fixture +def sample_sound_event(sample_bbox, sample_recording) -> data.SoundEvent: + return data.SoundEvent(geometry=sample_bbox, recording=sample_recording) + + @pytest.fixture def zero_bbox() -> data.BoundingBox: """A bounding box with zero duration and bandwidth.""" @@ -29,7 +38,13 @@ def zero_bbox() -> data.BoundingBox: @pytest.fixture -def default_encoder() -> AnchorBBoxMapper: +def zero_sound_event(zero_bbox, sample_recording) -> data.SoundEvent: + """A sample sound event with a zero-sized bounding box.""" + return data.SoundEvent(geometry=zero_bbox, recording=sample_recording) + + +@pytest.fixture +def default_mapper() -> AnchorBBoxMapper: """A BBoxEncoder with default settings.""" return AnchorBBoxMapper() @@ -37,36 +52,30 @@ def default_encoder() -> AnchorBBoxMapper: @pytest.fixture def custom_encoder() -> AnchorBBoxMapper: """A BBoxEncoder with custom settings.""" - return AnchorBBoxMapper(anchor="center", time_scale=1.0, frequency_scale=10.0) + return AnchorBBoxMapper( + anchor="center", time_scale=1.0, frequency_scale=10.0 + ) -def test_roi_config_defaults(): - """Test ROIConfig default values.""" - config = BBoxAnchorMapperConfig() - assert config.anchor == DEFAULT_ANCHOR - assert config.time_scale == DEFAULT_TIME_SCALE - assert config.frequency_scale == DEFAULT_FREQUENCY_SCALE +@pytest.fixture +def custom_mapper() -> AnchorBBoxMapper: + """An AnchorBBoxMapper with custom settings.""" + return AnchorBBoxMapper( + anchor="center", time_scale=1.0, frequency_scale=10.0 + ) -def test_roi_config_custom(): - """Test creating ROIConfig with custom values.""" - config = BBoxAnchorMapperConfig(anchor="center", time_scale=1.0, frequency_scale=10.0) - assert config.anchor == "center" - assert config.time_scale == 1.0 - assert config.frequency_scale == 10.0 - - -def test_bbox_encoder_init_defaults(default_encoder): +def test_bbox_encoder_init_defaults(default_mapper): """Test BBoxEncoder initialization with default arguments.""" - assert default_encoder.position == DEFAULT_ANCHOR - assert default_encoder.time_scale == DEFAULT_TIME_SCALE - assert default_encoder.frequency_scale == DEFAULT_FREQUENCY_SCALE - assert default_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT] + assert default_mapper.anchor == DEFAULT_ANCHOR + assert default_mapper.time_scale == DEFAULT_TIME_SCALE + assert default_mapper.frequency_scale == DEFAULT_FREQUENCY_SCALE + assert default_mapper.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT] def test_bbox_encoder_init_custom(custom_encoder): """Test BBoxEncoder initialization with custom arguments.""" - assert custom_encoder.position == "center" + assert custom_encoder.anchor == "center" assert custom_encoder.time_scale == 1.0 assert custom_encoder.frequency_scale == 10.0 assert custom_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT] @@ -87,52 +96,50 @@ POSITION_TEST_CASES = [ ] -@pytest.mark.parametrize("position_type, expected_pos", POSITION_TEST_CASES) -def test_bbox_encoder_get_roi_position( - sample_bbox, position_type, expected_pos +@pytest.mark.parametrize("anchor, expected_pos", POSITION_TEST_CASES) +def test_anchor_bbox_mapper_encode_position( + sample_sound_event, anchor, expected_pos ): - """Test get_roi_position for various position types.""" - encoder = AnchorBBoxMapper(anchor=position_type) - actual_pos = encoder.encode_position(sample_bbox) + """Test encode returns the correct position for various anchors.""" + encoder = AnchorBBoxMapper(anchor=anchor) + actual_pos, _ = encoder.encode(sample_sound_event) assert actual_pos == pytest.approx(expected_pos) -def test_bbox_encoder_get_roi_position_zero_box(zero_bbox): - """Test get_roi_position for a zero-sized box.""" - encoder = AnchorBBoxMapper(anchor="center") - assert encoder.encode_position(zero_bbox) == pytest.approx((15.0, 150.0)) - - -def test_bbox_encoder_get_roi_size_defaults(sample_bbox, default_encoder): - """Test get_roi_size with default scaling.""" +def test_anchor_bbox_mapper_encode_defaults( + sample_sound_event, default_mapper +): + """Test encode with default settings returns correct position and size.""" + expected_pos = (10.0, 100.0) # bottom-left expected_size = np.array( [ 10.0 * DEFAULT_TIME_SCALE, 100.0 * DEFAULT_FREQUENCY_SCALE, ] ) - actual_size = default_encoder.get_roi_size(sample_bbox) + actual_pos, actual_size = default_mapper.encode(sample_sound_event) + assert actual_pos == pytest.approx(expected_pos) np.testing.assert_allclose(actual_size, expected_size) assert actual_size.shape == (2,) -def test_bbox_encoder_get_roi_size_custom(sample_bbox, custom_encoder): - """Test get_roi_size with custom scaling.""" - expected_size = np.array( - [ - 10.0 * 1.0, - 100.0 * 10.0, - ] - ) - actual_size = custom_encoder.get_roi_size(sample_bbox) +def test_anchor_bbox_mapper_encode_custom(sample_sound_event, custom_mapper): + """Test encode with custom settings returns correct position and size.""" + expected_pos = (15.0, 150.0) # center + expected_size = np.array([10.0 * 1.0, 100.0 * 10.0]) + + actual_pos, actual_size = custom_mapper.encode(sample_sound_event) + assert actual_pos == pytest.approx(expected_pos) np.testing.assert_allclose(actual_size, expected_size) assert actual_size.shape == (2,) -def test_bbox_encoder_get_roi_size_zero_box(zero_bbox, default_encoder): - """Test get_roi_size for a zero-sized box.""" +def test_anchor_bbox_mapper_encode_zero_box(zero_sound_event, default_mapper): + """Test encode for a zero-sized box.""" + expected_pos = (15.0, 150.0) expected_size = np.array([0.0, 0.0]) - actual_size = default_encoder.get_roi_size(zero_bbox) + actual_pos, actual_size = default_mapper.encode(zero_sound_event) + assert actual_pos == pytest.approx(expected_pos) np.testing.assert_allclose(actual_size, expected_size) @@ -166,9 +173,9 @@ def test_build_bounding_box(position_type, expected_coords): np.testing.assert_allclose(bbox.coordinates, expected_coords) -def test_build_bounding_box_invalid_position(): +def test_build_bounding_box_invalid_anchor(): """Test _build_bounding_box raises error for invalid position.""" - with pytest.raises(ValueError, match="Invalid position"): + with pytest.raises(ValueError, match="Invalid anchor"): _build_bounding_box( (0, 0), 1, @@ -177,13 +184,16 @@ def test_build_bounding_box_invalid_position(): ) -@pytest.mark.parametrize("position_type, ref_pos", POSITION_TEST_CASES) -def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos): - """Test recover_roi correctly reconstructs the original bbox.""" - encoder = AnchorBBoxMapper(anchor=position_type) - scaled_dims = encoder.encode_size(sample_bbox) - - recovered_bbox = encoder.decode(ref_pos, scaled_dims) +@pytest.mark.parametrize( + "anchor", [anchor for anchor, _ in POSITION_TEST_CASES] +) +def test_anchor_bbox_mapper_encode_decode_roundtrip( + sample_sound_event, sample_bbox, anchor +): + """Test encode-decode roundtrip reconstructs the original bbox.""" + mapper = AnchorBBoxMapper(anchor=anchor) + position, size = mapper.encode(sample_sound_event) + recovered_bbox = mapper.decode(position, size) assert isinstance(recovered_bbox, data.BoundingBox) np.testing.assert_allclose( @@ -191,12 +201,12 @@ def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos): ) -def test_bbox_encoder_recover_roi_custom_scale(sample_bbox, custom_encoder): - """Test recover_roi with custom scaling factors.""" - ref_pos = custom_encoder.get_roi_position(sample_bbox) - scaled_dims = custom_encoder.get_roi_size(sample_bbox) - - recovered_bbox = custom_encoder.recover_roi(ref_pos, scaled_dims) +def test_anchor_bbox_mapper_roundtrip_custom_scale( + sample_sound_event, sample_bbox, custom_mapper +): + """Test encode-decode roundtrip with custom scaling factors.""" + position, size = custom_mapper.encode(sample_sound_event) + recovered_bbox = custom_mapper.decode(position, size) assert isinstance(recovered_bbox, data.BoundingBox) np.testing.assert_allclose( @@ -204,25 +214,26 @@ def test_bbox_encoder_recover_roi_custom_scale(sample_bbox, custom_encoder): ) -def test_bbox_encoder_recover_roi_zero_box(zero_bbox, default_encoder): - """Test recover_roi for a zero-sized box.""" - ref_pos = default_encoder.get_roi_position(zero_bbox) - scaled_dims = default_encoder.get_roi_size(zero_bbox) - recovered_bbox = default_encoder.recover_roi(ref_pos, scaled_dims) +def test_anchor_bbox_mapper_roundtrip_zero_box( + zero_sound_event, zero_bbox, default_mapper +): + """Test encode-decode roundtrip for a zero-sized box.""" + position, size = default_mapper.encode(zero_sound_event) + recovered_bbox = default_mapper.decode(position, size) np.testing.assert_allclose( recovered_bbox.coordinates, zero_bbox.coordinates, atol=1e-6 ) -def test_bbox_encoder_recover_roi_invalid_dims_shape(default_encoder): - """Test recover_roi raises ValueError for incorrect dims shape.""" +def test_anchor_bbox_mapper_decode_invalid_size_shape(default_mapper): + """Test decode raises ValueError for incorrect size shape.""" ref_pos = (10, 100) - with pytest.raises(ValueError): - default_encoder.recover_roi(ref_pos, np.array([1.0])) - with pytest.raises(ValueError): - default_encoder.recover_roi(ref_pos, np.array([1.0, 2.0, 3.0])) - with pytest.raises(ValueError): - default_encoder.recover_roi(ref_pos, np.array([[1.0], [2.0]])) + with pytest.raises(ValueError, match="does not have the expected shape"): + default_mapper.decode(ref_pos, np.array([1.0])) + with pytest.raises(ValueError, match="does not have the expected shape"): + default_mapper.decode(ref_pos, np.array([1.0, 2.0, 3.0])) + with pytest.raises(ValueError, match="does not have the expected shape"): + default_mapper.decode(ref_pos, np.array([[1.0], [2.0]])) def test_build_roi_mapper(): @@ -236,69 +247,3 @@ def test_build_roi_mapper(): assert mapper.anchor == config.anchor assert mapper.time_scale == config.time_scale assert mapper.frequency_scale == config.frequency_scale - - -@pytest.fixture -def sample_config_yaml_content() -> str: - """YAML content for a sample ROIConfig.""" - return f""" -position: center -time_scale: 500.0 -frequency_scale: {1 / 1000.0} -""" - - -@pytest.fixture -def nested_config_yaml_content() -> str: - """YAML content with ROIConfig nested under a field.""" - return f""" -model_settings: - preprocessing: - whatever: true - roi_mapping: - position: bottom-right - time_scale: {DEFAULT_TIME_SCALE} - frequency_scale: 0.01 -other_stuff: 123 -""" - - -def test_load_roi_mapper_simple(tmp_path, sample_config_yaml_content): - """Test loading a simple ROIConfig from YAML.""" - config_path = tmp_path / "config.yaml" - config_path.write_text(sample_config_yaml_content) - - mapper = load_roi_mapper(config_path) - - assert isinstance(mapper, AnchorBBoxMapper) - assert mapper.anchor == "center" - assert mapper.time_scale == 500.0 - assert mapper.frequency_scale == pytest.approx(1 / 1000.0) - - -def test_load_roi_mapper_nested(tmp_path, nested_config_yaml_content): - """Test loading a nested ROIConfig from YAML using 'field'.""" - config_path = tmp_path / "nested_config.yaml" - config_path.write_text(nested_config_yaml_content) - - mapper = load_roi_mapper(config_path, field="model_settings.roi_mapping") - - assert isinstance(mapper, AnchorBBoxMapper) - assert mapper.anchor == "bottom-right" - assert mapper.time_scale == DEFAULT_TIME_SCALE - assert mapper.frequency_scale == 0.01 - - -def test_load_roi_mapper_file_not_found(tmp_path): - """Test load_roi_mapper raises error if file doesn't exist.""" - non_existent_path = tmp_path / "not_real.yaml" - with pytest.raises(FileNotFoundError): - load_roi_mapper(non_existent_path) - - -def test_load_roi_mapper_invalid_field(tmp_path, sample_config_yaml_content): - """Test load_roi_mapper raises error for invalid field.""" - config_path = tmp_path / "config.yaml" - config_path.write_text(sample_config_yaml_content) - with pytest.raises(KeyError): - load_roi_mapper(config_path, field="invalid.path")