Compare commits

..

No commits in common. "a462beaeb8619256a1f995f5594c503d4d1bc24b" and "ebad489cb196929532e52091f77504127e7e3e87" have entirely different histories.

110 changed files with 881 additions and 1856 deletions

2
.gitignore vendored
View File

@ -107,7 +107,7 @@ experiments/*
# DO Include
!batdetect2_notebook.ipynb
!src/batdetect2/models/checkpoints/*.pth.tar
!batdetect2/models/checkpoints/*.pth.tar
!tests/data/*.wav
!notebooks/*.ipynb
!tests/data/**/*.wav

View File

@ -1,7 +1,7 @@
# Variables
SOURCE_DIR = src
SOURCE_DIR = batdetect2
TESTS_DIR = tests
PYTHON_DIRS = src tests
PYTHON_DIRS = batdetect2 tests
DOCS_SOURCE = docs/source
DOCS_BUILD = docs/build
HTML_COVERAGE_DIR = htmlcov

View File

@ -189,7 +189,8 @@ def train_command(
config=postprocess_config_loaded,
)
logger.debug(
"Loaded postprocessor from file {path}", path=postprocess_config
"Loaded postprocessor from file {path}",
path=train_config,
)
except IOError:
logger.debug(

View File

@ -157,4 +157,4 @@ def load_config(
if field:
config = get_object_field(config, field)
return schema.model_validate(config or {})
return schema.model_validate(config)

View File

@ -8,7 +8,6 @@ from batdetect2.data.annotations import (
from batdetect2.data.datasets import (
DatasetConfig,
load_dataset,
load_dataset_config,
load_dataset_from_config,
)
@ -20,6 +19,5 @@ __all__ = [
"DatasetConfig",
"load_annotated_dataset",
"load_dataset",
"load_dataset_config",
"load_dataset_from_config",
]

View File

@ -161,11 +161,6 @@ def insert_source_tag(
)
# TODO: add documentation
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
return load_config(path=path, schema=DatasetConfig, field=field)
def load_dataset_from_config(
path: data.PathLike,
field: Optional[str] = None,

View File

@ -72,7 +72,7 @@ def iterate_over_sound_events(
sound_event_annotation
)
class_name = targets.encode_class(sound_event_annotation)
class_name = targets.encode(sound_event_annotation)
if class_name is None and exclude_generic:
continue

View File

@ -1,3 +1,4 @@
import numpy as np
from sklearn.metrics import auc, roc_curve

View File

@ -40,7 +40,7 @@ def match_sound_events_and_raw_predictions(
gt_uuid = target.uuid if target is not None else None
gt_det = target is not None
gt_class = targets.encode_class(target) if target is not None else None
gt_class = targets.encode(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0

View File

@ -7,8 +7,8 @@ containing detected sound events with associated class tags and geometry.
The pipeline involves several configurable steps, implemented in submodules:
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
model output arrays.
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
coordinates to raw model output arrays.
3. Detection Extraction (`.detection`): Identifies candidate detection points
(location and score) based on thresholds and score ranking (top-k).
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
@ -526,7 +526,7 @@ class Postprocessor(PostprocessorProtocol):
return [
convert_xr_dataset_to_raw_prediction(
dataset,
self.targets.decode_roi,
self.targets.recover_roi,
)
for dataset in detection_datasets
]
@ -558,7 +558,7 @@ class Postprocessor(PostprocessorProtocol):
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
sound_event_decoder=self.targets.decode_class,
sound_event_decoder=self.targets.decode,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold,
)

View File

@ -4,7 +4,8 @@ This module handles the final stages of the BatDetect2 postprocessing pipeline.
It takes the structured detection data extracted by the `extraction` module
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
class probabilities, and features for each detection point) and converts it
into standardized prediction objects based on the `soundevent` data model.
into meaningful, standardized prediction objects based on the `soundevent` data
model.
The process involves:
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
@ -32,7 +33,7 @@ import numpy as np
import xarray as xr
from soundevent import data
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder
from batdetect2.utils.arrays import iterate_over_array
@ -54,7 +55,7 @@ decoding.
def convert_xr_dataset_to_raw_prediction(
detection_dataset: xr.Dataset,
geometry_decoder: GeometryDecoder,
geometry_builder: GeometryBuilder,
) -> List[RawPrediction]:
"""Convert an xarray.Dataset of detections to RawPrediction objects.
@ -71,7 +72,7 @@ def convert_xr_dataset_to_raw_prediction(
output by `extract_detection_xr_dataset`. Expected variables include
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
Must have a 'detection' dimension.
geometry_decoder : GeometryDecoder
geometry_builder : GeometryBuilder
A function that takes a position tuple `(time, freq)` and a NumPy array
of dimensions, and returns the corresponding reconstructed
`soundevent.data.Geometry`.
@ -95,20 +96,14 @@ def convert_xr_dataset_to_raw_prediction(
for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num)
# TODO: Maybe clean this up
highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax()
].item()
geom = geometry_decoder(
geom = geometry_builder(
(det_info.time, det_info.frequency),
det_info.dimensions,
class_name=highest_scoring_class,
)
detections.append(
RawPrediction(
detection_score=det_info.scores,
detection_score=det_info.score,
geometry=geom,
class_scores=det_info.classes,
features=det_info.features,

View File

@ -1,6 +1,6 @@
"""Extracts candidate detection points from a model output heatmap.
This module implements Step 3 within the BatDetect2 postprocessing
This module implements a specific step within the BatDetect2 postprocessing
pipeline. Its primary function is to identify potential sound event locations
by finding peaks (local maxima or high-scoring points) in the detection heatmap
produced by the neural network (usually after Non-Maximum Suppression and

View File

@ -1,9 +1,9 @@
"""Extracts associated data for detected points from model output arrays.
This module implements a Step 4 in the BatDetect2 postprocessing pipeline.
After candidate detection points (time, frequency, score) have been identified,
this module extracts the corresponding values from other raw model output
arrays, such as:
This module implements a key step (Step 4) in the BatDetect2 postprocessing
pipeline. After candidate detection points (time, frequency, score) have been
identified, this module extracts the corresponding values from other raw model
output arrays, such as:
- Predicted bounding box sizes (width, height).
- Class probability scores for each defined target class.

View File

@ -11,37 +11,30 @@ modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions.
"""
from typing import List, NamedTuple, Optional, Protocol
from typing import Callable, List, NamedTuple, Protocol
import numpy as np
import xarray as xr
from soundevent import data
from batdetect2.models.types import ModelOutput
from batdetect2.targets.types import Position, Size
__all__ = [
"RawPrediction",
"PostprocessorProtocol",
"GeometryDecoder",
"GeometryBuilder",
]
# TODO: update the docstring
class GeometryDecoder(Protocol):
"""Type alias for a function that recovers geometry from position and size.
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
"""Type alias for a function that recovers geometry from position and size.
This callable takes:
1. A position tuple `(time, frequency)`.
2. A NumPy array of size dimensions (e.g., `[width, height]`).
3. Optionally a class name of the highest scoring class. This is to accomodate
different ways of decoding geometry that depend on the predicted class.
It should return the reconstructed `soundevent.data.Geometry` (typically a
`BoundingBox`).
"""
def __call__(
self, position: Position, size: Size, class_name: Optional[str] = None
) -> data.Geometry: ...
This callable takes:
1. A position tuple `(time, frequency)`.
2. A NumPy array of size dimensions (e.g., `[width, height]`).
It should return the reconstructed `soundevent.data.Geometry` (typically a
`BoundingBox`).
"""
class RawPrediction(NamedTuple):

View File

@ -23,7 +23,7 @@ object is via the `build_targets` or `load_targets` functions.
from typing import List, Optional
from loguru import logger
import numpy as np
from pydantic import Field
from soundevent import data
@ -50,8 +50,7 @@ from batdetect2.targets.filtering import (
load_filter_from_config,
)
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
ROIMapperConfig,
ROIConfig,
ROITargetMapper,
build_roi_mapper,
)
@ -60,11 +59,11 @@ from batdetect2.targets.terms import (
TermInfo,
TermRegistry,
call_type,
default_term_registry,
get_tag_from_info,
get_term_from_key,
individual,
register_term,
term_registry,
)
from batdetect2.targets.transform import (
DerivationRegistry,
@ -74,13 +73,13 @@ from batdetect2.targets.transform import (
SoundEventTransformation,
TransformConfig,
build_transformation_from_config,
default_derivation_registry,
derivation_registry,
get_derivation,
load_transformation_config,
load_transformation_from_config,
register_derivation,
)
from batdetect2.targets.types import Position, Size, TargetProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClassesConfig",
@ -89,7 +88,7 @@ __all__ = [
"FilterConfig",
"FilterRule",
"MapValueRule",
"AnchorBBoxMapperConfig",
"ROIConfig",
"ROITargetMapper",
"ReplaceRule",
"SoundEventDecoder",
@ -157,12 +156,12 @@ class TargetConfig(BaseConfig):
omitted, default ROI mapping settings are used.
"""
filtering: FilterConfig = Field(default_factory=FilterConfig)
transforms: TransformConfig = Field(default_factory=TransformConfig)
filtering: Optional[FilterConfig] = None
transforms: Optional[TransformConfig] = None
classes: ClassesConfig = Field(
default_factory=lambda: DEFAULT_CLASSES_CONFIG
)
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
roi: Optional[ROIConfig] = None
def load_target_config(
@ -241,7 +240,6 @@ class Targets(TargetProtocol):
generic_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventFilter] = None,
transform_fn: Optional[SoundEventTransformation] = None,
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
):
"""Initialize the Targets object.
@ -274,16 +272,6 @@ class Targets(TargetProtocol):
self._encode_fn = encode_fn
self._decode_fn = decode_fn
self._transform_fn = transform_fn
self._roi_mapper_overrides = roi_mapper_overrides or {}
for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names:
# TODO: improve this warning
logger.warning(
"The ROI mapper overrides contains a class ({class_name}) "
"not present in the class names.",
class_name=class_name,
)
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation.
@ -303,9 +291,7 @@ class Targets(TargetProtocol):
return True
return self._filter_fn(sound_event)
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]:
"""Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority)
@ -326,7 +312,7 @@ class Targets(TargetProtocol):
"""
return self._encode_fn(sound_event)
def decode_class(self, class_label: str) -> List[data.Tag]:
def decode(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or
@ -366,9 +352,9 @@ class Targets(TargetProtocol):
return self._transform_fn(sound_event)
return sound_event
def encode_roi(
def get_position(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
) -> tuple[float, float]:
"""Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method.
@ -388,20 +374,50 @@ class Targets(TargetProtocol):
ValueError
If the annotation lacks geometry.
"""
class_name = self.encode_class(sound_event)
geom = sound_event.sound_event.geometry
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].encode(
sound_event.sound_event
if geom is None:
raise ValueError(
"Sound event has no geometry, cannot get its position."
)
return self._roi_mapper.encode(sound_event.sound_event)
return self._roi_mapper.get_roi_position(geom)
def decode_roi(
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
"""Calculate the target size dimensions from the annotation's geometry.
Delegates to the internal ROI mapper's `get_roi_size` method, which
applies configured scaling factors.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
np.ndarray
NumPy array containing the size dimensions, matching the
order in `self.dimension_names` (e.g., `[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry.
"""
geom = sound_event.sound_event.geometry
if geom is None:
raise ValueError(
"Sound event has no geometry, cannot get its size."
)
return self._roi_mapper.get_roi_size(geom)
def recover_roi(
self,
position: Position,
size: Size,
class_name: Optional[str] = None,
pos: tuple[float, float],
dims: np.ndarray,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
@ -422,13 +438,7 @@ class Targets(TargetProtocol):
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].decode(
position,
size,
)
return self._roi_mapper.decode(position, size)
return self._roi_mapper.recover_roi(pos, dims)
DEFAULT_CLASSES = [
@ -483,12 +493,10 @@ DEFAULT_CLASSES = [
TargetClass(
tags=[TagInfo(value="Nyctalus leisleri")],
name="nyclei",
roi=AnchorBBoxMapperConfig(anchor="top-left"),
),
TargetClass(
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
name="rhifer",
roi=AnchorBBoxMapperConfig(anchor="top-left"),
),
TargetClass(
tags=[TagInfo(value="Plecotus auritus")],
@ -529,14 +537,13 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
]
),
classes=DEFAULT_CLASSES_CONFIG,
roi=AnchorBBoxMapperConfig(),
)
def build_targets(
config: Optional[TargetConfig] = None,
term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = default_derivation_registry,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
@ -599,17 +606,12 @@ def build_targets(
if config.transforms
else None
)
roi_mapper = build_roi_mapper(config.roi)
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags(
config.classes,
term_registry=term_registry,
)
roi_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classes.classes
if class_config.roi is not None
}
return Targets(
filter_fn=filter_fn,
@ -619,15 +621,14 @@ def build_targets(
roi_mapper=roi_mapper,
generic_class_tags=generic_class_tags,
transform_fn=transform_fn,
roi_mapper_overrides=roi_overrides,
)
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = default_derivation_registry,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> Targets:
"""Load a Targets object directly from a configuration file.

View File

@ -6,27 +6,29 @@ from pydantic import Field, field_validator
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import (
GENERIC_CLASS_KEY,
TagInfo,
TermRegistry,
default_term_registry,
get_tag_from_info,
term_registry,
)
__all__ = [
"DEFAULT_SPECIES_LIST",
"build_generic_class_tags",
"build_sound_event_decoder",
"build_sound_event_encoder",
"get_class_names_from_config",
"SoundEventEncoder",
"SoundEventDecoder",
"TargetClass",
"ClassesConfig",
"load_classes_config",
"load_decoder_from_config",
"load_encoder_from_config",
"load_decoder_from_config",
"build_sound_event_encoder",
"build_sound_event_decoder",
"build_generic_class_tags",
"get_class_names_from_config",
"DEFAULT_SPECIES_LIST",
]
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Type alias for a sound event class encoder function.
@ -111,7 +113,6 @@ class TargetClass(BaseConfig):
tags: List[TagInfo] = Field(min_length=1)
match_type: Literal["all", "any"] = Field(default="all")
output_tags: Optional[List[TagInfo]] = None
roi: Optional[ROIMapperConfig] = None
def _get_default_classes() -> List[TargetClass]:
@ -234,7 +235,7 @@ class ClassesConfig(BaseConfig):
return v
def is_target_class(
def _is_target_class(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
match_all: bool = True,
@ -315,7 +316,7 @@ def _encode_with_multiple_classifiers(
def build_sound_event_encoder(
config: ClassesConfig,
term_registry: TermRegistry = default_term_registry,
term_registry: TermRegistry = term_registry,
) -> SoundEventEncoder:
"""Build a sound event encoder function from the classes configuration.
@ -349,7 +350,7 @@ def build_sound_event_encoder(
(
class_info.name,
partial(
is_target_class,
_is_target_class,
tags={
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in class_info.tags
@ -409,7 +410,7 @@ def _decode_class(
def build_sound_event_decoder(
config: ClassesConfig,
term_registry: TermRegistry = default_term_registry,
term_registry: TermRegistry = term_registry,
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Build a sound event decoder function from the classes configuration.
@ -464,7 +465,7 @@ def build_sound_event_decoder(
def build_generic_class_tags(
config: ClassesConfig,
term_registry: TermRegistry = default_term_registry,
term_registry: TermRegistry = term_registry,
) -> List[data.Tag]:
"""Extract and build the list of tags for the generic class from config.
@ -529,7 +530,7 @@ def load_classes_config(
def load_encoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
term_registry: TermRegistry = term_registry,
) -> SoundEventEncoder:
"""Load a class encoder function directly from a configuration file.
@ -570,7 +571,7 @@ def load_encoder_from_config(
def load_decoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
term_registry: TermRegistry = term_registry,
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Load a class decoder function directly from a configuration file.

View File

@ -10,7 +10,7 @@ from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
default_term_registry,
term_registry,
)
__all__ = [
@ -156,7 +156,7 @@ def equal_tags(
def build_filter_from_rule(
rule: FilterRule,
term_registry: TermRegistry = default_term_registry,
term_registry: TermRegistry = term_registry,
) -> SoundEventFilter:
"""Creates a callable filter function from a single FilterRule.
@ -243,7 +243,7 @@ class FilterConfig(BaseConfig):
def build_sound_event_filter(
config: FilterConfig,
term_registry: TermRegistry = default_term_registry,
term_registry: TermRegistry = term_registry,
) -> SoundEventFilter:
"""Builds a merged filter function from a FilterConfig object.
@ -291,7 +291,7 @@ def load_filter_config(
def load_filter_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
term_registry: TermRegistry = term_registry,
) -> SoundEventFilter:
"""Loads filter configuration from a file and builds the filter function.

511
batdetect2/targets/rois.py Normal file
View File

@ -0,0 +1,511 @@
"""Handles mapping between geometric ROIs and target representations.
This module defines the interface and provides implementation for converting
a sound event's Region of Interest (ROI), typically represented by a
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
suitable for use as a machine learning target. This usually involves:
1. Extracting a single reference point (time, frequency) from the geometry.
2. Calculating relevant size dimensions (e.g., duration/width,
bandwidth/height) and applying scaling factors.
It also provides the inverse operation: recovering an approximate geometric ROI
(like a `BoundingBox`) from a predicted reference point and predicted size
dimensions.
This logic is encapsulated within components adhering to the `ROITargetMapper`
protocol. Configuration for this mapping (e.g., which reference point to use,
scaling factors) is managed by the `ROIConfig`. This module separates the
*geometric* aspect of target definition from the *semantic* classification
handled in `batdetect2.targets.classes`.
"""
from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
Positions = Literal[
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
__all__ = [
"ROITargetMapper",
"ROIConfig",
"BBoxEncoder",
"build_roi_mapper",
"load_roi_mapper",
"DEFAULT_POSITION",
"SIZE_WIDTH",
"SIZE_HEIGHT",
"SIZE_ORDER",
"DEFAULT_TIME_SCALE",
"DEFAULT_FREQUENCY_SCALE",
]
SIZE_WIDTH = "width"
"""Standard name for the width/time dimension component ('width')."""
SIZE_HEIGHT = "height"
"""Standard name for the height/frequency dimension component ('height')."""
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
"""Standard order of dimensions for size arrays ([width, height])."""
DEFAULT_TIME_SCALE = 1000.0
"""Default scaling factor for time duration."""
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
"""Default scaling factor for frequency bandwidth."""
DEFAULT_POSITION = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner)."""
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the methods required for converting a geometric region of interest
(`soundevent.data.Geometry`) into a target representation (reference point
and scaled dimensions) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions returned by
`get_roi_size` and expected by `recover_roi`
(e.g., ['width', 'height']).
"""
dimension_names: List[str]
def get_roi_position(
self,
geom: data.Geometry,
position: Optional[Positions] = None,
) -> tuple[float, float]:
"""Extract the reference position from a geometry.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry (e.g., BoundingBox, Polygon).
position : Positions, optional
Overrides the default `position` configured for the mapper.
If provided, this position will be used instead of the mapper's
internal default.
Returns
-------
Tuple[float, float]
The calculated reference position as (time, frequency) coordinates,
based on the implementing class's configuration (e.g., "center",
"bottom-left").
Raises
------
ValueError
If the position cannot be calculated for the given geometry type
or configured reference point.
"""
...
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
"""Calculate the scaled target dimensions from a geometry.
Computes the relevant size measures.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry.
Returns
-------
np.ndarray
A NumPy array containing the scaled dimensions corresponding to
`dimension_names`. For bounding boxes, typically contains
`[scaled_width, scaled_height]`.
Raises
------
TypeError, ValueError
If the size cannot be computed for the given geometry type.
"""
...
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
position: Optional[Positions] = None,
) -> data.Geometry:
"""Recover an approximate ROI from a position and target dimensions.
Performs the inverse mapping: takes a reference position and the
predicted dimensions and reconstructs a geometric representation.
Parameters
----------
pos : Tuple[float, float]
The reference position (time, frequency).
dims : np.ndarray
NumPy array containing the dimensions, matching the order
specified by `dimension_names`.
position : Positions, optional
Overrides the default `position` configured for the mapper.
If provided, this position will be used instead of the mapper's
internal default when reconstructing the roi geometry.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry.
Raises
------
ValueError
If the number of provided dimensions `dims` does not match
`dimension_names` or if reconstruction fails.
"""
...
class ROIConfig(BaseConfig):
"""Configuration for mapping Regions of Interest (ROIs).
Defines parameters controlling how geometric ROIs are converted into
target representations (reference points and scaled sizes).
Attributes
----------
position : Positions, default="bottom-left"
Specifies the reference point within the geometry (e.g., bounding box)
to use as the target location (e.g., "center", "bottom-left").
See `soundevent.geometry.operations.Positions`.
time_scale : float, default=1000.0
Scaling factor applied to the time duration (width) of the ROI
when calculating the target size representation. Must match model
expectations.
frequency_scale : float, default=1/859.375
Scaling factor applied to the frequency bandwidth (height) of the ROI
when calculating the target size representation. Must match model
expectations.
"""
position: Positions = DEFAULT_POSITION
time_scale: float = DEFAULT_TIME_SCALE
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
class BBoxEncoder(ROITargetMapper):
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
This class implements the ROI mapping protocol primarily for
`soundevent.data.BoundingBox` geometry. It extracts reference points,
calculates scaled width/height, and recovers bounding boxes based on
configured position and scaling factors.
Attributes
----------
dimension_names : List[str]
Specifies the output dimension names as ['width', 'height'].
position : Positions
The configured reference point type (e.g., "center", "bottom-left").
time_scale : float
The configured scaling factor for the time dimension (width).
frequency_scale : float
The configured scaling factor for the frequency dimension (height).
"""
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
def __init__(
self,
position: Positions = DEFAULT_POSITION,
time_scale: float = DEFAULT_TIME_SCALE,
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
):
"""Initialize the BBoxEncoder.
Parameters
----------
position : Positions, default="bottom-left"
Reference point type within the bounding box.
time_scale : float, default=1000.0
Scaling factor for time duration (width).
frequency_scale : float, default=1/859.375
Scaling factor for frequency bandwidth (height).
"""
self.position: Positions = position
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def get_roi_position(
self,
geom: data.Geometry,
position: Optional[Positions] = None,
) -> Tuple[float, float]:
"""Extract the configured reference position from the geometry.
Uses `soundevent.geometry.get_geometry_point`.
Parameters
----------
geom : soundevent.data.Geometry
Input geometry (e.g., BoundingBox).
position : Positions, optional
Overrides the default `position` configured for the encoder.
If provided, this position will be used instead of `self.position`.
Returns
-------
Tuple[float, float]
Reference position (time, frequency).
"""
from soundevent import geometry
position = position or self.position
return geometry.get_geometry_point(geom, position=position)
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
"""Calculate the scaled [width, height] from the geometry's bounds.
Computes the bounding box, extracts duration and bandwidth, and applies
the configured `time_scale` and `frequency_scale`.
Parameters
----------
geom : soundevent.data.Geometry
Input geometry.
Returns
-------
np.ndarray
A 1D NumPy array: `[scaled_width, scaled_height]`.
"""
from soundevent import geometry
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
return np.array(
[
(end_time - start_time) * self.time_scale,
(high_freq - low_freq) * self.frequency_scale,
]
)
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
position: Optional[Positions] = None,
) -> data.Geometry:
"""Recover a BoundingBox from a position and scaled dimensions.
Un-scales the input dimensions using the configured factors and
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
the given reference `pos` according to the configured `position` type.
Parameters
----------
pos : Tuple[float, float]
Reference position (time, frequency).
dims : np.ndarray
NumPy array containing the *scaled* dimensions, expected order is
[scaled_width, scaled_height].
position : Positions, optional
Overrides the default `position` configured for the encoder.
If provided, this position will be used instead of `self.position`
when reconstructing the bounding box.
Returns
-------
soundevent.data.BoundingBox
The reconstructed bounding box.
Raises
------
ValueError
If `dims` does not have the expected shape (length 2).
"""
position = position or self.position
if dims.ndim != 1 or dims.shape[0] != 2:
raise ValueError(
"Dimension array does not have the expected shape. "
f"({dims.shape = }) != ([2])"
)
width, height = dims
return _build_bounding_box(
pos,
duration=float(width) / self.time_scale,
bandwidth=float(height) / self.frequency_scale,
position=self.position,
)
def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from configuration.
Currently creates a `BBoxEncoder` instance based on the provided
`ROIConfig`.
Parameters
----------
config : ROIConfig
Configuration object specifying ROI mapping parameters.
Returns
-------
ROITargetMapper
An initialized `BBoxEncoder` instance configured with the settings
from `config`.
"""
return BBoxEncoder(
position=config.position,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
def load_roi_mapper(
path: data.PathLike, field: Optional[str] = None
) -> ROITargetMapper:
"""Load ROI mapping configuration from a file and build the mapper.
Convenience function that loads an `ROIConfig` from the specified file
(and optional field) and then uses `build_roi_mapper` to create the
corresponding `ROITargetMapper` instance.
Parameters
----------
path : PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing the
ROI configuration. If None, the entire file content is used.
Returns
-------
ROITargetMapper
An initialized ROI mapper instance based on the configuration file.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
If the configuration file cannot be found, parsed, validated, or if
the specified `field` is invalid.
"""
config = load_config(path=path, schema=ROIConfig, field=field)
return build_roi_mapper(config)
VALID_POSITIONS = [
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
def _build_bounding_box(
pos: tuple[float, float],
duration: float,
bandwidth: float,
position: Positions = DEFAULT_POSITION,
) -> data.BoundingBox:
"""Construct a BoundingBox from a reference point, size, and position type.
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
coordinates [start_time, low_freq, end_time, high_freq] based on where
the input `pos` (time, freq) is located relative to the box (e.g.,
center, corner).
Parameters
----------
pos : Tuple[float, float]
Reference position (time, frequency).
duration : float
The required *unscaled* duration (width) of the bounding box.
bandwidth : float
The required *unscaled* frequency bandwidth (height) of the bounding
box.
position : Positions, default="bottom-left"
Specifies which part of the bounding box the input `pos` corresponds to.
Returns
-------
data.BoundingBox
The constructed bounding box object.
Raises
------
ValueError
If `position` is not a recognized value or format.
"""
time, freq = map(float, pos)
duration = max(0, duration)
bandwidth = max(0, bandwidth)
if position in ["center", "centroid", "point_on_surface"]:
return data.BoundingBox(
coordinates=[
max(time - duration / 2, 0),
max(freq - bandwidth / 2, 0),
max(time + duration / 2, 0),
max(freq + bandwidth / 2, 0),
]
)
if position not in VALID_POSITIONS:
raise ValueError(
f"Invalid position: {position}. "
f"Valid options are: {VALID_POSITIONS}"
)
y, x = position.split("-")
start_time = {
"left": time,
"center": time - duration / 2,
"right": time - duration,
}[x]
low_freq = {
"bottom": freq,
"center": freq - bandwidth / 2,
"top": freq - bandwidth,
}[y]
return data.BoundingBox(
coordinates=[
max(0, start_time),
max(0, low_freq),
max(0, start_time + duration),
max(0, low_freq + bandwidth),
]
)

View File

@ -230,12 +230,11 @@ class TermRegistry(Mapping[str, data.Term]):
del self._terms[key]
default_term_registry = TermRegistry(
term_registry = TermRegistry(
terms=dict(
[
*getmembers(terms, lambda x: isinstance(x, data.Term)),
("event", call_type),
("species", terms.scientific_name),
("individual", individual),
("data_source", data_source),
(GENERIC_CLASS_KEY, generic_class),
@ -253,7 +252,7 @@ is explicitly passed.
def get_term_from_key(
key: str,
term_registry: Optional[TermRegistry] = None,
term_registry: TermRegistry = term_registry,
) -> data.Term:
"""Convenience function to retrieve a term by key from a registry.
@ -278,13 +277,10 @@ def get_term_from_key(
KeyError
If the key is not found in the specified registry.
"""
term_registry = term_registry or default_term_registry
return term_registry.get_term(key)
def get_term_keys(
term_registry: TermRegistry = default_term_registry,
) -> List[str]:
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
"""Convenience function to get all registered keys from a registry.
Uses the global default registry unless a specific `term_registry`
@ -303,9 +299,7 @@ def get_term_keys(
return term_registry.get_keys()
def get_terms(
term_registry: TermRegistry = default_term_registry,
) -> List[data.Term]:
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
"""Convenience function to get all registered terms from a registry.
Uses the global default registry unless a specific `term_registry`
@ -348,7 +342,7 @@ class TagInfo(BaseModel):
def get_tag_from_info(
tag_info: TagInfo,
term_registry: Optional[TermRegistry] = None,
term_registry: TermRegistry = term_registry,
) -> data.Tag:
"""Creates a soundevent.data.Tag object from TagInfo data.
@ -374,7 +368,6 @@ def get_tag_from_info(
If the term key specified in `tag_info.key` is not found
in the registry.
"""
term_registry = term_registry or default_term_registry
term = get_term_from_key(tag_info.key, term_registry=term_registry)
return data.Tag(term=term, value=tag_info.value)
@ -446,7 +439,7 @@ class TermConfig(BaseModel):
def load_terms_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = default_term_registry,
term_registry: TermRegistry = term_registry,
) -> Dict[str, data.Term]:
"""Loads term definitions from a configuration file and registers them.
@ -497,6 +490,6 @@ def load_terms_from_config(
def register_term(
key: str, term: data.Term, registry: TermRegistry = default_term_registry
key: str, term: data.Term, registry: TermRegistry = term_registry
) -> None:
registry.add_term(key, term)

View File

@ -21,6 +21,9 @@ from batdetect2.targets.terms import (
get_tag_from_info,
get_term_from_key,
)
from batdetect2.targets.terms import (
term_registry as default_term_registry,
)
__all__ = [
"DerivationRegistry",
@ -31,7 +34,7 @@ __all__ = [
"TransformConfig",
"build_transform_from_rule",
"build_transformation_from_config",
"default_derivation_registry",
"derivation_registry",
"get_derivation",
"load_transformation_config",
"load_transformation_from_config",
@ -395,7 +398,7 @@ class DerivationRegistry(Mapping[str, Derivation]):
return list(self._derivations.values())
default_derivation_registry = DerivationRegistry()
derivation_registry = DerivationRegistry()
"""Global instance of the DerivationRegistry.
Register custom derivation functions here to make them available by key
@ -406,7 +409,7 @@ in `DeriveTagRule` configuration.
def get_derivation(
key: str,
import_derivation: bool = False,
registry: Optional[DerivationRegistry] = None,
registry: DerivationRegistry = derivation_registry,
):
"""Retrieve a derivation function by key, optionally importing it.
@ -440,8 +443,6 @@ def get_derivation(
AttributeError
If dynamic import fails because the function name isn't in the module.
"""
registry = registry or default_derivation_registry
if not import_derivation or key in registry:
return registry.get_derivation(key)
@ -457,16 +458,10 @@ def get_derivation(
) from err
TranformationRule = Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
def build_transform_from_rule(
rule: TranformationRule,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventTransformation:
"""Build a specific SoundEventTransformation function from a rule config.
@ -564,8 +559,8 @@ def build_transform_from_rule(
def build_transformation_from_config(
config: TransformConfig,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventTransformation:
"""Build a composite transformation function from a TransformConfig.
@ -586,7 +581,6 @@ def build_transformation_from_config(
SoundEventTransformation
A single function that applies all configured transformations in order.
"""
transforms = [
build_transform_from_rule(
rule,
@ -596,17 +590,15 @@ def build_transformation_from_config(
for rule in config.rules
]
return partial(apply_sequence_of_transforms, transforms=transforms)
def apply_sequence_of_transforms(
def transformation(
sound_event_annotation: data.SoundEventAnnotation,
transforms: list[SoundEventTransformation],
) -> data.SoundEventAnnotation:
) -> data.SoundEventAnnotation:
for transform in transforms:
sound_event_annotation = transform(sound_event_annotation)
return sound_event_annotation
return transformation
def load_transformation_config(
path: data.PathLike, field: Optional[str] = None
@ -639,8 +631,8 @@ def load_transformation_config(
def load_transformation_from_config(
path: data.PathLike,
field: Optional[str] = None,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventTransformation:
"""Load transformation config from a file and build the final function.
@ -685,7 +677,7 @@ def load_transformation_from_config(
def register_derivation(
key: str,
derivation: Derivation,
derivation_registry: Optional[DerivationRegistry] = None,
derivation_registry: DerivationRegistry = derivation_registry,
) -> None:
"""Register a new derivation function in the global registry.
@ -704,5 +696,4 @@ def register_derivation(
KeyError
If a derivation function with the same key is already registered.
"""
derivation_registry = derivation_registry or default_derivation_registry
derivation_registry.register(key, derivation)

View File

@ -19,16 +19,8 @@ from soundevent import data
__all__ = [
"TargetProtocol",
"Position",
"Size",
]
Position = tuple[float, float]
"""A tuple representing (time, frequency) coordinates."""
Size = np.ndarray
"""A NumPy array representing the size dimensions of a target."""
class TargetProtocol(Protocol):
"""Protocol defining the interface for the target definition pipeline.
@ -110,7 +102,7 @@ class TargetProtocol(Protocol):
"""
...
def encode_class(
def encode(
self,
sound_event: data.SoundEventAnnotation,
) -> Optional[str]:
@ -131,7 +123,7 @@ class TargetProtocol(Protocol):
"""
...
def decode_class(self, class_label: str) -> List[data.Tag]:
def decode(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Parameters
@ -155,9 +147,9 @@ class TargetProtocol(Protocol):
"""
...
def encode_roi(
def get_position(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
) -> tuple[float, float]:
"""Extract the target reference position from the annotation's geometry.
Calculates the `(time, frequency)` coordinate representing the primary
@ -181,12 +173,36 @@ class TargetProtocol(Protocol):
"""
...
# TODO: Update docstrings
def decode_roi(
self,
position: Position,
size: Size,
class_name: Optional[str] = None,
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
"""Calculate the target size dimensions from the annotation's geometry.
Computes the relevant physical size (e.g., duration/width,
bandwidth/height from a bounding box) to produce
the numerical target values expected by the model.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
np.ndarray
A NumPy array containing the size dimensions, matching the
order specified by the `dimension_names` attribute (e.g.,
`[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry or if the size cannot be computed.
TypeError
If geometry type is unsupported.
"""
...
def recover_roi(
self, pos: tuple[float, float], dims: np.ndarray
) -> data.Geometry:
"""Recover the ROI geometry from a position and dimensions.
@ -201,8 +217,6 @@ class TargetProtocol(Protocol):
dims : np.ndarray
The NumPy array containing the dimensions (e.g., predicted
by the model), corresponding to the order in `dimension_names`.
class_name: str
class
Returns
-------

View File

@ -97,7 +97,7 @@ def _is_in_subclip(
start_time: float,
end_time: float,
) -> bool:
(time, _), _ = targets.encode_roi(sound_event_annotation)
time, _ = targets.get_position(sound_event_annotation)
return start_time <= time <= end_time

View File

@ -138,7 +138,7 @@ def generate_clip_label(
logger.debug(
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
uuid=clip_annotation.uuid,
num=len(clip_annotation.sound_events),
num=len(clip_annotation.sound_events)
)
sound_events = []
@ -260,7 +260,7 @@ def generate_heatmaps(
continue
# Get the position of the sound event
(time, frequency), size = targets.encode_roi(sound_event_annotation)
time, frequency = targets.get_position(sound_event_annotation)
# Set 1.0 at the position of the sound event in the detection heatmap
try:
@ -280,6 +280,8 @@ def generate_heatmaps(
)
continue
size = targets.get_size(sound_event_annotation)
size_heatmap = arrays.set_value_at_pos(
size_heatmap,
size,
@ -289,7 +291,7 @@ def generate_heatmaps(
# Get the class name of the sound event
try:
class_name = targets.encode_class(sound_event_annotation)
class_name = targets.encode(sound_event_annotation)
except ValueError as e:
logger.warning(
"Skipping annotation %s: Unexpected error while encoding "

View File

@ -1,19 +0,0 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
@app.cell
def _():
from batdetect2.preprocess import build_preprocessor
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

Some files were not shown because too many files have changed in this diff Show More