Compare commits

...

13 Commits

Author SHA1 Message Date
mbsantiago
a462beaeb8 Remove rogue print 2025-06-24 06:27:41 -06:00
mbsantiago
daab8ff0d7 Fix validation is_in_subclip after encoder changes 2025-06-24 06:26:08 -06:00
mbsantiago
235f0e27da Add load dataset config function 2025-06-24 12:40:29 +01:00
mbsantiago
b5b4229990 Fix testing issues 2025-06-23 19:08:55 +01:00
mbsantiago
8253b5bdc4 Update makefile 2025-06-23 18:53:15 +01:00
mbsantiago
c7ea361cf4 Implement changes needed to make roi encode/decode class dependent 2025-06-23 18:52:36 +01:00
mbsantiago
3407e1b5f0 Add other roi tests 2025-06-21 23:51:07 +01:00
mbsantiago
0a0d6f7162 Use standard dependency-groups instead of tool.uv section 2025-06-21 23:01:54 +01:00
mbsantiago
ad0f0bcb24 Add tests for peak energy function 2025-06-21 23:01:08 +01:00
mbsantiago
3103630c26 Update pyproject to use src layout 2025-06-21 13:49:06 +01:00
mbsantiago
960558be8b move to src layout 2025-06-21 13:48:40 +01:00
mbsantiago
e352dc40bd Fixed Target object after changes to roi 2025-06-21 13:47:04 +01:00
mbsantiago
c559bcc682 Changed ROIMapper protocol to only have encoder/decoder methods 2025-06-21 11:44:15 +01:00
110 changed files with 1844 additions and 869 deletions

2
.gitignore vendored
View File

@ -107,7 +107,7 @@ experiments/*
# DO Include # DO Include
!batdetect2_notebook.ipynb !batdetect2_notebook.ipynb
!batdetect2/models/checkpoints/*.pth.tar !src/batdetect2/models/checkpoints/*.pth.tar
!tests/data/*.wav !tests/data/*.wav
!notebooks/*.ipynb !notebooks/*.ipynb
!tests/data/**/*.wav !tests/data/**/*.wav

View File

@ -1,7 +1,7 @@
# Variables # Variables
SOURCE_DIR = batdetect2 SOURCE_DIR = src
TESTS_DIR = tests TESTS_DIR = tests
PYTHON_DIRS = batdetect2 tests PYTHON_DIRS = src tests
DOCS_SOURCE = docs/source DOCS_SOURCE = docs/source
DOCS_BUILD = docs/build DOCS_BUILD = docs/build
HTML_COVERAGE_DIR = htmlcov HTML_COVERAGE_DIR = htmlcov

View File

@ -1,511 +0,0 @@
"""Handles mapping between geometric ROIs and target representations.
This module defines the interface and provides implementation for converting
a sound event's Region of Interest (ROI), typically represented by a
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
suitable for use as a machine learning target. This usually involves:
1. Extracting a single reference point (time, frequency) from the geometry.
2. Calculating relevant size dimensions (e.g., duration/width,
bandwidth/height) and applying scaling factors.
It also provides the inverse operation: recovering an approximate geometric ROI
(like a `BoundingBox`) from a predicted reference point and predicted size
dimensions.
This logic is encapsulated within components adhering to the `ROITargetMapper`
protocol. Configuration for this mapping (e.g., which reference point to use,
scaling factors) is managed by the `ROIConfig`. This module separates the
*geometric* aspect of target definition from the *semantic* classification
handled in `batdetect2.targets.classes`.
"""
from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
Positions = Literal[
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
__all__ = [
"ROITargetMapper",
"ROIConfig",
"BBoxEncoder",
"build_roi_mapper",
"load_roi_mapper",
"DEFAULT_POSITION",
"SIZE_WIDTH",
"SIZE_HEIGHT",
"SIZE_ORDER",
"DEFAULT_TIME_SCALE",
"DEFAULT_FREQUENCY_SCALE",
]
SIZE_WIDTH = "width"
"""Standard name for the width/time dimension component ('width')."""
SIZE_HEIGHT = "height"
"""Standard name for the height/frequency dimension component ('height')."""
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
"""Standard order of dimensions for size arrays ([width, height])."""
DEFAULT_TIME_SCALE = 1000.0
"""Default scaling factor for time duration."""
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
"""Default scaling factor for frequency bandwidth."""
DEFAULT_POSITION = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner)."""
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the methods required for converting a geometric region of interest
(`soundevent.data.Geometry`) into a target representation (reference point
and scaled dimensions) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions returned by
`get_roi_size` and expected by `recover_roi`
(e.g., ['width', 'height']).
"""
dimension_names: List[str]
def get_roi_position(
self,
geom: data.Geometry,
position: Optional[Positions] = None,
) -> tuple[float, float]:
"""Extract the reference position from a geometry.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry (e.g., BoundingBox, Polygon).
position : Positions, optional
Overrides the default `position` configured for the mapper.
If provided, this position will be used instead of the mapper's
internal default.
Returns
-------
Tuple[float, float]
The calculated reference position as (time, frequency) coordinates,
based on the implementing class's configuration (e.g., "center",
"bottom-left").
Raises
------
ValueError
If the position cannot be calculated for the given geometry type
or configured reference point.
"""
...
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
"""Calculate the scaled target dimensions from a geometry.
Computes the relevant size measures.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry.
Returns
-------
np.ndarray
A NumPy array containing the scaled dimensions corresponding to
`dimension_names`. For bounding boxes, typically contains
`[scaled_width, scaled_height]`.
Raises
------
TypeError, ValueError
If the size cannot be computed for the given geometry type.
"""
...
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
position: Optional[Positions] = None,
) -> data.Geometry:
"""Recover an approximate ROI from a position and target dimensions.
Performs the inverse mapping: takes a reference position and the
predicted dimensions and reconstructs a geometric representation.
Parameters
----------
pos : Tuple[float, float]
The reference position (time, frequency).
dims : np.ndarray
NumPy array containing the dimensions, matching the order
specified by `dimension_names`.
position : Positions, optional
Overrides the default `position` configured for the mapper.
If provided, this position will be used instead of the mapper's
internal default when reconstructing the roi geometry.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry.
Raises
------
ValueError
If the number of provided dimensions `dims` does not match
`dimension_names` or if reconstruction fails.
"""
...
class ROIConfig(BaseConfig):
"""Configuration for mapping Regions of Interest (ROIs).
Defines parameters controlling how geometric ROIs are converted into
target representations (reference points and scaled sizes).
Attributes
----------
position : Positions, default="bottom-left"
Specifies the reference point within the geometry (e.g., bounding box)
to use as the target location (e.g., "center", "bottom-left").
See `soundevent.geometry.operations.Positions`.
time_scale : float, default=1000.0
Scaling factor applied to the time duration (width) of the ROI
when calculating the target size representation. Must match model
expectations.
frequency_scale : float, default=1/859.375
Scaling factor applied to the frequency bandwidth (height) of the ROI
when calculating the target size representation. Must match model
expectations.
"""
position: Positions = DEFAULT_POSITION
time_scale: float = DEFAULT_TIME_SCALE
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
class BBoxEncoder(ROITargetMapper):
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
This class implements the ROI mapping protocol primarily for
`soundevent.data.BoundingBox` geometry. It extracts reference points,
calculates scaled width/height, and recovers bounding boxes based on
configured position and scaling factors.
Attributes
----------
dimension_names : List[str]
Specifies the output dimension names as ['width', 'height'].
position : Positions
The configured reference point type (e.g., "center", "bottom-left").
time_scale : float
The configured scaling factor for the time dimension (width).
frequency_scale : float
The configured scaling factor for the frequency dimension (height).
"""
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
def __init__(
self,
position: Positions = DEFAULT_POSITION,
time_scale: float = DEFAULT_TIME_SCALE,
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
):
"""Initialize the BBoxEncoder.
Parameters
----------
position : Positions, default="bottom-left"
Reference point type within the bounding box.
time_scale : float, default=1000.0
Scaling factor for time duration (width).
frequency_scale : float, default=1/859.375
Scaling factor for frequency bandwidth (height).
"""
self.position: Positions = position
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def get_roi_position(
self,
geom: data.Geometry,
position: Optional[Positions] = None,
) -> Tuple[float, float]:
"""Extract the configured reference position from the geometry.
Uses `soundevent.geometry.get_geometry_point`.
Parameters
----------
geom : soundevent.data.Geometry
Input geometry (e.g., BoundingBox).
position : Positions, optional
Overrides the default `position` configured for the encoder.
If provided, this position will be used instead of `self.position`.
Returns
-------
Tuple[float, float]
Reference position (time, frequency).
"""
from soundevent import geometry
position = position or self.position
return geometry.get_geometry_point(geom, position=position)
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
"""Calculate the scaled [width, height] from the geometry's bounds.
Computes the bounding box, extracts duration and bandwidth, and applies
the configured `time_scale` and `frequency_scale`.
Parameters
----------
geom : soundevent.data.Geometry
Input geometry.
Returns
-------
np.ndarray
A 1D NumPy array: `[scaled_width, scaled_height]`.
"""
from soundevent import geometry
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
return np.array(
[
(end_time - start_time) * self.time_scale,
(high_freq - low_freq) * self.frequency_scale,
]
)
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
position: Optional[Positions] = None,
) -> data.Geometry:
"""Recover a BoundingBox from a position and scaled dimensions.
Un-scales the input dimensions using the configured factors and
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
the given reference `pos` according to the configured `position` type.
Parameters
----------
pos : Tuple[float, float]
Reference position (time, frequency).
dims : np.ndarray
NumPy array containing the *scaled* dimensions, expected order is
[scaled_width, scaled_height].
position : Positions, optional
Overrides the default `position` configured for the encoder.
If provided, this position will be used instead of `self.position`
when reconstructing the bounding box.
Returns
-------
soundevent.data.BoundingBox
The reconstructed bounding box.
Raises
------
ValueError
If `dims` does not have the expected shape (length 2).
"""
position = position or self.position
if dims.ndim != 1 or dims.shape[0] != 2:
raise ValueError(
"Dimension array does not have the expected shape. "
f"({dims.shape = }) != ([2])"
)
width, height = dims
return _build_bounding_box(
pos,
duration=float(width) / self.time_scale,
bandwidth=float(height) / self.frequency_scale,
position=self.position,
)
def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from configuration.
Currently creates a `BBoxEncoder` instance based on the provided
`ROIConfig`.
Parameters
----------
config : ROIConfig
Configuration object specifying ROI mapping parameters.
Returns
-------
ROITargetMapper
An initialized `BBoxEncoder` instance configured with the settings
from `config`.
"""
return BBoxEncoder(
position=config.position,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
def load_roi_mapper(
path: data.PathLike, field: Optional[str] = None
) -> ROITargetMapper:
"""Load ROI mapping configuration from a file and build the mapper.
Convenience function that loads an `ROIConfig` from the specified file
(and optional field) and then uses `build_roi_mapper` to create the
corresponding `ROITargetMapper` instance.
Parameters
----------
path : PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing the
ROI configuration. If None, the entire file content is used.
Returns
-------
ROITargetMapper
An initialized ROI mapper instance based on the configuration file.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
If the configuration file cannot be found, parsed, validated, or if
the specified `field` is invalid.
"""
config = load_config(path=path, schema=ROIConfig, field=field)
return build_roi_mapper(config)
VALID_POSITIONS = [
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
def _build_bounding_box(
pos: tuple[float, float],
duration: float,
bandwidth: float,
position: Positions = DEFAULT_POSITION,
) -> data.BoundingBox:
"""Construct a BoundingBox from a reference point, size, and position type.
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
coordinates [start_time, low_freq, end_time, high_freq] based on where
the input `pos` (time, freq) is located relative to the box (e.g.,
center, corner).
Parameters
----------
pos : Tuple[float, float]
Reference position (time, frequency).
duration : float
The required *unscaled* duration (width) of the bounding box.
bandwidth : float
The required *unscaled* frequency bandwidth (height) of the bounding
box.
position : Positions, default="bottom-left"
Specifies which part of the bounding box the input `pos` corresponds to.
Returns
-------
data.BoundingBox
The constructed bounding box object.
Raises
------
ValueError
If `position` is not a recognized value or format.
"""
time, freq = map(float, pos)
duration = max(0, duration)
bandwidth = max(0, bandwidth)
if position in ["center", "centroid", "point_on_surface"]:
return data.BoundingBox(
coordinates=[
max(time - duration / 2, 0),
max(freq - bandwidth / 2, 0),
max(time + duration / 2, 0),
max(freq + bandwidth / 2, 0),
]
)
if position not in VALID_POSITIONS:
raise ValueError(
f"Invalid position: {position}. "
f"Valid options are: {VALID_POSITIONS}"
)
y, x = position.split("-")
start_time = {
"left": time,
"center": time - duration / 2,
"right": time - duration,
}[x]
low_freq = {
"bottom": freq,
"center": freq - bandwidth / 2,
"top": freq - bandwidth,
}[y]
return data.BoundingBox(
coordinates=[
max(0, start_time),
max(0, low_freq),
max(0, start_time + duration),
max(0, low_freq + bandwidth),
]
)

View File

@ -0,0 +1,19 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
@app.cell
def _():
from batdetect2.preprocess import build_preprocessor
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -65,8 +65,12 @@ build-backend = "hatchling.build"
[project.scripts] [project.scripts]
batdetect2 = "batdetect2.cli:cli" batdetect2 = "batdetect2.cli:cli"
[tool.uv] [dependency-groups]
dev-dependencies = [ jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"]
marimo = [
"marimo>=0.12.2",
]
dev = [
"debugpy>=1.8.8", "debugpy>=1.8.8",
"hypothesis>=6.118.7", "hypothesis>=6.118.7",
"pytest>=7.2.2", "pytest>=7.2.2",
@ -98,25 +102,16 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
convention = "numpy" convention = "numpy"
[tool.pyright] [tool.pyright]
include = ["batdetect2", "tests"] include = ["src", "tests"]
pythonVersion = "3.9" pythonVersion = "3.9"
pythonPlatform = "All" pythonPlatform = "All"
exclude = [ exclude = [
"batdetect2/detector/", "src/batdetect2/detector/",
"batdetect2/finetune", "src/batdetect2/finetune",
"batdetect2/utils", "src/batdetect2/utils",
"batdetect2/plotting", "src/batdetect2/plotting",
"batdetect2/plot", "src/batdetect2/plot",
"batdetect2/api", "src/batdetect2/api",
"batdetect2/evaluate/legacy", "src/batdetect2/evaluate/legacy",
"batdetect2/train/legacy", "src/batdetect2/train/legacy",
]
[dependency-groups]
jupyter = [
"ipywidgets>=8.1.5",
"jupyter>=1.1.1",
]
marimo = [
"marimo>=0.12.2",
] ]

View File

@ -189,8 +189,7 @@ def train_command(
config=postprocess_config_loaded, config=postprocess_config_loaded,
) )
logger.debug( logger.debug(
"Loaded postprocessor from file {path}", "Loaded postprocessor from file {path}", path=postprocess_config
path=train_config,
) )
except IOError: except IOError:
logger.debug( logger.debug(

View File

@ -157,4 +157,4 @@ def load_config(
if field: if field:
config = get_object_field(config, field) config = get_object_field(config, field)
return schema.model_validate(config) return schema.model_validate(config or {})

View File

@ -8,6 +8,7 @@ from batdetect2.data.annotations import (
from batdetect2.data.datasets import ( from batdetect2.data.datasets import (
DatasetConfig, DatasetConfig,
load_dataset, load_dataset,
load_dataset_config,
load_dataset_from_config, load_dataset_from_config,
) )
@ -19,5 +20,6 @@ __all__ = [
"DatasetConfig", "DatasetConfig",
"load_annotated_dataset", "load_annotated_dataset",
"load_dataset", "load_dataset",
"load_dataset_config",
"load_dataset_from_config", "load_dataset_from_config",
] ]

View File

@ -161,6 +161,11 @@ def insert_source_tag(
) )
# TODO: add documentation
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
return load_config(path=path, schema=DatasetConfig, field=field)
def load_dataset_from_config( def load_dataset_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,

View File

@ -72,7 +72,7 @@ def iterate_over_sound_events(
sound_event_annotation sound_event_annotation
) )
class_name = targets.encode(sound_event_annotation) class_name = targets.encode_class(sound_event_annotation)
if class_name is None and exclude_generic: if class_name is None and exclude_generic:
continue continue

View File

@ -1,4 +1,3 @@
import numpy as np import numpy as np
from sklearn.metrics import auc, roc_curve from sklearn.metrics import auc, roc_curve

View File

@ -40,7 +40,7 @@ def match_sound_events_and_raw_predictions(
gt_uuid = target.uuid if target is not None else None gt_uuid = target.uuid if target is not None else None
gt_det = target is not None gt_det = target is not None
gt_class = targets.encode(target) if target is not None else None gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0 pred_score = float(prediction.detection_score) if prediction else 0

View File

@ -7,8 +7,8 @@ containing detected sound events with associated class tags and geometry.
The pipeline involves several configurable steps, implemented in submodules: The pipeline involves several configurable steps, implemented in submodules:
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks. 1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency 2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
coordinates to raw model output arrays. model output arrays.
3. Detection Extraction (`.detection`): Identifies candidate detection points 3. Detection Extraction (`.detection`): Identifies candidate detection points
(location and score) based on thresholds and score ranking (top-k). (location and score) based on thresholds and score ranking (top-k).
4. Data Extraction (`.extraction`): Gathers associated model outputs (size, 4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
@ -526,7 +526,7 @@ class Postprocessor(PostprocessorProtocol):
return [ return [
convert_xr_dataset_to_raw_prediction( convert_xr_dataset_to_raw_prediction(
dataset, dataset,
self.targets.recover_roi, self.targets.decode_roi,
) )
for dataset in detection_datasets for dataset in detection_datasets
] ]
@ -558,7 +558,7 @@ class Postprocessor(PostprocessorProtocol):
convert_raw_predictions_to_clip_prediction( convert_raw_predictions_to_clip_prediction(
prediction, prediction,
clip, clip,
sound_event_decoder=self.targets.decode, sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags, generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold, classification_threshold=self.config.classification_threshold,
) )

View File

@ -4,8 +4,7 @@ This module handles the final stages of the BatDetect2 postprocessing pipeline.
It takes the structured detection data extracted by the `extraction` module It takes the structured detection data extracted by the `extraction` module
(typically an `xarray.Dataset` containing scores, positions, predicted sizes, (typically an `xarray.Dataset` containing scores, positions, predicted sizes,
class probabilities, and features for each detection point) and converts it class probabilities, and features for each detection point) and converts it
into meaningful, standardized prediction objects based on the `soundevent` data into standardized prediction objects based on the `soundevent` data model.
model.
The process involves: The process involves:
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction` 1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
@ -33,7 +32,7 @@ import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder from batdetect2.targets.classes import SoundEventDecoder
from batdetect2.utils.arrays import iterate_over_array from batdetect2.utils.arrays import iterate_over_array
@ -55,7 +54,7 @@ decoding.
def convert_xr_dataset_to_raw_prediction( def convert_xr_dataset_to_raw_prediction(
detection_dataset: xr.Dataset, detection_dataset: xr.Dataset,
geometry_builder: GeometryBuilder, geometry_decoder: GeometryDecoder,
) -> List[RawPrediction]: ) -> List[RawPrediction]:
"""Convert an xarray.Dataset of detections to RawPrediction objects. """Convert an xarray.Dataset of detections to RawPrediction objects.
@ -72,7 +71,7 @@ def convert_xr_dataset_to_raw_prediction(
output by `extract_detection_xr_dataset`. Expected variables include output by `extract_detection_xr_dataset`. Expected variables include
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'. 'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
Must have a 'detection' dimension. Must have a 'detection' dimension.
geometry_builder : GeometryBuilder geometry_decoder : GeometryDecoder
A function that takes a position tuple `(time, freq)` and a NumPy array A function that takes a position tuple `(time, freq)` and a NumPy array
of dimensions, and returns the corresponding reconstructed of dimensions, and returns the corresponding reconstructed
`soundevent.data.Geometry`. `soundevent.data.Geometry`.
@ -96,14 +95,20 @@ def convert_xr_dataset_to_raw_prediction(
for det_num in range(detection_dataset.sizes["detection"]): for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num) det_info = detection_dataset.sel(detection=det_num)
geom = geometry_builder( # TODO: Maybe clean this up
highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax()
].item()
geom = geometry_decoder(
(det_info.time, det_info.frequency), (det_info.time, det_info.frequency),
det_info.dimensions, det_info.dimensions,
class_name=highest_scoring_class,
) )
detections.append( detections.append(
RawPrediction( RawPrediction(
detection_score=det_info.score, detection_score=det_info.scores,
geometry=geom, geometry=geom,
class_scores=det_info.classes, class_scores=det_info.classes,
features=det_info.features, features=det_info.features,

View File

@ -1,6 +1,6 @@
"""Extracts candidate detection points from a model output heatmap. """Extracts candidate detection points from a model output heatmap.
This module implements a specific step within the BatDetect2 postprocessing This module implements Step 3 within the BatDetect2 postprocessing
pipeline. Its primary function is to identify potential sound event locations pipeline. Its primary function is to identify potential sound event locations
by finding peaks (local maxima or high-scoring points) in the detection heatmap by finding peaks (local maxima or high-scoring points) in the detection heatmap
produced by the neural network (usually after Non-Maximum Suppression and produced by the neural network (usually after Non-Maximum Suppression and

View File

@ -1,9 +1,9 @@
"""Extracts associated data for detected points from model output arrays. """Extracts associated data for detected points from model output arrays.
This module implements a key step (Step 4) in the BatDetect2 postprocessing This module implements a Step 4 in the BatDetect2 postprocessing pipeline.
pipeline. After candidate detection points (time, frequency, score) have been After candidate detection points (time, frequency, score) have been identified,
identified, this module extracts the corresponding values from other raw model this module extracts the corresponding values from other raw model output
output arrays, such as: arrays, such as:
- Predicted bounding box sizes (width, height). - Predicted bounding box sizes (width, height).
- Class probability scores for each defined target class. - Class probability scores for each defined target class.

View File

@ -11,30 +11,37 @@ modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions. system that deal with model predictions.
""" """
from typing import Callable, List, NamedTuple, Protocol from typing import List, NamedTuple, Optional, Protocol
import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.targets.types import Position, Size
__all__ = [ __all__ = [
"RawPrediction", "RawPrediction",
"PostprocessorProtocol", "PostprocessorProtocol",
"GeometryBuilder", "GeometryDecoder",
] ]
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry] # TODO: update the docstring
"""Type alias for a function that recovers geometry from position and size. class GeometryDecoder(Protocol):
"""Type alias for a function that recovers geometry from position and size.
This callable takes: This callable takes:
1. A position tuple `(time, frequency)`. 1. A position tuple `(time, frequency)`.
2. A NumPy array of size dimensions (e.g., `[width, height]`). 2. A NumPy array of size dimensions (e.g., `[width, height]`).
It should return the reconstructed `soundevent.data.Geometry` (typically a 3. Optionally a class name of the highest scoring class. This is to accomodate
`BoundingBox`). different ways of decoding geometry that depend on the predicted class.
""" It should return the reconstructed `soundevent.data.Geometry` (typically a
`BoundingBox`).
"""
def __call__(
self, position: Position, size: Size, class_name: Optional[str] = None
) -> data.Geometry: ...
class RawPrediction(NamedTuple): class RawPrediction(NamedTuple):

View File

@ -23,7 +23,7 @@ object is via the `build_targets` or `load_targets` functions.
from typing import List, Optional from typing import List, Optional
import numpy as np from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -50,7 +50,8 @@ from batdetect2.targets.filtering import (
load_filter_from_config, load_filter_from_config,
) )
from batdetect2.targets.rois import ( from batdetect2.targets.rois import (
ROIConfig, AnchorBBoxMapperConfig,
ROIMapperConfig,
ROITargetMapper, ROITargetMapper,
build_roi_mapper, build_roi_mapper,
) )
@ -59,11 +60,11 @@ from batdetect2.targets.terms import (
TermInfo, TermInfo,
TermRegistry, TermRegistry,
call_type, call_type,
default_term_registry,
get_tag_from_info, get_tag_from_info,
get_term_from_key, get_term_from_key,
individual, individual,
register_term, register_term,
term_registry,
) )
from batdetect2.targets.transform import ( from batdetect2.targets.transform import (
DerivationRegistry, DerivationRegistry,
@ -73,13 +74,13 @@ from batdetect2.targets.transform import (
SoundEventTransformation, SoundEventTransformation,
TransformConfig, TransformConfig,
build_transformation_from_config, build_transformation_from_config,
derivation_registry, default_derivation_registry,
get_derivation, get_derivation,
load_transformation_config, load_transformation_config,
load_transformation_from_config, load_transformation_from_config,
register_derivation, register_derivation,
) )
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import Position, Size, TargetProtocol
__all__ = [ __all__ = [
"ClassesConfig", "ClassesConfig",
@ -88,7 +89,7 @@ __all__ = [
"FilterConfig", "FilterConfig",
"FilterRule", "FilterRule",
"MapValueRule", "MapValueRule",
"ROIConfig", "AnchorBBoxMapperConfig",
"ROITargetMapper", "ROITargetMapper",
"ReplaceRule", "ReplaceRule",
"SoundEventDecoder", "SoundEventDecoder",
@ -156,12 +157,12 @@ class TargetConfig(BaseConfig):
omitted, default ROI mapping settings are used. omitted, default ROI mapping settings are used.
""" """
filtering: Optional[FilterConfig] = None filtering: FilterConfig = Field(default_factory=FilterConfig)
transforms: Optional[TransformConfig] = None transforms: TransformConfig = Field(default_factory=TransformConfig)
classes: ClassesConfig = Field( classes: ClassesConfig = Field(
default_factory=lambda: DEFAULT_CLASSES_CONFIG default_factory=lambda: DEFAULT_CLASSES_CONFIG
) )
roi: Optional[ROIConfig] = None roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
def load_target_config( def load_target_config(
@ -240,6 +241,7 @@ class Targets(TargetProtocol):
generic_class_tags: List[data.Tag], generic_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventFilter] = None, filter_fn: Optional[SoundEventFilter] = None,
transform_fn: Optional[SoundEventTransformation] = None, transform_fn: Optional[SoundEventTransformation] = None,
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
): ):
"""Initialize the Targets object. """Initialize the Targets object.
@ -272,6 +274,16 @@ class Targets(TargetProtocol):
self._encode_fn = encode_fn self._encode_fn = encode_fn
self._decode_fn = decode_fn self._decode_fn = decode_fn
self._transform_fn = transform_fn self._transform_fn = transform_fn
self._roi_mapper_overrides = roi_mapper_overrides or {}
for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names:
# TODO: improve this warning
logger.warning(
"The ROI mapper overrides contains a class ({class_name}) "
"not present in the class names.",
class_name=class_name,
)
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation. """Apply the configured filter to a sound event annotation.
@ -291,7 +303,9 @@ class Targets(TargetProtocol):
return True return True
return self._filter_fn(sound_event) return self._filter_fn(sound_event)
def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]: def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
"""Encode a sound event annotation to its target class name. """Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority) Applies the configured class definition rules (including priority)
@ -312,7 +326,7 @@ class Targets(TargetProtocol):
""" """
return self._encode_fn(sound_event) return self._encode_fn(sound_event)
def decode(self, class_label: str) -> List[data.Tag]: def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags. """Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or Uses the configured mapping (based on `TargetClass.output_tags` or
@ -352,9 +366,9 @@ class Targets(TargetProtocol):
return self._transform_fn(sound_event) return self._transform_fn(sound_event)
return sound_event return sound_event
def get_position( def encode_roi(
self, sound_event: data.SoundEventAnnotation self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]: ) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's roi. """Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method. Delegates to the internal ROI mapper's `get_roi_position` method.
@ -374,50 +388,20 @@ class Targets(TargetProtocol):
ValueError ValueError
If the annotation lacks geometry. If the annotation lacks geometry.
""" """
geom = sound_event.sound_event.geometry class_name = self.encode_class(sound_event)
if geom is None: if class_name in self._roi_mapper_overrides:
raise ValueError( return self._roi_mapper_overrides[class_name].encode(
"Sound event has no geometry, cannot get its position." sound_event.sound_event
) )
return self._roi_mapper.get_roi_position(geom) return self._roi_mapper.encode(sound_event.sound_event)
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray: def decode_roi(
"""Calculate the target size dimensions from the annotation's geometry.
Delegates to the internal ROI mapper's `get_roi_size` method, which
applies configured scaling factors.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
np.ndarray
NumPy array containing the size dimensions, matching the
order in `self.dimension_names` (e.g., `[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry.
"""
geom = sound_event.sound_event.geometry
if geom is None:
raise ValueError(
"Sound event has no geometry, cannot get its size."
)
return self._roi_mapper.get_roi_size(geom)
def recover_roi(
self, self,
pos: tuple[float, float], position: Position,
dims: np.ndarray, size: Size,
class_name: Optional[str] = None,
) -> data.Geometry: ) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions. """Recover an approximate geometric ROI from a position and dimensions.
@ -438,7 +422,13 @@ class Targets(TargetProtocol):
data.Geometry data.Geometry
The reconstructed geometry (typically `BoundingBox`). The reconstructed geometry (typically `BoundingBox`).
""" """
return self._roi_mapper.recover_roi(pos, dims) if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].decode(
position,
size,
)
return self._roi_mapper.decode(position, size)
DEFAULT_CLASSES = [ DEFAULT_CLASSES = [
@ -493,10 +483,12 @@ DEFAULT_CLASSES = [
TargetClass( TargetClass(
tags=[TagInfo(value="Nyctalus leisleri")], tags=[TagInfo(value="Nyctalus leisleri")],
name="nyclei", name="nyclei",
roi=AnchorBBoxMapperConfig(anchor="top-left"),
), ),
TargetClass( TargetClass(
tags=[TagInfo(value="Rhinolophus ferrumequinum")], tags=[TagInfo(value="Rhinolophus ferrumequinum")],
name="rhifer", name="rhifer",
roi=AnchorBBoxMapperConfig(anchor="top-left"),
), ),
TargetClass( TargetClass(
tags=[TagInfo(value="Plecotus auritus")], tags=[TagInfo(value="Plecotus auritus")],
@ -537,13 +529,14 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
] ]
), ),
classes=DEFAULT_CLASSES_CONFIG, classes=DEFAULT_CLASSES_CONFIG,
roi=AnchorBBoxMapperConfig(),
) )
def build_targets( def build_targets(
config: Optional[TargetConfig] = None, config: Optional[TargetConfig] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets: ) -> Targets:
"""Build a Targets object from a loaded TargetConfig. """Build a Targets object from a loaded TargetConfig.
@ -606,12 +599,17 @@ def build_targets(
if config.transforms if config.transforms
else None else None
) )
roi_mapper = build_roi_mapper(config.roi or ROIConfig()) roi_mapper = build_roi_mapper(config.roi)
class_names = get_class_names_from_config(config.classes) class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags( generic_class_tags = build_generic_class_tags(
config.classes, config.classes,
term_registry=term_registry, term_registry=term_registry,
) )
roi_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classes.classes
if class_config.roi is not None
}
return Targets( return Targets(
filter_fn=filter_fn, filter_fn=filter_fn,
@ -621,14 +619,15 @@ def build_targets(
roi_mapper=roi_mapper, roi_mapper=roi_mapper,
generic_class_tags=generic_class_tags, generic_class_tags=generic_class_tags,
transform_fn=transform_fn, transform_fn=transform_fn,
roi_mapper_overrides=roi_overrides,
) )
def load_targets( def load_targets(
config_path: data.PathLike, config_path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: DerivationRegistry = default_derivation_registry,
) -> Targets: ) -> Targets:
"""Load a Targets object directly from a configuration file. """Load a Targets object directly from a configuration file.

View File

@ -6,29 +6,27 @@ from pydantic import Field, field_validator
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import ( from batdetect2.targets.terms import (
GENERIC_CLASS_KEY, GENERIC_CLASS_KEY,
TagInfo, TagInfo,
TermRegistry, TermRegistry,
default_term_registry,
get_tag_from_info, get_tag_from_info,
term_registry,
) )
__all__ = [ __all__ = [
"SoundEventEncoder",
"SoundEventDecoder",
"TargetClass",
"ClassesConfig",
"load_classes_config",
"load_encoder_from_config",
"load_decoder_from_config",
"build_sound_event_encoder",
"build_sound_event_decoder",
"build_generic_class_tags",
"get_class_names_from_config",
"DEFAULT_SPECIES_LIST", "DEFAULT_SPECIES_LIST",
"build_generic_class_tags",
"build_sound_event_decoder",
"build_sound_event_encoder",
"get_class_names_from_config",
"load_classes_config",
"load_decoder_from_config",
"load_encoder_from_config",
] ]
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]] SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Type alias for a sound event class encoder function. """Type alias for a sound event class encoder function.
@ -113,6 +111,7 @@ class TargetClass(BaseConfig):
tags: List[TagInfo] = Field(min_length=1) tags: List[TagInfo] = Field(min_length=1)
match_type: Literal["all", "any"] = Field(default="all") match_type: Literal["all", "any"] = Field(default="all")
output_tags: Optional[List[TagInfo]] = None output_tags: Optional[List[TagInfo]] = None
roi: Optional[ROIMapperConfig] = None
def _get_default_classes() -> List[TargetClass]: def _get_default_classes() -> List[TargetClass]:
@ -235,7 +234,7 @@ class ClassesConfig(BaseConfig):
return v return v
def _is_target_class( def is_target_class(
sound_event_annotation: data.SoundEventAnnotation, sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag], tags: Set[data.Tag],
match_all: bool = True, match_all: bool = True,
@ -316,7 +315,7 @@ def _encode_with_multiple_classifiers(
def build_sound_event_encoder( def build_sound_event_encoder(
config: ClassesConfig, config: ClassesConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> SoundEventEncoder: ) -> SoundEventEncoder:
"""Build a sound event encoder function from the classes configuration. """Build a sound event encoder function from the classes configuration.
@ -350,7 +349,7 @@ def build_sound_event_encoder(
( (
class_info.name, class_info.name,
partial( partial(
_is_target_class, is_target_class,
tags={ tags={
get_tag_from_info(tag_info, term_registry=term_registry) get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in class_info.tags for tag_info in class_info.tags
@ -410,7 +409,7 @@ def _decode_class(
def build_sound_event_decoder( def build_sound_event_decoder(
config: ClassesConfig, config: ClassesConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
raise_on_unmapped: bool = False, raise_on_unmapped: bool = False,
) -> SoundEventDecoder: ) -> SoundEventDecoder:
"""Build a sound event decoder function from the classes configuration. """Build a sound event decoder function from the classes configuration.
@ -465,7 +464,7 @@ def build_sound_event_decoder(
def build_generic_class_tags( def build_generic_class_tags(
config: ClassesConfig, config: ClassesConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> List[data.Tag]: ) -> List[data.Tag]:
"""Extract and build the list of tags for the generic class from config. """Extract and build the list of tags for the generic class from config.
@ -530,7 +529,7 @@ def load_classes_config(
def load_encoder_from_config( def load_encoder_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> SoundEventEncoder: ) -> SoundEventEncoder:
"""Load a class encoder function directly from a configuration file. """Load a class encoder function directly from a configuration file.
@ -571,7 +570,7 @@ def load_encoder_from_config(
def load_decoder_from_config( def load_decoder_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
raise_on_unmapped: bool = False, raise_on_unmapped: bool = False,
) -> SoundEventDecoder: ) -> SoundEventDecoder:
"""Load a class decoder function directly from a configuration file. """Load a class decoder function directly from a configuration file.

View File

@ -10,7 +10,7 @@ from batdetect2.targets.terms import (
TagInfo, TagInfo,
TermRegistry, TermRegistry,
get_tag_from_info, get_tag_from_info,
term_registry, default_term_registry,
) )
__all__ = [ __all__ = [
@ -156,7 +156,7 @@ def equal_tags(
def build_filter_from_rule( def build_filter_from_rule(
rule: FilterRule, rule: FilterRule,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter: ) -> SoundEventFilter:
"""Creates a callable filter function from a single FilterRule. """Creates a callable filter function from a single FilterRule.
@ -243,7 +243,7 @@ class FilterConfig(BaseConfig):
def build_sound_event_filter( def build_sound_event_filter(
config: FilterConfig, config: FilterConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter: ) -> SoundEventFilter:
"""Builds a merged filter function from a FilterConfig object. """Builds a merged filter function from a FilterConfig object.
@ -291,7 +291,7 @@ def load_filter_config(
def load_filter_from_config( def load_filter_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> SoundEventFilter: ) -> SoundEventFilter:
"""Loads filter configuration from a file and builds the filter function. """Loads filter configuration from a file and builds the filter function.

View File

@ -0,0 +1,684 @@
"""Handles mapping between geometric ROIs and target representations.
This module defines a standardized interface (`ROITargetMapper`) for converting
a sound event's Region of Interest (ROI) into a target representation suitable
for machine learning models, and for decoding model outputs back into geometric
ROIs.
The core operations are:
1. **Encoding**: A `soundevent.data.SoundEvent` is mapped to a reference
`Position` (time, frequency) and a `Size` array. The method for
determining the position and size varies by the mapper implementation
(e.g., using a bounding box anchor or the point of peak energy).
2. **Decoding**: A `Position` and `Size` array are mapped back to an
approximate `soundevent.data.Geometry` (typically a `BoundingBox`).
This logic is encapsulated within specific mapper classes. Configuration for
each mapper (e.g., anchor point, scaling factors) is managed by a corresponding
Pydantic config object. The `ROIMapperConfig` type allows for flexibly
selecting and configuring the desired mapper. This module separates the
*geometric* aspect of target definition from *semantic* classification.
"""
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
import numpy as np
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import Position, Size
__all__ = [
"Anchor",
"AnchorBBoxMapper",
"AnchorBBoxMapperConfig",
"DEFAULT_ANCHOR",
"DEFAULT_FREQUENCY_SCALE",
"DEFAULT_TIME_SCALE",
"PeakEnergyBBoxMapper",
"PeakEnergyBBoxMapperConfig",
"ROIMapperConfig",
"ROITargetMapper",
"SIZE_HEIGHT",
"SIZE_ORDER",
"SIZE_WIDTH",
"build_roi_mapper",
]
Anchor = Literal[
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
SIZE_WIDTH = "width"
"""Standard name for the width/time dimension component ('width')."""
SIZE_HEIGHT = "height"
"""Standard name for the height/frequency dimension component ('height')."""
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
"""Standard order of dimensions for size arrays ([width, height])."""
DEFAULT_TIME_SCALE = 1000.0
"""Default scaling factor for time duration."""
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
"""Default scaling factor for frequency bandwidth."""
DEFAULT_ANCHOR = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner)."""
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the `encode` and `decode` methods required for converting a
`soundevent.data.SoundEvent` into a target representation (a reference
position and a size vector) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions in the `Size` array
returned by `encode` and expected by `decode`.
"""
dimension_names: List[str]
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Encode a SoundEvent's geometry into a position and size.
Parameters
----------
sound_event : data.SoundEvent
The input sound event, which must have a geometry attribute.
Returns
-------
Tuple[Position, Size]
A tuple containing:
- The reference position as (time, frequency) coordinates.
- A NumPy array with the calculated size dimensions.
Raises
------
ValueError
If the sound event does not have a geometry.
"""
...
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Decode a position and size back into a geometric ROI.
Performs the inverse mapping: takes a reference position and size
dimensions and reconstructs a geometric representation.
Parameters
----------
position : Position
The reference position (time, frequency).
size : Size
NumPy array containing the size dimensions, matching the order
and meaning specified by `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry, typically a `BoundingBox`.
Raises
------
ValueError
If the `size` array has an unexpected shape or if reconstruction
fails.
"""
...
class AnchorBBoxMapperConfig(BaseConfig):
"""Configuration for `AnchorBBoxMapper`.
Defines parameters for converting ROIs into targets using a fixed anchor
point on the bounding box.
Attributes
----------
name : Literal["anchor_bbox"]
The unique identifier for this mapper type.
anchor : Anchor
Specifies the anchor point within the bounding box to use as the
target's reference position (e.g., "center", "bottom-left").
time_scale : float
Scaling factor applied to the time duration (width) of the ROI.
frequency_scale : float
Scaling factor applied to the frequency bandwidth (height) of the ROI.
"""
name: Literal["anchor_bbox"] = "anchor_bbox"
anchor: Anchor = DEFAULT_ANCHOR
time_scale: float = DEFAULT_TIME_SCALE
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
class AnchorBBoxMapper(ROITargetMapper):
"""Maps ROIs using a bounding box anchor point and width/height.
This class implements the `ROITargetMapper` protocol for `BoundingBox`
geometries.
**Encoding**: The `position` is a fixed anchor point on the bounding box
(e.g., "bottom-left"). The `size` is a 2-element array containing the
scaled width and height of the box.
**Decoding**: Reconstructs a `BoundingBox` from an anchor point and
scaled width/height.
Attributes
----------
dimension_names : List[str]
The output dimension names: `['width', 'height']`.
anchor : Anchor
The configured anchor point type (e.g., "center", "bottom-left").
time_scale : float
The scaling factor for the time dimension (width).
frequency_scale : float
The scaling factor for the frequency dimension (height).
"""
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
def __init__(
self,
anchor: Anchor = DEFAULT_ANCHOR,
time_scale: float = DEFAULT_TIME_SCALE,
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
):
"""Initialize the BBoxEncoder.
Parameters
----------
anchor : Anchor
Reference point type within the bounding box.
time_scale : float
Scaling factor for time duration (width).
frequency_scale : float
Scaling factor for frequency bandwidth (height).
"""
self.anchor: Anchor = anchor
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
"""Encode a SoundEvent into an anchor position and scaled box size.
The position is determined by the configured anchor on the sound
event's bounding box. The size is the scaled width and height.
Parameters
----------
sound_event : data.SoundEvent
The input sound event with a geometry.
Returns
-------
Tuple[Position, Size]
A tuple of (anchor_position, [scaled_width, scaled_height]).
"""
from soundevent import geometry
geom = sound_event.geometry
if geom is None:
raise ValueError(
"Cannot encode the geometry of a sound event without geometry."
f" Sound event: {sound_event}"
)
position = geometry.get_geometry_point(geom, position=self.anchor)
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
size = np.array(
[
(end_time - start_time) * self.time_scale,
(high_freq - low_freq) * self.frequency_scale,
]
)
return position, size
def decode(
self,
position: Position,
size: Size,
) -> data.Geometry:
"""Recover a BoundingBox from an anchor position and scaled size.
Un-scales the input dimensions and reconstructs a
`soundevent.data.BoundingBox` relative to the given anchor position.
Parameters
----------
position : Position
Reference anchor position (time, frequency).
size : Size
NumPy array containing the scaled [width, height].
Returns
-------
data.BoundingBox
The reconstructed bounding box.
Raises
------
ValueError
If `size` does not have the expected shape (length 2).
"""
if size.ndim != 1 or size.shape[0] != 2:
raise ValueError(
"Dimension array does not have the expected shape. "
f"({size.shape = }) != ([2])"
)
width, height = size
return _build_bounding_box(
position,
duration=float(width) / self.time_scale,
bandwidth=float(height) / self.frequency_scale,
anchor=self.anchor,
)
class PeakEnergyBBoxMapperConfig(BaseConfig):
"""Configuration for `PeakEnergyBBoxMapper`.
Attributes
----------
name : Literal["peak_energy_bbox"]
The unique identifier for this mapper type.
preprocessing : PreprocessingConfig
Configuration for the spectrogram preprocessor needed to find the
peak energy.
loading_buffer : float
Seconds to add to each side of the ROI when loading audio to ensure
the peak is captured accurately, avoiding boundary effects.
time_scale : float
Scaling factor applied to the time dimensions.
frequency_scale : float
Scaling factor applied to the frequency dimensions.
"""
name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
loading_buffer: float = 0.01
time_scale: float = DEFAULT_TIME_SCALE
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
class PeakEnergyBBoxMapper(ROITargetMapper):
"""Maps ROIs using the peak energy point and distances to edges.
This class implements the `ROITargetMapper` protocol.
**Encoding**: The `position` is the (time, frequency) coordinate of the
point with the highest energy within the sound event's bounding box. The
`size` is a 4-element array representing the scaled distances from this
peak energy point to the left, bottom, right, and top edges of the box.
**Decoding**: Reconstructs a `BoundingBox` by adding/subtracting the
un-scaled distances from the peak energy point.
Attributes
----------
dimension_names : List[str]
The output dimension names: `['left', 'bottom', 'right', 'top']`.
preprocessor : PreprocessorProtocol
The spectrogram preprocessor instance.
time_scale : float
The scaling factor for time-based distances.
frequency_scale : float
The scaling factor for frequency-based distances.
loading_buffer : float
The buffer used for loading audio around the ROI.
"""
dimension_names = ["left", "bottom", "right", "top"]
def __init__(
self,
preprocessor: PreprocessorProtocol,
time_scale: float = DEFAULT_TIME_SCALE,
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
loading_buffer: float = 0.01,
):
"""Initialize the PeakEnergyBBoxMapper.
Parameters
----------
preprocessor : PreprocessorProtocol
An initialized preprocessor for generating spectrograms.
time_scale : float
Scaling factor for time dimensions (left, right distances).
frequency_scale : float
Scaling factor for frequency dimensions (bottom, top distances).
loading_buffer : float
Buffer in seconds to add when loading audio clips.
"""
self.preprocessor = preprocessor
self.time_scale = time_scale
self.frequency_scale = frequency_scale
self.loading_buffer = loading_buffer
def encode(
self,
sound_event: data.SoundEvent,
) -> tuple[Position, Size]:
"""Encode a SoundEvent into a peak energy position and edge distances.
Finds the peak energy coordinates within the event's bounding box
and calculates the scaled distances from this point to the box edges.
Parameters
----------
sound_event : data.SoundEvent
The input sound event with a geometry and associated recording.
Returns
-------
Tuple[Position, Size]
A tuple of (peak_position, [l, b, r, t] distances).
"""
from soundevent import geometry
geom = sound_event.geometry
if geom is None:
raise ValueError(
"Cannot encode the geometry of a sound event without geometry."
f" Sound event: {sound_event}"
)
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
time, freq = get_peak_energy_coordinates(
recording=sound_event.recording,
preprocessor=self.preprocessor,
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
loading_buffer=self.loading_buffer,
)
size = np.array(
[
(time - start_time) * self.time_scale,
(freq - low_freq) * self.frequency_scale,
(end_time - time) * self.time_scale,
(high_freq - freq) * self.frequency_scale,
]
)
return (time, freq), size
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Recover a BoundingBox from a peak position and edge distances.
Parameters
----------
position : Position
The reference peak energy position (time, frequency).
size : Size
NumPy array with scaled distances [left, bottom, right, top].
Returns
-------
data.BoundingBox
The reconstructed bounding box.
"""
time, freq = position
left, bottom, right, top = size
return data.BoundingBox(
coordinates=[
time - max(0, float(left)) / self.time_scale,
freq - max(0, float(bottom)) / self.frequency_scale,
time + max(0, float(right)) / self.time_scale,
freq + max(0, float(top)) / self.frequency_scale,
]
)
ROIMapperConfig = Annotated[
Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
Field(discriminator="name"),
]
"""A discriminated union of all supported ROI mapper configurations.
This type allows for selecting and configuring different `ROITargetMapper`
implementations by using the `name` field as a discriminator.
"""
def build_roi_mapper(
config: Optional[ROIMapperConfig] = None,
) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from a config object.
Parameters
----------
config : ROIMapperConfig
A configuration object specifying the mapper type and its parameters.
Returns
-------
ROITargetMapper
An initialized ROI mapper instance.
Raises
------
NotImplementedError
If the `name` in the config does not correspond to a known mapper.
"""
config = config or AnchorBBoxMapperConfig()
if config.name == "anchor_bbox":
return AnchorBBoxMapper(
anchor=config.anchor,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
if config.name == "peak_energy_bbox":
preprocessor = build_preprocessor(config.preprocessing)
return PeakEnergyBBoxMapper(
preprocessor=preprocessor,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
loading_buffer=config.loading_buffer,
)
raise NotImplementedError(
f"No ROI mapper of name '{config.name}' is implemented"
)
VALID_ANCHORS = [
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
def _build_bounding_box(
pos: tuple[float, float],
duration: float,
bandwidth: float,
anchor: Anchor = DEFAULT_ANCHOR,
) -> data.BoundingBox:
"""Construct a BoundingBox from a reference point, size, and position type.
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
coordinates [start_time, low_freq, end_time, high_freq] based on where
the input `pos` (time, freq) is located relative to the box (e.g.,
center, corner).
Parameters
----------
pos : Tuple[float, float]
Reference position (time, frequency).
duration : float
The required *unscaled* duration (width) of the bounding box.
bandwidth : float
The required *unscaled* frequency bandwidth (height) of the bounding
box.
anchor : Anchor
Specifies which part of the bounding box the input `pos` corresponds to.
Returns
-------
data.BoundingBox
The constructed bounding box object.
Raises
------
ValueError
If `anchor` is not a recognized value or format.
"""
time, freq = map(float, pos)
duration = max(0, duration)
bandwidth = max(0, bandwidth)
if anchor in ["center", "centroid", "point_on_surface"]:
return data.BoundingBox(
coordinates=[
max(time - duration / 2, 0),
max(freq - bandwidth / 2, 0),
max(time + duration / 2, 0),
max(freq + bandwidth / 2, 0),
]
)
if anchor not in VALID_ANCHORS:
raise ValueError(
f"Invalid anchor: {anchor}. Valid options are: {VALID_ANCHORS}"
)
y, x = anchor.split("-")
start_time = {
"left": time,
"center": time - duration / 2,
"right": time - duration,
}[x]
low_freq = {
"bottom": freq,
"center": freq - bandwidth / 2,
"top": freq - bandwidth,
}[y]
return data.BoundingBox(
coordinates=[
max(0, start_time),
max(0, low_freq),
max(0, start_time + duration),
max(0, low_freq + bandwidth),
]
)
def get_peak_energy_coordinates(
recording: data.Recording,
preprocessor: PreprocessorProtocol,
start_time: float = 0,
end_time: Optional[float] = None,
low_freq: float = 0,
high_freq: Optional[float] = None,
loading_buffer: float = 0.05,
) -> Position:
"""Find the coordinates of the highest energy point in a spectrogram.
Generates a spectrogram for a specified time-frequency region of a
recording and returns the (time, frequency) coordinates of the pixel with
the maximum value.
Parameters
----------
recording : data.Recording
The recording to analyze.
preprocessor : PreprocessorProtocol
The processor to convert audio to a spectrogram.
start_time : float, default=0
The start time of the region of interest.
end_time : float, optional
The end time of the region of interest. Defaults to recording duration.
low_freq : float, default=0
The low frequency of the region of interest.
high_freq : float, optional
The high frequency of the region of interest. Defaults to Nyquist.
loading_buffer : float, default=0.05
Buffer in seconds to add around the time range when loading the clip
to mitigate border effects from transformations like STFT.
Returns
-------
Position
A (time, frequency) tuple for the peak energy location.
"""
if end_time is None:
end_time = recording.duration
end_time = min(end_time, recording.duration)
if high_freq is None:
high_freq = recording.samplerate / 2
clip_start = max(0, start_time - loading_buffer)
clip_end = min(recording.duration, end_time + loading_buffer)
clip = data.Clip(
recording=recording,
start_time=clip_start,
end_time=clip_end,
)
spec = preprocessor.preprocess_clip(clip)
low_freq = max(low_freq, preprocessor.min_freq)
high_freq = min(high_freq, preprocessor.max_freq)
selection = spec.sel(
time=slice(start_time, end_time),
frequency=slice(low_freq, high_freq),
)
index = selection.argmax(dim=["time", "frequency"])
point = selection.isel(index) # type: ignore
peak_time: float = point.time.item()
peak_freq: float = point.frequency.item()
return peak_time, peak_freq

View File

@ -230,11 +230,12 @@ class TermRegistry(Mapping[str, data.Term]):
del self._terms[key] del self._terms[key]
term_registry = TermRegistry( default_term_registry = TermRegistry(
terms=dict( terms=dict(
[ [
*getmembers(terms, lambda x: isinstance(x, data.Term)), *getmembers(terms, lambda x: isinstance(x, data.Term)),
("event", call_type), ("event", call_type),
("species", terms.scientific_name),
("individual", individual), ("individual", individual),
("data_source", data_source), ("data_source", data_source),
(GENERIC_CLASS_KEY, generic_class), (GENERIC_CLASS_KEY, generic_class),
@ -252,7 +253,7 @@ is explicitly passed.
def get_term_from_key( def get_term_from_key(
key: str, key: str,
term_registry: TermRegistry = term_registry, term_registry: Optional[TermRegistry] = None,
) -> data.Term: ) -> data.Term:
"""Convenience function to retrieve a term by key from a registry. """Convenience function to retrieve a term by key from a registry.
@ -277,10 +278,13 @@ def get_term_from_key(
KeyError KeyError
If the key is not found in the specified registry. If the key is not found in the specified registry.
""" """
term_registry = term_registry or default_term_registry
return term_registry.get_term(key) return term_registry.get_term(key)
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]: def get_term_keys(
term_registry: TermRegistry = default_term_registry,
) -> List[str]:
"""Convenience function to get all registered keys from a registry. """Convenience function to get all registered keys from a registry.
Uses the global default registry unless a specific `term_registry` Uses the global default registry unless a specific `term_registry`
@ -299,7 +303,9 @@ def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
return term_registry.get_keys() return term_registry.get_keys()
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]: def get_terms(
term_registry: TermRegistry = default_term_registry,
) -> List[data.Term]:
"""Convenience function to get all registered terms from a registry. """Convenience function to get all registered terms from a registry.
Uses the global default registry unless a specific `term_registry` Uses the global default registry unless a specific `term_registry`
@ -342,7 +348,7 @@ class TagInfo(BaseModel):
def get_tag_from_info( def get_tag_from_info(
tag_info: TagInfo, tag_info: TagInfo,
term_registry: TermRegistry = term_registry, term_registry: Optional[TermRegistry] = None,
) -> data.Tag: ) -> data.Tag:
"""Creates a soundevent.data.Tag object from TagInfo data. """Creates a soundevent.data.Tag object from TagInfo data.
@ -368,6 +374,7 @@ def get_tag_from_info(
If the term key specified in `tag_info.key` is not found If the term key specified in `tag_info.key` is not found
in the registry. in the registry.
""" """
term_registry = term_registry or default_term_registry
term = get_term_from_key(tag_info.key, term_registry=term_registry) term = get_term_from_key(tag_info.key, term_registry=term_registry)
return data.Tag(term=term, value=tag_info.value) return data.Tag(term=term, value=tag_info.value)
@ -439,7 +446,7 @@ class TermConfig(BaseModel):
def load_terms_from_config( def load_terms_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = default_term_registry,
) -> Dict[str, data.Term]: ) -> Dict[str, data.Term]:
"""Loads term definitions from a configuration file and registers them. """Loads term definitions from a configuration file and registers them.
@ -490,6 +497,6 @@ def load_terms_from_config(
def register_term( def register_term(
key: str, term: data.Term, registry: TermRegistry = term_registry key: str, term: data.Term, registry: TermRegistry = default_term_registry
) -> None: ) -> None:
registry.add_term(key, term) registry.add_term(key, term)

View File

@ -21,9 +21,6 @@ from batdetect2.targets.terms import (
get_tag_from_info, get_tag_from_info,
get_term_from_key, get_term_from_key,
) )
from batdetect2.targets.terms import (
term_registry as default_term_registry,
)
__all__ = [ __all__ = [
"DerivationRegistry", "DerivationRegistry",
@ -34,7 +31,7 @@ __all__ = [
"TransformConfig", "TransformConfig",
"build_transform_from_rule", "build_transform_from_rule",
"build_transformation_from_config", "build_transformation_from_config",
"derivation_registry", "default_derivation_registry",
"get_derivation", "get_derivation",
"load_transformation_config", "load_transformation_config",
"load_transformation_from_config", "load_transformation_from_config",
@ -398,7 +395,7 @@ class DerivationRegistry(Mapping[str, Derivation]):
return list(self._derivations.values()) return list(self._derivations.values())
derivation_registry = DerivationRegistry() default_derivation_registry = DerivationRegistry()
"""Global instance of the DerivationRegistry. """Global instance of the DerivationRegistry.
Register custom derivation functions here to make them available by key Register custom derivation functions here to make them available by key
@ -409,7 +406,7 @@ in `DeriveTagRule` configuration.
def get_derivation( def get_derivation(
key: str, key: str,
import_derivation: bool = False, import_derivation: bool = False,
registry: DerivationRegistry = derivation_registry, registry: Optional[DerivationRegistry] = None,
): ):
"""Retrieve a derivation function by key, optionally importing it. """Retrieve a derivation function by key, optionally importing it.
@ -443,6 +440,8 @@ def get_derivation(
AttributeError AttributeError
If dynamic import fails because the function name isn't in the module. If dynamic import fails because the function name isn't in the module.
""" """
registry = registry or default_derivation_registry
if not import_derivation or key in registry: if not import_derivation or key in registry:
return registry.get_derivation(key) return registry.get_derivation(key)
@ -458,10 +457,16 @@ def get_derivation(
) from err ) from err
TranformationRule = Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
def build_transform_from_rule( def build_transform_from_rule(
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule], rule: TranformationRule,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: Optional[DerivationRegistry] = None,
term_registry: TermRegistry = default_term_registry, term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation: ) -> SoundEventTransformation:
"""Build a specific SoundEventTransformation function from a rule config. """Build a specific SoundEventTransformation function from a rule config.
@ -559,8 +564,8 @@ def build_transform_from_rule(
def build_transformation_from_config( def build_transformation_from_config(
config: TransformConfig, config: TransformConfig,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: Optional[DerivationRegistry] = None,
term_registry: TermRegistry = default_term_registry, term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation: ) -> SoundEventTransformation:
"""Build a composite transformation function from a TransformConfig. """Build a composite transformation function from a TransformConfig.
@ -581,6 +586,7 @@ def build_transformation_from_config(
SoundEventTransformation SoundEventTransformation
A single function that applies all configured transformations in order. A single function that applies all configured transformations in order.
""" """
transforms = [ transforms = [
build_transform_from_rule( build_transform_from_rule(
rule, rule,
@ -590,14 +596,16 @@ def build_transformation_from_config(
for rule in config.rules for rule in config.rules
] ]
def transformation( return partial(apply_sequence_of_transforms, transforms=transforms)
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
for transform in transforms:
sound_event_annotation = transform(sound_event_annotation)
return sound_event_annotation
return transformation
def apply_sequence_of_transforms(
sound_event_annotation: data.SoundEventAnnotation,
transforms: list[SoundEventTransformation],
) -> data.SoundEventAnnotation:
for transform in transforms:
sound_event_annotation = transform(sound_event_annotation)
return sound_event_annotation
def load_transformation_config( def load_transformation_config(
@ -631,8 +639,8 @@ def load_transformation_config(
def load_transformation_from_config( def load_transformation_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: Optional[str] = None,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: Optional[DerivationRegistry] = None,
term_registry: TermRegistry = default_term_registry, term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation: ) -> SoundEventTransformation:
"""Load transformation config from a file and build the final function. """Load transformation config from a file and build the final function.
@ -677,7 +685,7 @@ def load_transformation_from_config(
def register_derivation( def register_derivation(
key: str, key: str,
derivation: Derivation, derivation: Derivation,
derivation_registry: DerivationRegistry = derivation_registry, derivation_registry: Optional[DerivationRegistry] = None,
) -> None: ) -> None:
"""Register a new derivation function in the global registry. """Register a new derivation function in the global registry.
@ -696,4 +704,5 @@ def register_derivation(
KeyError KeyError
If a derivation function with the same key is already registered. If a derivation function with the same key is already registered.
""" """
derivation_registry = derivation_registry or default_derivation_registry
derivation_registry.register(key, derivation) derivation_registry.register(key, derivation)

View File

@ -19,8 +19,16 @@ from soundevent import data
__all__ = [ __all__ = [
"TargetProtocol", "TargetProtocol",
"Position",
"Size",
] ]
Position = tuple[float, float]
"""A tuple representing (time, frequency) coordinates."""
Size = np.ndarray
"""A NumPy array representing the size dimensions of a target."""
class TargetProtocol(Protocol): class TargetProtocol(Protocol):
"""Protocol defining the interface for the target definition pipeline. """Protocol defining the interface for the target definition pipeline.
@ -102,7 +110,7 @@ class TargetProtocol(Protocol):
""" """
... ...
def encode( def encode_class(
self, self,
sound_event: data.SoundEventAnnotation, sound_event: data.SoundEventAnnotation,
) -> Optional[str]: ) -> Optional[str]:
@ -123,7 +131,7 @@ class TargetProtocol(Protocol):
""" """
... ...
def decode(self, class_label: str) -> List[data.Tag]: def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags. """Decode a predicted class name back into representative tags.
Parameters Parameters
@ -147,9 +155,9 @@ class TargetProtocol(Protocol):
""" """
... ...
def get_position( def encode_roi(
self, sound_event: data.SoundEventAnnotation self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]: ) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's geometry. """Extract the target reference position from the annotation's geometry.
Calculates the `(time, frequency)` coordinate representing the primary Calculates the `(time, frequency)` coordinate representing the primary
@ -173,36 +181,12 @@ class TargetProtocol(Protocol):
""" """
... ...
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray: # TODO: Update docstrings
"""Calculate the target size dimensions from the annotation's geometry. def decode_roi(
self,
Computes the relevant physical size (e.g., duration/width, position: Position,
bandwidth/height from a bounding box) to produce size: Size,
the numerical target values expected by the model. class_name: Optional[str] = None,
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
np.ndarray
A NumPy array containing the size dimensions, matching the
order specified by the `dimension_names` attribute (e.g.,
`[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry or if the size cannot be computed.
TypeError
If geometry type is unsupported.
"""
...
def recover_roi(
self, pos: tuple[float, float], dims: np.ndarray
) -> data.Geometry: ) -> data.Geometry:
"""Recover the ROI geometry from a position and dimensions. """Recover the ROI geometry from a position and dimensions.
@ -217,6 +201,8 @@ class TargetProtocol(Protocol):
dims : np.ndarray dims : np.ndarray
The NumPy array containing the dimensions (e.g., predicted The NumPy array containing the dimensions (e.g., predicted
by the model), corresponding to the order in `dimension_names`. by the model), corresponding to the order in `dimension_names`.
class_name: str
class
Returns Returns
------- -------

View File

@ -97,7 +97,7 @@ def _is_in_subclip(
start_time: float, start_time: float,
end_time: float, end_time: float,
) -> bool: ) -> bool:
time, _ = targets.get_position(sound_event_annotation) (time, _), _ = targets.encode_roi(sound_event_annotation)
return start_time <= time <= end_time return start_time <= time <= end_time

View File

@ -138,7 +138,7 @@ def generate_clip_label(
logger.debug( logger.debug(
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events", "Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
uuid=clip_annotation.uuid, uuid=clip_annotation.uuid,
num=len(clip_annotation.sound_events) num=len(clip_annotation.sound_events),
) )
sound_events = [] sound_events = []
@ -260,7 +260,7 @@ def generate_heatmaps(
continue continue
# Get the position of the sound event # Get the position of the sound event
time, frequency = targets.get_position(sound_event_annotation) (time, frequency), size = targets.encode_roi(sound_event_annotation)
# Set 1.0 at the position of the sound event in the detection heatmap # Set 1.0 at the position of the sound event in the detection heatmap
try: try:
@ -280,8 +280,6 @@ def generate_heatmaps(
) )
continue continue
size = targets.get_size(sound_event_annotation)
size_heatmap = arrays.set_value_at_pos( size_heatmap = arrays.set_value_at_pos(
size_heatmap, size_heatmap,
size, size,
@ -291,7 +289,7 @@ def generate_heatmaps(
# Get the class name of the sound event # Get the class name of the sound event
try: try:
class_name = targets.encode(sound_event_annotation) class_name = targets.encode_class(sound_event_annotation)
except ValueError as e: except ValueError as e:
logger.warning( logger.warning(
"Skipping annotation %s: Unexpected error while encoding " "Skipping annotation %s: Unexpected error while encoding "

Some files were not shown because too many files have changed in this diff Show More