mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Fixed Target object after changes to roi
This commit is contained in:
parent
c559bcc682
commit
e352dc40bd
@ -72,7 +72,7 @@ def iterate_over_sound_events(
|
||||
sound_event_annotation
|
||||
)
|
||||
|
||||
class_name = targets.encode(sound_event_annotation)
|
||||
class_name = targets.encode_class(sound_event_annotation)
|
||||
if class_name is None and exclude_generic:
|
||||
continue
|
||||
|
||||
|
@ -40,7 +40,7 @@ def match_sound_events_and_raw_predictions(
|
||||
|
||||
gt_uuid = target.uuid if target is not None else None
|
||||
gt_det = target is not None
|
||||
gt_class = targets.encode(target) if target is not None else None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
|
||||
pred_score = float(prediction.detection_score) if prediction else 0
|
||||
|
||||
|
@ -526,7 +526,7 @@ class Postprocessor(PostprocessorProtocol):
|
||||
return [
|
||||
convert_xr_dataset_to_raw_prediction(
|
||||
dataset,
|
||||
self.targets.recover_roi,
|
||||
self.targets.decode_roi,
|
||||
)
|
||||
for dataset in detection_datasets
|
||||
]
|
||||
@ -558,7 +558,7 @@ class Postprocessor(PostprocessorProtocol):
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
sound_event_decoder=self.targets.decode,
|
||||
sound_event_decoder=self.targets.decode_class,
|
||||
generic_class_tags=self.targets.generic_class_tags,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
)
|
||||
|
@ -23,7 +23,6 @@ object is via the `build_targets` or `load_targets` functions.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -63,7 +62,7 @@ from batdetect2.targets.terms import (
|
||||
get_term_from_key,
|
||||
individual,
|
||||
register_term,
|
||||
term_registry,
|
||||
default_term_registry,
|
||||
)
|
||||
from batdetect2.targets.transform import (
|
||||
DerivationRegistry,
|
||||
@ -73,13 +72,13 @@ from batdetect2.targets.transform import (
|
||||
SoundEventTransformation,
|
||||
TransformConfig,
|
||||
build_transformation_from_config,
|
||||
derivation_registry,
|
||||
default_derivation_registry,
|
||||
get_derivation,
|
||||
load_transformation_config,
|
||||
load_transformation_from_config,
|
||||
register_derivation,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.targets.types import Position, Size, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClassesConfig",
|
||||
@ -291,7 +290,9 @@ class Targets(TargetProtocol):
|
||||
return True
|
||||
return self._filter_fn(sound_event)
|
||||
|
||||
def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]:
|
||||
def encode_class(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
"""Encode a sound event annotation to its target class name.
|
||||
|
||||
Applies the configured class definition rules (including priority)
|
||||
@ -312,7 +313,7 @@ class Targets(TargetProtocol):
|
||||
"""
|
||||
return self._encode_fn(sound_event)
|
||||
|
||||
def decode(self, class_label: str) -> List[data.Tag]:
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||
@ -352,9 +353,9 @@ class Targets(TargetProtocol):
|
||||
return self._transform_fn(sound_event)
|
||||
return sound_event
|
||||
|
||||
def get_position(
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[float, float]:
|
||||
) -> tuple[Position, Size]:
|
||||
"""Extract the target reference position from the annotation's roi.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||
@ -374,37 +375,9 @@ class Targets(TargetProtocol):
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
return self._roi_mapper.encode_position(sound_event.sound_event)
|
||||
return self._roi_mapper.encode(sound_event.sound_event)
|
||||
|
||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||
"""Calculate the target size dimensions from the annotation's geometry.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_size` method, which
|
||||
applies configured scaling factors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
NumPy array containing the size dimensions, matching the
|
||||
order in `self.dimension_names` (e.g., `[width, height]`).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
return self._roi_mapper.encode_size(sound_event.sound_event)
|
||||
|
||||
def recover_roi(
|
||||
self,
|
||||
pos: tuple[float, float],
|
||||
dims: np.ndarray,
|
||||
) -> data.Geometry:
|
||||
def decode_roi(self, position: Position, size: Size) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||
@ -424,7 +397,7 @@ class Targets(TargetProtocol):
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
return self._roi_mapper.decode(pos, dims)
|
||||
return self._roi_mapper.decode(position, size)
|
||||
|
||||
|
||||
DEFAULT_CLASSES = [
|
||||
@ -528,8 +501,8 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
|
||||
def build_targets(
|
||||
config: Optional[TargetConfig] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||
) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
@ -613,8 +586,8 @@ def build_targets(
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
|
||||
|
@ -11,7 +11,7 @@ from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
term_registry,
|
||||
default_term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -339,7 +339,7 @@ def _encode_with_multiple_classifiers(
|
||||
|
||||
def build_sound_event_encoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
"""Build a sound event encoder function from the classes configuration.
|
||||
|
||||
@ -433,7 +433,7 @@ def _decode_class(
|
||||
|
||||
def build_sound_event_decoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
raise_on_unmapped: bool = False,
|
||||
) -> SoundEventDecoder:
|
||||
"""Build a sound event decoder function from the classes configuration.
|
||||
@ -488,7 +488,7 @@ def build_sound_event_decoder(
|
||||
|
||||
def build_generic_class_tags(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[data.Tag]:
|
||||
"""Extract and build the list of tags for the generic class from config.
|
||||
|
||||
@ -553,7 +553,7 @@ def load_classes_config(
|
||||
def load_encoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
"""Load a class encoder function directly from a configuration file.
|
||||
|
||||
@ -594,7 +594,7 @@ def load_encoder_from_config(
|
||||
def load_decoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
raise_on_unmapped: bool = False,
|
||||
) -> SoundEventDecoder:
|
||||
"""Load a class decoder function directly from a configuration file.
|
||||
|
@ -10,7 +10,7 @@ from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
term_registry,
|
||||
default_term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -156,7 +156,7 @@ def equal_tags(
|
||||
|
||||
def build_filter_from_rule(
|
||||
rule: FilterRule,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Creates a callable filter function from a single FilterRule.
|
||||
|
||||
@ -243,7 +243,7 @@ class FilterConfig(BaseConfig):
|
||||
|
||||
def build_sound_event_filter(
|
||||
config: FilterConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Builds a merged filter function from a FilterConfig object.
|
||||
|
||||
@ -291,7 +291,7 @@ def load_filter_config(
|
||||
def load_filter_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Loads filter configuration from a file and builds the filter function.
|
||||
|
||||
|
@ -1,23 +1,23 @@
|
||||
"""Handles mapping between geometric ROIs and target representations.
|
||||
|
||||
This module defines the interface and provides implementation for converting
|
||||
a sound event's Region of Interest (ROI), typically represented by a
|
||||
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
|
||||
suitable for use as a machine learning target. This usually involves:
|
||||
This module defines a standardized interface (`ROITargetMapper`) for converting
|
||||
a sound event's Region of Interest (ROI) into a target representation suitable
|
||||
for machine learning models, and for decoding model outputs back into geometric
|
||||
ROIs.
|
||||
|
||||
1. Extracting a single reference point (time, frequency) from the geometry.
|
||||
2. Calculating relevant size dimensions (e.g., duration/width,
|
||||
bandwidth/height) and applying scaling factors.
|
||||
The core operations are:
|
||||
1. **Encoding**: A `soundevent.data.SoundEvent` is mapped to a reference
|
||||
`Position` (time, frequency) and a `Size` array. The method for
|
||||
determining the position and size varies by the mapper implementation
|
||||
(e.g., using a bounding box anchor or the point of peak energy).
|
||||
2. **Decoding**: A `Position` and `Size` array are mapped back to an
|
||||
approximate `soundevent.data.Geometry` (typically a `BoundingBox`).
|
||||
|
||||
It also provides the inverse operation: recovering an approximate geometric ROI
|
||||
(like a `BoundingBox`) from a predicted reference point and predicted size
|
||||
dimensions.
|
||||
|
||||
This logic is encapsulated within components adhering to the `ROITargetMapper`
|
||||
protocol. Configuration for this mapping (e.g., which reference point to use,
|
||||
scaling factors) is managed by the `ROIConfig`. This module separates the
|
||||
*geometric* aspect of target definition from the *semantic* classification
|
||||
handled in `batdetect2.targets.classes`.
|
||||
This logic is encapsulated within specific mapper classes. Configuration for
|
||||
each mapper (e.g., anchor point, scaling factors) is managed by a corresponding
|
||||
Pydantic config object. The `ROIMapperConfig` type allows for flexibly
|
||||
selecting and configuring the desired mapper. This module separates the
|
||||
*geometric* aspect of target definition from *semantic* classification.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
|
||||
@ -26,22 +26,26 @@ import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import Position, Size
|
||||
|
||||
__all__ = [
|
||||
"ROITargetMapper",
|
||||
"BBoxAnchorMapperConfig",
|
||||
"Anchor",
|
||||
"AnchorBBoxMapper",
|
||||
"build_roi_mapper",
|
||||
"load_roi_mapper",
|
||||
"BBoxAnchorMapperConfig",
|
||||
"DEFAULT_ANCHOR",
|
||||
"SIZE_WIDTH",
|
||||
"DEFAULT_FREQUENCY_SCALE",
|
||||
"DEFAULT_TIME_SCALE",
|
||||
"PeakEnergyBBoxMapper",
|
||||
"PeakEnergyBBoxMapperConfig",
|
||||
"ROIMapperConfig",
|
||||
"ROITargetMapper",
|
||||
"SIZE_HEIGHT",
|
||||
"SIZE_ORDER",
|
||||
"DEFAULT_TIME_SCALE",
|
||||
"DEFAULT_FREQUENCY_SCALE",
|
||||
"SIZE_WIDTH",
|
||||
"build_roi_mapper",
|
||||
]
|
||||
|
||||
Anchor = Literal[
|
||||
@ -73,104 +77,94 @@ DEFAULT_TIME_SCALE = 1000.0
|
||||
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
|
||||
"""Default scaling factor for frequency bandwidth."""
|
||||
|
||||
|
||||
DEFAULT_ANCHOR = "bottom-left"
|
||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||
|
||||
|
||||
Position = tuple[float, float]
|
||||
|
||||
Size = np.ndarray
|
||||
|
||||
|
||||
class ROITargetMapper(Protocol):
|
||||
"""Protocol defining the interface for ROI-to-target mapping.
|
||||
|
||||
Specifies the methods required for converting a geometric region of interest
|
||||
(`soundevent.data.Geometry`) into a target representation (reference point
|
||||
and scaled dimensions) and for recovering an approximate ROI from that
|
||||
Specifies the `encode` and `decode` methods required for converting a
|
||||
`soundevent.data.SoundEvent` into a target representation (a reference
|
||||
position and a size vector) and for recovering an approximate ROI from that
|
||||
representation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
A list containing the names of the dimensions returned by
|
||||
`get_roi_size` and expected by `recover_roi`
|
||||
(e.g., ['width', 'height']).
|
||||
A list containing the names of the dimensions in the `Size` array
|
||||
returned by `encode` and expected by `decode`.
|
||||
"""
|
||||
|
||||
dimension_names: List[str]
|
||||
|
||||
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||
"""Extract the reference position from a geometry.
|
||||
"""Encode a SoundEvent's geometry into a position and size.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
The input geometry (e.g., BoundingBox, Polygon).
|
||||
sound_event : data.SoundEvent
|
||||
The input sound event, which must have a geometry attribute.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The calculated reference position as (time, frequency) coordinates,
|
||||
based on the implementing class's configuration (e.g., "center",
|
||||
"bottom-left").
|
||||
Tuple[Position, Size]
|
||||
A tuple containing:
|
||||
- The reference position as (time, frequency) coordinates.
|
||||
- A NumPy array with the calculated size dimensions.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the position cannot be calculated for the given geometry type
|
||||
or configured reference point.
|
||||
If the sound event does not have a geometry.
|
||||
"""
|
||||
...
|
||||
|
||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||
"""Recover an approximate ROI from a position and target dimensions.
|
||||
"""Decode a position and size back into a geometric ROI.
|
||||
|
||||
Performs the inverse mapping: takes a reference position and the
|
||||
predicted dimensions and reconstructs a geometric representation.
|
||||
Performs the inverse mapping: takes a reference position and size
|
||||
dimensions and reconstructs a geometric representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Tuple[float, float]
|
||||
position : Position
|
||||
The reference position (time, frequency).
|
||||
size : np.ndarray
|
||||
NumPy array containing the dimensions, matching the order
|
||||
specified by `dimension_names`.
|
||||
size : Size
|
||||
NumPy array containing the size dimensions, matching the order
|
||||
and meaning specified by `dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Geometry
|
||||
The reconstructed geometry.
|
||||
The reconstructed geometry, typically a `BoundingBox`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of provided dimensions `dims` does not match
|
||||
`dimension_names` or if reconstruction fails.
|
||||
If the `size` array has an unexpected shape or if reconstruction
|
||||
fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class BBoxAnchorMapperConfig(BaseConfig):
|
||||
"""Configuration for mapping Regions of Interest (ROIs).
|
||||
"""Configuration for `AnchorBBoxMapper`.
|
||||
|
||||
Defines parameters controlling how geometric ROIs are converted into
|
||||
target representations (reference points and scaled sizes).
|
||||
Defines parameters for converting ROIs into targets using a fixed anchor
|
||||
point on the bounding box.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
anchor : Anchor, default="bottom-left"
|
||||
Specifies the reference point within the geometry (e.g., bounding box)
|
||||
to use as the target location (e.g., "center", "bottom-left").
|
||||
time_scale : float, default=1000.0
|
||||
Scaling factor applied to the time duration (width) of the ROI
|
||||
when calculating the target size representation. Must match model
|
||||
expectations.
|
||||
frequency_scale : float, default=1/859.375
|
||||
Scaling factor applied to the frequency bandwidth (height) of the ROI
|
||||
when calculating the target size representation. Must match model
|
||||
expectations.
|
||||
name : Literal["anchor_bbox"]
|
||||
The unique identifier for this mapper type.
|
||||
anchor : Anchor
|
||||
Specifies the anchor point within the bounding box to use as the
|
||||
target's reference position (e.g., "center", "bottom-left").
|
||||
time_scale : float
|
||||
Scaling factor applied to the time duration (width) of the ROI.
|
||||
frequency_scale : float
|
||||
Scaling factor applied to the frequency bandwidth (height) of the ROI.
|
||||
"""
|
||||
|
||||
name: Literal["anchor_bbox"] = "anchor_bbox"
|
||||
@ -180,23 +174,28 @@ class BBoxAnchorMapperConfig(BaseConfig):
|
||||
|
||||
|
||||
class AnchorBBoxMapper(ROITargetMapper):
|
||||
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
|
||||
"""Maps ROIs using a bounding box anchor point and width/height.
|
||||
|
||||
This class implements the ROI mapping protocol primarily for
|
||||
`soundevent.data.BoundingBox` geometry. It extracts reference points,
|
||||
calculates scaled width/height, and recovers bounding boxes based on
|
||||
configured position and scaling factors.
|
||||
This class implements the `ROITargetMapper` protocol for `BoundingBox`
|
||||
geometries.
|
||||
|
||||
**Encoding**: The `position` is a fixed anchor point on the bounding box
|
||||
(e.g., "bottom-left"). The `size` is a 2-element array containing the
|
||||
scaled width and height of the box.
|
||||
|
||||
**Decoding**: Reconstructs a `BoundingBox` from an anchor point and
|
||||
scaled width/height.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
Specifies the output dimension names as ['width', 'height'].
|
||||
The output dimension names: `['width', 'height']`.
|
||||
anchor : Anchor
|
||||
The configured reference point type (e.g., "center", "bottom-left").
|
||||
The configured anchor point type (e.g., "center", "bottom-left").
|
||||
time_scale : float
|
||||
The configured scaling factor for the time dimension (width).
|
||||
The scaling factor for the time dimension (width).
|
||||
frequency_scale : float
|
||||
The configured scaling factor for the frequency dimension (height).
|
||||
The scaling factor for the frequency dimension (height).
|
||||
"""
|
||||
|
||||
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
|
||||
@ -211,11 +210,11 @@ class AnchorBBoxMapper(ROITargetMapper):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
anchor : Anchor, default="bottom-left"
|
||||
anchor : Anchor
|
||||
Reference point type within the bounding box.
|
||||
time_scale : float, default=1000.0
|
||||
time_scale : float
|
||||
Scaling factor for time duration (width).
|
||||
frequency_scale : float, default=1/859.375
|
||||
frequency_scale : float
|
||||
Scaling factor for frequency bandwidth (height).
|
||||
"""
|
||||
self.anchor: Anchor = anchor
|
||||
@ -223,19 +222,20 @@ class AnchorBBoxMapper(ROITargetMapper):
|
||||
self.frequency_scale = frequency_scale
|
||||
|
||||
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
|
||||
"""Extract the configured reference position from the geometry.
|
||||
"""Encode a SoundEvent into an anchor position and scaled box size.
|
||||
|
||||
Uses `soundevent.geometry.get_geometry_point`.
|
||||
The position is determined by the configured anchor on the sound
|
||||
event's bounding box. The size is the scaled width and height.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
Input geometry (e.g., BoundingBox).
|
||||
sound_event : data.SoundEvent
|
||||
The input sound event with a geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
Reference position (time, frequency).
|
||||
Tuple[Position, Size]
|
||||
A tuple of (anchor_position, [scaled_width, scaled_height]).
|
||||
"""
|
||||
from soundevent import geometry
|
||||
|
||||
@ -267,29 +267,27 @@ class AnchorBBoxMapper(ROITargetMapper):
|
||||
position: Position,
|
||||
size: Size,
|
||||
) -> data.Geometry:
|
||||
"""Recover a BoundingBox from a position and scaled dimensions.
|
||||
"""Recover a BoundingBox from an anchor position and scaled size.
|
||||
|
||||
Un-scales the input dimensions using the configured factors and
|
||||
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
|
||||
the given reference `pos` according to the configured `position` type.
|
||||
Un-scales the input dimensions and reconstructs a
|
||||
`soundevent.data.BoundingBox` relative to the given anchor position.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
Reference position (time, frequency).
|
||||
dims : np.ndarray
|
||||
NumPy array containing the *scaled* dimensions, expected order is
|
||||
[scaled_width, scaled_height].
|
||||
position : Position
|
||||
Reference anchor position (time, frequency).
|
||||
size : Size
|
||||
NumPy array containing the scaled [width, height].
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.BoundingBox
|
||||
data.BoundingBox
|
||||
The reconstructed bounding box.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `dims` does not have the expected shape (length 2).
|
||||
If `size` does not have the expected shape (length 2).
|
||||
"""
|
||||
|
||||
if size.ndim != 1 or size.shape[0] != 2:
|
||||
@ -308,6 +306,24 @@ class AnchorBBoxMapper(ROITargetMapper):
|
||||
|
||||
|
||||
class PeakEnergyBBoxMapperConfig(BaseConfig):
|
||||
"""Configuration for `PeakEnergyBBoxMapper`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : Literal["peak_energy_bbox"]
|
||||
The unique identifier for this mapper type.
|
||||
preprocessing : PreprocessingConfig
|
||||
Configuration for the spectrogram preprocessor needed to find the
|
||||
peak energy.
|
||||
loading_buffer : float
|
||||
Seconds to add to each side of the ROI when loading audio to ensure
|
||||
the peak is captured accurately, avoiding boundary effects.
|
||||
time_scale : float
|
||||
Scaling factor applied to the time dimensions.
|
||||
frequency_scale : float
|
||||
Scaling factor applied to the frequency dimensions.
|
||||
"""
|
||||
|
||||
name: Literal["peak_energy_bbox"]
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
@ -318,9 +334,30 @@ class PeakEnergyBBoxMapperConfig(BaseConfig):
|
||||
|
||||
|
||||
class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
"""
|
||||
Encodes the ROI using the location of the peak energy within the bounding box
|
||||
as the 'position' and the distances from that point to the box edges as the 'size'.
|
||||
"""Maps ROIs using the peak energy point and distances to edges.
|
||||
|
||||
This class implements the `ROITargetMapper` protocol.
|
||||
|
||||
**Encoding**: The `position` is the (time, frequency) coordinate of the
|
||||
point with the highest energy within the sound event's bounding box. The
|
||||
`size` is a 4-element array representing the scaled distances from this
|
||||
peak energy point to the left, bottom, right, and top edges of the box.
|
||||
|
||||
**Decoding**: Reconstructs a `BoundingBox` by adding/subtracting the
|
||||
un-scaled distances from the peak energy point.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
The output dimension names: `['left', 'bottom', 'right', 'top']`.
|
||||
preprocessor : PreprocessorProtocol
|
||||
The spectrogram preprocessor instance.
|
||||
time_scale : float
|
||||
The scaling factor for time-based distances.
|
||||
frequency_scale : float
|
||||
The scaling factor for frequency-based distances.
|
||||
loading_buffer : float
|
||||
The buffer used for loading audio around the ROI.
|
||||
"""
|
||||
|
||||
dimension_names = ["left", "bottom", "right", "top"]
|
||||
@ -332,6 +369,19 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||
loading_buffer: float = 0.01,
|
||||
):
|
||||
"""Initialize the PeakEnergyBBoxMapper.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
preprocessor : PreprocessorProtocol
|
||||
An initialized preprocessor for generating spectrograms.
|
||||
time_scale : float
|
||||
Scaling factor for time dimensions (left, right distances).
|
||||
frequency_scale : float
|
||||
Scaling factor for frequency dimensions (bottom, top distances).
|
||||
loading_buffer : float
|
||||
Buffer in seconds to add when loading audio clips.
|
||||
"""
|
||||
self.preprocessor = preprocessor
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
@ -341,6 +391,21 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
self,
|
||||
sound_event: data.SoundEvent,
|
||||
) -> tuple[Position, Size]:
|
||||
"""Encode a SoundEvent into a peak energy position and edge distances.
|
||||
|
||||
Finds the peak energy coordinates within the event's bounding box
|
||||
and calculates the scaled distances from this point to the box edges.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEvent
|
||||
The input sound event with a geometry and associated recording.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[Position, Size]
|
||||
A tuple of (peak_position, [l, b, r, t] distances).
|
||||
"""
|
||||
from soundevent import geometry
|
||||
|
||||
geom = sound_event.geometry
|
||||
@ -377,6 +442,20 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
return (time, freq), size
|
||||
|
||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||
"""Recover a BoundingBox from a peak position and edge distances.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Position
|
||||
The reference peak energy position (time, frequency).
|
||||
size : Size
|
||||
NumPy array with scaled distances [left, bottom, right, top].
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.BoundingBox
|
||||
The reconstructed bounding box.
|
||||
"""
|
||||
time, freq = position
|
||||
left, bottom, right, top = size
|
||||
|
||||
@ -394,21 +473,30 @@ ROIMapperConfig = Annotated[
|
||||
Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""A discriminated union of all supported ROI mapper configurations.
|
||||
|
||||
This type allows for selecting and configuring different `ROITargetMapper`
|
||||
implementations by using the `name` field as a discriminator.
|
||||
"""
|
||||
|
||||
|
||||
def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
|
||||
"""Factory function to create an ROITargetMapper from configuration.
|
||||
"""Factory function to create an ROITargetMapper from a config object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ROIConfig
|
||||
Configuration object specifying ROI mapping parameters.
|
||||
config : ROIMapperConfig
|
||||
A configuration object specifying the mapper type and its parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ROITargetMapper
|
||||
An initialized `BBoxEncoder` instance configured with the settings
|
||||
from `config`.
|
||||
An initialized ROI mapper instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the `name` in the config does not correspond to a known mapper.
|
||||
"""
|
||||
if config.name == "anchor_bbox":
|
||||
return AnchorBBoxMapper(
|
||||
@ -431,39 +519,6 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
|
||||
)
|
||||
|
||||
|
||||
def load_roi_mapper(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> ROITargetMapper:
|
||||
"""Load ROI mapping configuration from a file and build the mapper.
|
||||
|
||||
Convenience function that loads an `ROIConfig` from the specified file
|
||||
(and optional field) and then uses `build_roi_mapper` to create the
|
||||
corresponding `ROITargetMapper` instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
ROI configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ROITargetMapper
|
||||
An initialized ROI mapper instance based on the configuration file.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||
TypeError
|
||||
If the configuration file cannot be found, parsed, validated, or if
|
||||
the specified `field` is invalid.
|
||||
"""
|
||||
config = load_config(path=path, schema=BBoxAnchorMapperConfig, field=field)
|
||||
return build_roi_mapper(config)
|
||||
|
||||
|
||||
VALID_ANCHORS = [
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
@ -501,7 +556,7 @@ def _build_bounding_box(
|
||||
bandwidth : float
|
||||
The required *unscaled* frequency bandwidth (height) of the bounding
|
||||
box.
|
||||
anchor : Anchor, default="bottom-left"
|
||||
anchor : Anchor
|
||||
Specifies which part of the bounding box the input `pos` corresponds to.
|
||||
|
||||
Returns
|
||||
@ -565,6 +620,35 @@ def get_peak_energy_coordinates(
|
||||
high_freq: Optional[float] = None,
|
||||
loading_buffer: float = 0.05,
|
||||
) -> Position:
|
||||
"""Find the coordinates of the highest energy point in a spectrogram.
|
||||
|
||||
Generates a spectrogram for a specified time-frequency region of a
|
||||
recording and returns the (time, frequency) coordinates of the pixel with
|
||||
the maximum value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The recording to analyze.
|
||||
preprocessor : PreprocessorProtocol
|
||||
The processor to convert audio to a spectrogram.
|
||||
start_time : float, default=0
|
||||
The start time of the region of interest.
|
||||
end_time : float, optional
|
||||
The end time of the region of interest. Defaults to recording duration.
|
||||
low_freq : float, default=0
|
||||
The low frequency of the region of interest.
|
||||
high_freq : float, optional
|
||||
The high frequency of the region of interest. Defaults to Nyquist.
|
||||
loading_buffer : float, default=0.05
|
||||
Buffer in seconds to add around the time range when loading the clip
|
||||
to mitigate border effects from transformations like STFT.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Position
|
||||
A (time, frequency) tuple for the peak energy location.
|
||||
"""
|
||||
if end_time is None:
|
||||
end_time = recording.duration
|
||||
end_time = min(end_time, recording.duration)
|
||||
|
@ -230,7 +230,7 @@ class TermRegistry(Mapping[str, data.Term]):
|
||||
del self._terms[key]
|
||||
|
||||
|
||||
term_registry = TermRegistry(
|
||||
default_term_registry = TermRegistry(
|
||||
terms=dict(
|
||||
[
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
@ -252,7 +252,7 @@ is explicitly passed.
|
||||
|
||||
def get_term_from_key(
|
||||
key: str,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> data.Term:
|
||||
"""Convenience function to retrieve a term by key from a registry.
|
||||
|
||||
@ -277,10 +277,13 @@ def get_term_from_key(
|
||||
KeyError
|
||||
If the key is not found in the specified registry.
|
||||
"""
|
||||
term_registry = term_registry or default_term_registry
|
||||
return term_registry.get_term(key)
|
||||
|
||||
|
||||
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
|
||||
def get_term_keys(
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[str]:
|
||||
"""Convenience function to get all registered keys from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
@ -299,7 +302,9 @@ def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
|
||||
return term_registry.get_keys()
|
||||
|
||||
|
||||
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
|
||||
def get_terms(
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[data.Term]:
|
||||
"""Convenience function to get all registered terms from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
@ -342,7 +347,7 @@ class TagInfo(BaseModel):
|
||||
|
||||
def get_tag_from_info(
|
||||
tag_info: TagInfo,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> data.Tag:
|
||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||
|
||||
@ -368,6 +373,7 @@ def get_tag_from_info(
|
||||
If the term key specified in `tag_info.key` is not found
|
||||
in the registry.
|
||||
"""
|
||||
term_registry = term_registry or default_term_registry
|
||||
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
||||
return data.Tag(term=term, value=tag_info.value)
|
||||
|
||||
@ -439,7 +445,7 @@ class TermConfig(BaseModel):
|
||||
def load_terms_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> Dict[str, data.Term]:
|
||||
"""Loads term definitions from a configuration file and registers them.
|
||||
|
||||
@ -490,6 +496,6 @@ def load_terms_from_config(
|
||||
|
||||
|
||||
def register_term(
|
||||
key: str, term: data.Term, registry: TermRegistry = term_registry
|
||||
key: str, term: data.Term, registry: TermRegistry = default_term_registry
|
||||
) -> None:
|
||||
registry.add_term(key, term)
|
||||
|
@ -21,9 +21,6 @@ from batdetect2.targets.terms import (
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
term_registry as default_term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DerivationRegistry",
|
||||
@ -34,7 +31,7 @@ __all__ = [
|
||||
"TransformConfig",
|
||||
"build_transform_from_rule",
|
||||
"build_transformation_from_config",
|
||||
"derivation_registry",
|
||||
"default_derivation_registry",
|
||||
"get_derivation",
|
||||
"load_transformation_config",
|
||||
"load_transformation_from_config",
|
||||
@ -398,7 +395,7 @@ class DerivationRegistry(Mapping[str, Derivation]):
|
||||
return list(self._derivations.values())
|
||||
|
||||
|
||||
derivation_registry = DerivationRegistry()
|
||||
default_derivation_registry = DerivationRegistry()
|
||||
"""Global instance of the DerivationRegistry.
|
||||
|
||||
Register custom derivation functions here to make them available by key
|
||||
@ -409,7 +406,7 @@ in `DeriveTagRule` configuration.
|
||||
def get_derivation(
|
||||
key: str,
|
||||
import_derivation: bool = False,
|
||||
registry: DerivationRegistry = derivation_registry,
|
||||
registry: Optional[DerivationRegistry] = None,
|
||||
):
|
||||
"""Retrieve a derivation function by key, optionally importing it.
|
||||
|
||||
@ -443,6 +440,8 @@ def get_derivation(
|
||||
AttributeError
|
||||
If dynamic import fails because the function name isn't in the module.
|
||||
"""
|
||||
registry = registry or default_derivation_registry
|
||||
|
||||
if not import_derivation or key in registry:
|
||||
return registry.get_derivation(key)
|
||||
|
||||
@ -458,10 +457,16 @@ def get_derivation(
|
||||
) from err
|
||||
|
||||
|
||||
TranformationRule = Annotated[
|
||||
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||
Field(discriminator="rule_type"),
|
||||
]
|
||||
|
||||
|
||||
def build_transform_from_rule(
|
||||
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
rule: TranformationRule,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a specific SoundEventTransformation function from a rule config.
|
||||
|
||||
@ -559,8 +564,8 @@ def build_transform_from_rule(
|
||||
|
||||
def build_transformation_from_config(
|
||||
config: TransformConfig,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a composite transformation function from a TransformConfig.
|
||||
|
||||
@ -581,6 +586,7 @@ def build_transformation_from_config(
|
||||
SoundEventTransformation
|
||||
A single function that applies all configured transformations in order.
|
||||
"""
|
||||
|
||||
transforms = [
|
||||
build_transform_from_rule(
|
||||
rule,
|
||||
@ -590,14 +596,16 @@ def build_transformation_from_config(
|
||||
for rule in config.rules
|
||||
]
|
||||
|
||||
def transformation(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
for transform in transforms:
|
||||
sound_event_annotation = transform(sound_event_annotation)
|
||||
return sound_event_annotation
|
||||
return partial(apply_sequence_of_transforms, transforms=transforms)
|
||||
|
||||
return transformation
|
||||
|
||||
def apply_sequence_of_transforms(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
transforms: list[SoundEventTransformation],
|
||||
) -> data.SoundEventAnnotation:
|
||||
for transform in transforms:
|
||||
sound_event_annotation = transform(sound_event_annotation)
|
||||
return sound_event_annotation
|
||||
|
||||
|
||||
def load_transformation_config(
|
||||
@ -631,8 +639,8 @@ def load_transformation_config(
|
||||
def load_transformation_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Load transformation config from a file and build the final function.
|
||||
|
||||
@ -677,7 +685,7 @@ def load_transformation_from_config(
|
||||
def register_derivation(
|
||||
key: str,
|
||||
derivation: Derivation,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
) -> None:
|
||||
"""Register a new derivation function in the global registry.
|
||||
|
||||
@ -696,4 +704,5 @@ def register_derivation(
|
||||
KeyError
|
||||
If a derivation function with the same key is already registered.
|
||||
"""
|
||||
derivation_registry = derivation_registry or default_derivation_registry
|
||||
derivation_registry.register(key, derivation)
|
||||
|
@ -19,8 +19,16 @@ from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"TargetProtocol",
|
||||
"Position",
|
||||
"Size",
|
||||
]
|
||||
|
||||
Position = tuple[float, float]
|
||||
"""A tuple representing (time, frequency) coordinates."""
|
||||
|
||||
Size = np.ndarray
|
||||
"""A NumPy array representing the size dimensions of a target."""
|
||||
|
||||
|
||||
class TargetProtocol(Protocol):
|
||||
"""Protocol defining the interface for the target definition pipeline.
|
||||
@ -102,7 +110,7 @@ class TargetProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def encode(
|
||||
def encode_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
) -> Optional[str]:
|
||||
@ -123,7 +131,7 @@ class TargetProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def decode(self, class_label: str) -> List[data.Tag]:
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Parameters
|
||||
@ -147,9 +155,9 @@ class TargetProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_position(
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[float, float]:
|
||||
) -> tuple[Position, Size]:
|
||||
"""Extract the target reference position from the annotation's geometry.
|
||||
|
||||
Calculates the `(time, frequency)` coordinate representing the primary
|
||||
@ -173,37 +181,7 @@ class TargetProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||
"""Calculate the target size dimensions from the annotation's geometry.
|
||||
|
||||
Computes the relevant physical size (e.g., duration/width,
|
||||
bandwidth/height from a bounding box) to produce
|
||||
the numerical target values expected by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI) to process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
A NumPy array containing the size dimensions, matching the
|
||||
order specified by the `dimension_names` attribute (e.g.,
|
||||
`[width, height]`).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry or if the size cannot be computed.
|
||||
TypeError
|
||||
If geometry type is unsupported.
|
||||
"""
|
||||
...
|
||||
|
||||
def recover_roi(
|
||||
self, pos: tuple[float, float], dims: np.ndarray
|
||||
) -> data.Geometry:
|
||||
def decode_roi(self, position: Position, size: Size) -> data.Geometry:
|
||||
"""Recover the ROI geometry from a position and dimensions.
|
||||
|
||||
Performs the inverse mapping of `get_position` and `get_size`. It takes
|
||||
|
@ -97,7 +97,7 @@ def _is_in_subclip(
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> bool:
|
||||
time, _ = targets.get_position(sound_event_annotation)
|
||||
time, _ = targets.encode_roi(sound_event_annotation)
|
||||
return start_time <= time <= end_time
|
||||
|
||||
|
||||
|
@ -138,7 +138,7 @@ def generate_clip_label(
|
||||
logger.debug(
|
||||
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
|
||||
uuid=clip_annotation.uuid,
|
||||
num=len(clip_annotation.sound_events)
|
||||
num=len(clip_annotation.sound_events),
|
||||
)
|
||||
|
||||
sound_events = []
|
||||
@ -260,7 +260,7 @@ def generate_heatmaps(
|
||||
continue
|
||||
|
||||
# Get the position of the sound event
|
||||
time, frequency = targets.get_position(sound_event_annotation)
|
||||
(time, frequency), size = targets.encode_roi(sound_event_annotation)
|
||||
|
||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||
try:
|
||||
@ -280,8 +280,6 @@ def generate_heatmaps(
|
||||
)
|
||||
continue
|
||||
|
||||
size = targets.get_size(sound_event_annotation)
|
||||
|
||||
size_heatmap = arrays.set_value_at_pos(
|
||||
size_heatmap,
|
||||
size,
|
||||
@ -291,7 +289,7 @@ def generate_heatmaps(
|
||||
|
||||
# Get the class name of the sound event
|
||||
try:
|
||||
class_name = targets.encode(sound_event_annotation)
|
||||
class_name = targets.encode_class(sound_event_annotation)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Skipping annotation %s: Unexpected error while encoding "
|
||||
|
19
notebooks/signal_generation.py
Normal file
19
notebooks/signal_generation.py
Normal file
@ -0,0 +1,19 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.13.15"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
@ -62,6 +62,9 @@ keywords = [
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["batdetect2/"]
|
||||
|
||||
[project.scripts]
|
||||
batdetect2 = "batdetect2.cli:cli"
|
||||
|
||||
|
@ -3,8 +3,8 @@ import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets.rois import (
|
||||
DEFAULT_FREQUENCY_SCALE,
|
||||
DEFAULT_ANCHOR,
|
||||
DEFAULT_FREQUENCY_SCALE,
|
||||
DEFAULT_TIME_SCALE,
|
||||
SIZE_HEIGHT,
|
||||
SIZE_WIDTH,
|
||||
@ -12,7 +12,6 @@ from batdetect2.targets.rois import (
|
||||
BBoxAnchorMapperConfig,
|
||||
_build_bounding_box,
|
||||
build_roi_mapper,
|
||||
load_roi_mapper,
|
||||
)
|
||||
|
||||
|
||||
@ -22,6 +21,16 @@ def sample_bbox() -> data.BoundingBox:
|
||||
return data.BoundingBox(coordinates=[10.0, 100.0, 20.0, 200.0])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_recording(create_recording) -> data.Recording:
|
||||
return create_recording(duration=30, samplerate=4_000)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sound_event(sample_bbox, sample_recording) -> data.SoundEvent:
|
||||
return data.SoundEvent(geometry=sample_bbox, recording=sample_recording)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zero_bbox() -> data.BoundingBox:
|
||||
"""A bounding box with zero duration and bandwidth."""
|
||||
@ -29,7 +38,13 @@ def zero_bbox() -> data.BoundingBox:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_encoder() -> AnchorBBoxMapper:
|
||||
def zero_sound_event(zero_bbox, sample_recording) -> data.SoundEvent:
|
||||
"""A sample sound event with a zero-sized bounding box."""
|
||||
return data.SoundEvent(geometry=zero_bbox, recording=sample_recording)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_mapper() -> AnchorBBoxMapper:
|
||||
"""A BBoxEncoder with default settings."""
|
||||
return AnchorBBoxMapper()
|
||||
|
||||
@ -37,36 +52,30 @@ def default_encoder() -> AnchorBBoxMapper:
|
||||
@pytest.fixture
|
||||
def custom_encoder() -> AnchorBBoxMapper:
|
||||
"""A BBoxEncoder with custom settings."""
|
||||
return AnchorBBoxMapper(anchor="center", time_scale=1.0, frequency_scale=10.0)
|
||||
return AnchorBBoxMapper(
|
||||
anchor="center", time_scale=1.0, frequency_scale=10.0
|
||||
)
|
||||
|
||||
|
||||
def test_roi_config_defaults():
|
||||
"""Test ROIConfig default values."""
|
||||
config = BBoxAnchorMapperConfig()
|
||||
assert config.anchor == DEFAULT_ANCHOR
|
||||
assert config.time_scale == DEFAULT_TIME_SCALE
|
||||
assert config.frequency_scale == DEFAULT_FREQUENCY_SCALE
|
||||
@pytest.fixture
|
||||
def custom_mapper() -> AnchorBBoxMapper:
|
||||
"""An AnchorBBoxMapper with custom settings."""
|
||||
return AnchorBBoxMapper(
|
||||
anchor="center", time_scale=1.0, frequency_scale=10.0
|
||||
)
|
||||
|
||||
|
||||
def test_roi_config_custom():
|
||||
"""Test creating ROIConfig with custom values."""
|
||||
config = BBoxAnchorMapperConfig(anchor="center", time_scale=1.0, frequency_scale=10.0)
|
||||
assert config.anchor == "center"
|
||||
assert config.time_scale == 1.0
|
||||
assert config.frequency_scale == 10.0
|
||||
|
||||
|
||||
def test_bbox_encoder_init_defaults(default_encoder):
|
||||
def test_bbox_encoder_init_defaults(default_mapper):
|
||||
"""Test BBoxEncoder initialization with default arguments."""
|
||||
assert default_encoder.position == DEFAULT_ANCHOR
|
||||
assert default_encoder.time_scale == DEFAULT_TIME_SCALE
|
||||
assert default_encoder.frequency_scale == DEFAULT_FREQUENCY_SCALE
|
||||
assert default_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
|
||||
assert default_mapper.anchor == DEFAULT_ANCHOR
|
||||
assert default_mapper.time_scale == DEFAULT_TIME_SCALE
|
||||
assert default_mapper.frequency_scale == DEFAULT_FREQUENCY_SCALE
|
||||
assert default_mapper.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
|
||||
|
||||
|
||||
def test_bbox_encoder_init_custom(custom_encoder):
|
||||
"""Test BBoxEncoder initialization with custom arguments."""
|
||||
assert custom_encoder.position == "center"
|
||||
assert custom_encoder.anchor == "center"
|
||||
assert custom_encoder.time_scale == 1.0
|
||||
assert custom_encoder.frequency_scale == 10.0
|
||||
assert custom_encoder.dimension_names == [SIZE_WIDTH, SIZE_HEIGHT]
|
||||
@ -87,52 +96,50 @@ POSITION_TEST_CASES = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("position_type, expected_pos", POSITION_TEST_CASES)
|
||||
def test_bbox_encoder_get_roi_position(
|
||||
sample_bbox, position_type, expected_pos
|
||||
@pytest.mark.parametrize("anchor, expected_pos", POSITION_TEST_CASES)
|
||||
def test_anchor_bbox_mapper_encode_position(
|
||||
sample_sound_event, anchor, expected_pos
|
||||
):
|
||||
"""Test get_roi_position for various position types."""
|
||||
encoder = AnchorBBoxMapper(anchor=position_type)
|
||||
actual_pos = encoder.encode_position(sample_bbox)
|
||||
"""Test encode returns the correct position for various anchors."""
|
||||
encoder = AnchorBBoxMapper(anchor=anchor)
|
||||
actual_pos, _ = encoder.encode(sample_sound_event)
|
||||
assert actual_pos == pytest.approx(expected_pos)
|
||||
|
||||
|
||||
def test_bbox_encoder_get_roi_position_zero_box(zero_bbox):
|
||||
"""Test get_roi_position for a zero-sized box."""
|
||||
encoder = AnchorBBoxMapper(anchor="center")
|
||||
assert encoder.encode_position(zero_bbox) == pytest.approx((15.0, 150.0))
|
||||
|
||||
|
||||
def test_bbox_encoder_get_roi_size_defaults(sample_bbox, default_encoder):
|
||||
"""Test get_roi_size with default scaling."""
|
||||
def test_anchor_bbox_mapper_encode_defaults(
|
||||
sample_sound_event, default_mapper
|
||||
):
|
||||
"""Test encode with default settings returns correct position and size."""
|
||||
expected_pos = (10.0, 100.0) # bottom-left
|
||||
expected_size = np.array(
|
||||
[
|
||||
10.0 * DEFAULT_TIME_SCALE,
|
||||
100.0 * DEFAULT_FREQUENCY_SCALE,
|
||||
]
|
||||
)
|
||||
actual_size = default_encoder.get_roi_size(sample_bbox)
|
||||
actual_pos, actual_size = default_mapper.encode(sample_sound_event)
|
||||
assert actual_pos == pytest.approx(expected_pos)
|
||||
np.testing.assert_allclose(actual_size, expected_size)
|
||||
assert actual_size.shape == (2,)
|
||||
|
||||
|
||||
def test_bbox_encoder_get_roi_size_custom(sample_bbox, custom_encoder):
|
||||
"""Test get_roi_size with custom scaling."""
|
||||
expected_size = np.array(
|
||||
[
|
||||
10.0 * 1.0,
|
||||
100.0 * 10.0,
|
||||
]
|
||||
)
|
||||
actual_size = custom_encoder.get_roi_size(sample_bbox)
|
||||
def test_anchor_bbox_mapper_encode_custom(sample_sound_event, custom_mapper):
|
||||
"""Test encode with custom settings returns correct position and size."""
|
||||
expected_pos = (15.0, 150.0) # center
|
||||
expected_size = np.array([10.0 * 1.0, 100.0 * 10.0])
|
||||
|
||||
actual_pos, actual_size = custom_mapper.encode(sample_sound_event)
|
||||
assert actual_pos == pytest.approx(expected_pos)
|
||||
np.testing.assert_allclose(actual_size, expected_size)
|
||||
assert actual_size.shape == (2,)
|
||||
|
||||
|
||||
def test_bbox_encoder_get_roi_size_zero_box(zero_bbox, default_encoder):
|
||||
"""Test get_roi_size for a zero-sized box."""
|
||||
def test_anchor_bbox_mapper_encode_zero_box(zero_sound_event, default_mapper):
|
||||
"""Test encode for a zero-sized box."""
|
||||
expected_pos = (15.0, 150.0)
|
||||
expected_size = np.array([0.0, 0.0])
|
||||
actual_size = default_encoder.get_roi_size(zero_bbox)
|
||||
actual_pos, actual_size = default_mapper.encode(zero_sound_event)
|
||||
assert actual_pos == pytest.approx(expected_pos)
|
||||
np.testing.assert_allclose(actual_size, expected_size)
|
||||
|
||||
|
||||
@ -166,9 +173,9 @@ def test_build_bounding_box(position_type, expected_coords):
|
||||
np.testing.assert_allclose(bbox.coordinates, expected_coords)
|
||||
|
||||
|
||||
def test_build_bounding_box_invalid_position():
|
||||
def test_build_bounding_box_invalid_anchor():
|
||||
"""Test _build_bounding_box raises error for invalid position."""
|
||||
with pytest.raises(ValueError, match="Invalid position"):
|
||||
with pytest.raises(ValueError, match="Invalid anchor"):
|
||||
_build_bounding_box(
|
||||
(0, 0),
|
||||
1,
|
||||
@ -177,13 +184,16 @@ def test_build_bounding_box_invalid_position():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("position_type, ref_pos", POSITION_TEST_CASES)
|
||||
def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos):
|
||||
"""Test recover_roi correctly reconstructs the original bbox."""
|
||||
encoder = AnchorBBoxMapper(anchor=position_type)
|
||||
scaled_dims = encoder.encode_size(sample_bbox)
|
||||
|
||||
recovered_bbox = encoder.decode(ref_pos, scaled_dims)
|
||||
@pytest.mark.parametrize(
|
||||
"anchor", [anchor for anchor, _ in POSITION_TEST_CASES]
|
||||
)
|
||||
def test_anchor_bbox_mapper_encode_decode_roundtrip(
|
||||
sample_sound_event, sample_bbox, anchor
|
||||
):
|
||||
"""Test encode-decode roundtrip reconstructs the original bbox."""
|
||||
mapper = AnchorBBoxMapper(anchor=anchor)
|
||||
position, size = mapper.encode(sample_sound_event)
|
||||
recovered_bbox = mapper.decode(position, size)
|
||||
|
||||
assert isinstance(recovered_bbox, data.BoundingBox)
|
||||
np.testing.assert_allclose(
|
||||
@ -191,12 +201,12 @@ def test_bbox_encoder_recover_roi(sample_bbox, position_type, ref_pos):
|
||||
)
|
||||
|
||||
|
||||
def test_bbox_encoder_recover_roi_custom_scale(sample_bbox, custom_encoder):
|
||||
"""Test recover_roi with custom scaling factors."""
|
||||
ref_pos = custom_encoder.get_roi_position(sample_bbox)
|
||||
scaled_dims = custom_encoder.get_roi_size(sample_bbox)
|
||||
|
||||
recovered_bbox = custom_encoder.recover_roi(ref_pos, scaled_dims)
|
||||
def test_anchor_bbox_mapper_roundtrip_custom_scale(
|
||||
sample_sound_event, sample_bbox, custom_mapper
|
||||
):
|
||||
"""Test encode-decode roundtrip with custom scaling factors."""
|
||||
position, size = custom_mapper.encode(sample_sound_event)
|
||||
recovered_bbox = custom_mapper.decode(position, size)
|
||||
|
||||
assert isinstance(recovered_bbox, data.BoundingBox)
|
||||
np.testing.assert_allclose(
|
||||
@ -204,25 +214,26 @@ def test_bbox_encoder_recover_roi_custom_scale(sample_bbox, custom_encoder):
|
||||
)
|
||||
|
||||
|
||||
def test_bbox_encoder_recover_roi_zero_box(zero_bbox, default_encoder):
|
||||
"""Test recover_roi for a zero-sized box."""
|
||||
ref_pos = default_encoder.get_roi_position(zero_bbox)
|
||||
scaled_dims = default_encoder.get_roi_size(zero_bbox)
|
||||
recovered_bbox = default_encoder.recover_roi(ref_pos, scaled_dims)
|
||||
def test_anchor_bbox_mapper_roundtrip_zero_box(
|
||||
zero_sound_event, zero_bbox, default_mapper
|
||||
):
|
||||
"""Test encode-decode roundtrip for a zero-sized box."""
|
||||
position, size = default_mapper.encode(zero_sound_event)
|
||||
recovered_bbox = default_mapper.decode(position, size)
|
||||
np.testing.assert_allclose(
|
||||
recovered_bbox.coordinates, zero_bbox.coordinates, atol=1e-6
|
||||
)
|
||||
|
||||
|
||||
def test_bbox_encoder_recover_roi_invalid_dims_shape(default_encoder):
|
||||
"""Test recover_roi raises ValueError for incorrect dims shape."""
|
||||
def test_anchor_bbox_mapper_decode_invalid_size_shape(default_mapper):
|
||||
"""Test decode raises ValueError for incorrect size shape."""
|
||||
ref_pos = (10, 100)
|
||||
with pytest.raises(ValueError):
|
||||
default_encoder.recover_roi(ref_pos, np.array([1.0]))
|
||||
with pytest.raises(ValueError):
|
||||
default_encoder.recover_roi(ref_pos, np.array([1.0, 2.0, 3.0]))
|
||||
with pytest.raises(ValueError):
|
||||
default_encoder.recover_roi(ref_pos, np.array([[1.0], [2.0]]))
|
||||
with pytest.raises(ValueError, match="does not have the expected shape"):
|
||||
default_mapper.decode(ref_pos, np.array([1.0]))
|
||||
with pytest.raises(ValueError, match="does not have the expected shape"):
|
||||
default_mapper.decode(ref_pos, np.array([1.0, 2.0, 3.0]))
|
||||
with pytest.raises(ValueError, match="does not have the expected shape"):
|
||||
default_mapper.decode(ref_pos, np.array([[1.0], [2.0]]))
|
||||
|
||||
|
||||
def test_build_roi_mapper():
|
||||
@ -236,69 +247,3 @@ def test_build_roi_mapper():
|
||||
assert mapper.anchor == config.anchor
|
||||
assert mapper.time_scale == config.time_scale
|
||||
assert mapper.frequency_scale == config.frequency_scale
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config_yaml_content() -> str:
|
||||
"""YAML content for a sample ROIConfig."""
|
||||
return f"""
|
||||
position: center
|
||||
time_scale: 500.0
|
||||
frequency_scale: {1 / 1000.0}
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nested_config_yaml_content() -> str:
|
||||
"""YAML content with ROIConfig nested under a field."""
|
||||
return f"""
|
||||
model_settings:
|
||||
preprocessing:
|
||||
whatever: true
|
||||
roi_mapping:
|
||||
position: bottom-right
|
||||
time_scale: {DEFAULT_TIME_SCALE}
|
||||
frequency_scale: 0.01
|
||||
other_stuff: 123
|
||||
"""
|
||||
|
||||
|
||||
def test_load_roi_mapper_simple(tmp_path, sample_config_yaml_content):
|
||||
"""Test loading a simple ROIConfig from YAML."""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(sample_config_yaml_content)
|
||||
|
||||
mapper = load_roi_mapper(config_path)
|
||||
|
||||
assert isinstance(mapper, AnchorBBoxMapper)
|
||||
assert mapper.anchor == "center"
|
||||
assert mapper.time_scale == 500.0
|
||||
assert mapper.frequency_scale == pytest.approx(1 / 1000.0)
|
||||
|
||||
|
||||
def test_load_roi_mapper_nested(tmp_path, nested_config_yaml_content):
|
||||
"""Test loading a nested ROIConfig from YAML using 'field'."""
|
||||
config_path = tmp_path / "nested_config.yaml"
|
||||
config_path.write_text(nested_config_yaml_content)
|
||||
|
||||
mapper = load_roi_mapper(config_path, field="model_settings.roi_mapping")
|
||||
|
||||
assert isinstance(mapper, AnchorBBoxMapper)
|
||||
assert mapper.anchor == "bottom-right"
|
||||
assert mapper.time_scale == DEFAULT_TIME_SCALE
|
||||
assert mapper.frequency_scale == 0.01
|
||||
|
||||
|
||||
def test_load_roi_mapper_file_not_found(tmp_path):
|
||||
"""Test load_roi_mapper raises error if file doesn't exist."""
|
||||
non_existent_path = tmp_path / "not_real.yaml"
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_roi_mapper(non_existent_path)
|
||||
|
||||
|
||||
def test_load_roi_mapper_invalid_field(tmp_path, sample_config_yaml_content):
|
||||
"""Test load_roi_mapper raises error for invalid field."""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(sample_config_yaml_content)
|
||||
with pytest.raises(KeyError):
|
||||
load_roi_mapper(config_path, field="invalid.path")
|
||||
|
Loading…
Reference in New Issue
Block a user