Fixed Target object after changes to roi

This commit is contained in:
mbsantiago 2025-06-21 13:47:04 +01:00
parent c559bcc682
commit e352dc40bd
15 changed files with 426 additions and 411 deletions

View File

@ -72,7 +72,7 @@ def iterate_over_sound_events(
sound_event_annotation
)
class_name = targets.encode(sound_event_annotation)
class_name = targets.encode_class(sound_event_annotation)
if class_name is None and exclude_generic:
continue

View File

@ -40,7 +40,7 @@ def match_sound_events_and_raw_predictions(
gt_uuid = target.uuid if target is not None else None
gt_det = target is not None
gt_class = targets.encode(target) if target is not None else None
gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0

View File

@ -526,7 +526,7 @@ class Postprocessor(PostprocessorProtocol):
return [
convert_xr_dataset_to_raw_prediction(
dataset,
self.targets.recover_roi,
self.targets.decode_roi,
)
for dataset in detection_datasets
]
@ -558,7 +558,7 @@ class Postprocessor(PostprocessorProtocol):
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
sound_event_decoder=self.targets.decode,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold,
)

View File

@ -23,7 +23,6 @@ object is via the `build_targets` or `load_targets` functions.
from typing import List, Optional
import numpy as np
from pydantic import Field
from soundevent import data
@ -63,7 +62,7 @@ from batdetect2.targets.terms import (
get_term_from_key,
individual,
register_term,
term_registry,
default_term_registry,
)
from batdetect2.targets.transform import (
DerivationRegistry,
@ -73,13 +72,13 @@ from batdetect2.targets.transform import (
SoundEventTransformation,
TransformConfig,
build_transformation_from_config,
derivation_registry,
default_derivation_registry,
get_derivation,
load_transformation_config,
load_transformation_from_config,
register_derivation,
)
from batdetect2.targets.types import TargetProtocol
from batdetect2.targets.types import Position, Size, TargetProtocol
__all__ = [
"ClassesConfig",
@ -291,7 +290,9 @@ class Targets(TargetProtocol):
return True
return self._filter_fn(sound_event)
def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]:
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
"""Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority)
@ -312,7 +313,7 @@ class Targets(TargetProtocol):
"""
return self._encode_fn(sound_event)
def decode(self, class_label: str) -> List[data.Tag]:
def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or
@ -352,9 +353,9 @@ class Targets(TargetProtocol):
return self._transform_fn(sound_event)
return sound_event
def get_position(
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]:
) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method.
@ -374,37 +375,9 @@ class Targets(TargetProtocol):
ValueError
If the annotation lacks geometry.
"""
return self._roi_mapper.encode_position(sound_event.sound_event)
return self._roi_mapper.encode(sound_event.sound_event)
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
"""Calculate the target size dimensions from the annotation's geometry.
Delegates to the internal ROI mapper's `get_roi_size` method, which
applies configured scaling factors.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
np.ndarray
NumPy array containing the size dimensions, matching the
order in `self.dimension_names` (e.g., `[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry.
"""
return self._roi_mapper.encode_size(sound_event.sound_event)
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
) -> data.Geometry:
def decode_roi(self, position: Position, size: Size) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which
@ -424,7 +397,7 @@ class Targets(TargetProtocol):
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
return self._roi_mapper.decode(pos, dims)
return self._roi_mapper.decode(position, size)
DEFAULT_CLASSES = [
@ -528,8 +501,8 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
def build_targets(
config: Optional[TargetConfig] = None,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
@ -613,8 +586,8 @@ def build_targets(
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets:
"""Load a Targets object directly from a configuration file.

View File

@ -11,7 +11,7 @@ from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
term_registry,
default_term_registry,
)
__all__ = [
@ -339,7 +339,7 @@ def _encode_with_multiple_classifiers(
def build_sound_event_encoder(
config: ClassesConfig,
term_registry: TermRegistry = term_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventEncoder:
"""Build a sound event encoder function from the classes configuration.
@ -433,7 +433,7 @@ def _decode_class(
def build_sound_event_decoder(
config: ClassesConfig,
term_registry: TermRegistry = term_registry,
term_registry: TermRegistry = default_term_registry,
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Build a sound event decoder function from the classes configuration.
@ -488,7 +488,7 @@ def build_sound_event_decoder(
def build_generic_class_tags(
config: ClassesConfig,
term_registry: TermRegistry = term_registry,
term_registry: TermRegistry = default_term_registry,
) -> List[data.Tag]:
"""Extract and build the list of tags for the generic class from config.
@ -553,7 +553,7 @@ def load_classes_config(
def load_encoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventEncoder:
"""Load a class encoder function directly from a configuration file.
@ -594,7 +594,7 @@ def load_encoder_from_config(
def load_decoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
term_registry: TermRegistry = default_term_registry,
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Load a class decoder function directly from a configuration file.

View File

@ -10,7 +10,7 @@ from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
term_registry,
default_term_registry,
)
__all__ = [
@ -156,7 +156,7 @@ def equal_tags(
def build_filter_from_rule(
rule: FilterRule,
term_registry: TermRegistry = term_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter:
"""Creates a callable filter function from a single FilterRule.
@ -243,7 +243,7 @@ class FilterConfig(BaseConfig):
def build_sound_event_filter(
config: FilterConfig,
term_registry: TermRegistry = term_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter:
"""Builds a merged filter function from a FilterConfig object.
@ -291,7 +291,7 @@ def load_filter_config(
def load_filter_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter:
"""Loads filter configuration from a file and builds the filter function.

View File

@ -1,23 +1,23 @@
"""Handles mapping between geometric ROIs and target representations.
This module defines the interface and provides implementation for converting
a sound event's Region of Interest (ROI), typically represented by a
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
suitable for use as a machine learning target. This usually involves:
This module defines a standardized interface (`ROITargetMapper`) for converting
a sound event's Region of Interest (ROI) into a target representation suitable
for machine learning models, and for decoding model outputs back into geometric
ROIs.
1. Extracting a single reference point (time, frequency) from the geometry.
2. Calculating relevant size dimensions (e.g., duration/width,
bandwidth/height) and applying scaling factors.
The core operations are:
1. **Encoding**: A `soundevent.data.SoundEvent` is mapped to a reference
`Position` (time, frequency) and a `Size` array. The method for
determining the position and size varies by the mapper implementation
(e.g., using a bounding box anchor or the point of peak energy).
2. **Decoding**: A `Position` and `Size` array are mapped back to an
approximate `soundevent.data.Geometry` (typically a `BoundingBox`).
It also provides the inverse operation: recovering an approximate geometric ROI
(like a `BoundingBox`) from a predicted reference point and predicted size
dimensions.
This logic is encapsulated within components adhering to the `ROITargetMapper`
protocol. Configuration for this mapping (e.g., which reference point to use,
scaling factors) is managed by the `ROIConfig`. This module separates the
*geometric* aspect of target definition from the *semantic* classification
handled in `batdetect2.targets.classes`.
This logic is encapsulated within specific mapper classes. Configuration for
each mapper (e.g., anchor point, scaling factors) is managed by a corresponding
Pydantic config object. The `ROIMapperConfig` type allows for flexibly
selecting and configuring the desired mapper. This module separates the
*geometric* aspect of target definition from *semantic* classification.
"""
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
@ -26,22 +26,26 @@ import numpy as np
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import Position, Size
__all__ = [
"ROITargetMapper",
"BBoxAnchorMapperConfig",
"Anchor",
"AnchorBBoxMapper",
"build_roi_mapper",
"load_roi_mapper",
"BBoxAnchorMapperConfig",
"DEFAULT_ANCHOR",
"SIZE_WIDTH",
"DEFAULT_FREQUENCY_SCALE",
"DEFAULT_TIME_SCALE",
"PeakEnergyBBoxMapper",
"PeakEnergyBBoxMapperConfig",
"ROIMapperConfig",
"ROITargetMapper",
"SIZE_HEIGHT",
"SIZE_ORDER",
"DEFAULT_TIME_SCALE",
"DEFAULT_FREQUENCY_SCALE",
"SIZE_WIDTH",
"build_roi_mapper",
]
Anchor = Literal[
@ -73,104 +77,94 @@ DEFAULT_TIME_SCALE = 1000.0
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
"""Default scaling factor for frequency bandwidth."""
DEFAULT_ANCHOR = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner)."""
Position = tuple[float, float]
Size = np.ndarray
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the methods required for converting a geometric region of interest
(`soundevent.data.Geometry`) into a target representation (reference point
and scaled dimensions) and for recovering an approximate ROI from that
Specifies the `encode` and `decode` methods required for converting a
`soundevent.data.SoundEvent` into a target representation (a reference
position and a size vector) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions returned by
`get_roi_size` and expected by `recover_roi`
(e.g., ['width', 'height']).
A list containing the names of the dimensions in the `Size` array
returned by `encode` and expected by `decode`.
"""
dimension_names: List[str]
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Extract the reference position from a geometry.
"""Encode a SoundEvent's geometry into a position and size.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry (e.g., BoundingBox, Polygon).
sound_event : data.SoundEvent
The input sound event, which must have a geometry attribute.
Returns
-------
Tuple[float, float]
The calculated reference position as (time, frequency) coordinates,
based on the implementing class's configuration (e.g., "center",
"bottom-left").
Tuple[Position, Size]
A tuple containing:
- The reference position as (time, frequency) coordinates.
- A NumPy array with the calculated size dimensions.
Raises
------
ValueError
If the position cannot be calculated for the given geometry type
or configured reference point.
If the sound event does not have a geometry.
"""
...
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Recover an approximate ROI from a position and target dimensions.
"""Decode a position and size back into a geometric ROI.
Performs the inverse mapping: takes a reference position and the
predicted dimensions and reconstructs a geometric representation.
Performs the inverse mapping: takes a reference position and size
dimensions and reconstructs a geometric representation.
Parameters
----------
position : Tuple[float, float]
position : Position
The reference position (time, frequency).
size : np.ndarray
NumPy array containing the dimensions, matching the order
specified by `dimension_names`.
size : Size
NumPy array containing the size dimensions, matching the order
and meaning specified by `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry.
The reconstructed geometry, typically a `BoundingBox`.
Raises
------
ValueError
If the number of provided dimensions `dims` does not match
`dimension_names` or if reconstruction fails.
If the `size` array has an unexpected shape or if reconstruction
fails.
"""
...
class BBoxAnchorMapperConfig(BaseConfig):
"""Configuration for mapping Regions of Interest (ROIs).
"""Configuration for `AnchorBBoxMapper`.
Defines parameters controlling how geometric ROIs are converted into
target representations (reference points and scaled sizes).
Defines parameters for converting ROIs into targets using a fixed anchor
point on the bounding box.
Attributes
----------
anchor : Anchor, default="bottom-left"
Specifies the reference point within the geometry (e.g., bounding box)
to use as the target location (e.g., "center", "bottom-left").
time_scale : float, default=1000.0
Scaling factor applied to the time duration (width) of the ROI
when calculating the target size representation. Must match model
expectations.
frequency_scale : float, default=1/859.375
Scaling factor applied to the frequency bandwidth (height) of the ROI
when calculating the target size representation. Must match model
expectations.
name : Literal["anchor_bbox"]
The unique identifier for this mapper type.
anchor : Anchor
Specifies the anchor point within the bounding box to use as the
target's reference position (e.g., "center", "bottom-left").
time_scale : float
Scaling factor applied to the time duration (width) of the ROI.
frequency_scale : float
Scaling factor applied to the frequency bandwidth (height) of the ROI.
"""
name: Literal["anchor_bbox"] = "anchor_bbox"
@ -180,23 +174,28 @@ class BBoxAnchorMapperConfig(BaseConfig):
class AnchorBBoxMapper(ROITargetMapper):
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
"""Maps ROIs using a bounding box anchor point and width/height.
This class implements the ROI mapping protocol primarily for
`soundevent.data.BoundingBox` geometry. It extracts reference points,
calculates scaled width/height, and recovers bounding boxes based on
configured position and scaling factors.
This class implements the `ROITargetMapper` protocol for `BoundingBox`
geometries.
**Encoding**: The `position` is a fixed anchor point on the bounding box
(e.g., "bottom-left"). The `size` is a 2-element array containing the
scaled width and height of the box.
**Decoding**: Reconstructs a `BoundingBox` from an anchor point and
scaled width/height.
Attributes
----------
dimension_names : List[str]
Specifies the output dimension names as ['width', 'height'].
The output dimension names: `['width', 'height']`.
anchor : Anchor
The configured reference point type (e.g., "center", "bottom-left").
The configured anchor point type (e.g., "center", "bottom-left").
time_scale : float
The configured scaling factor for the time dimension (width).
The scaling factor for the time dimension (width).
frequency_scale : float
The configured scaling factor for the frequency dimension (height).
The scaling factor for the frequency dimension (height).
"""
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
@ -211,11 +210,11 @@ class AnchorBBoxMapper(ROITargetMapper):
Parameters
----------
anchor : Anchor, default="bottom-left"
anchor : Anchor
Reference point type within the bounding box.
time_scale : float, default=1000.0
time_scale : float
Scaling factor for time duration (width).
frequency_scale : float, default=1/859.375
frequency_scale : float
Scaling factor for frequency bandwidth (height).
"""
self.anchor: Anchor = anchor
@ -223,19 +222,20 @@ class AnchorBBoxMapper(ROITargetMapper):
self.frequency_scale = frequency_scale
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
"""Extract the configured reference position from the geometry.
"""Encode a SoundEvent into an anchor position and scaled box size.
Uses `soundevent.geometry.get_geometry_point`.
The position is determined by the configured anchor on the sound
event's bounding box. The size is the scaled width and height.
Parameters
----------
geom : soundevent.data.Geometry
Input geometry (e.g., BoundingBox).
sound_event : data.SoundEvent
The input sound event with a geometry.
Returns
-------
Tuple[float, float]
Reference position (time, frequency).
Tuple[Position, Size]
A tuple of (anchor_position, [scaled_width, scaled_height]).
"""
from soundevent import geometry
@ -267,29 +267,27 @@ class AnchorBBoxMapper(ROITargetMapper):
position: Position,
size: Size,
) -> data.Geometry:
"""Recover a BoundingBox from a position and scaled dimensions.
"""Recover a BoundingBox from an anchor position and scaled size.
Un-scales the input dimensions using the configured factors and
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
the given reference `pos` according to the configured `position` type.
Un-scales the input dimensions and reconstructs a
`soundevent.data.BoundingBox` relative to the given anchor position.
Parameters
----------
pos : Tuple[float, float]
Reference position (time, frequency).
dims : np.ndarray
NumPy array containing the *scaled* dimensions, expected order is
[scaled_width, scaled_height].
position : Position
Reference anchor position (time, frequency).
size : Size
NumPy array containing the scaled [width, height].
Returns
-------
soundevent.data.BoundingBox
data.BoundingBox
The reconstructed bounding box.
Raises
------
ValueError
If `dims` does not have the expected shape (length 2).
If `size` does not have the expected shape (length 2).
"""
if size.ndim != 1 or size.shape[0] != 2:
@ -308,6 +306,24 @@ class AnchorBBoxMapper(ROITargetMapper):
class PeakEnergyBBoxMapperConfig(BaseConfig):
"""Configuration for `PeakEnergyBBoxMapper`.
Attributes
----------
name : Literal["peak_energy_bbox"]
The unique identifier for this mapper type.
preprocessing : PreprocessingConfig
Configuration for the spectrogram preprocessor needed to find the
peak energy.
loading_buffer : float
Seconds to add to each side of the ROI when loading audio to ensure
the peak is captured accurately, avoiding boundary effects.
time_scale : float
Scaling factor applied to the time dimensions.
frequency_scale : float
Scaling factor applied to the frequency dimensions.
"""
name: Literal["peak_energy_bbox"]
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
@ -318,9 +334,30 @@ class PeakEnergyBBoxMapperConfig(BaseConfig):
class PeakEnergyBBoxMapper(ROITargetMapper):
"""
Encodes the ROI using the location of the peak energy within the bounding box
as the 'position' and the distances from that point to the box edges as the 'size'.
"""Maps ROIs using the peak energy point and distances to edges.
This class implements the `ROITargetMapper` protocol.
**Encoding**: The `position` is the (time, frequency) coordinate of the
point with the highest energy within the sound event's bounding box. The
`size` is a 4-element array representing the scaled distances from this
peak energy point to the left, bottom, right, and top edges of the box.
**Decoding**: Reconstructs a `BoundingBox` by adding/subtracting the
un-scaled distances from the peak energy point.
Attributes
----------
dimension_names : List[str]
The output dimension names: `['left', 'bottom', 'right', 'top']`.
preprocessor : PreprocessorProtocol
The spectrogram preprocessor instance.
time_scale : float
The scaling factor for time-based distances.
frequency_scale : float
The scaling factor for frequency-based distances.
loading_buffer : float
The buffer used for loading audio around the ROI.
"""
dimension_names = ["left", "bottom", "right", "top"]
@ -332,6 +369,19 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
loading_buffer: float = 0.01,
):
"""Initialize the PeakEnergyBBoxMapper.
Parameters
----------
preprocessor : PreprocessorProtocol
An initialized preprocessor for generating spectrograms.
time_scale : float
Scaling factor for time dimensions (left, right distances).
frequency_scale : float
Scaling factor for frequency dimensions (bottom, top distances).
loading_buffer : float
Buffer in seconds to add when loading audio clips.
"""
self.preprocessor = preprocessor
self.time_scale = time_scale
self.frequency_scale = frequency_scale
@ -341,6 +391,21 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
self,
sound_event: data.SoundEvent,
) -> tuple[Position, Size]:
"""Encode a SoundEvent into a peak energy position and edge distances.
Finds the peak energy coordinates within the event's bounding box
and calculates the scaled distances from this point to the box edges.
Parameters
----------
sound_event : data.SoundEvent
The input sound event with a geometry and associated recording.
Returns
-------
Tuple[Position, Size]
A tuple of (peak_position, [l, b, r, t] distances).
"""
from soundevent import geometry
geom = sound_event.geometry
@ -377,6 +442,20 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
return (time, freq), size
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Recover a BoundingBox from a peak position and edge distances.
Parameters
----------
position : Position
The reference peak energy position (time, frequency).
size : Size
NumPy array with scaled distances [left, bottom, right, top].
Returns
-------
data.BoundingBox
The reconstructed bounding box.
"""
time, freq = position
left, bottom, right, top = size
@ -394,21 +473,30 @@ ROIMapperConfig = Annotated[
Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig],
Field(discriminator="name"),
]
"""A discriminated union of all supported ROI mapper configurations.
This type allows for selecting and configuring different `ROITargetMapper`
implementations by using the `name` field as a discriminator.
"""
def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from configuration.
"""Factory function to create an ROITargetMapper from a config object.
Parameters
----------
config : ROIConfig
Configuration object specifying ROI mapping parameters.
config : ROIMapperConfig
A configuration object specifying the mapper type and its parameters.
Returns
-------
ROITargetMapper
An initialized `BBoxEncoder` instance configured with the settings
from `config`.
An initialized ROI mapper instance.
Raises
------
NotImplementedError
If the `name` in the config does not correspond to a known mapper.
"""
if config.name == "anchor_bbox":
return AnchorBBoxMapper(
@ -431,39 +519,6 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
)
def load_roi_mapper(
path: data.PathLike, field: Optional[str] = None
) -> ROITargetMapper:
"""Load ROI mapping configuration from a file and build the mapper.
Convenience function that loads an `ROIConfig` from the specified file
(and optional field) and then uses `build_roi_mapper` to create the
corresponding `ROITargetMapper` instance.
Parameters
----------
path : PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing the
ROI configuration. If None, the entire file content is used.
Returns
-------
ROITargetMapper
An initialized ROI mapper instance based on the configuration file.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
If the configuration file cannot be found, parsed, validated, or if
the specified `field` is invalid.
"""
config = load_config(path=path, schema=BBoxAnchorMapperConfig, field=field)
return build_roi_mapper(config)
VALID_ANCHORS = [
"bottom-left",
"bottom-right",
@ -501,7 +556,7 @@ def _build_bounding_box(
bandwidth : float
The required *unscaled* frequency bandwidth (height) of the bounding
box.
anchor : Anchor, default="bottom-left"
anchor : Anchor
Specifies which part of the bounding box the input `pos` corresponds to.
Returns
@ -565,6 +620,35 @@ def get_peak_energy_coordinates(
high_freq: Optional[float] = None,
loading_buffer: float = 0.05,
) -> Position:
"""Find the coordinates of the highest energy point in a spectrogram.
Generates a spectrogram for a specified time-frequency region of a
recording and returns the (time, frequency) coordinates of the pixel with
the maximum value.
Parameters
----------
recording : data.Recording
The recording to analyze.
preprocessor : PreprocessorProtocol
The processor to convert audio to a spectrogram.
start_time : float, default=0
The start time of the region of interest.
end_time : float, optional
The end time of the region of interest. Defaults to recording duration.
low_freq : float, default=0
The low frequency of the region of interest.
high_freq : float, optional
The high frequency of the region of interest. Defaults to Nyquist.
loading_buffer : float, default=0.05
Buffer in seconds to add around the time range when loading the clip
to mitigate border effects from transformations like STFT.
Returns
-------
Position
A (time, frequency) tuple for the peak energy location.
"""
if end_time is None:
end_time = recording.duration
end_time = min(end_time, recording.duration)

View File

@ -230,7 +230,7 @@ class TermRegistry(Mapping[str, data.Term]):
del self._terms[key]
term_registry = TermRegistry(
default_term_registry = TermRegistry(
terms=dict(
[
*getmembers(terms, lambda x: isinstance(x, data.Term)),
@ -252,7 +252,7 @@ is explicitly passed.
def get_term_from_key(
key: str,
term_registry: TermRegistry = term_registry,
term_registry: Optional[TermRegistry] = None,
) -> data.Term:
"""Convenience function to retrieve a term by key from a registry.
@ -277,10 +277,13 @@ def get_term_from_key(
KeyError
If the key is not found in the specified registry.
"""
term_registry = term_registry or default_term_registry
return term_registry.get_term(key)
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
def get_term_keys(
term_registry: TermRegistry = default_term_registry,
) -> List[str]:
"""Convenience function to get all registered keys from a registry.
Uses the global default registry unless a specific `term_registry`
@ -299,7 +302,9 @@ def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
return term_registry.get_keys()
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
def get_terms(
term_registry: TermRegistry = default_term_registry,
) -> List[data.Term]:
"""Convenience function to get all registered terms from a registry.
Uses the global default registry unless a specific `term_registry`
@ -342,7 +347,7 @@ class TagInfo(BaseModel):
def get_tag_from_info(
tag_info: TagInfo,
term_registry: TermRegistry = term_registry,
term_registry: Optional[TermRegistry] = None,
) -> data.Tag:
"""Creates a soundevent.data.Tag object from TagInfo data.
@ -368,6 +373,7 @@ def get_tag_from_info(
If the term key specified in `tag_info.key` is not found
in the registry.
"""
term_registry = term_registry or default_term_registry
term = get_term_from_key(tag_info.key, term_registry=term_registry)
return data.Tag(term=term, value=tag_info.value)
@ -439,7 +445,7 @@ class TermConfig(BaseModel):
def load_terms_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
term_registry: TermRegistry = default_term_registry,
) -> Dict[str, data.Term]:
"""Loads term definitions from a configuration file and registers them.
@ -490,6 +496,6 @@ def load_terms_from_config(
def register_term(
key: str, term: data.Term, registry: TermRegistry = term_registry
key: str, term: data.Term, registry: TermRegistry = default_term_registry
) -> None:
registry.add_term(key, term)

View File

@ -21,9 +21,6 @@ from batdetect2.targets.terms import (
get_tag_from_info,
get_term_from_key,
)
from batdetect2.targets.terms import (
term_registry as default_term_registry,
)
__all__ = [
"DerivationRegistry",
@ -34,7 +31,7 @@ __all__ = [
"TransformConfig",
"build_transform_from_rule",
"build_transformation_from_config",
"derivation_registry",
"default_derivation_registry",
"get_derivation",
"load_transformation_config",
"load_transformation_from_config",
@ -398,7 +395,7 @@ class DerivationRegistry(Mapping[str, Derivation]):
return list(self._derivations.values())
derivation_registry = DerivationRegistry()
default_derivation_registry = DerivationRegistry()
"""Global instance of the DerivationRegistry.
Register custom derivation functions here to make them available by key
@ -409,7 +406,7 @@ in `DeriveTagRule` configuration.
def get_derivation(
key: str,
import_derivation: bool = False,
registry: DerivationRegistry = derivation_registry,
registry: Optional[DerivationRegistry] = None,
):
"""Retrieve a derivation function by key, optionally importing it.
@ -443,6 +440,8 @@ def get_derivation(
AttributeError
If dynamic import fails because the function name isn't in the module.
"""
registry = registry or default_derivation_registry
if not import_derivation or key in registry:
return registry.get_derivation(key)
@ -458,10 +457,16 @@ def get_derivation(
) from err
TranformationRule = Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
def build_transform_from_rule(
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
rule: TranformationRule,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Build a specific SoundEventTransformation function from a rule config.
@ -559,8 +564,8 @@ def build_transform_from_rule(
def build_transformation_from_config(
config: TransformConfig,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Build a composite transformation function from a TransformConfig.
@ -581,6 +586,7 @@ def build_transformation_from_config(
SoundEventTransformation
A single function that applies all configured transformations in order.
"""
transforms = [
build_transform_from_rule(
rule,
@ -590,15 +596,17 @@ def build_transformation_from_config(
for rule in config.rules
]
def transformation(
return partial(apply_sequence_of_transforms, transforms=transforms)
def apply_sequence_of_transforms(
sound_event_annotation: data.SoundEventAnnotation,
transforms: list[SoundEventTransformation],
) -> data.SoundEventAnnotation:
for transform in transforms:
sound_event_annotation = transform(sound_event_annotation)
return sound_event_annotation
return transformation
def load_transformation_config(
path: data.PathLike, field: Optional[str] = None
@ -631,8 +639,8 @@ def load_transformation_config(
def load_transformation_from_config(
path: data.PathLike,
field: Optional[str] = None,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Load transformation config from a file and build the final function.
@ -677,7 +685,7 @@ def load_transformation_from_config(
def register_derivation(
key: str,
derivation: Derivation,
derivation_registry: DerivationRegistry = derivation_registry,
derivation_registry: Optional[DerivationRegistry] = None,
) -> None:
"""Register a new derivation function in the global registry.
@ -696,4 +704,5 @@ def register_derivation(
KeyError
If a derivation function with the same key is already registered.
"""
derivation_registry = derivation_registry or default_derivation_registry
derivation_registry.register(key, derivation)

View File

@ -19,8 +19,16 @@ from soundevent import data
__all__ = [
"TargetProtocol",
"Position",
"Size",
]
Position = tuple[float, float]
"""A tuple representing (time, frequency) coordinates."""
Size = np.ndarray
"""A NumPy array representing the size dimensions of a target."""
class TargetProtocol(Protocol):
"""Protocol defining the interface for the target definition pipeline.
@ -102,7 +110,7 @@ class TargetProtocol(Protocol):
"""
...
def encode(
def encode_class(
self,
sound_event: data.SoundEventAnnotation,
) -> Optional[str]:
@ -123,7 +131,7 @@ class TargetProtocol(Protocol):
"""
...
def decode(self, class_label: str) -> List[data.Tag]:
def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Parameters
@ -147,9 +155,9 @@ class TargetProtocol(Protocol):
"""
...
def get_position(
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]:
) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's geometry.
Calculates the `(time, frequency)` coordinate representing the primary
@ -173,37 +181,7 @@ class TargetProtocol(Protocol):
"""
...
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
"""Calculate the target size dimensions from the annotation's geometry.
Computes the relevant physical size (e.g., duration/width,
bandwidth/height from a bounding box) to produce
the numerical target values expected by the model.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
np.ndarray
A NumPy array containing the size dimensions, matching the
order specified by the `dimension_names` attribute (e.g.,
`[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry or if the size cannot be computed.
TypeError
If geometry type is unsupported.
"""
...
def recover_roi(
self, pos: tuple[float, float], dims: np.ndarray
) -> data.Geometry:
def decode_roi(self, position: Position, size: Size) -> data.Geometry:
"""Recover the ROI geometry from a position and dimensions.
Performs the inverse mapping of `get_position` and `get_size`. It takes

View File

@ -97,7 +97,7 @@ def _is_in_subclip(
start_time: float,
end_time: float,
) -> bool:
time, _ = targets.get_position(sound_event_annotation)
time, _ = targets.encode_roi(sound_event_annotation)
return start_time <= time <= end_time

View File

@ -138,7 +138,7 @@ def generate_clip_label(
logger.debug(
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
uuid=clip_annotation.uuid,
num=len(clip_annotation.sound_events)
num=len(clip_annotation.sound_events),
)
sound_events = []
@ -260,7 +260,7 @@ def generate_heatmaps(
continue
# Get the position of the sound event
time, frequency = targets.get_position(sound_event_annotation)
(time, frequency), size = targets.encode_roi(sound_event_annotation)
# Set 1.0 at the position of the sound event in the detection heatmap
try:
@ -280,8 +280,6 @@ def generate_heatmaps(
)
continue
size = targets.get_size(sound_event_annotation)
size_heatmap = arrays.set_value_at_pos(
size_heatmap,
size,
@ -291,7 +289,7 @@ def generate_heatmaps(
# Get the class name of the sound event
try:
class_name = targets.encode(sound_event_annotation)
class_name = targets.encode_class(sound_event_annotation)
except ValueError as e:
logger.warning(
"Skipping annotation %s: Unexpected error while encoding "

View File

@ -0,0 +1,19 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
@app.cell
def _():
from batdetect2.preprocess import build_preprocessor
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -62,6 +62,9 @@ keywords = [
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["batdetect2/"]
[project.scripts]
batdetect2 = "batdetect2.cli:cli"

View File

@ -3,8 +3,8 @@ import pytest
from soundevent import data
from batdetect2.targets.rois import (
DEFAULT_FREQUENCY_SCALE,
DEFAULT_ANCHOR,
DEFAULT_FREQUENCY_SCALE,
DEFAULT_TIME_SCALE,
SIZE_HEIGHT,
SIZE_WIDTH,
@ -12,7 +12,6 @@ from batdetect2.targets.rois import (
BBoxAnchorMapperConfig,
_build_bounding_box,
build_roi_mapper,
load_roi_mapper,
)
@ -22,6 +21,16 @@ def sample_bbox() -> data.BoundingBox:
return data.BoundingBox(coordinates=[10.0, 100.0, 20.0, 200.0])
@pytest.fixture
def sample_recording(create_recording) -> data.Recording:
return create_recording(duration=30, samplerate=4_000)
@pytest.fixture
def sample_sound_event(sample_bbox, sample_recording) -> data.SoundEvent:
return data.SoundEvent(geometry=sample_bbox, recording=sample_recording)
@pytest.fixture
def zero_bbox() -> data.BoundingBox:
"""A bounding box with zero duration and bandwidth."""
@ -29,7 +38,13 @@ def zero_bbox() -> data.BoundingBox:
@pytest.fixture
def default_encoder() -> AnchorBBoxMapper:
def zero_sound_event(zero_bbox, sample_recording) -> data.SoundEvent:
"""A sample sound event with a zero-sized bounding box."""
return data.SoundEvent(geometry=zero_bbox, recording=sample_recording)
@pytest.fixture
def default_mapper() -> AnchorBBoxMapper:
"""A BBoxEncoder with default settings."""
return AnchorBBoxMapper()
@ -37,36 +52,30 @@ def default_encoder() -> AnchorBBoxMapper:
@pytest.fixture
def custom_encoder() -> AnchorBBoxMapper:
"""A BBoxEncoder with custom settings."""
return AnchorBBoxMapper(anchor="center", time_scale=1.0, frequency_scale=10.0)
return AnchorBBoxMapper(
anchor="center", time_scale=1.0, frequency_scale=10.0
)
def test_roi_config_defaults():
"""Test ROIConfig default values."""
config = BBoxAnchorMapperConfig()
assert config.anchor == DEFAULT_ANCHOR
assert config.time_scale == DEFAULT_TIME_SCALE
assert config.frequency_scale == DEFAULT_FREQUENCY_SCALE
@pytest.fixture
def custom_mapper() -> AnchorBBoxMapper:
"""An AnchorBBoxMapper with custom settings."""
return AnchorBBoxMapper(
anchor="center", time_scale=1.0, frequency_scale=10.0
)
def test_roi_config_custom():
"""Test creating ROIConfig with custom values."""
config = BBoxAnchorMapperConfig(anchor="center", time_scale=1.0, frequency_scale=10.0)
assert config.anchor == "center"
assert config.time_scale == 1.0
assert config.frequency_scale == 10.0
def test_bbox_encoder_init_defaults(default_encoder):
def test_bbox_encoder_init_defaults(default_mapper):
"""Test BBoxEncoder initialization with default arguments."""
assert default_encoder.position == DEFAULT_ANCHOR
assert default_encoder.time_scale == DEFAULT_TIME_SCALE
assert default_encoder.frequency_scale == DEFAULT_FREQUENCY_SCALE
assert default_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
assert default_mapper.anchor == DEFAULT_ANCHOR
assert default_mapper.time_scale == DEFAULT_TIME_SCALE
assert default_mapper.frequency_scale == DEFAULT_FREQUENCY_SCALE
assert default_mapper.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
def test_bbox_encoder_init_custom(custom_encoder):
"""Test BBoxEncoder initialization with custom arguments."""
assert custom_encoder.position == "center"
assert custom_encoder.anchor == "center"
assert custom_encoder.time_scale == 1.0
assert custom_encoder.frequency_scale == 10.0
assert custom_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
@ -87,52 +96,50 @@ POSITION_TEST_CASES = [
]
@pytest.mark.parametrize("position_type, expected_pos", POSITION_TEST_CASES)
def test_bbox_encoder_get_roi_position(
sample_bbox, position_type, expected_pos
@pytest.mark.parametrize("anchor, expected_pos", POSITION_TEST_CASES)
def test_anchor_bbox_mapper_encode_position(
sample_sound_event, anchor, expected_pos
):
"""Test get_roi_position for various position types."""
encoder = AnchorBBoxMapper(anchor=position_type)
actual_pos = encoder.encode_position(sample_bbox)
"""Test encode returns the correct position for various anchors."""
encoder = AnchorBBoxMapper(anchor=anchor)
actual_pos, _ = encoder.encode(sample_sound_event)
assert actual_pos == pytest.approx(expected_pos)
def test_bbox_encoder_get_roi_position_zero_box(zero_bbox):
"""Test get_roi_position for a zero-sized box."""
encoder = AnchorBBoxMapper(anchor="center")
assert encoder.encode_position(zero_bbox) == pytest.approx((15.0, 150.0))
def test_bbox_encoder_get_roi_size_defaults(sample_bbox, default_encoder):
"""Test get_roi_size with default scaling."""
def test_anchor_bbox_mapper_encode_defaults(
sample_sound_event, default_mapper
):
"""Test encode with default settings returns correct position and size."""
expected_pos = (10.0, 100.0) # bottom-left
expected_size = np.array(
[
10.0 * DEFAULT_TIME_SCALE,
100.0 * DEFAULT_FREQUENCY_SCALE,
]
)
actual_size = default_encoder.get_roi_size(sample_bbox)
actual_pos, actual_size = default_mapper.encode(sample_sound_event)
assert actual_pos == pytest.approx(expected_pos)
np.testing.assert_allclose(actual_size, expected_size)
assert actual_size.shape == (2,)
def test_bbox_encoder_get_roi_size_custom(sample_bbox, custom_encoder):
"""Test get_roi_size with custom scaling."""
expected_size = np.array(
[
10.0 * 1.0,
100.0 * 10.0,
]
)
actual_size = custom_encoder.get_roi_size(sample_bbox)
def test_anchor_bbox_mapper_encode_custom(sample_sound_event, custom_mapper):
"""Test encode with custom settings returns correct position and size."""
expected_pos = (15.0, 150.0) # center
expected_size = np.array([10.0 * 1.0, 100.0 * 10.0])
actual_pos, actual_size = custom_mapper.encode(sample_sound_event)
assert actual_pos == pytest.approx(expected_pos)
np.testing.assert_allclose(actual_size, expected_size)
assert actual_size.shape == (2,)
def test_bbox_encoder_get_roi_size_zero_box(zero_bbox, default_encoder):
"""Test get_roi_size for a zero-sized box."""
def test_anchor_bbox_mapper_encode_zero_box(zero_sound_event, default_mapper):
"""Test encode for a zero-sized box."""
expected_pos = (15.0, 150.0)
expected_size = np.array([0.0, 0.0])
actual_size = default_encoder.get_roi_size(zero_bbox)
actual_pos, actual_size = default_mapper.encode(zero_sound_event)
assert actual_pos == pytest.approx(expected_pos)
np.testing.assert_allclose(actual_size, expected_size)
@ -166,9 +173,9 @@ def test_build_bounding_box(position_type, expected_coords):
np.testing.assert_allclose(bbox.coordinates, expected_coords)
def test_build_bounding_box_invalid_position():
def test_build_bounding_box_invalid_anchor():
"""Test _build_bounding_box raises error for invalid position."""
with pytest.raises(ValueError, match="Invalid position"):
with pytest.raises(ValueError, match="Invalid anchor"):
_build_bounding_box(
(0, 0),
1,
@ -177,13 +184,16 @@ def test_build_bounding_box_invalid_position():
)
@pytest.mark.parametrize("position_type, ref_pos", POSITION_TEST_CASES)
def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos):
"""Test recover_roi correctly reconstructs the original bbox."""
encoder = AnchorBBoxMapper(anchor=position_type)
scaled_dims = encoder.encode_size(sample_bbox)
recovered_bbox = encoder.decode(ref_pos, scaled_dims)
@pytest.mark.parametrize(
"anchor", [anchor for anchor, _ in POSITION_TEST_CASES]
)
def test_anchor_bbox_mapper_encode_decode_roundtrip(
sample_sound_event, sample_bbox, anchor
):
"""Test encode-decode roundtrip reconstructs the original bbox."""
mapper = AnchorBBoxMapper(anchor=anchor)
position, size = mapper.encode(sample_sound_event)
recovered_bbox = mapper.decode(position, size)
assert isinstance(recovered_bbox, data.BoundingBox)
np.testing.assert_allclose(
@ -191,12 +201,12 @@ def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos):
)
def test_bbox_encoder_recover_roi_custom_scale(sample_bbox, custom_encoder):
"""Test recover_roi with custom scaling factors."""
ref_pos = custom_encoder.get_roi_position(sample_bbox)
scaled_dims = custom_encoder.get_roi_size(sample_bbox)
recovered_bbox = custom_encoder.recover_roi(ref_pos, scaled_dims)
def test_anchor_bbox_mapper_roundtrip_custom_scale(
sample_sound_event, sample_bbox, custom_mapper
):
"""Test encode-decode roundtrip with custom scaling factors."""
position, size = custom_mapper.encode(sample_sound_event)
recovered_bbox = custom_mapper.decode(position, size)
assert isinstance(recovered_bbox, data.BoundingBox)
np.testing.assert_allclose(
@ -204,25 +214,26 @@ def test_bbox_encoder_recover_roi_custom_scale(sample_bbox, custom_encoder):
)
def test_bbox_encoder_recover_roi_zero_box(zero_bbox, default_encoder):
"""Test recover_roi for a zero-sized box."""
ref_pos = default_encoder.get_roi_position(zero_bbox)
scaled_dims = default_encoder.get_roi_size(zero_bbox)
recovered_bbox = default_encoder.recover_roi(ref_pos, scaled_dims)
def test_anchor_bbox_mapper_roundtrip_zero_box(
zero_sound_event, zero_bbox, default_mapper
):
"""Test encode-decode roundtrip for a zero-sized box."""
position, size = default_mapper.encode(zero_sound_event)
recovered_bbox = default_mapper.decode(position, size)
np.testing.assert_allclose(
recovered_bbox.coordinates, zero_bbox.coordinates, atol=1e-6
)
def test_bbox_encoder_recover_roi_invalid_dims_shape(default_encoder):
"""Test recover_roi raises ValueError for incorrect dims shape."""
def test_anchor_bbox_mapper_decode_invalid_size_shape(default_mapper):
"""Test decode raises ValueError for incorrect size shape."""
ref_pos = (10, 100)
with pytest.raises(ValueError):
default_encoder.recover_roi(ref_pos, np.array([1.0]))
with pytest.raises(ValueError):
default_encoder.recover_roi(ref_pos, np.array([1.0, 2.0, 3.0]))
with pytest.raises(ValueError):
default_encoder.recover_roi(ref_pos, np.array([[1.0], [2.0]]))
with pytest.raises(ValueError, match="does not have the expected shape"):
default_mapper.decode(ref_pos, np.array([1.0]))
with pytest.raises(ValueError, match="does not have the expected shape"):
default_mapper.decode(ref_pos, np.array([1.0, 2.0, 3.0]))
with pytest.raises(ValueError, match="does not have the expected shape"):
default_mapper.decode(ref_pos, np.array([[1.0], [2.0]]))
def test_build_roi_mapper():
@ -236,69 +247,3 @@ def test_build_roi_mapper():
assert mapper.anchor == config.anchor
assert mapper.time_scale == config.time_scale
assert mapper.frequency_scale == config.frequency_scale
@pytest.fixture
def sample_config_yaml_content() -> str:
"""YAML content for a sample ROIConfig."""
return f"""
position: center
time_scale: 500.0
frequency_scale: {1 / 1000.0}
"""
@pytest.fixture
def nested_config_yaml_content() -> str:
"""YAML content with ROIConfig nested under a field."""
return f"""
model_settings:
preprocessing:
whatever: true
roi_mapping:
position: bottom-right
time_scale: {DEFAULT_TIME_SCALE}
frequency_scale: 0.01
other_stuff: 123
"""
def test_load_roi_mapper_simple(tmp_path, sample_config_yaml_content):
"""Test loading a simple ROIConfig from YAML."""
config_path = tmp_path / "config.yaml"
config_path.write_text(sample_config_yaml_content)
mapper = load_roi_mapper(config_path)
assert isinstance(mapper, AnchorBBoxMapper)
assert mapper.anchor == "center"
assert mapper.time_scale == 500.0
assert mapper.frequency_scale == pytest.approx(1 / 1000.0)
def test_load_roi_mapper_nested(tmp_path, nested_config_yaml_content):
"""Test loading a nested ROIConfig from YAML using 'field'."""
config_path = tmp_path / "nested_config.yaml"
config_path.write_text(nested_config_yaml_content)
mapper = load_roi_mapper(config_path, field="model_settings.roi_mapping")
assert isinstance(mapper, AnchorBBoxMapper)
assert mapper.anchor == "bottom-right"
assert mapper.time_scale == DEFAULT_TIME_SCALE
assert mapper.frequency_scale == 0.01
def test_load_roi_mapper_file_not_found(tmp_path):
"""Test load_roi_mapper raises error if file doesn't exist."""
non_existent_path = tmp_path / "not_real.yaml"
with pytest.raises(FileNotFoundError):
load_roi_mapper(non_existent_path)
def test_load_roi_mapper_invalid_field(tmp_path, sample_config_yaml_content):
"""Test load_roi_mapper raises error for invalid field."""
config_path = tmp_path / "config.yaml"
config_path.write_text(sample_config_yaml_content)
with pytest.raises(KeyError):
load_roi_mapper(config_path, field="invalid.path")