mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Compare commits
13 Commits
ebad489cb1
...
a462beaeb8
Author | SHA1 | Date | |
---|---|---|---|
![]() |
a462beaeb8 | ||
![]() |
daab8ff0d7 | ||
![]() |
235f0e27da | ||
![]() |
b5b4229990 | ||
![]() |
8253b5bdc4 | ||
![]() |
c7ea361cf4 | ||
![]() |
3407e1b5f0 | ||
![]() |
0a0d6f7162 | ||
![]() |
ad0f0bcb24 | ||
![]() |
3103630c26 | ||
![]() |
960558be8b | ||
![]() |
e352dc40bd | ||
![]() |
c559bcc682 |
2
.gitignore
vendored
2
.gitignore
vendored
@ -107,7 +107,7 @@ experiments/*
|
||||
|
||||
# DO Include
|
||||
!batdetect2_notebook.ipynb
|
||||
!batdetect2/models/checkpoints/*.pth.tar
|
||||
!src/batdetect2/models/checkpoints/*.pth.tar
|
||||
!tests/data/*.wav
|
||||
!notebooks/*.ipynb
|
||||
!tests/data/**/*.wav
|
||||
|
4
Makefile
4
Makefile
@ -1,7 +1,7 @@
|
||||
# Variables
|
||||
SOURCE_DIR = batdetect2
|
||||
SOURCE_DIR = src
|
||||
TESTS_DIR = tests
|
||||
PYTHON_DIRS = batdetect2 tests
|
||||
PYTHON_DIRS = src tests
|
||||
DOCS_SOURCE = docs/source
|
||||
DOCS_BUILD = docs/build
|
||||
HTML_COVERAGE_DIR = htmlcov
|
||||
|
@ -1,511 +0,0 @@
|
||||
"""Handles mapping between geometric ROIs and target representations.
|
||||
|
||||
This module defines the interface and provides implementation for converting
|
||||
a sound event's Region of Interest (ROI), typically represented by a
|
||||
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
|
||||
suitable for use as a machine learning target. This usually involves:
|
||||
|
||||
1. Extracting a single reference point (time, frequency) from the geometry.
|
||||
2. Calculating relevant size dimensions (e.g., duration/width,
|
||||
bandwidth/height) and applying scaling factors.
|
||||
|
||||
It also provides the inverse operation: recovering an approximate geometric ROI
|
||||
(like a `BoundingBox`) from a predicted reference point and predicted size
|
||||
dimensions.
|
||||
|
||||
This logic is encapsulated within components adhering to the `ROITargetMapper`
|
||||
protocol. Configuration for this mapping (e.g., which reference point to use,
|
||||
scaling factors) is managed by the `ROIConfig`. This module separates the
|
||||
*geometric* aspect of target definition from the *semantic* classification
|
||||
handled in `batdetect2.targets.classes`.
|
||||
"""
|
||||
|
||||
from typing import List, Literal, Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
|
||||
Positions = Literal[
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
"top-left",
|
||||
"top-right",
|
||||
"center-left",
|
||||
"center-right",
|
||||
"top-center",
|
||||
"bottom-center",
|
||||
"center",
|
||||
"centroid",
|
||||
"point_on_surface",
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"ROITargetMapper",
|
||||
"ROIConfig",
|
||||
"BBoxEncoder",
|
||||
"build_roi_mapper",
|
||||
"load_roi_mapper",
|
||||
"DEFAULT_POSITION",
|
||||
"SIZE_WIDTH",
|
||||
"SIZE_HEIGHT",
|
||||
"SIZE_ORDER",
|
||||
"DEFAULT_TIME_SCALE",
|
||||
"DEFAULT_FREQUENCY_SCALE",
|
||||
]
|
||||
|
||||
SIZE_WIDTH = "width"
|
||||
"""Standard name for the width/time dimension component ('width')."""
|
||||
|
||||
SIZE_HEIGHT = "height"
|
||||
"""Standard name for the height/frequency dimension component ('height')."""
|
||||
|
||||
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
|
||||
"""Standard order of dimensions for size arrays ([width, height])."""
|
||||
|
||||
DEFAULT_TIME_SCALE = 1000.0
|
||||
"""Default scaling factor for time duration."""
|
||||
|
||||
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
|
||||
"""Default scaling factor for frequency bandwidth."""
|
||||
|
||||
|
||||
DEFAULT_POSITION = "bottom-left"
|
||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||
|
||||
|
||||
class ROITargetMapper(Protocol):
|
||||
"""Protocol defining the interface for ROI-to-target mapping.
|
||||
|
||||
Specifies the methods required for converting a geometric region of interest
|
||||
(`soundevent.data.Geometry`) into a target representation (reference point
|
||||
and scaled dimensions) and for recovering an approximate ROI from that
|
||||
representation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
A list containing the names of the dimensions returned by
|
||||
`get_roi_size` and expected by `recover_roi`
|
||||
(e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
dimension_names: List[str]
|
||||
|
||||
def get_roi_position(
|
||||
self,
|
||||
geom: data.Geometry,
|
||||
position: Optional[Positions] = None,
|
||||
) -> tuple[float, float]:
|
||||
"""Extract the reference position from a geometry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
The input geometry (e.g., BoundingBox, Polygon).
|
||||
position : Positions, optional
|
||||
Overrides the default `position` configured for the mapper.
|
||||
If provided, this position will be used instead of the mapper's
|
||||
internal default.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The calculated reference position as (time, frequency) coordinates,
|
||||
based on the implementing class's configuration (e.g., "center",
|
||||
"bottom-left").
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the position cannot be calculated for the given geometry type
|
||||
or configured reference point.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||
"""Calculate the scaled target dimensions from a geometry.
|
||||
|
||||
Computes the relevant size measures.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
The input geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
A NumPy array containing the scaled dimensions corresponding to
|
||||
`dimension_names`. For bounding boxes, typically contains
|
||||
`[scaled_width, scaled_height]`.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError, ValueError
|
||||
If the size cannot be computed for the given geometry type.
|
||||
"""
|
||||
...
|
||||
|
||||
def recover_roi(
|
||||
self,
|
||||
pos: tuple[float, float],
|
||||
dims: np.ndarray,
|
||||
position: Optional[Positions] = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate ROI from a position and target dimensions.
|
||||
|
||||
Performs the inverse mapping: takes a reference position and the
|
||||
predicted dimensions and reconstructs a geometric representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position (time, frequency).
|
||||
dims : np.ndarray
|
||||
NumPy array containing the dimensions, matching the order
|
||||
specified by `dimension_names`.
|
||||
position : Positions, optional
|
||||
Overrides the default `position` configured for the mapper.
|
||||
If provided, this position will be used instead of the mapper's
|
||||
internal default when reconstructing the roi geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Geometry
|
||||
The reconstructed geometry.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of provided dimensions `dims` does not match
|
||||
`dimension_names` or if reconstruction fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ROIConfig(BaseConfig):
|
||||
"""Configuration for mapping Regions of Interest (ROIs).
|
||||
|
||||
Defines parameters controlling how geometric ROIs are converted into
|
||||
target representations (reference points and scaled sizes).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
position : Positions, default="bottom-left"
|
||||
Specifies the reference point within the geometry (e.g., bounding box)
|
||||
to use as the target location (e.g., "center", "bottom-left").
|
||||
See `soundevent.geometry.operations.Positions`.
|
||||
time_scale : float, default=1000.0
|
||||
Scaling factor applied to the time duration (width) of the ROI
|
||||
when calculating the target size representation. Must match model
|
||||
expectations.
|
||||
frequency_scale : float, default=1/859.375
|
||||
Scaling factor applied to the frequency bandwidth (height) of the ROI
|
||||
when calculating the target size representation. Must match model
|
||||
expectations.
|
||||
"""
|
||||
|
||||
position: Positions = DEFAULT_POSITION
|
||||
time_scale: float = DEFAULT_TIME_SCALE
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
||||
|
||||
|
||||
class BBoxEncoder(ROITargetMapper):
|
||||
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
|
||||
|
||||
This class implements the ROI mapping protocol primarily for
|
||||
`soundevent.data.BoundingBox` geometry. It extracts reference points,
|
||||
calculates scaled width/height, and recovers bounding boxes based on
|
||||
configured position and scaling factors.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
Specifies the output dimension names as ['width', 'height'].
|
||||
position : Positions
|
||||
The configured reference point type (e.g., "center", "bottom-left").
|
||||
time_scale : float
|
||||
The configured scaling factor for the time dimension (width).
|
||||
frequency_scale : float
|
||||
The configured scaling factor for the frequency dimension (height).
|
||||
"""
|
||||
|
||||
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
position: Positions = DEFAULT_POSITION,
|
||||
time_scale: float = DEFAULT_TIME_SCALE,
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||
):
|
||||
"""Initialize the BBoxEncoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Positions, default="bottom-left"
|
||||
Reference point type within the bounding box.
|
||||
time_scale : float, default=1000.0
|
||||
Scaling factor for time duration (width).
|
||||
frequency_scale : float, default=1/859.375
|
||||
Scaling factor for frequency bandwidth (height).
|
||||
"""
|
||||
self.position: Positions = position
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
|
||||
def get_roi_position(
|
||||
self,
|
||||
geom: data.Geometry,
|
||||
position: Optional[Positions] = None,
|
||||
) -> Tuple[float, float]:
|
||||
"""Extract the configured reference position from the geometry.
|
||||
|
||||
Uses `soundevent.geometry.get_geometry_point`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
Input geometry (e.g., BoundingBox).
|
||||
position : Positions, optional
|
||||
Overrides the default `position` configured for the encoder.
|
||||
If provided, this position will be used instead of `self.position`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
Reference position (time, frequency).
|
||||
"""
|
||||
from soundevent import geometry
|
||||
|
||||
position = position or self.position
|
||||
return geometry.get_geometry_point(geom, position=position)
|
||||
|
||||
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||
"""Calculate the scaled [width, height] from the geometry's bounds.
|
||||
|
||||
Computes the bounding box, extracts duration and bandwidth, and applies
|
||||
the configured `time_scale` and `frequency_scale`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
Input geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
A 1D NumPy array: `[scaled_width, scaled_height]`.
|
||||
"""
|
||||
from soundevent import geometry
|
||||
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
return np.array(
|
||||
[
|
||||
(end_time - start_time) * self.time_scale,
|
||||
(high_freq - low_freq) * self.frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
def recover_roi(
|
||||
self,
|
||||
pos: tuple[float, float],
|
||||
dims: np.ndarray,
|
||||
position: Optional[Positions] = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover a BoundingBox from a position and scaled dimensions.
|
||||
|
||||
Un-scales the input dimensions using the configured factors and
|
||||
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
|
||||
the given reference `pos` according to the configured `position` type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
Reference position (time, frequency).
|
||||
dims : np.ndarray
|
||||
NumPy array containing the *scaled* dimensions, expected order is
|
||||
[scaled_width, scaled_height].
|
||||
position : Positions, optional
|
||||
Overrides the default `position` configured for the encoder.
|
||||
If provided, this position will be used instead of `self.position`
|
||||
when reconstructing the bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.BoundingBox
|
||||
The reconstructed bounding box.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `dims` does not have the expected shape (length 2).
|
||||
"""
|
||||
position = position or self.position
|
||||
|
||||
if dims.ndim != 1 or dims.shape[0] != 2:
|
||||
raise ValueError(
|
||||
"Dimension array does not have the expected shape. "
|
||||
f"({dims.shape = }) != ([2])"
|
||||
)
|
||||
|
||||
width, height = dims
|
||||
return _build_bounding_box(
|
||||
pos,
|
||||
duration=float(width) / self.time_scale,
|
||||
bandwidth=float(height) / self.frequency_scale,
|
||||
position=self.position,
|
||||
)
|
||||
|
||||
|
||||
def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
|
||||
"""Factory function to create an ROITargetMapper from configuration.
|
||||
|
||||
Currently creates a `BBoxEncoder` instance based on the provided
|
||||
`ROIConfig`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ROIConfig
|
||||
Configuration object specifying ROI mapping parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ROITargetMapper
|
||||
An initialized `BBoxEncoder` instance configured with the settings
|
||||
from `config`.
|
||||
"""
|
||||
return BBoxEncoder(
|
||||
position=config.position,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
|
||||
def load_roi_mapper(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> ROITargetMapper:
|
||||
"""Load ROI mapping configuration from a file and build the mapper.
|
||||
|
||||
Convenience function that loads an `ROIConfig` from the specified file
|
||||
(and optional field) and then uses `build_roi_mapper` to create the
|
||||
corresponding `ROITargetMapper` instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
ROI configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ROITargetMapper
|
||||
An initialized ROI mapper instance based on the configuration file.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||
TypeError
|
||||
If the configuration file cannot be found, parsed, validated, or if
|
||||
the specified `field` is invalid.
|
||||
"""
|
||||
config = load_config(path=path, schema=ROIConfig, field=field)
|
||||
return build_roi_mapper(config)
|
||||
|
||||
|
||||
VALID_POSITIONS = [
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
"top-left",
|
||||
"top-right",
|
||||
"center-left",
|
||||
"center-right",
|
||||
"top-center",
|
||||
"bottom-center",
|
||||
"center",
|
||||
"centroid",
|
||||
"point_on_surface",
|
||||
]
|
||||
|
||||
|
||||
def _build_bounding_box(
|
||||
pos: tuple[float, float],
|
||||
duration: float,
|
||||
bandwidth: float,
|
||||
position: Positions = DEFAULT_POSITION,
|
||||
) -> data.BoundingBox:
|
||||
"""Construct a BoundingBox from a reference point, size, and position type.
|
||||
|
||||
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
|
||||
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
||||
the input `pos` (time, freq) is located relative to the box (e.g.,
|
||||
center, corner).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
Reference position (time, frequency).
|
||||
duration : float
|
||||
The required *unscaled* duration (width) of the bounding box.
|
||||
bandwidth : float
|
||||
The required *unscaled* frequency bandwidth (height) of the bounding
|
||||
box.
|
||||
position : Positions, default="bottom-left"
|
||||
Specifies which part of the bounding box the input `pos` corresponds to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.BoundingBox
|
||||
The constructed bounding box object.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `position` is not a recognized value or format.
|
||||
"""
|
||||
time, freq = map(float, pos)
|
||||
duration = max(0, duration)
|
||||
bandwidth = max(0, bandwidth)
|
||||
if position in ["center", "centroid", "point_on_surface"]:
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
max(time - duration / 2, 0),
|
||||
max(freq - bandwidth / 2, 0),
|
||||
max(time + duration / 2, 0),
|
||||
max(freq + bandwidth / 2, 0),
|
||||
]
|
||||
)
|
||||
|
||||
if position not in VALID_POSITIONS:
|
||||
raise ValueError(
|
||||
f"Invalid position: {position}. "
|
||||
f"Valid options are: {VALID_POSITIONS}"
|
||||
)
|
||||
|
||||
y, x = position.split("-")
|
||||
|
||||
start_time = {
|
||||
"left": time,
|
||||
"center": time - duration / 2,
|
||||
"right": time - duration,
|
||||
}[x]
|
||||
|
||||
low_freq = {
|
||||
"bottom": freq,
|
||||
"center": freq - bandwidth / 2,
|
||||
"top": freq - bandwidth,
|
||||
}[y]
|
||||
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
max(0, start_time),
|
||||
max(0, low_freq),
|
||||
max(0, start_time + duration),
|
||||
max(0, low_freq + bandwidth),
|
||||
]
|
||||
)
|
19
notebooks/signal_generation.py
Normal file
19
notebooks/signal_generation.py
Normal file
@ -0,0 +1,19 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.13.15"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
@ -65,8 +65,12 @@ build-backend = "hatchling.build"
|
||||
[project.scripts]
|
||||
batdetect2 = "batdetect2.cli:cli"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
[dependency-groups]
|
||||
jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"]
|
||||
marimo = [
|
||||
"marimo>=0.12.2",
|
||||
]
|
||||
dev = [
|
||||
"debugpy>=1.8.8",
|
||||
"hypothesis>=6.118.7",
|
||||
"pytest>=7.2.2",
|
||||
@ -98,25 +102,16 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
||||
convention = "numpy"
|
||||
|
||||
[tool.pyright]
|
||||
include = ["batdetect2", "tests"]
|
||||
include = ["src", "tests"]
|
||||
pythonVersion = "3.9"
|
||||
pythonPlatform = "All"
|
||||
exclude = [
|
||||
"batdetect2/detector/",
|
||||
"batdetect2/finetune",
|
||||
"batdetect2/utils",
|
||||
"batdetect2/plotting",
|
||||
"batdetect2/plot",
|
||||
"batdetect2/api",
|
||||
"batdetect2/evaluate/legacy",
|
||||
"batdetect2/train/legacy",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
jupyter = [
|
||||
"ipywidgets>=8.1.5",
|
||||
"jupyter>=1.1.1",
|
||||
]
|
||||
marimo = [
|
||||
"marimo>=0.12.2",
|
||||
"src/batdetect2/detector/",
|
||||
"src/batdetect2/finetune",
|
||||
"src/batdetect2/utils",
|
||||
"src/batdetect2/plotting",
|
||||
"src/batdetect2/plot",
|
||||
"src/batdetect2/api",
|
||||
"src/batdetect2/evaluate/legacy",
|
||||
"src/batdetect2/train/legacy",
|
||||
]
|
||||
|
@ -189,8 +189,7 @@ def train_command(
|
||||
config=postprocess_config_loaded,
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded postprocessor from file {path}",
|
||||
path=train_config,
|
||||
"Loaded postprocessor from file {path}", path=postprocess_config
|
||||
)
|
||||
except IOError:
|
||||
logger.debug(
|
@ -157,4 +157,4 @@ def load_config(
|
||||
if field:
|
||||
config = get_object_field(config, field)
|
||||
|
||||
return schema.model_validate(config)
|
||||
return schema.model_validate(config or {})
|
@ -8,6 +8,7 @@ from batdetect2.data.annotations import (
|
||||
from batdetect2.data.datasets import (
|
||||
DatasetConfig,
|
||||
load_dataset,
|
||||
load_dataset_config,
|
||||
load_dataset_from_config,
|
||||
)
|
||||
|
||||
@ -19,5 +20,6 @@ __all__ = [
|
||||
"DatasetConfig",
|
||||
"load_annotated_dataset",
|
||||
"load_dataset",
|
||||
"load_dataset_config",
|
||||
"load_dataset_from_config",
|
||||
]
|
@ -161,6 +161,11 @@ def insert_source_tag(
|
||||
)
|
||||
|
||||
|
||||
# TODO: add documentation
|
||||
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||
|
||||
|
||||
def load_dataset_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
@ -72,7 +72,7 @@ def iterate_over_sound_events(
|
||||
sound_event_annotation
|
||||
)
|
||||
|
||||
class_name = targets.encode(sound_event_annotation)
|
||||
class_name = targets.encode_class(sound_event_annotation)
|
||||
if class_name is None and exclude_generic:
|
||||
continue
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
|
@ -40,7 +40,7 @@ def match_sound_events_and_raw_predictions(
|
||||
|
||||
gt_uuid = target.uuid if target is not None else None
|
||||
gt_det = target is not None
|
||||
gt_class = targets.encode(target) if target is not None else None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
|
||||
pred_score = float(prediction.detection_score) if prediction else 0
|
||||
|
@ -7,8 +7,8 @@ containing detected sound events with associated class tags and geometry.
|
||||
|
||||
The pipeline involves several configurable steps, implemented in submodules:
|
||||
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
|
||||
coordinates to raw model output arrays.
|
||||
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
|
||||
model output arrays.
|
||||
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||
(location and score) based on thresholds and score ranking (top-k).
|
||||
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||
@ -526,7 +526,7 @@ class Postprocessor(PostprocessorProtocol):
|
||||
return [
|
||||
convert_xr_dataset_to_raw_prediction(
|
||||
dataset,
|
||||
self.targets.recover_roi,
|
||||
self.targets.decode_roi,
|
||||
)
|
||||
for dataset in detection_datasets
|
||||
]
|
||||
@ -558,7 +558,7 @@ class Postprocessor(PostprocessorProtocol):
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
sound_event_decoder=self.targets.decode,
|
||||
sound_event_decoder=self.targets.decode_class,
|
||||
generic_class_tags=self.targets.generic_class_tags,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
)
|
@ -4,8 +4,7 @@ This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
||||
It takes the structured detection data extracted by the `extraction` module
|
||||
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||
class probabilities, and features for each detection point) and converts it
|
||||
into meaningful, standardized prediction objects based on the `soundevent` data
|
||||
model.
|
||||
into standardized prediction objects based on the `soundevent` data model.
|
||||
|
||||
The process involves:
|
||||
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||
@ -33,7 +32,7 @@ import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
||||
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||
from batdetect2.targets.classes import SoundEventDecoder
|
||||
from batdetect2.utils.arrays import iterate_over_array
|
||||
|
||||
@ -55,7 +54,7 @@ decoding.
|
||||
|
||||
def convert_xr_dataset_to_raw_prediction(
|
||||
detection_dataset: xr.Dataset,
|
||||
geometry_builder: GeometryBuilder,
|
||||
geometry_decoder: GeometryDecoder,
|
||||
) -> List[RawPrediction]:
|
||||
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||
|
||||
@ -72,7 +71,7 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
output by `extract_detection_xr_dataset`. Expected variables include
|
||||
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||
Must have a 'detection' dimension.
|
||||
geometry_builder : GeometryBuilder
|
||||
geometry_decoder : GeometryDecoder
|
||||
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||
of dimensions, and returns the corresponding reconstructed
|
||||
`soundevent.data.Geometry`.
|
||||
@ -96,14 +95,20 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
for det_num in range(detection_dataset.sizes["detection"]):
|
||||
det_info = detection_dataset.sel(detection=det_num)
|
||||
|
||||
geom = geometry_builder(
|
||||
# TODO: Maybe clean this up
|
||||
highest_scoring_class = det_info.coords["category"][
|
||||
det_info["classes"].argmax()
|
||||
].item()
|
||||
|
||||
geom = geometry_decoder(
|
||||
(det_info.time, det_info.frequency),
|
||||
det_info.dimensions,
|
||||
class_name=highest_scoring_class,
|
||||
)
|
||||
|
||||
detections.append(
|
||||
RawPrediction(
|
||||
detection_score=det_info.score,
|
||||
detection_score=det_info.scores,
|
||||
geometry=geom,
|
||||
class_scores=det_info.classes,
|
||||
features=det_info.features,
|
@ -1,6 +1,6 @@
|
||||
"""Extracts candidate detection points from a model output heatmap.
|
||||
|
||||
This module implements a specific step within the BatDetect2 postprocessing
|
||||
This module implements Step 3 within the BatDetect2 postprocessing
|
||||
pipeline. Its primary function is to identify potential sound event locations
|
||||
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||
produced by the neural network (usually after Non-Maximum Suppression and
|
@ -1,9 +1,9 @@
|
||||
"""Extracts associated data for detected points from model output arrays.
|
||||
|
||||
This module implements a key step (Step 4) in the BatDetect2 postprocessing
|
||||
pipeline. After candidate detection points (time, frequency, score) have been
|
||||
identified, this module extracts the corresponding values from other raw model
|
||||
output arrays, such as:
|
||||
This module implements a Step 4 in the BatDetect2 postprocessing pipeline.
|
||||
After candidate detection points (time, frequency, score) have been identified,
|
||||
this module extracts the corresponding values from other raw model output
|
||||
arrays, such as:
|
||||
|
||||
- Predicted bounding box sizes (width, height).
|
||||
- Class probability scores for each defined target class.
|
@ -11,30 +11,37 @@ modularity and consistent interaction between different parts of the BatDetect2
|
||||
system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from typing import Callable, List, NamedTuple, Protocol
|
||||
from typing import List, NamedTuple, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.targets.types import Position, Size
|
||||
|
||||
__all__ = [
|
||||
"RawPrediction",
|
||||
"PostprocessorProtocol",
|
||||
"GeometryBuilder",
|
||||
"GeometryDecoder",
|
||||
]
|
||||
|
||||
|
||||
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
|
||||
"""Type alias for a function that recovers geometry from position and size.
|
||||
# TODO: update the docstring
|
||||
class GeometryDecoder(Protocol):
|
||||
"""Type alias for a function that recovers geometry from position and size.
|
||||
|
||||
This callable takes:
|
||||
1. A position tuple `(time, frequency)`.
|
||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||
`BoundingBox`).
|
||||
"""
|
||||
This callable takes:
|
||||
1. A position tuple `(time, frequency)`.
|
||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||
3. Optionally a class name of the highest scoring class. This is to accomodate
|
||||
different ways of decoding geometry that depend on the predicted class.
|
||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||
`BoundingBox`).
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, position: Position, size: Size, class_name: Optional[str] = None
|
||||
) -> data.Geometry: ...
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
@ -23,7 +23,7 @@ object is via the `build_targets` or `load_targets` functions.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -50,7 +50,8 @@ from batdetect2.targets.filtering import (
|
||||
load_filter_from_config,
|
||||
)
|
||||
from batdetect2.targets.rois import (
|
||||
ROIConfig,
|
||||
AnchorBBoxMapperConfig,
|
||||
ROIMapperConfig,
|
||||
ROITargetMapper,
|
||||
build_roi_mapper,
|
||||
)
|
||||
@ -59,11 +60,11 @@ from batdetect2.targets.terms import (
|
||||
TermInfo,
|
||||
TermRegistry,
|
||||
call_type,
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
individual,
|
||||
register_term,
|
||||
term_registry,
|
||||
)
|
||||
from batdetect2.targets.transform import (
|
||||
DerivationRegistry,
|
||||
@ -73,13 +74,13 @@ from batdetect2.targets.transform import (
|
||||
SoundEventTransformation,
|
||||
TransformConfig,
|
||||
build_transformation_from_config,
|
||||
derivation_registry,
|
||||
default_derivation_registry,
|
||||
get_derivation,
|
||||
load_transformation_config,
|
||||
load_transformation_from_config,
|
||||
register_derivation,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.targets.types import Position, Size, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClassesConfig",
|
||||
@ -88,7 +89,7 @@ __all__ = [
|
||||
"FilterConfig",
|
||||
"FilterRule",
|
||||
"MapValueRule",
|
||||
"ROIConfig",
|
||||
"AnchorBBoxMapperConfig",
|
||||
"ROITargetMapper",
|
||||
"ReplaceRule",
|
||||
"SoundEventDecoder",
|
||||
@ -156,12 +157,12 @@ class TargetConfig(BaseConfig):
|
||||
omitted, default ROI mapping settings are used.
|
||||
"""
|
||||
|
||||
filtering: Optional[FilterConfig] = None
|
||||
transforms: Optional[TransformConfig] = None
|
||||
filtering: FilterConfig = Field(default_factory=FilterConfig)
|
||||
transforms: TransformConfig = Field(default_factory=TransformConfig)
|
||||
classes: ClassesConfig = Field(
|
||||
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
||||
)
|
||||
roi: Optional[ROIConfig] = None
|
||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||
|
||||
|
||||
def load_target_config(
|
||||
@ -240,6 +241,7 @@ class Targets(TargetProtocol):
|
||||
generic_class_tags: List[data.Tag],
|
||||
filter_fn: Optional[SoundEventFilter] = None,
|
||||
transform_fn: Optional[SoundEventTransformation] = None,
|
||||
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
||||
):
|
||||
"""Initialize the Targets object.
|
||||
|
||||
@ -272,6 +274,16 @@ class Targets(TargetProtocol):
|
||||
self._encode_fn = encode_fn
|
||||
self._decode_fn = decode_fn
|
||||
self._transform_fn = transform_fn
|
||||
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
||||
|
||||
for class_name in self._roi_mapper_overrides:
|
||||
if class_name not in self.class_names:
|
||||
# TODO: improve this warning
|
||||
logger.warning(
|
||||
"The ROI mapper overrides contains a class ({class_name}) "
|
||||
"not present in the class names.",
|
||||
class_name=class_name,
|
||||
)
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the configured filter to a sound event annotation.
|
||||
@ -291,7 +303,9 @@ class Targets(TargetProtocol):
|
||||
return True
|
||||
return self._filter_fn(sound_event)
|
||||
|
||||
def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]:
|
||||
def encode_class(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
"""Encode a sound event annotation to its target class name.
|
||||
|
||||
Applies the configured class definition rules (including priority)
|
||||
@ -312,7 +326,7 @@ class Targets(TargetProtocol):
|
||||
"""
|
||||
return self._encode_fn(sound_event)
|
||||
|
||||
def decode(self, class_label: str) -> List[data.Tag]:
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||
@ -352,9 +366,9 @@ class Targets(TargetProtocol):
|
||||
return self._transform_fn(sound_event)
|
||||
return sound_event
|
||||
|
||||
def get_position(
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[float, float]:
|
||||
) -> tuple[Position, Size]:
|
||||
"""Extract the target reference position from the annotation's roi.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||
@ -374,50 +388,20 @@ class Targets(TargetProtocol):
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
geom = sound_event.sound_event.geometry
|
||||
class_name = self.encode_class(sound_event)
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Sound event has no geometry, cannot get its position."
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].encode(
|
||||
sound_event.sound_event
|
||||
)
|
||||
|
||||
return self._roi_mapper.get_roi_position(geom)
|
||||
return self._roi_mapper.encode(sound_event.sound_event)
|
||||
|
||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||
"""Calculate the target size dimensions from the annotation's geometry.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_size` method, which
|
||||
applies configured scaling factors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
NumPy array containing the size dimensions, matching the
|
||||
order in `self.dimension_names` (e.g., `[width, height]`).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
geom = sound_event.sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Sound event has no geometry, cannot get its size."
|
||||
)
|
||||
|
||||
return self._roi_mapper.get_roi_size(geom)
|
||||
|
||||
def recover_roi(
|
||||
def decode_roi(
|
||||
self,
|
||||
pos: tuple[float, float],
|
||||
dims: np.ndarray,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: Optional[str] = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
@ -438,7 +422,13 @@ class Targets(TargetProtocol):
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
return self._roi_mapper.recover_roi(pos, dims)
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].decode(
|
||||
position,
|
||||
size,
|
||||
)
|
||||
|
||||
return self._roi_mapper.decode(position, size)
|
||||
|
||||
|
||||
DEFAULT_CLASSES = [
|
||||
@ -493,10 +483,12 @@ DEFAULT_CLASSES = [
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Nyctalus leisleri")],
|
||||
name="nyclei",
|
||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
||||
name="rhifer",
|
||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Plecotus auritus")],
|
||||
@ -537,13 +529,14 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
]
|
||||
),
|
||||
classes=DEFAULT_CLASSES_CONFIG,
|
||||
roi=AnchorBBoxMapperConfig(),
|
||||
)
|
||||
|
||||
|
||||
def build_targets(
|
||||
config: Optional[TargetConfig] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||
) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
@ -606,12 +599,17 @@ def build_targets(
|
||||
if config.transforms
|
||||
else None
|
||||
)
|
||||
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
|
||||
roi_mapper = build_roi_mapper(config.roi)
|
||||
class_names = get_class_names_from_config(config.classes)
|
||||
generic_class_tags = build_generic_class_tags(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
roi_overrides = {
|
||||
class_config.name: build_roi_mapper(class_config.roi)
|
||||
for class_config in config.classes.classes
|
||||
if class_config.roi is not None
|
||||
}
|
||||
|
||||
return Targets(
|
||||
filter_fn=filter_fn,
|
||||
@ -621,14 +619,15 @@ def build_targets(
|
||||
roi_mapper=roi_mapper,
|
||||
generic_class_tags=generic_class_tags,
|
||||
transform_fn=transform_fn,
|
||||
roi_mapper_overrides=roi_overrides,
|
||||
)
|
||||
|
||||
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
|
@ -6,29 +6,27 @@ from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.rois import ROIMapperConfig
|
||||
from batdetect2.targets.terms import (
|
||||
GENERIC_CLASS_KEY,
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SoundEventEncoder",
|
||||
"SoundEventDecoder",
|
||||
"TargetClass",
|
||||
"ClassesConfig",
|
||||
"load_classes_config",
|
||||
"load_encoder_from_config",
|
||||
"load_decoder_from_config",
|
||||
"build_sound_event_encoder",
|
||||
"build_sound_event_decoder",
|
||||
"build_generic_class_tags",
|
||||
"get_class_names_from_config",
|
||||
"DEFAULT_SPECIES_LIST",
|
||||
"build_generic_class_tags",
|
||||
"build_sound_event_decoder",
|
||||
"build_sound_event_encoder",
|
||||
"get_class_names_from_config",
|
||||
"load_classes_config",
|
||||
"load_decoder_from_config",
|
||||
"load_encoder_from_config",
|
||||
]
|
||||
|
||||
|
||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
"""Type alias for a sound event class encoder function.
|
||||
|
||||
@ -113,6 +111,7 @@ class TargetClass(BaseConfig):
|
||||
tags: List[TagInfo] = Field(min_length=1)
|
||||
match_type: Literal["all", "any"] = Field(default="all")
|
||||
output_tags: Optional[List[TagInfo]] = None
|
||||
roi: Optional[ROIMapperConfig] = None
|
||||
|
||||
|
||||
def _get_default_classes() -> List[TargetClass]:
|
||||
@ -235,7 +234,7 @@ class ClassesConfig(BaseConfig):
|
||||
return v
|
||||
|
||||
|
||||
def _is_target_class(
|
||||
def is_target_class(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
match_all: bool = True,
|
||||
@ -316,7 +315,7 @@ def _encode_with_multiple_classifiers(
|
||||
|
||||
def build_sound_event_encoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
"""Build a sound event encoder function from the classes configuration.
|
||||
|
||||
@ -350,7 +349,7 @@ def build_sound_event_encoder(
|
||||
(
|
||||
class_info.name,
|
||||
partial(
|
||||
_is_target_class,
|
||||
is_target_class,
|
||||
tags={
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in class_info.tags
|
||||
@ -410,7 +409,7 @@ def _decode_class(
|
||||
|
||||
def build_sound_event_decoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
raise_on_unmapped: bool = False,
|
||||
) -> SoundEventDecoder:
|
||||
"""Build a sound event decoder function from the classes configuration.
|
||||
@ -465,7 +464,7 @@ def build_sound_event_decoder(
|
||||
|
||||
def build_generic_class_tags(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[data.Tag]:
|
||||
"""Extract and build the list of tags for the generic class from config.
|
||||
|
||||
@ -530,7 +529,7 @@ def load_classes_config(
|
||||
def load_encoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
"""Load a class encoder function directly from a configuration file.
|
||||
|
||||
@ -571,7 +570,7 @@ def load_encoder_from_config(
|
||||
def load_decoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
raise_on_unmapped: bool = False,
|
||||
) -> SoundEventDecoder:
|
||||
"""Load a class decoder function directly from a configuration file.
|
@ -10,7 +10,7 @@ from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
term_registry,
|
||||
default_term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -156,7 +156,7 @@ def equal_tags(
|
||||
|
||||
def build_filter_from_rule(
|
||||
rule: FilterRule,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Creates a callable filter function from a single FilterRule.
|
||||
|
||||
@ -243,7 +243,7 @@ class FilterConfig(BaseConfig):
|
||||
|
||||
def build_sound_event_filter(
|
||||
config: FilterConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Builds a merged filter function from a FilterConfig object.
|
||||
|
||||
@ -291,7 +291,7 @@ def load_filter_config(
|
||||
def load_filter_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Loads filter configuration from a file and builds the filter function.
|
||||
|
684
src/batdetect2/targets/rois.py
Normal file
684
src/batdetect2/targets/rois.py
Normal file
@ -0,0 +1,684 @@
|
||||
"""Handles mapping between geometric ROIs and target representations.
|
||||
|
||||
This module defines a standardized interface (`ROITargetMapper`) for converting
|
||||
a sound event's Region of Interest (ROI) into a target representation suitable
|
||||
for machine learning models, and for decoding model outputs back into geometric
|
||||
ROIs.
|
||||
|
||||
The core operations are:
|
||||
1. **Encoding**: A `soundevent.data.SoundEvent` is mapped to a reference
|
||||
`Position` (time, frequency) and a `Size` array. The method for
|
||||
determining the position and size varies by the mapper implementation
|
||||
(e.g., using a bounding box anchor or the point of peak energy).
|
||||
2. **Decoding**: A `Position` and `Size` array are mapped back to an
|
||||
approximate `soundevent.data.Geometry` (typically a `BoundingBox`).
|
||||
|
||||
This logic is encapsulated within specific mapper classes. Configuration for
|
||||
each mapper (e.g., anchor point, scaling factors) is managed by a corresponding
|
||||
Pydantic config object. The `ROIMapperConfig` type allows for flexibly
|
||||
selecting and configuring the desired mapper. This module separates the
|
||||
*geometric* aspect of target definition from *semantic* classification.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import Position, Size
|
||||
|
||||
__all__ = [
|
||||
"Anchor",
|
||||
"AnchorBBoxMapper",
|
||||
"AnchorBBoxMapperConfig",
|
||||
"DEFAULT_ANCHOR",
|
||||
"DEFAULT_FREQUENCY_SCALE",
|
||||
"DEFAULT_TIME_SCALE",
|
||||
"PeakEnergyBBoxMapper",
|
||||
"PeakEnergyBBoxMapperConfig",
|
||||
"ROIMapperConfig",
|
||||
"ROITargetMapper",
|
||||
"SIZE_HEIGHT",
|
||||
"SIZE_ORDER",
|
||||
"SIZE_WIDTH",
|
||||
"build_roi_mapper",
|
||||
]
|
||||
|
||||
Anchor = Literal[
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
"top-left",
|
||||
"top-right",
|
||||
"center-left",
|
||||
"center-right",
|
||||
"top-center",
|
||||
"bottom-center",
|
||||
"center",
|
||||
"centroid",
|
||||
"point_on_surface",
|
||||
]
|
||||
|
||||
SIZE_WIDTH = "width"
|
||||
"""Standard name for the width/time dimension component ('width')."""
|
||||
|
||||
SIZE_HEIGHT = "height"
|
||||
"""Standard name for the height/frequency dimension component ('height')."""
|
||||
|
||||
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
|
||||
"""Standard order of dimensions for size arrays ([width, height])."""
|
||||
|
||||
DEFAULT_TIME_SCALE = 1000.0
|
||||
"""Default scaling factor for time duration."""
|
||||
|
||||
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
|
||||
"""Default scaling factor for frequency bandwidth."""
|
||||
|
||||
DEFAULT_ANCHOR = "bottom-left"
|
||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||
|
||||
|
||||
class ROITargetMapper(Protocol):
|
||||
"""Protocol defining the interface for ROI-to-target mapping.
|
||||
|
||||
Specifies the `encode` and `decode` methods required for converting a
|
||||
`soundevent.data.SoundEvent` into a target representation (a reference
|
||||
position and a size vector) and for recovering an approximate ROI from that
|
||||
representation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
A list containing the names of the dimensions in the `Size` array
|
||||
returned by `encode` and expected by `decode`.
|
||||
"""
|
||||
|
||||
dimension_names: List[str]
|
||||
|
||||
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||
"""Encode a SoundEvent's geometry into a position and size.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEvent
|
||||
The input sound event, which must have a geometry attribute.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[Position, Size]
|
||||
A tuple containing:
|
||||
- The reference position as (time, frequency) coordinates.
|
||||
- A NumPy array with the calculated size dimensions.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the sound event does not have a geometry.
|
||||
"""
|
||||
...
|
||||
|
||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||
"""Decode a position and size back into a geometric ROI.
|
||||
|
||||
Performs the inverse mapping: takes a reference position and size
|
||||
dimensions and reconstructs a geometric representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Position
|
||||
The reference position (time, frequency).
|
||||
size : Size
|
||||
NumPy array containing the size dimensions, matching the order
|
||||
and meaning specified by `dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Geometry
|
||||
The reconstructed geometry, typically a `BoundingBox`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the `size` array has an unexpected shape or if reconstruction
|
||||
fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AnchorBBoxMapperConfig(BaseConfig):
|
||||
"""Configuration for `AnchorBBoxMapper`.
|
||||
|
||||
Defines parameters for converting ROIs into targets using a fixed anchor
|
||||
point on the bounding box.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : Literal["anchor_bbox"]
|
||||
The unique identifier for this mapper type.
|
||||
anchor : Anchor
|
||||
Specifies the anchor point within the bounding box to use as the
|
||||
target's reference position (e.g., "center", "bottom-left").
|
||||
time_scale : float
|
||||
Scaling factor applied to the time duration (width) of the ROI.
|
||||
frequency_scale : float
|
||||
Scaling factor applied to the frequency bandwidth (height) of the ROI.
|
||||
"""
|
||||
|
||||
name: Literal["anchor_bbox"] = "anchor_bbox"
|
||||
anchor: Anchor = DEFAULT_ANCHOR
|
||||
time_scale: float = DEFAULT_TIME_SCALE
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
||||
|
||||
|
||||
class AnchorBBoxMapper(ROITargetMapper):
|
||||
"""Maps ROIs using a bounding box anchor point and width/height.
|
||||
|
||||
This class implements the `ROITargetMapper` protocol for `BoundingBox`
|
||||
geometries.
|
||||
|
||||
**Encoding**: The `position` is a fixed anchor point on the bounding box
|
||||
(e.g., "bottom-left"). The `size` is a 2-element array containing the
|
||||
scaled width and height of the box.
|
||||
|
||||
**Decoding**: Reconstructs a `BoundingBox` from an anchor point and
|
||||
scaled width/height.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
The output dimension names: `['width', 'height']`.
|
||||
anchor : Anchor
|
||||
The configured anchor point type (e.g., "center", "bottom-left").
|
||||
time_scale : float
|
||||
The scaling factor for the time dimension (width).
|
||||
frequency_scale : float
|
||||
The scaling factor for the frequency dimension (height).
|
||||
"""
|
||||
|
||||
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
anchor: Anchor = DEFAULT_ANCHOR,
|
||||
time_scale: float = DEFAULT_TIME_SCALE,
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||
):
|
||||
"""Initialize the BBoxEncoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
anchor : Anchor
|
||||
Reference point type within the bounding box.
|
||||
time_scale : float
|
||||
Scaling factor for time duration (width).
|
||||
frequency_scale : float
|
||||
Scaling factor for frequency bandwidth (height).
|
||||
"""
|
||||
self.anchor: Anchor = anchor
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
|
||||
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
|
||||
"""Encode a SoundEvent into an anchor position and scaled box size.
|
||||
|
||||
The position is determined by the configured anchor on the sound
|
||||
event's bounding box. The size is the scaled width and height.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEvent
|
||||
The input sound event with a geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[Position, Size]
|
||||
A tuple of (anchor_position, [scaled_width, scaled_height]).
|
||||
"""
|
||||
from soundevent import geometry
|
||||
|
||||
geom = sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Cannot encode the geometry of a sound event without geometry."
|
||||
f" Sound event: {sound_event}"
|
||||
)
|
||||
|
||||
position = geometry.get_geometry_point(geom, position=self.anchor)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
|
||||
size = np.array(
|
||||
[
|
||||
(end_time - start_time) * self.time_scale,
|
||||
(high_freq - low_freq) * self.frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
return position, size
|
||||
|
||||
def decode(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
) -> data.Geometry:
|
||||
"""Recover a BoundingBox from an anchor position and scaled size.
|
||||
|
||||
Un-scales the input dimensions and reconstructs a
|
||||
`soundevent.data.BoundingBox` relative to the given anchor position.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Position
|
||||
Reference anchor position (time, frequency).
|
||||
size : Size
|
||||
NumPy array containing the scaled [width, height].
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.BoundingBox
|
||||
The reconstructed bounding box.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `size` does not have the expected shape (length 2).
|
||||
"""
|
||||
|
||||
if size.ndim != 1 or size.shape[0] != 2:
|
||||
raise ValueError(
|
||||
"Dimension array does not have the expected shape. "
|
||||
f"({size.shape = }) != ([2])"
|
||||
)
|
||||
|
||||
width, height = size
|
||||
return _build_bounding_box(
|
||||
position,
|
||||
duration=float(width) / self.time_scale,
|
||||
bandwidth=float(height) / self.frequency_scale,
|
||||
anchor=self.anchor,
|
||||
)
|
||||
|
||||
|
||||
class PeakEnergyBBoxMapperConfig(BaseConfig):
|
||||
"""Configuration for `PeakEnergyBBoxMapper`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : Literal["peak_energy_bbox"]
|
||||
The unique identifier for this mapper type.
|
||||
preprocessing : PreprocessingConfig
|
||||
Configuration for the spectrogram preprocessor needed to find the
|
||||
peak energy.
|
||||
loading_buffer : float
|
||||
Seconds to add to each side of the ROI when loading audio to ensure
|
||||
the peak is captured accurately, avoiding boundary effects.
|
||||
time_scale : float
|
||||
Scaling factor applied to the time dimensions.
|
||||
frequency_scale : float
|
||||
Scaling factor applied to the frequency dimensions.
|
||||
"""
|
||||
|
||||
name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
loading_buffer: float = 0.01
|
||||
time_scale: float = DEFAULT_TIME_SCALE
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
||||
|
||||
|
||||
class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
"""Maps ROIs using the peak energy point and distances to edges.
|
||||
|
||||
This class implements the `ROITargetMapper` protocol.
|
||||
|
||||
**Encoding**: The `position` is the (time, frequency) coordinate of the
|
||||
point with the highest energy within the sound event's bounding box. The
|
||||
`size` is a 4-element array representing the scaled distances from this
|
||||
peak energy point to the left, bottom, right, and top edges of the box.
|
||||
|
||||
**Decoding**: Reconstructs a `BoundingBox` by adding/subtracting the
|
||||
un-scaled distances from the peak energy point.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
The output dimension names: `['left', 'bottom', 'right', 'top']`.
|
||||
preprocessor : PreprocessorProtocol
|
||||
The spectrogram preprocessor instance.
|
||||
time_scale : float
|
||||
The scaling factor for time-based distances.
|
||||
frequency_scale : float
|
||||
The scaling factor for frequency-based distances.
|
||||
loading_buffer : float
|
||||
The buffer used for loading audio around the ROI.
|
||||
"""
|
||||
|
||||
dimension_names = ["left", "bottom", "right", "top"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
time_scale: float = DEFAULT_TIME_SCALE,
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||
loading_buffer: float = 0.01,
|
||||
):
|
||||
"""Initialize the PeakEnergyBBoxMapper.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
preprocessor : PreprocessorProtocol
|
||||
An initialized preprocessor for generating spectrograms.
|
||||
time_scale : float
|
||||
Scaling factor for time dimensions (left, right distances).
|
||||
frequency_scale : float
|
||||
Scaling factor for frequency dimensions (bottom, top distances).
|
||||
loading_buffer : float
|
||||
Buffer in seconds to add when loading audio clips.
|
||||
"""
|
||||
self.preprocessor = preprocessor
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
self.loading_buffer = loading_buffer
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sound_event: data.SoundEvent,
|
||||
) -> tuple[Position, Size]:
|
||||
"""Encode a SoundEvent into a peak energy position and edge distances.
|
||||
|
||||
Finds the peak energy coordinates within the event's bounding box
|
||||
and calculates the scaled distances from this point to the box edges.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEvent
|
||||
The input sound event with a geometry and associated recording.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[Position, Size]
|
||||
A tuple of (peak_position, [l, b, r, t] distances).
|
||||
"""
|
||||
from soundevent import geometry
|
||||
|
||||
geom = sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Cannot encode the geometry of a sound event without geometry."
|
||||
f" Sound event: {sound_event}"
|
||||
)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
|
||||
time, freq = get_peak_energy_coordinates(
|
||||
recording=sound_event.recording,
|
||||
preprocessor=self.preprocessor,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
loading_buffer=self.loading_buffer,
|
||||
)
|
||||
|
||||
size = np.array(
|
||||
[
|
||||
(time - start_time) * self.time_scale,
|
||||
(freq - low_freq) * self.frequency_scale,
|
||||
(end_time - time) * self.time_scale,
|
||||
(high_freq - freq) * self.frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
return (time, freq), size
|
||||
|
||||
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||
"""Recover a BoundingBox from a peak position and edge distances.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Position
|
||||
The reference peak energy position (time, frequency).
|
||||
size : Size
|
||||
NumPy array with scaled distances [left, bottom, right, top].
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.BoundingBox
|
||||
The reconstructed bounding box.
|
||||
"""
|
||||
time, freq = position
|
||||
left, bottom, right, top = size
|
||||
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
time - max(0, float(left)) / self.time_scale,
|
||||
freq - max(0, float(bottom)) / self.frequency_scale,
|
||||
time + max(0, float(right)) / self.time_scale,
|
||||
freq + max(0, float(top)) / self.frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
ROIMapperConfig = Annotated[
|
||||
Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""A discriminated union of all supported ROI mapper configurations.
|
||||
|
||||
This type allows for selecting and configuring different `ROITargetMapper`
|
||||
implementations by using the `name` field as a discriminator.
|
||||
"""
|
||||
|
||||
|
||||
def build_roi_mapper(
|
||||
config: Optional[ROIMapperConfig] = None,
|
||||
) -> ROITargetMapper:
|
||||
"""Factory function to create an ROITargetMapper from a config object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ROIMapperConfig
|
||||
A configuration object specifying the mapper type and its parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ROITargetMapper
|
||||
An initialized ROI mapper instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the `name` in the config does not correspond to a known mapper.
|
||||
"""
|
||||
config = config or AnchorBBoxMapperConfig()
|
||||
|
||||
if config.name == "anchor_bbox":
|
||||
return AnchorBBoxMapper(
|
||||
anchor=config.anchor,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
if config.name == "peak_energy_bbox":
|
||||
preprocessor = build_preprocessor(config.preprocessing)
|
||||
return PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
loading_buffer=config.loading_buffer,
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"No ROI mapper of name '{config.name}' is implemented"
|
||||
)
|
||||
|
||||
|
||||
VALID_ANCHORS = [
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
"top-left",
|
||||
"top-right",
|
||||
"center-left",
|
||||
"center-right",
|
||||
"top-center",
|
||||
"bottom-center",
|
||||
"center",
|
||||
"centroid",
|
||||
"point_on_surface",
|
||||
]
|
||||
|
||||
|
||||
def _build_bounding_box(
|
||||
pos: tuple[float, float],
|
||||
duration: float,
|
||||
bandwidth: float,
|
||||
anchor: Anchor = DEFAULT_ANCHOR,
|
||||
) -> data.BoundingBox:
|
||||
"""Construct a BoundingBox from a reference point, size, and position type.
|
||||
|
||||
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
|
||||
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
||||
the input `pos` (time, freq) is located relative to the box (e.g.,
|
||||
center, corner).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
Reference position (time, frequency).
|
||||
duration : float
|
||||
The required *unscaled* duration (width) of the bounding box.
|
||||
bandwidth : float
|
||||
The required *unscaled* frequency bandwidth (height) of the bounding
|
||||
box.
|
||||
anchor : Anchor
|
||||
Specifies which part of the bounding box the input `pos` corresponds to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.BoundingBox
|
||||
The constructed bounding box object.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `anchor` is not a recognized value or format.
|
||||
"""
|
||||
time, freq = map(float, pos)
|
||||
duration = max(0, duration)
|
||||
bandwidth = max(0, bandwidth)
|
||||
if anchor in ["center", "centroid", "point_on_surface"]:
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
max(time - duration / 2, 0),
|
||||
max(freq - bandwidth / 2, 0),
|
||||
max(time + duration / 2, 0),
|
||||
max(freq + bandwidth / 2, 0),
|
||||
]
|
||||
)
|
||||
|
||||
if anchor not in VALID_ANCHORS:
|
||||
raise ValueError(
|
||||
f"Invalid anchor: {anchor}. Valid options are: {VALID_ANCHORS}"
|
||||
)
|
||||
|
||||
y, x = anchor.split("-")
|
||||
|
||||
start_time = {
|
||||
"left": time,
|
||||
"center": time - duration / 2,
|
||||
"right": time - duration,
|
||||
}[x]
|
||||
|
||||
low_freq = {
|
||||
"bottom": freq,
|
||||
"center": freq - bandwidth / 2,
|
||||
"top": freq - bandwidth,
|
||||
}[y]
|
||||
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
max(0, start_time),
|
||||
max(0, low_freq),
|
||||
max(0, start_time + duration),
|
||||
max(0, low_freq + bandwidth),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_peak_energy_coordinates(
|
||||
recording: data.Recording,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
start_time: float = 0,
|
||||
end_time: Optional[float] = None,
|
||||
low_freq: float = 0,
|
||||
high_freq: Optional[float] = None,
|
||||
loading_buffer: float = 0.05,
|
||||
) -> Position:
|
||||
"""Find the coordinates of the highest energy point in a spectrogram.
|
||||
|
||||
Generates a spectrogram for a specified time-frequency region of a
|
||||
recording and returns the (time, frequency) coordinates of the pixel with
|
||||
the maximum value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The recording to analyze.
|
||||
preprocessor : PreprocessorProtocol
|
||||
The processor to convert audio to a spectrogram.
|
||||
start_time : float, default=0
|
||||
The start time of the region of interest.
|
||||
end_time : float, optional
|
||||
The end time of the region of interest. Defaults to recording duration.
|
||||
low_freq : float, default=0
|
||||
The low frequency of the region of interest.
|
||||
high_freq : float, optional
|
||||
The high frequency of the region of interest. Defaults to Nyquist.
|
||||
loading_buffer : float, default=0.05
|
||||
Buffer in seconds to add around the time range when loading the clip
|
||||
to mitigate border effects from transformations like STFT.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Position
|
||||
A (time, frequency) tuple for the peak energy location.
|
||||
"""
|
||||
if end_time is None:
|
||||
end_time = recording.duration
|
||||
end_time = min(end_time, recording.duration)
|
||||
|
||||
if high_freq is None:
|
||||
high_freq = recording.samplerate / 2
|
||||
|
||||
clip_start = max(0, start_time - loading_buffer)
|
||||
clip_end = min(recording.duration, end_time + loading_buffer)
|
||||
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=clip_start,
|
||||
end_time=clip_end,
|
||||
)
|
||||
|
||||
spec = preprocessor.preprocess_clip(clip)
|
||||
low_freq = max(low_freq, preprocessor.min_freq)
|
||||
high_freq = min(high_freq, preprocessor.max_freq)
|
||||
selection = spec.sel(
|
||||
time=slice(start_time, end_time),
|
||||
frequency=slice(low_freq, high_freq),
|
||||
)
|
||||
|
||||
index = selection.argmax(dim=["time", "frequency"])
|
||||
point = selection.isel(index) # type: ignore
|
||||
peak_time: float = point.time.item()
|
||||
peak_freq: float = point.frequency.item()
|
||||
return peak_time, peak_freq
|
@ -230,11 +230,12 @@ class TermRegistry(Mapping[str, data.Term]):
|
||||
del self._terms[key]
|
||||
|
||||
|
||||
term_registry = TermRegistry(
|
||||
default_term_registry = TermRegistry(
|
||||
terms=dict(
|
||||
[
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
("event", call_type),
|
||||
("species", terms.scientific_name),
|
||||
("individual", individual),
|
||||
("data_source", data_source),
|
||||
(GENERIC_CLASS_KEY, generic_class),
|
||||
@ -252,7 +253,7 @@ is explicitly passed.
|
||||
|
||||
def get_term_from_key(
|
||||
key: str,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> data.Term:
|
||||
"""Convenience function to retrieve a term by key from a registry.
|
||||
|
||||
@ -277,10 +278,13 @@ def get_term_from_key(
|
||||
KeyError
|
||||
If the key is not found in the specified registry.
|
||||
"""
|
||||
term_registry = term_registry or default_term_registry
|
||||
return term_registry.get_term(key)
|
||||
|
||||
|
||||
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
|
||||
def get_term_keys(
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[str]:
|
||||
"""Convenience function to get all registered keys from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
@ -299,7 +303,9 @@ def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
|
||||
return term_registry.get_keys()
|
||||
|
||||
|
||||
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
|
||||
def get_terms(
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> List[data.Term]:
|
||||
"""Convenience function to get all registered terms from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
@ -342,7 +348,7 @@ class TagInfo(BaseModel):
|
||||
|
||||
def get_tag_from_info(
|
||||
tag_info: TagInfo,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> data.Tag:
|
||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||
|
||||
@ -368,6 +374,7 @@ def get_tag_from_info(
|
||||
If the term key specified in `tag_info.key` is not found
|
||||
in the registry.
|
||||
"""
|
||||
term_registry = term_registry or default_term_registry
|
||||
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
||||
return data.Tag(term=term, value=tag_info.value)
|
||||
|
||||
@ -439,7 +446,7 @@ class TermConfig(BaseModel):
|
||||
def load_terms_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> Dict[str, data.Term]:
|
||||
"""Loads term definitions from a configuration file and registers them.
|
||||
|
||||
@ -490,6 +497,6 @@ def load_terms_from_config(
|
||||
|
||||
|
||||
def register_term(
|
||||
key: str, term: data.Term, registry: TermRegistry = term_registry
|
||||
key: str, term: data.Term, registry: TermRegistry = default_term_registry
|
||||
) -> None:
|
||||
registry.add_term(key, term)
|
@ -21,9 +21,6 @@ from batdetect2.targets.terms import (
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
term_registry as default_term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DerivationRegistry",
|
||||
@ -34,7 +31,7 @@ __all__ = [
|
||||
"TransformConfig",
|
||||
"build_transform_from_rule",
|
||||
"build_transformation_from_config",
|
||||
"derivation_registry",
|
||||
"default_derivation_registry",
|
||||
"get_derivation",
|
||||
"load_transformation_config",
|
||||
"load_transformation_from_config",
|
||||
@ -398,7 +395,7 @@ class DerivationRegistry(Mapping[str, Derivation]):
|
||||
return list(self._derivations.values())
|
||||
|
||||
|
||||
derivation_registry = DerivationRegistry()
|
||||
default_derivation_registry = DerivationRegistry()
|
||||
"""Global instance of the DerivationRegistry.
|
||||
|
||||
Register custom derivation functions here to make them available by key
|
||||
@ -409,7 +406,7 @@ in `DeriveTagRule` configuration.
|
||||
def get_derivation(
|
||||
key: str,
|
||||
import_derivation: bool = False,
|
||||
registry: DerivationRegistry = derivation_registry,
|
||||
registry: Optional[DerivationRegistry] = None,
|
||||
):
|
||||
"""Retrieve a derivation function by key, optionally importing it.
|
||||
|
||||
@ -443,6 +440,8 @@ def get_derivation(
|
||||
AttributeError
|
||||
If dynamic import fails because the function name isn't in the module.
|
||||
"""
|
||||
registry = registry or default_derivation_registry
|
||||
|
||||
if not import_derivation or key in registry:
|
||||
return registry.get_derivation(key)
|
||||
|
||||
@ -458,10 +457,16 @@ def get_derivation(
|
||||
) from err
|
||||
|
||||
|
||||
TranformationRule = Annotated[
|
||||
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||
Field(discriminator="rule_type"),
|
||||
]
|
||||
|
||||
|
||||
def build_transform_from_rule(
|
||||
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
rule: TranformationRule,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a specific SoundEventTransformation function from a rule config.
|
||||
|
||||
@ -559,8 +564,8 @@ def build_transform_from_rule(
|
||||
|
||||
def build_transformation_from_config(
|
||||
config: TransformConfig,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a composite transformation function from a TransformConfig.
|
||||
|
||||
@ -581,6 +586,7 @@ def build_transformation_from_config(
|
||||
SoundEventTransformation
|
||||
A single function that applies all configured transformations in order.
|
||||
"""
|
||||
|
||||
transforms = [
|
||||
build_transform_from_rule(
|
||||
rule,
|
||||
@ -590,14 +596,16 @@ def build_transformation_from_config(
|
||||
for rule in config.rules
|
||||
]
|
||||
|
||||
def transformation(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
for transform in transforms:
|
||||
sound_event_annotation = transform(sound_event_annotation)
|
||||
return sound_event_annotation
|
||||
return partial(apply_sequence_of_transforms, transforms=transforms)
|
||||
|
||||
return transformation
|
||||
|
||||
def apply_sequence_of_transforms(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
transforms: list[SoundEventTransformation],
|
||||
) -> data.SoundEventAnnotation:
|
||||
for transform in transforms:
|
||||
sound_event_annotation = transform(sound_event_annotation)
|
||||
return sound_event_annotation
|
||||
|
||||
|
||||
def load_transformation_config(
|
||||
@ -631,8 +639,8 @@ def load_transformation_config(
|
||||
def load_transformation_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
term_registry: Optional[TermRegistry] = None,
|
||||
) -> SoundEventTransformation:
|
||||
"""Load transformation config from a file and build the final function.
|
||||
|
||||
@ -677,7 +685,7 @@ def load_transformation_from_config(
|
||||
def register_derivation(
|
||||
key: str,
|
||||
derivation: Derivation,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
derivation_registry: Optional[DerivationRegistry] = None,
|
||||
) -> None:
|
||||
"""Register a new derivation function in the global registry.
|
||||
|
||||
@ -696,4 +704,5 @@ def register_derivation(
|
||||
KeyError
|
||||
If a derivation function with the same key is already registered.
|
||||
"""
|
||||
derivation_registry = derivation_registry or default_derivation_registry
|
||||
derivation_registry.register(key, derivation)
|
@ -19,8 +19,16 @@ from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"TargetProtocol",
|
||||
"Position",
|
||||
"Size",
|
||||
]
|
||||
|
||||
Position = tuple[float, float]
|
||||
"""A tuple representing (time, frequency) coordinates."""
|
||||
|
||||
Size = np.ndarray
|
||||
"""A NumPy array representing the size dimensions of a target."""
|
||||
|
||||
|
||||
class TargetProtocol(Protocol):
|
||||
"""Protocol defining the interface for the target definition pipeline.
|
||||
@ -102,7 +110,7 @@ class TargetProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def encode(
|
||||
def encode_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
) -> Optional[str]:
|
||||
@ -123,7 +131,7 @@ class TargetProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def decode(self, class_label: str) -> List[data.Tag]:
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Parameters
|
||||
@ -147,9 +155,9 @@ class TargetProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_position(
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[float, float]:
|
||||
) -> tuple[Position, Size]:
|
||||
"""Extract the target reference position from the annotation's geometry.
|
||||
|
||||
Calculates the `(time, frequency)` coordinate representing the primary
|
||||
@ -173,36 +181,12 @@ class TargetProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||
"""Calculate the target size dimensions from the annotation's geometry.
|
||||
|
||||
Computes the relevant physical size (e.g., duration/width,
|
||||
bandwidth/height from a bounding box) to produce
|
||||
the numerical target values expected by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI) to process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
A NumPy array containing the size dimensions, matching the
|
||||
order specified by the `dimension_names` attribute (e.g.,
|
||||
`[width, height]`).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry or if the size cannot be computed.
|
||||
TypeError
|
||||
If geometry type is unsupported.
|
||||
"""
|
||||
...
|
||||
|
||||
def recover_roi(
|
||||
self, pos: tuple[float, float], dims: np.ndarray
|
||||
# TODO: Update docstrings
|
||||
def decode_roi(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: Optional[str] = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover the ROI geometry from a position and dimensions.
|
||||
|
||||
@ -217,6 +201,8 @@ class TargetProtocol(Protocol):
|
||||
dims : np.ndarray
|
||||
The NumPy array containing the dimensions (e.g., predicted
|
||||
by the model), corresponding to the order in `dimension_names`.
|
||||
class_name: str
|
||||
class
|
||||
|
||||
Returns
|
||||
-------
|
@ -97,7 +97,7 @@ def _is_in_subclip(
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> bool:
|
||||
time, _ = targets.get_position(sound_event_annotation)
|
||||
(time, _), _ = targets.encode_roi(sound_event_annotation)
|
||||
return start_time <= time <= end_time
|
||||
|
||||
|
@ -138,7 +138,7 @@ def generate_clip_label(
|
||||
logger.debug(
|
||||
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
|
||||
uuid=clip_annotation.uuid,
|
||||
num=len(clip_annotation.sound_events)
|
||||
num=len(clip_annotation.sound_events),
|
||||
)
|
||||
|
||||
sound_events = []
|
||||
@ -260,7 +260,7 @@ def generate_heatmaps(
|
||||
continue
|
||||
|
||||
# Get the position of the sound event
|
||||
time, frequency = targets.get_position(sound_event_annotation)
|
||||
(time, frequency), size = targets.encode_roi(sound_event_annotation)
|
||||
|
||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||
try:
|
||||
@ -280,8 +280,6 @@ def generate_heatmaps(
|
||||
)
|
||||
continue
|
||||
|
||||
size = targets.get_size(sound_event_annotation)
|
||||
|
||||
size_heatmap = arrays.set_value_at_pos(
|
||||
size_heatmap,
|
||||
size,
|
||||
@ -291,7 +289,7 @@ def generate_heatmaps(
|
||||
|
||||
# Get the class name of the sound event
|
||||
try:
|
||||
class_name = targets.encode(sound_event_annotation)
|
||||
class_name = targets.encode_class(sound_event_annotation)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Skipping annotation %s: Unexpected error while encoding "
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user