mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 15:12:06 +02:00
Compare commits
13 Commits
ebad489cb1
...
a462beaeb8
Author | SHA1 | Date | |
---|---|---|---|
![]() |
a462beaeb8 | ||
![]() |
daab8ff0d7 | ||
![]() |
235f0e27da | ||
![]() |
b5b4229990 | ||
![]() |
8253b5bdc4 | ||
![]() |
c7ea361cf4 | ||
![]() |
3407e1b5f0 | ||
![]() |
0a0d6f7162 | ||
![]() |
ad0f0bcb24 | ||
![]() |
3103630c26 | ||
![]() |
960558be8b | ||
![]() |
e352dc40bd | ||
![]() |
c559bcc682 |
2
.gitignore
vendored
2
.gitignore
vendored
@ -107,7 +107,7 @@ experiments/*
|
|||||||
|
|
||||||
# DO Include
|
# DO Include
|
||||||
!batdetect2_notebook.ipynb
|
!batdetect2_notebook.ipynb
|
||||||
!batdetect2/models/checkpoints/*.pth.tar
|
!src/batdetect2/models/checkpoints/*.pth.tar
|
||||||
!tests/data/*.wav
|
!tests/data/*.wav
|
||||||
!notebooks/*.ipynb
|
!notebooks/*.ipynb
|
||||||
!tests/data/**/*.wav
|
!tests/data/**/*.wav
|
||||||
|
4
Makefile
4
Makefile
@ -1,7 +1,7 @@
|
|||||||
# Variables
|
# Variables
|
||||||
SOURCE_DIR = batdetect2
|
SOURCE_DIR = src
|
||||||
TESTS_DIR = tests
|
TESTS_DIR = tests
|
||||||
PYTHON_DIRS = batdetect2 tests
|
PYTHON_DIRS = src tests
|
||||||
DOCS_SOURCE = docs/source
|
DOCS_SOURCE = docs/source
|
||||||
DOCS_BUILD = docs/build
|
DOCS_BUILD = docs/build
|
||||||
HTML_COVERAGE_DIR = htmlcov
|
HTML_COVERAGE_DIR = htmlcov
|
||||||
|
@ -1,511 +0,0 @@
|
|||||||
"""Handles mapping between geometric ROIs and target representations.
|
|
||||||
|
|
||||||
This module defines the interface and provides implementation for converting
|
|
||||||
a sound event's Region of Interest (ROI), typically represented by a
|
|
||||||
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
|
|
||||||
suitable for use as a machine learning target. This usually involves:
|
|
||||||
|
|
||||||
1. Extracting a single reference point (time, frequency) from the geometry.
|
|
||||||
2. Calculating relevant size dimensions (e.g., duration/width,
|
|
||||||
bandwidth/height) and applying scaling factors.
|
|
||||||
|
|
||||||
It also provides the inverse operation: recovering an approximate geometric ROI
|
|
||||||
(like a `BoundingBox`) from a predicted reference point and predicted size
|
|
||||||
dimensions.
|
|
||||||
|
|
||||||
This logic is encapsulated within components adhering to the `ROITargetMapper`
|
|
||||||
protocol. Configuration for this mapping (e.g., which reference point to use,
|
|
||||||
scaling factors) is managed by the `ROIConfig`. This module separates the
|
|
||||||
*geometric* aspect of target definition from the *semantic* classification
|
|
||||||
handled in `batdetect2.targets.classes`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Literal, Optional, Protocol, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
|
|
||||||
Positions = Literal[
|
|
||||||
"bottom-left",
|
|
||||||
"bottom-right",
|
|
||||||
"top-left",
|
|
||||||
"top-right",
|
|
||||||
"center-left",
|
|
||||||
"center-right",
|
|
||||||
"top-center",
|
|
||||||
"bottom-center",
|
|
||||||
"center",
|
|
||||||
"centroid",
|
|
||||||
"point_on_surface",
|
|
||||||
]
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ROITargetMapper",
|
|
||||||
"ROIConfig",
|
|
||||||
"BBoxEncoder",
|
|
||||||
"build_roi_mapper",
|
|
||||||
"load_roi_mapper",
|
|
||||||
"DEFAULT_POSITION",
|
|
||||||
"SIZE_WIDTH",
|
|
||||||
"SIZE_HEIGHT",
|
|
||||||
"SIZE_ORDER",
|
|
||||||
"DEFAULT_TIME_SCALE",
|
|
||||||
"DEFAULT_FREQUENCY_SCALE",
|
|
||||||
]
|
|
||||||
|
|
||||||
SIZE_WIDTH = "width"
|
|
||||||
"""Standard name for the width/time dimension component ('width')."""
|
|
||||||
|
|
||||||
SIZE_HEIGHT = "height"
|
|
||||||
"""Standard name for the height/frequency dimension component ('height')."""
|
|
||||||
|
|
||||||
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
|
|
||||||
"""Standard order of dimensions for size arrays ([width, height])."""
|
|
||||||
|
|
||||||
DEFAULT_TIME_SCALE = 1000.0
|
|
||||||
"""Default scaling factor for time duration."""
|
|
||||||
|
|
||||||
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
|
|
||||||
"""Default scaling factor for frequency bandwidth."""
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_POSITION = "bottom-left"
|
|
||||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
|
||||||
|
|
||||||
|
|
||||||
class ROITargetMapper(Protocol):
|
|
||||||
"""Protocol defining the interface for ROI-to-target mapping.
|
|
||||||
|
|
||||||
Specifies the methods required for converting a geometric region of interest
|
|
||||||
(`soundevent.data.Geometry`) into a target representation (reference point
|
|
||||||
and scaled dimensions) and for recovering an approximate ROI from that
|
|
||||||
representation.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
dimension_names : List[str]
|
|
||||||
A list containing the names of the dimensions returned by
|
|
||||||
`get_roi_size` and expected by `recover_roi`
|
|
||||||
(e.g., ['width', 'height']).
|
|
||||||
"""
|
|
||||||
|
|
||||||
dimension_names: List[str]
|
|
||||||
|
|
||||||
def get_roi_position(
|
|
||||||
self,
|
|
||||||
geom: data.Geometry,
|
|
||||||
position: Optional[Positions] = None,
|
|
||||||
) -> tuple[float, float]:
|
|
||||||
"""Extract the reference position from a geometry.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
geom : soundevent.data.Geometry
|
|
||||||
The input geometry (e.g., BoundingBox, Polygon).
|
|
||||||
position : Positions, optional
|
|
||||||
Overrides the default `position` configured for the mapper.
|
|
||||||
If provided, this position will be used instead of the mapper's
|
|
||||||
internal default.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Tuple[float, float]
|
|
||||||
The calculated reference position as (time, frequency) coordinates,
|
|
||||||
based on the implementing class's configuration (e.g., "center",
|
|
||||||
"bottom-left").
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the position cannot be calculated for the given geometry type
|
|
||||||
or configured reference point.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
|
||||||
"""Calculate the scaled target dimensions from a geometry.
|
|
||||||
|
|
||||||
Computes the relevant size measures.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
geom : soundevent.data.Geometry
|
|
||||||
The input geometry.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
A NumPy array containing the scaled dimensions corresponding to
|
|
||||||
`dimension_names`. For bounding boxes, typically contains
|
|
||||||
`[scaled_width, scaled_height]`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
TypeError, ValueError
|
|
||||||
If the size cannot be computed for the given geometry type.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def recover_roi(
|
|
||||||
self,
|
|
||||||
pos: tuple[float, float],
|
|
||||||
dims: np.ndarray,
|
|
||||||
position: Optional[Positions] = None,
|
|
||||||
) -> data.Geometry:
|
|
||||||
"""Recover an approximate ROI from a position and target dimensions.
|
|
||||||
|
|
||||||
Performs the inverse mapping: takes a reference position and the
|
|
||||||
predicted dimensions and reconstructs a geometric representation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
pos : Tuple[float, float]
|
|
||||||
The reference position (time, frequency).
|
|
||||||
dims : np.ndarray
|
|
||||||
NumPy array containing the dimensions, matching the order
|
|
||||||
specified by `dimension_names`.
|
|
||||||
position : Positions, optional
|
|
||||||
Overrides the default `position` configured for the mapper.
|
|
||||||
If provided, this position will be used instead of the mapper's
|
|
||||||
internal default when reconstructing the roi geometry.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.Geometry
|
|
||||||
The reconstructed geometry.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the number of provided dimensions `dims` does not match
|
|
||||||
`dimension_names` or if reconstruction fails.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class ROIConfig(BaseConfig):
|
|
||||||
"""Configuration for mapping Regions of Interest (ROIs).
|
|
||||||
|
|
||||||
Defines parameters controlling how geometric ROIs are converted into
|
|
||||||
target representations (reference points and scaled sizes).
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
position : Positions, default="bottom-left"
|
|
||||||
Specifies the reference point within the geometry (e.g., bounding box)
|
|
||||||
to use as the target location (e.g., "center", "bottom-left").
|
|
||||||
See `soundevent.geometry.operations.Positions`.
|
|
||||||
time_scale : float, default=1000.0
|
|
||||||
Scaling factor applied to the time duration (width) of the ROI
|
|
||||||
when calculating the target size representation. Must match model
|
|
||||||
expectations.
|
|
||||||
frequency_scale : float, default=1/859.375
|
|
||||||
Scaling factor applied to the frequency bandwidth (height) of the ROI
|
|
||||||
when calculating the target size representation. Must match model
|
|
||||||
expectations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
position: Positions = DEFAULT_POSITION
|
|
||||||
time_scale: float = DEFAULT_TIME_SCALE
|
|
||||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
|
||||||
|
|
||||||
|
|
||||||
class BBoxEncoder(ROITargetMapper):
|
|
||||||
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
|
|
||||||
|
|
||||||
This class implements the ROI mapping protocol primarily for
|
|
||||||
`soundevent.data.BoundingBox` geometry. It extracts reference points,
|
|
||||||
calculates scaled width/height, and recovers bounding boxes based on
|
|
||||||
configured position and scaling factors.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
dimension_names : List[str]
|
|
||||||
Specifies the output dimension names as ['width', 'height'].
|
|
||||||
position : Positions
|
|
||||||
The configured reference point type (e.g., "center", "bottom-left").
|
|
||||||
time_scale : float
|
|
||||||
The configured scaling factor for the time dimension (width).
|
|
||||||
frequency_scale : float
|
|
||||||
The configured scaling factor for the frequency dimension (height).
|
|
||||||
"""
|
|
||||||
|
|
||||||
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
position: Positions = DEFAULT_POSITION,
|
|
||||||
time_scale: float = DEFAULT_TIME_SCALE,
|
|
||||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
|
||||||
):
|
|
||||||
"""Initialize the BBoxEncoder.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
position : Positions, default="bottom-left"
|
|
||||||
Reference point type within the bounding box.
|
|
||||||
time_scale : float, default=1000.0
|
|
||||||
Scaling factor for time duration (width).
|
|
||||||
frequency_scale : float, default=1/859.375
|
|
||||||
Scaling factor for frequency bandwidth (height).
|
|
||||||
"""
|
|
||||||
self.position: Positions = position
|
|
||||||
self.time_scale = time_scale
|
|
||||||
self.frequency_scale = frequency_scale
|
|
||||||
|
|
||||||
def get_roi_position(
|
|
||||||
self,
|
|
||||||
geom: data.Geometry,
|
|
||||||
position: Optional[Positions] = None,
|
|
||||||
) -> Tuple[float, float]:
|
|
||||||
"""Extract the configured reference position from the geometry.
|
|
||||||
|
|
||||||
Uses `soundevent.geometry.get_geometry_point`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
geom : soundevent.data.Geometry
|
|
||||||
Input geometry (e.g., BoundingBox).
|
|
||||||
position : Positions, optional
|
|
||||||
Overrides the default `position` configured for the encoder.
|
|
||||||
If provided, this position will be used instead of `self.position`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Tuple[float, float]
|
|
||||||
Reference position (time, frequency).
|
|
||||||
"""
|
|
||||||
from soundevent import geometry
|
|
||||||
|
|
||||||
position = position or self.position
|
|
||||||
return geometry.get_geometry_point(geom, position=position)
|
|
||||||
|
|
||||||
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
|
||||||
"""Calculate the scaled [width, height] from the geometry's bounds.
|
|
||||||
|
|
||||||
Computes the bounding box, extracts duration and bandwidth, and applies
|
|
||||||
the configured `time_scale` and `frequency_scale`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
geom : soundevent.data.Geometry
|
|
||||||
Input geometry.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
A 1D NumPy array: `[scaled_width, scaled_height]`.
|
|
||||||
"""
|
|
||||||
from soundevent import geometry
|
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
|
||||||
geom
|
|
||||||
)
|
|
||||||
return np.array(
|
|
||||||
[
|
|
||||||
(end_time - start_time) * self.time_scale,
|
|
||||||
(high_freq - low_freq) * self.frequency_scale,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def recover_roi(
|
|
||||||
self,
|
|
||||||
pos: tuple[float, float],
|
|
||||||
dims: np.ndarray,
|
|
||||||
position: Optional[Positions] = None,
|
|
||||||
) -> data.Geometry:
|
|
||||||
"""Recover a BoundingBox from a position and scaled dimensions.
|
|
||||||
|
|
||||||
Un-scales the input dimensions using the configured factors and
|
|
||||||
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
|
|
||||||
the given reference `pos` according to the configured `position` type.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
pos : Tuple[float, float]
|
|
||||||
Reference position (time, frequency).
|
|
||||||
dims : np.ndarray
|
|
||||||
NumPy array containing the *scaled* dimensions, expected order is
|
|
||||||
[scaled_width, scaled_height].
|
|
||||||
position : Positions, optional
|
|
||||||
Overrides the default `position` configured for the encoder.
|
|
||||||
If provided, this position will be used instead of `self.position`
|
|
||||||
when reconstructing the bounding box.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
soundevent.data.BoundingBox
|
|
||||||
The reconstructed bounding box.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `dims` does not have the expected shape (length 2).
|
|
||||||
"""
|
|
||||||
position = position or self.position
|
|
||||||
|
|
||||||
if dims.ndim != 1 or dims.shape[0] != 2:
|
|
||||||
raise ValueError(
|
|
||||||
"Dimension array does not have the expected shape. "
|
|
||||||
f"({dims.shape = }) != ([2])"
|
|
||||||
)
|
|
||||||
|
|
||||||
width, height = dims
|
|
||||||
return _build_bounding_box(
|
|
||||||
pos,
|
|
||||||
duration=float(width) / self.time_scale,
|
|
||||||
bandwidth=float(height) / self.frequency_scale,
|
|
||||||
position=self.position,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
|
|
||||||
"""Factory function to create an ROITargetMapper from configuration.
|
|
||||||
|
|
||||||
Currently creates a `BBoxEncoder` instance based on the provided
|
|
||||||
`ROIConfig`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : ROIConfig
|
|
||||||
Configuration object specifying ROI mapping parameters.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
ROITargetMapper
|
|
||||||
An initialized `BBoxEncoder` instance configured with the settings
|
|
||||||
from `config`.
|
|
||||||
"""
|
|
||||||
return BBoxEncoder(
|
|
||||||
position=config.position,
|
|
||||||
time_scale=config.time_scale,
|
|
||||||
frequency_scale=config.frequency_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_roi_mapper(
|
|
||||||
path: data.PathLike, field: Optional[str] = None
|
|
||||||
) -> ROITargetMapper:
|
|
||||||
"""Load ROI mapping configuration from a file and build the mapper.
|
|
||||||
|
|
||||||
Convenience function that loads an `ROIConfig` from the specified file
|
|
||||||
(and optional field) and then uses `build_roi_mapper` to create the
|
|
||||||
corresponding `ROITargetMapper` instance.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the configuration file (e.g., YAML).
|
|
||||||
field : str, optional
|
|
||||||
Dot-separated path to a nested section within the file containing the
|
|
||||||
ROI configuration. If None, the entire file content is used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
ROITargetMapper
|
|
||||||
An initialized ROI mapper instance based on the configuration file.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
|
||||||
TypeError
|
|
||||||
If the configuration file cannot be found, parsed, validated, or if
|
|
||||||
the specified `field` is invalid.
|
|
||||||
"""
|
|
||||||
config = load_config(path=path, schema=ROIConfig, field=field)
|
|
||||||
return build_roi_mapper(config)
|
|
||||||
|
|
||||||
|
|
||||||
VALID_POSITIONS = [
|
|
||||||
"bottom-left",
|
|
||||||
"bottom-right",
|
|
||||||
"top-left",
|
|
||||||
"top-right",
|
|
||||||
"center-left",
|
|
||||||
"center-right",
|
|
||||||
"top-center",
|
|
||||||
"bottom-center",
|
|
||||||
"center",
|
|
||||||
"centroid",
|
|
||||||
"point_on_surface",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _build_bounding_box(
|
|
||||||
pos: tuple[float, float],
|
|
||||||
duration: float,
|
|
||||||
bandwidth: float,
|
|
||||||
position: Positions = DEFAULT_POSITION,
|
|
||||||
) -> data.BoundingBox:
|
|
||||||
"""Construct a BoundingBox from a reference point, size, and position type.
|
|
||||||
|
|
||||||
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
|
|
||||||
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
|
||||||
the input `pos` (time, freq) is located relative to the box (e.g.,
|
|
||||||
center, corner).
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
pos : Tuple[float, float]
|
|
||||||
Reference position (time, frequency).
|
|
||||||
duration : float
|
|
||||||
The required *unscaled* duration (width) of the bounding box.
|
|
||||||
bandwidth : float
|
|
||||||
The required *unscaled* frequency bandwidth (height) of the bounding
|
|
||||||
box.
|
|
||||||
position : Positions, default="bottom-left"
|
|
||||||
Specifies which part of the bounding box the input `pos` corresponds to.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
data.BoundingBox
|
|
||||||
The constructed bounding box object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `position` is not a recognized value or format.
|
|
||||||
"""
|
|
||||||
time, freq = map(float, pos)
|
|
||||||
duration = max(0, duration)
|
|
||||||
bandwidth = max(0, bandwidth)
|
|
||||||
if position in ["center", "centroid", "point_on_surface"]:
|
|
||||||
return data.BoundingBox(
|
|
||||||
coordinates=[
|
|
||||||
max(time - duration / 2, 0),
|
|
||||||
max(freq - bandwidth / 2, 0),
|
|
||||||
max(time + duration / 2, 0),
|
|
||||||
max(freq + bandwidth / 2, 0),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if position not in VALID_POSITIONS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid position: {position}. "
|
|
||||||
f"Valid options are: {VALID_POSITIONS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
y, x = position.split("-")
|
|
||||||
|
|
||||||
start_time = {
|
|
||||||
"left": time,
|
|
||||||
"center": time - duration / 2,
|
|
||||||
"right": time - duration,
|
|
||||||
}[x]
|
|
||||||
|
|
||||||
low_freq = {
|
|
||||||
"bottom": freq,
|
|
||||||
"center": freq - bandwidth / 2,
|
|
||||||
"top": freq - bandwidth,
|
|
||||||
}[y]
|
|
||||||
|
|
||||||
return data.BoundingBox(
|
|
||||||
coordinates=[
|
|
||||||
max(0, start_time),
|
|
||||||
max(0, low_freq),
|
|
||||||
max(0, start_time + duration),
|
|
||||||
max(0, low_freq + bandwidth),
|
|
||||||
]
|
|
||||||
)
|
|
19
notebooks/signal_generation.py
Normal file
19
notebooks/signal_generation.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import marimo
|
||||||
|
|
||||||
|
__generated_with = "0.13.15"
|
||||||
|
app = marimo.App(width="medium")
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
@ -65,8 +65,12 @@ build-backend = "hatchling.build"
|
|||||||
[project.scripts]
|
[project.scripts]
|
||||||
batdetect2 = "batdetect2.cli:cli"
|
batdetect2 = "batdetect2.cli:cli"
|
||||||
|
|
||||||
[tool.uv]
|
[dependency-groups]
|
||||||
dev-dependencies = [
|
jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"]
|
||||||
|
marimo = [
|
||||||
|
"marimo>=0.12.2",
|
||||||
|
]
|
||||||
|
dev = [
|
||||||
"debugpy>=1.8.8",
|
"debugpy>=1.8.8",
|
||||||
"hypothesis>=6.118.7",
|
"hypothesis>=6.118.7",
|
||||||
"pytest>=7.2.2",
|
"pytest>=7.2.2",
|
||||||
@ -98,25 +102,16 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
|||||||
convention = "numpy"
|
convention = "numpy"
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
include = ["batdetect2", "tests"]
|
include = ["src", "tests"]
|
||||||
pythonVersion = "3.9"
|
pythonVersion = "3.9"
|
||||||
pythonPlatform = "All"
|
pythonPlatform = "All"
|
||||||
exclude = [
|
exclude = [
|
||||||
"batdetect2/detector/",
|
"src/batdetect2/detector/",
|
||||||
"batdetect2/finetune",
|
"src/batdetect2/finetune",
|
||||||
"batdetect2/utils",
|
"src/batdetect2/utils",
|
||||||
"batdetect2/plotting",
|
"src/batdetect2/plotting",
|
||||||
"batdetect2/plot",
|
"src/batdetect2/plot",
|
||||||
"batdetect2/api",
|
"src/batdetect2/api",
|
||||||
"batdetect2/evaluate/legacy",
|
"src/batdetect2/evaluate/legacy",
|
||||||
"batdetect2/train/legacy",
|
"src/batdetect2/train/legacy",
|
||||||
]
|
|
||||||
|
|
||||||
[dependency-groups]
|
|
||||||
jupyter = [
|
|
||||||
"ipywidgets>=8.1.5",
|
|
||||||
"jupyter>=1.1.1",
|
|
||||||
]
|
|
||||||
marimo = [
|
|
||||||
"marimo>=0.12.2",
|
|
||||||
]
|
]
|
||||||
|
@ -189,8 +189,7 @@ def train_command(
|
|||||||
config=postprocess_config_loaded,
|
config=postprocess_config_loaded,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded postprocessor from file {path}",
|
"Loaded postprocessor from file {path}", path=postprocess_config
|
||||||
path=train_config,
|
|
||||||
)
|
)
|
||||||
except IOError:
|
except IOError:
|
||||||
logger.debug(
|
logger.debug(
|
@ -157,4 +157,4 @@ def load_config(
|
|||||||
if field:
|
if field:
|
||||||
config = get_object_field(config, field)
|
config = get_object_field(config, field)
|
||||||
|
|
||||||
return schema.model_validate(config)
|
return schema.model_validate(config or {})
|
@ -8,6 +8,7 @@ from batdetect2.data.annotations import (
|
|||||||
from batdetect2.data.datasets import (
|
from batdetect2.data.datasets import (
|
||||||
DatasetConfig,
|
DatasetConfig,
|
||||||
load_dataset,
|
load_dataset,
|
||||||
|
load_dataset_config,
|
||||||
load_dataset_from_config,
|
load_dataset_from_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,5 +20,6 @@ __all__ = [
|
|||||||
"DatasetConfig",
|
"DatasetConfig",
|
||||||
"load_annotated_dataset",
|
"load_annotated_dataset",
|
||||||
"load_dataset",
|
"load_dataset",
|
||||||
|
"load_dataset_config",
|
||||||
"load_dataset_from_config",
|
"load_dataset_from_config",
|
||||||
]
|
]
|
@ -161,6 +161,11 @@ def insert_source_tag(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add documentation
|
||||||
|
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||||
|
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_from_config(
|
def load_dataset_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
@ -72,7 +72,7 @@ def iterate_over_sound_events(
|
|||||||
sound_event_annotation
|
sound_event_annotation
|
||||||
)
|
)
|
||||||
|
|
||||||
class_name = targets.encode(sound_event_annotation)
|
class_name = targets.encode_class(sound_event_annotation)
|
||||||
if class_name is None and exclude_generic:
|
if class_name is None and exclude_generic:
|
||||||
continue
|
continue
|
||||||
|
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.metrics import auc, roc_curve
|
from sklearn.metrics import auc, roc_curve
|
||||||
|
|
@ -40,7 +40,7 @@ def match_sound_events_and_raw_predictions(
|
|||||||
|
|
||||||
gt_uuid = target.uuid if target is not None else None
|
gt_uuid = target.uuid if target is not None else None
|
||||||
gt_det = target is not None
|
gt_det = target is not None
|
||||||
gt_class = targets.encode(target) if target is not None else None
|
gt_class = targets.encode_class(target) if target is not None else None
|
||||||
|
|
||||||
pred_score = float(prediction.detection_score) if prediction else 0
|
pred_score = float(prediction.detection_score) if prediction else 0
|
||||||
|
|
@ -7,8 +7,8 @@ containing detected sound events with associated class tags and geometry.
|
|||||||
|
|
||||||
The pipeline involves several configurable steps, implemented in submodules:
|
The pipeline involves several configurable steps, implemented in submodules:
|
||||||
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||||
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
|
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
|
||||||
coordinates to raw model output arrays.
|
model output arrays.
|
||||||
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||||
(location and score) based on thresholds and score ranking (top-k).
|
(location and score) based on thresholds and score ranking (top-k).
|
||||||
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||||
@ -526,7 +526,7 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
return [
|
return [
|
||||||
convert_xr_dataset_to_raw_prediction(
|
convert_xr_dataset_to_raw_prediction(
|
||||||
dataset,
|
dataset,
|
||||||
self.targets.recover_roi,
|
self.targets.decode_roi,
|
||||||
)
|
)
|
||||||
for dataset in detection_datasets
|
for dataset in detection_datasets
|
||||||
]
|
]
|
||||||
@ -558,7 +558,7 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
convert_raw_predictions_to_clip_prediction(
|
convert_raw_predictions_to_clip_prediction(
|
||||||
prediction,
|
prediction,
|
||||||
clip,
|
clip,
|
||||||
sound_event_decoder=self.targets.decode,
|
sound_event_decoder=self.targets.decode_class,
|
||||||
generic_class_tags=self.targets.generic_class_tags,
|
generic_class_tags=self.targets.generic_class_tags,
|
||||||
classification_threshold=self.config.classification_threshold,
|
classification_threshold=self.config.classification_threshold,
|
||||||
)
|
)
|
@ -4,8 +4,7 @@ This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
|||||||
It takes the structured detection data extracted by the `extraction` module
|
It takes the structured detection data extracted by the `extraction` module
|
||||||
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||||
class probabilities, and features for each detection point) and converts it
|
class probabilities, and features for each detection point) and converts it
|
||||||
into meaningful, standardized prediction objects based on the `soundevent` data
|
into standardized prediction objects based on the `soundevent` data model.
|
||||||
model.
|
|
||||||
|
|
||||||
The process involves:
|
The process involves:
|
||||||
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||||
@ -33,7 +32,7 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||||
from batdetect2.targets.classes import SoundEventDecoder
|
from batdetect2.targets.classes import SoundEventDecoder
|
||||||
from batdetect2.utils.arrays import iterate_over_array
|
from batdetect2.utils.arrays import iterate_over_array
|
||||||
|
|
||||||
@ -55,7 +54,7 @@ decoding.
|
|||||||
|
|
||||||
def convert_xr_dataset_to_raw_prediction(
|
def convert_xr_dataset_to_raw_prediction(
|
||||||
detection_dataset: xr.Dataset,
|
detection_dataset: xr.Dataset,
|
||||||
geometry_builder: GeometryBuilder,
|
geometry_decoder: GeometryDecoder,
|
||||||
) -> List[RawPrediction]:
|
) -> List[RawPrediction]:
|
||||||
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||||
|
|
||||||
@ -72,7 +71,7 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
output by `extract_detection_xr_dataset`. Expected variables include
|
output by `extract_detection_xr_dataset`. Expected variables include
|
||||||
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||||
Must have a 'detection' dimension.
|
Must have a 'detection' dimension.
|
||||||
geometry_builder : GeometryBuilder
|
geometry_decoder : GeometryDecoder
|
||||||
A function that takes a position tuple `(time, freq)` and a NumPy array
|
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||||
of dimensions, and returns the corresponding reconstructed
|
of dimensions, and returns the corresponding reconstructed
|
||||||
`soundevent.data.Geometry`.
|
`soundevent.data.Geometry`.
|
||||||
@ -96,14 +95,20 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
for det_num in range(detection_dataset.sizes["detection"]):
|
for det_num in range(detection_dataset.sizes["detection"]):
|
||||||
det_info = detection_dataset.sel(detection=det_num)
|
det_info = detection_dataset.sel(detection=det_num)
|
||||||
|
|
||||||
geom = geometry_builder(
|
# TODO: Maybe clean this up
|
||||||
|
highest_scoring_class = det_info.coords["category"][
|
||||||
|
det_info["classes"].argmax()
|
||||||
|
].item()
|
||||||
|
|
||||||
|
geom = geometry_decoder(
|
||||||
(det_info.time, det_info.frequency),
|
(det_info.time, det_info.frequency),
|
||||||
det_info.dimensions,
|
det_info.dimensions,
|
||||||
|
class_name=highest_scoring_class,
|
||||||
)
|
)
|
||||||
|
|
||||||
detections.append(
|
detections.append(
|
||||||
RawPrediction(
|
RawPrediction(
|
||||||
detection_score=det_info.score,
|
detection_score=det_info.scores,
|
||||||
geometry=geom,
|
geometry=geom,
|
||||||
class_scores=det_info.classes,
|
class_scores=det_info.classes,
|
||||||
features=det_info.features,
|
features=det_info.features,
|
@ -1,6 +1,6 @@
|
|||||||
"""Extracts candidate detection points from a model output heatmap.
|
"""Extracts candidate detection points from a model output heatmap.
|
||||||
|
|
||||||
This module implements a specific step within the BatDetect2 postprocessing
|
This module implements Step 3 within the BatDetect2 postprocessing
|
||||||
pipeline. Its primary function is to identify potential sound event locations
|
pipeline. Its primary function is to identify potential sound event locations
|
||||||
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||||
produced by the neural network (usually after Non-Maximum Suppression and
|
produced by the neural network (usually after Non-Maximum Suppression and
|
@ -1,9 +1,9 @@
|
|||||||
"""Extracts associated data for detected points from model output arrays.
|
"""Extracts associated data for detected points from model output arrays.
|
||||||
|
|
||||||
This module implements a key step (Step 4) in the BatDetect2 postprocessing
|
This module implements a Step 4 in the BatDetect2 postprocessing pipeline.
|
||||||
pipeline. After candidate detection points (time, frequency, score) have been
|
After candidate detection points (time, frequency, score) have been identified,
|
||||||
identified, this module extracts the corresponding values from other raw model
|
this module extracts the corresponding values from other raw model output
|
||||||
output arrays, such as:
|
arrays, such as:
|
||||||
|
|
||||||
- Predicted bounding box sizes (width, height).
|
- Predicted bounding box sizes (width, height).
|
||||||
- Class probability scores for each defined target class.
|
- Class probability scores for each defined target class.
|
@ -11,30 +11,37 @@ modularity and consistent interaction between different parts of the BatDetect2
|
|||||||
system that deal with model predictions.
|
system that deal with model predictions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Callable, List, NamedTuple, Protocol
|
from typing import List, NamedTuple, Optional, Protocol
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
|
from batdetect2.targets.types import Position, Size
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RawPrediction",
|
"RawPrediction",
|
||||||
"PostprocessorProtocol",
|
"PostprocessorProtocol",
|
||||||
"GeometryBuilder",
|
"GeometryDecoder",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
|
# TODO: update the docstring
|
||||||
"""Type alias for a function that recovers geometry from position and size.
|
class GeometryDecoder(Protocol):
|
||||||
|
"""Type alias for a function that recovers geometry from position and size.
|
||||||
|
|
||||||
This callable takes:
|
This callable takes:
|
||||||
1. A position tuple `(time, frequency)`.
|
1. A position tuple `(time, frequency)`.
|
||||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
3. Optionally a class name of the highest scoring class. This is to accomodate
|
||||||
`BoundingBox`).
|
different ways of decoding geometry that depend on the predicted class.
|
||||||
"""
|
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||||
|
`BoundingBox`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, position: Position, size: Size, class_name: Optional[str] = None
|
||||||
|
) -> data.Geometry: ...
|
||||||
|
|
||||||
|
|
||||||
class RawPrediction(NamedTuple):
|
class RawPrediction(NamedTuple):
|
@ -23,7 +23,7 @@ object is via the `build_targets` or `load_targets` functions.
|
|||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -50,7 +50,8 @@ from batdetect2.targets.filtering import (
|
|||||||
load_filter_from_config,
|
load_filter_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
ROIConfig,
|
AnchorBBoxMapperConfig,
|
||||||
|
ROIMapperConfig,
|
||||||
ROITargetMapper,
|
ROITargetMapper,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
)
|
)
|
||||||
@ -59,11 +60,11 @@ from batdetect2.targets.terms import (
|
|||||||
TermInfo,
|
TermInfo,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
call_type,
|
call_type,
|
||||||
|
default_term_registry,
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
get_term_from_key,
|
get_term_from_key,
|
||||||
individual,
|
individual,
|
||||||
register_term,
|
register_term,
|
||||||
term_registry,
|
|
||||||
)
|
)
|
||||||
from batdetect2.targets.transform import (
|
from batdetect2.targets.transform import (
|
||||||
DerivationRegistry,
|
DerivationRegistry,
|
||||||
@ -73,13 +74,13 @@ from batdetect2.targets.transform import (
|
|||||||
SoundEventTransformation,
|
SoundEventTransformation,
|
||||||
TransformConfig,
|
TransformConfig,
|
||||||
build_transformation_from_config,
|
build_transformation_from_config,
|
||||||
derivation_registry,
|
default_derivation_registry,
|
||||||
get_derivation,
|
get_derivation,
|
||||||
load_transformation_config,
|
load_transformation_config,
|
||||||
load_transformation_from_config,
|
load_transformation_from_config,
|
||||||
register_derivation,
|
register_derivation,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import Position, Size, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClassesConfig",
|
"ClassesConfig",
|
||||||
@ -88,7 +89,7 @@ __all__ = [
|
|||||||
"FilterConfig",
|
"FilterConfig",
|
||||||
"FilterRule",
|
"FilterRule",
|
||||||
"MapValueRule",
|
"MapValueRule",
|
||||||
"ROIConfig",
|
"AnchorBBoxMapperConfig",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"ReplaceRule",
|
"ReplaceRule",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
@ -156,12 +157,12 @@ class TargetConfig(BaseConfig):
|
|||||||
omitted, default ROI mapping settings are used.
|
omitted, default ROI mapping settings are used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
filtering: Optional[FilterConfig] = None
|
filtering: FilterConfig = Field(default_factory=FilterConfig)
|
||||||
transforms: Optional[TransformConfig] = None
|
transforms: TransformConfig = Field(default_factory=TransformConfig)
|
||||||
classes: ClassesConfig = Field(
|
classes: ClassesConfig = Field(
|
||||||
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
||||||
)
|
)
|
||||||
roi: Optional[ROIConfig] = None
|
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_target_config(
|
def load_target_config(
|
||||||
@ -240,6 +241,7 @@ class Targets(TargetProtocol):
|
|||||||
generic_class_tags: List[data.Tag],
|
generic_class_tags: List[data.Tag],
|
||||||
filter_fn: Optional[SoundEventFilter] = None,
|
filter_fn: Optional[SoundEventFilter] = None,
|
||||||
transform_fn: Optional[SoundEventTransformation] = None,
|
transform_fn: Optional[SoundEventTransformation] = None,
|
||||||
|
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the Targets object.
|
"""Initialize the Targets object.
|
||||||
|
|
||||||
@ -272,6 +274,16 @@ class Targets(TargetProtocol):
|
|||||||
self._encode_fn = encode_fn
|
self._encode_fn = encode_fn
|
||||||
self._decode_fn = decode_fn
|
self._decode_fn = decode_fn
|
||||||
self._transform_fn = transform_fn
|
self._transform_fn = transform_fn
|
||||||
|
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
||||||
|
|
||||||
|
for class_name in self._roi_mapper_overrides:
|
||||||
|
if class_name not in self.class_names:
|
||||||
|
# TODO: improve this warning
|
||||||
|
logger.warning(
|
||||||
|
"The ROI mapper overrides contains a class ({class_name}) "
|
||||||
|
"not present in the class names.",
|
||||||
|
class_name=class_name,
|
||||||
|
)
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||||
"""Apply the configured filter to a sound event annotation.
|
"""Apply the configured filter to a sound event annotation.
|
||||||
@ -291,7 +303,9 @@ class Targets(TargetProtocol):
|
|||||||
return True
|
return True
|
||||||
return self._filter_fn(sound_event)
|
return self._filter_fn(sound_event)
|
||||||
|
|
||||||
def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]:
|
def encode_class(
|
||||||
|
self, sound_event: data.SoundEventAnnotation
|
||||||
|
) -> Optional[str]:
|
||||||
"""Encode a sound event annotation to its target class name.
|
"""Encode a sound event annotation to its target class name.
|
||||||
|
|
||||||
Applies the configured class definition rules (including priority)
|
Applies the configured class definition rules (including priority)
|
||||||
@ -312,7 +326,7 @@ class Targets(TargetProtocol):
|
|||||||
"""
|
"""
|
||||||
return self._encode_fn(sound_event)
|
return self._encode_fn(sound_event)
|
||||||
|
|
||||||
def decode(self, class_label: str) -> List[data.Tag]:
|
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||||
"""Decode a predicted class name back into representative tags.
|
"""Decode a predicted class name back into representative tags.
|
||||||
|
|
||||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||||
@ -352,9 +366,9 @@ class Targets(TargetProtocol):
|
|||||||
return self._transform_fn(sound_event)
|
return self._transform_fn(sound_event)
|
||||||
return sound_event
|
return sound_event
|
||||||
|
|
||||||
def get_position(
|
def encode_roi(
|
||||||
self, sound_event: data.SoundEventAnnotation
|
self, sound_event: data.SoundEventAnnotation
|
||||||
) -> tuple[float, float]:
|
) -> tuple[Position, Size]:
|
||||||
"""Extract the target reference position from the annotation's roi.
|
"""Extract the target reference position from the annotation's roi.
|
||||||
|
|
||||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||||
@ -374,50 +388,20 @@ class Targets(TargetProtocol):
|
|||||||
ValueError
|
ValueError
|
||||||
If the annotation lacks geometry.
|
If the annotation lacks geometry.
|
||||||
"""
|
"""
|
||||||
geom = sound_event.sound_event.geometry
|
class_name = self.encode_class(sound_event)
|
||||||
|
|
||||||
if geom is None:
|
if class_name in self._roi_mapper_overrides:
|
||||||
raise ValueError(
|
return self._roi_mapper_overrides[class_name].encode(
|
||||||
"Sound event has no geometry, cannot get its position."
|
sound_event.sound_event
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._roi_mapper.get_roi_position(geom)
|
return self._roi_mapper.encode(sound_event.sound_event)
|
||||||
|
|
||||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
def decode_roi(
|
||||||
"""Calculate the target size dimensions from the annotation's geometry.
|
|
||||||
|
|
||||||
Delegates to the internal ROI mapper's `get_roi_size` method, which
|
|
||||||
applies configured scaling factors.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The annotation containing the geometry (ROI).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
NumPy array containing the size dimensions, matching the
|
|
||||||
order in `self.dimension_names` (e.g., `[width, height]`).
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the annotation lacks geometry.
|
|
||||||
"""
|
|
||||||
geom = sound_event.sound_event.geometry
|
|
||||||
|
|
||||||
if geom is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Sound event has no geometry, cannot get its size."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._roi_mapper.get_roi_size(geom)
|
|
||||||
|
|
||||||
def recover_roi(
|
|
||||||
self,
|
self,
|
||||||
pos: tuple[float, float],
|
position: Position,
|
||||||
dims: np.ndarray,
|
size: Size,
|
||||||
|
class_name: Optional[str] = None,
|
||||||
) -> data.Geometry:
|
) -> data.Geometry:
|
||||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||||
|
|
||||||
@ -438,7 +422,13 @@ class Targets(TargetProtocol):
|
|||||||
data.Geometry
|
data.Geometry
|
||||||
The reconstructed geometry (typically `BoundingBox`).
|
The reconstructed geometry (typically `BoundingBox`).
|
||||||
"""
|
"""
|
||||||
return self._roi_mapper.recover_roi(pos, dims)
|
if class_name in self._roi_mapper_overrides:
|
||||||
|
return self._roi_mapper_overrides[class_name].decode(
|
||||||
|
position,
|
||||||
|
size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._roi_mapper.decode(position, size)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CLASSES = [
|
DEFAULT_CLASSES = [
|
||||||
@ -493,10 +483,12 @@ DEFAULT_CLASSES = [
|
|||||||
TargetClass(
|
TargetClass(
|
||||||
tags=[TagInfo(value="Nyctalus leisleri")],
|
tags=[TagInfo(value="Nyctalus leisleri")],
|
||||||
name="nyclei",
|
name="nyclei",
|
||||||
|
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||||
),
|
),
|
||||||
TargetClass(
|
TargetClass(
|
||||||
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
||||||
name="rhifer",
|
name="rhifer",
|
||||||
|
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||||
),
|
),
|
||||||
TargetClass(
|
TargetClass(
|
||||||
tags=[TagInfo(value="Plecotus auritus")],
|
tags=[TagInfo(value="Plecotus auritus")],
|
||||||
@ -537,13 +529,14 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
classes=DEFAULT_CLASSES_CONFIG,
|
classes=DEFAULT_CLASSES_CONFIG,
|
||||||
|
roi=AnchorBBoxMapperConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_targets(
|
def build_targets(
|
||||||
config: Optional[TargetConfig] = None,
|
config: Optional[TargetConfig] = None,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
derivation_registry: DerivationRegistry = derivation_registry,
|
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||||
) -> Targets:
|
) -> Targets:
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
|
|
||||||
@ -606,12 +599,17 @@ def build_targets(
|
|||||||
if config.transforms
|
if config.transforms
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
|
roi_mapper = build_roi_mapper(config.roi)
|
||||||
class_names = get_class_names_from_config(config.classes)
|
class_names = get_class_names_from_config(config.classes)
|
||||||
generic_class_tags = build_generic_class_tags(
|
generic_class_tags = build_generic_class_tags(
|
||||||
config.classes,
|
config.classes,
|
||||||
term_registry=term_registry,
|
term_registry=term_registry,
|
||||||
)
|
)
|
||||||
|
roi_overrides = {
|
||||||
|
class_config.name: build_roi_mapper(class_config.roi)
|
||||||
|
for class_config in config.classes.classes
|
||||||
|
if class_config.roi is not None
|
||||||
|
}
|
||||||
|
|
||||||
return Targets(
|
return Targets(
|
||||||
filter_fn=filter_fn,
|
filter_fn=filter_fn,
|
||||||
@ -621,14 +619,15 @@ def build_targets(
|
|||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
generic_class_tags=generic_class_tags,
|
generic_class_tags=generic_class_tags,
|
||||||
transform_fn=transform_fn,
|
transform_fn=transform_fn,
|
||||||
|
roi_mapper_overrides=roi_overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_targets(
|
def load_targets(
|
||||||
config_path: data.PathLike,
|
config_path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
derivation_registry: DerivationRegistry = derivation_registry,
|
derivation_registry: DerivationRegistry = default_derivation_registry,
|
||||||
) -> Targets:
|
) -> Targets:
|
||||||
"""Load a Targets object directly from a configuration file.
|
"""Load a Targets object directly from a configuration file.
|
||||||
|
|
@ -6,29 +6,27 @@ from pydantic import Field, field_validator
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.targets.rois import ROIMapperConfig
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import (
|
||||||
GENERIC_CLASS_KEY,
|
GENERIC_CLASS_KEY,
|
||||||
TagInfo,
|
TagInfo,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
|
default_term_registry,
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
term_registry,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SoundEventEncoder",
|
|
||||||
"SoundEventDecoder",
|
|
||||||
"TargetClass",
|
|
||||||
"ClassesConfig",
|
|
||||||
"load_classes_config",
|
|
||||||
"load_encoder_from_config",
|
|
||||||
"load_decoder_from_config",
|
|
||||||
"build_sound_event_encoder",
|
|
||||||
"build_sound_event_decoder",
|
|
||||||
"build_generic_class_tags",
|
|
||||||
"get_class_names_from_config",
|
|
||||||
"DEFAULT_SPECIES_LIST",
|
"DEFAULT_SPECIES_LIST",
|
||||||
|
"build_generic_class_tags",
|
||||||
|
"build_sound_event_decoder",
|
||||||
|
"build_sound_event_encoder",
|
||||||
|
"get_class_names_from_config",
|
||||||
|
"load_classes_config",
|
||||||
|
"load_decoder_from_config",
|
||||||
|
"load_encoder_from_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
"""Type alias for a sound event class encoder function.
|
"""Type alias for a sound event class encoder function.
|
||||||
|
|
||||||
@ -113,6 +111,7 @@ class TargetClass(BaseConfig):
|
|||||||
tags: List[TagInfo] = Field(min_length=1)
|
tags: List[TagInfo] = Field(min_length=1)
|
||||||
match_type: Literal["all", "any"] = Field(default="all")
|
match_type: Literal["all", "any"] = Field(default="all")
|
||||||
output_tags: Optional[List[TagInfo]] = None
|
output_tags: Optional[List[TagInfo]] = None
|
||||||
|
roi: Optional[ROIMapperConfig] = None
|
||||||
|
|
||||||
|
|
||||||
def _get_default_classes() -> List[TargetClass]:
|
def _get_default_classes() -> List[TargetClass]:
|
||||||
@ -235,7 +234,7 @@ class ClassesConfig(BaseConfig):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def _is_target_class(
|
def is_target_class(
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
tags: Set[data.Tag],
|
tags: Set[data.Tag],
|
||||||
match_all: bool = True,
|
match_all: bool = True,
|
||||||
@ -316,7 +315,7 @@ def _encode_with_multiple_classifiers(
|
|||||||
|
|
||||||
def build_sound_event_encoder(
|
def build_sound_event_encoder(
|
||||||
config: ClassesConfig,
|
config: ClassesConfig,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventEncoder:
|
) -> SoundEventEncoder:
|
||||||
"""Build a sound event encoder function from the classes configuration.
|
"""Build a sound event encoder function from the classes configuration.
|
||||||
|
|
||||||
@ -350,7 +349,7 @@ def build_sound_event_encoder(
|
|||||||
(
|
(
|
||||||
class_info.name,
|
class_info.name,
|
||||||
partial(
|
partial(
|
||||||
_is_target_class,
|
is_target_class,
|
||||||
tags={
|
tags={
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||||
for tag_info in class_info.tags
|
for tag_info in class_info.tags
|
||||||
@ -410,7 +409,7 @@ def _decode_class(
|
|||||||
|
|
||||||
def build_sound_event_decoder(
|
def build_sound_event_decoder(
|
||||||
config: ClassesConfig,
|
config: ClassesConfig,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
raise_on_unmapped: bool = False,
|
raise_on_unmapped: bool = False,
|
||||||
) -> SoundEventDecoder:
|
) -> SoundEventDecoder:
|
||||||
"""Build a sound event decoder function from the classes configuration.
|
"""Build a sound event decoder function from the classes configuration.
|
||||||
@ -465,7 +464,7 @@ def build_sound_event_decoder(
|
|||||||
|
|
||||||
def build_generic_class_tags(
|
def build_generic_class_tags(
|
||||||
config: ClassesConfig,
|
config: ClassesConfig,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> List[data.Tag]:
|
) -> List[data.Tag]:
|
||||||
"""Extract and build the list of tags for the generic class from config.
|
"""Extract and build the list of tags for the generic class from config.
|
||||||
|
|
||||||
@ -530,7 +529,7 @@ def load_classes_config(
|
|||||||
def load_encoder_from_config(
|
def load_encoder_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventEncoder:
|
) -> SoundEventEncoder:
|
||||||
"""Load a class encoder function directly from a configuration file.
|
"""Load a class encoder function directly from a configuration file.
|
||||||
|
|
||||||
@ -571,7 +570,7 @@ def load_encoder_from_config(
|
|||||||
def load_decoder_from_config(
|
def load_decoder_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
raise_on_unmapped: bool = False,
|
raise_on_unmapped: bool = False,
|
||||||
) -> SoundEventDecoder:
|
) -> SoundEventDecoder:
|
||||||
"""Load a class decoder function directly from a configuration file.
|
"""Load a class decoder function directly from a configuration file.
|
@ -10,7 +10,7 @@ from batdetect2.targets.terms import (
|
|||||||
TagInfo,
|
TagInfo,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
term_registry,
|
default_term_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -156,7 +156,7 @@ def equal_tags(
|
|||||||
|
|
||||||
def build_filter_from_rule(
|
def build_filter_from_rule(
|
||||||
rule: FilterRule,
|
rule: FilterRule,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventFilter:
|
) -> SoundEventFilter:
|
||||||
"""Creates a callable filter function from a single FilterRule.
|
"""Creates a callable filter function from a single FilterRule.
|
||||||
|
|
||||||
@ -243,7 +243,7 @@ class FilterConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_sound_event_filter(
|
def build_sound_event_filter(
|
||||||
config: FilterConfig,
|
config: FilterConfig,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventFilter:
|
) -> SoundEventFilter:
|
||||||
"""Builds a merged filter function from a FilterConfig object.
|
"""Builds a merged filter function from a FilterConfig object.
|
||||||
|
|
||||||
@ -291,7 +291,7 @@ def load_filter_config(
|
|||||||
def load_filter_from_config(
|
def load_filter_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventFilter:
|
) -> SoundEventFilter:
|
||||||
"""Loads filter configuration from a file and builds the filter function.
|
"""Loads filter configuration from a file and builds the filter function.
|
||||||
|
|
684
src/batdetect2/targets/rois.py
Normal file
684
src/batdetect2/targets/rois.py
Normal file
@ -0,0 +1,684 @@
|
|||||||
|
"""Handles mapping between geometric ROIs and target representations.
|
||||||
|
|
||||||
|
This module defines a standardized interface (`ROITargetMapper`) for converting
|
||||||
|
a sound event's Region of Interest (ROI) into a target representation suitable
|
||||||
|
for machine learning models, and for decoding model outputs back into geometric
|
||||||
|
ROIs.
|
||||||
|
|
||||||
|
The core operations are:
|
||||||
|
1. **Encoding**: A `soundevent.data.SoundEvent` is mapped to a reference
|
||||||
|
`Position` (time, frequency) and a `Size` array. The method for
|
||||||
|
determining the position and size varies by the mapper implementation
|
||||||
|
(e.g., using a bounding box anchor or the point of peak energy).
|
||||||
|
2. **Decoding**: A `Position` and `Size` array are mapped back to an
|
||||||
|
approximate `soundevent.data.Geometry` (typically a `BoundingBox`).
|
||||||
|
|
||||||
|
This logic is encapsulated within specific mapper classes. Configuration for
|
||||||
|
each mapper (e.g., anchor point, scaling factors) is managed by a corresponding
|
||||||
|
Pydantic config object. The `ROIMapperConfig` type allows for flexibly
|
||||||
|
selecting and configuring the desired mapper. This module separates the
|
||||||
|
*geometric* aspect of target definition from *semantic* classification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets.types import Position, Size
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Anchor",
|
||||||
|
"AnchorBBoxMapper",
|
||||||
|
"AnchorBBoxMapperConfig",
|
||||||
|
"DEFAULT_ANCHOR",
|
||||||
|
"DEFAULT_FREQUENCY_SCALE",
|
||||||
|
"DEFAULT_TIME_SCALE",
|
||||||
|
"PeakEnergyBBoxMapper",
|
||||||
|
"PeakEnergyBBoxMapperConfig",
|
||||||
|
"ROIMapperConfig",
|
||||||
|
"ROITargetMapper",
|
||||||
|
"SIZE_HEIGHT",
|
||||||
|
"SIZE_ORDER",
|
||||||
|
"SIZE_WIDTH",
|
||||||
|
"build_roi_mapper",
|
||||||
|
]
|
||||||
|
|
||||||
|
Anchor = Literal[
|
||||||
|
"bottom-left",
|
||||||
|
"bottom-right",
|
||||||
|
"top-left",
|
||||||
|
"top-right",
|
||||||
|
"center-left",
|
||||||
|
"center-right",
|
||||||
|
"top-center",
|
||||||
|
"bottom-center",
|
||||||
|
"center",
|
||||||
|
"centroid",
|
||||||
|
"point_on_surface",
|
||||||
|
]
|
||||||
|
|
||||||
|
SIZE_WIDTH = "width"
|
||||||
|
"""Standard name for the width/time dimension component ('width')."""
|
||||||
|
|
||||||
|
SIZE_HEIGHT = "height"
|
||||||
|
"""Standard name for the height/frequency dimension component ('height')."""
|
||||||
|
|
||||||
|
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
|
||||||
|
"""Standard order of dimensions for size arrays ([width, height])."""
|
||||||
|
|
||||||
|
DEFAULT_TIME_SCALE = 1000.0
|
||||||
|
"""Default scaling factor for time duration."""
|
||||||
|
|
||||||
|
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
|
||||||
|
"""Default scaling factor for frequency bandwidth."""
|
||||||
|
|
||||||
|
DEFAULT_ANCHOR = "bottom-left"
|
||||||
|
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||||
|
|
||||||
|
|
||||||
|
class ROITargetMapper(Protocol):
|
||||||
|
"""Protocol defining the interface for ROI-to-target mapping.
|
||||||
|
|
||||||
|
Specifies the `encode` and `decode` methods required for converting a
|
||||||
|
`soundevent.data.SoundEvent` into a target representation (a reference
|
||||||
|
position and a size vector) and for recovering an approximate ROI from that
|
||||||
|
representation.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
dimension_names : List[str]
|
||||||
|
A list containing the names of the dimensions in the `Size` array
|
||||||
|
returned by `encode` and expected by `decode`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dimension_names: List[str]
|
||||||
|
|
||||||
|
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
|
||||||
|
"""Encode a SoundEvent's geometry into a position and size.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEvent
|
||||||
|
The input sound event, which must have a geometry attribute.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[Position, Size]
|
||||||
|
A tuple containing:
|
||||||
|
- The reference position as (time, frequency) coordinates.
|
||||||
|
- A NumPy array with the calculated size dimensions.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the sound event does not have a geometry.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||||
|
"""Decode a position and size back into a geometric ROI.
|
||||||
|
|
||||||
|
Performs the inverse mapping: takes a reference position and size
|
||||||
|
dimensions and reconstructs a geometric representation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
position : Position
|
||||||
|
The reference position (time, frequency).
|
||||||
|
size : Size
|
||||||
|
NumPy array containing the size dimensions, matching the order
|
||||||
|
and meaning specified by `dimension_names`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.Geometry
|
||||||
|
The reconstructed geometry, typically a `BoundingBox`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the `size` array has an unexpected shape or if reconstruction
|
||||||
|
fails.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class AnchorBBoxMapperConfig(BaseConfig):
|
||||||
|
"""Configuration for `AnchorBBoxMapper`.
|
||||||
|
|
||||||
|
Defines parameters for converting ROIs into targets using a fixed anchor
|
||||||
|
point on the bounding box.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : Literal["anchor_bbox"]
|
||||||
|
The unique identifier for this mapper type.
|
||||||
|
anchor : Anchor
|
||||||
|
Specifies the anchor point within the bounding box to use as the
|
||||||
|
target's reference position (e.g., "center", "bottom-left").
|
||||||
|
time_scale : float
|
||||||
|
Scaling factor applied to the time duration (width) of the ROI.
|
||||||
|
frequency_scale : float
|
||||||
|
Scaling factor applied to the frequency bandwidth (height) of the ROI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["anchor_bbox"] = "anchor_bbox"
|
||||||
|
anchor: Anchor = DEFAULT_ANCHOR
|
||||||
|
time_scale: float = DEFAULT_TIME_SCALE
|
||||||
|
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
||||||
|
|
||||||
|
|
||||||
|
class AnchorBBoxMapper(ROITargetMapper):
|
||||||
|
"""Maps ROIs using a bounding box anchor point and width/height.
|
||||||
|
|
||||||
|
This class implements the `ROITargetMapper` protocol for `BoundingBox`
|
||||||
|
geometries.
|
||||||
|
|
||||||
|
**Encoding**: The `position` is a fixed anchor point on the bounding box
|
||||||
|
(e.g., "bottom-left"). The `size` is a 2-element array containing the
|
||||||
|
scaled width and height of the box.
|
||||||
|
|
||||||
|
**Decoding**: Reconstructs a `BoundingBox` from an anchor point and
|
||||||
|
scaled width/height.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
dimension_names : List[str]
|
||||||
|
The output dimension names: `['width', 'height']`.
|
||||||
|
anchor : Anchor
|
||||||
|
The configured anchor point type (e.g., "center", "bottom-left").
|
||||||
|
time_scale : float
|
||||||
|
The scaling factor for the time dimension (width).
|
||||||
|
frequency_scale : float
|
||||||
|
The scaling factor for the frequency dimension (height).
|
||||||
|
"""
|
||||||
|
|
||||||
|
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
anchor: Anchor = DEFAULT_ANCHOR,
|
||||||
|
time_scale: float = DEFAULT_TIME_SCALE,
|
||||||
|
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||||
|
):
|
||||||
|
"""Initialize the BBoxEncoder.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
anchor : Anchor
|
||||||
|
Reference point type within the bounding box.
|
||||||
|
time_scale : float
|
||||||
|
Scaling factor for time duration (width).
|
||||||
|
frequency_scale : float
|
||||||
|
Scaling factor for frequency bandwidth (height).
|
||||||
|
"""
|
||||||
|
self.anchor: Anchor = anchor
|
||||||
|
self.time_scale = time_scale
|
||||||
|
self.frequency_scale = frequency_scale
|
||||||
|
|
||||||
|
def encode(self, sound_event: data.SoundEvent) -> Tuple[Position, Size]:
|
||||||
|
"""Encode a SoundEvent into an anchor position and scaled box size.
|
||||||
|
|
||||||
|
The position is determined by the configured anchor on the sound
|
||||||
|
event's bounding box. The size is the scaled width and height.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEvent
|
||||||
|
The input sound event with a geometry.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[Position, Size]
|
||||||
|
A tuple of (anchor_position, [scaled_width, scaled_height]).
|
||||||
|
"""
|
||||||
|
from soundevent import geometry
|
||||||
|
|
||||||
|
geom = sound_event.geometry
|
||||||
|
|
||||||
|
if geom is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot encode the geometry of a sound event without geometry."
|
||||||
|
f" Sound event: {sound_event}"
|
||||||
|
)
|
||||||
|
|
||||||
|
position = geometry.get_geometry_point(geom, position=self.anchor)
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||||
|
geom
|
||||||
|
)
|
||||||
|
|
||||||
|
size = np.array(
|
||||||
|
[
|
||||||
|
(end_time - start_time) * self.time_scale,
|
||||||
|
(high_freq - low_freq) * self.frequency_scale,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return position, size
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
position: Position,
|
||||||
|
size: Size,
|
||||||
|
) -> data.Geometry:
|
||||||
|
"""Recover a BoundingBox from an anchor position and scaled size.
|
||||||
|
|
||||||
|
Un-scales the input dimensions and reconstructs a
|
||||||
|
`soundevent.data.BoundingBox` relative to the given anchor position.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
position : Position
|
||||||
|
Reference anchor position (time, frequency).
|
||||||
|
size : Size
|
||||||
|
NumPy array containing the scaled [width, height].
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.BoundingBox
|
||||||
|
The reconstructed bounding box.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `size` does not have the expected shape (length 2).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if size.ndim != 1 or size.shape[0] != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Dimension array does not have the expected shape. "
|
||||||
|
f"({size.shape = }) != ([2])"
|
||||||
|
)
|
||||||
|
|
||||||
|
width, height = size
|
||||||
|
return _build_bounding_box(
|
||||||
|
position,
|
||||||
|
duration=float(width) / self.time_scale,
|
||||||
|
bandwidth=float(height) / self.frequency_scale,
|
||||||
|
anchor=self.anchor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PeakEnergyBBoxMapperConfig(BaseConfig):
|
||||||
|
"""Configuration for `PeakEnergyBBoxMapper`.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : Literal["peak_energy_bbox"]
|
||||||
|
The unique identifier for this mapper type.
|
||||||
|
preprocessing : PreprocessingConfig
|
||||||
|
Configuration for the spectrogram preprocessor needed to find the
|
||||||
|
peak energy.
|
||||||
|
loading_buffer : float
|
||||||
|
Seconds to add to each side of the ROI when loading audio to ensure
|
||||||
|
the peak is captured accurately, avoiding boundary effects.
|
||||||
|
time_scale : float
|
||||||
|
Scaling factor applied to the time dimensions.
|
||||||
|
frequency_scale : float
|
||||||
|
Scaling factor applied to the frequency dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
loading_buffer: float = 0.01
|
||||||
|
time_scale: float = DEFAULT_TIME_SCALE
|
||||||
|
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
||||||
|
|
||||||
|
|
||||||
|
class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||||
|
"""Maps ROIs using the peak energy point and distances to edges.
|
||||||
|
|
||||||
|
This class implements the `ROITargetMapper` protocol.
|
||||||
|
|
||||||
|
**Encoding**: The `position` is the (time, frequency) coordinate of the
|
||||||
|
point with the highest energy within the sound event's bounding box. The
|
||||||
|
`size` is a 4-element array representing the scaled distances from this
|
||||||
|
peak energy point to the left, bottom, right, and top edges of the box.
|
||||||
|
|
||||||
|
**Decoding**: Reconstructs a `BoundingBox` by adding/subtracting the
|
||||||
|
un-scaled distances from the peak energy point.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
dimension_names : List[str]
|
||||||
|
The output dimension names: `['left', 'bottom', 'right', 'top']`.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
The spectrogram preprocessor instance.
|
||||||
|
time_scale : float
|
||||||
|
The scaling factor for time-based distances.
|
||||||
|
frequency_scale : float
|
||||||
|
The scaling factor for frequency-based distances.
|
||||||
|
loading_buffer : float
|
||||||
|
The buffer used for loading audio around the ROI.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dimension_names = ["left", "bottom", "right", "top"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
time_scale: float = DEFAULT_TIME_SCALE,
|
||||||
|
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||||
|
loading_buffer: float = 0.01,
|
||||||
|
):
|
||||||
|
"""Initialize the PeakEnergyBBoxMapper.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
An initialized preprocessor for generating spectrograms.
|
||||||
|
time_scale : float
|
||||||
|
Scaling factor for time dimensions (left, right distances).
|
||||||
|
frequency_scale : float
|
||||||
|
Scaling factor for frequency dimensions (bottom, top distances).
|
||||||
|
loading_buffer : float
|
||||||
|
Buffer in seconds to add when loading audio clips.
|
||||||
|
"""
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.time_scale = time_scale
|
||||||
|
self.frequency_scale = frequency_scale
|
||||||
|
self.loading_buffer = loading_buffer
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
) -> tuple[Position, Size]:
|
||||||
|
"""Encode a SoundEvent into a peak energy position and edge distances.
|
||||||
|
|
||||||
|
Finds the peak energy coordinates within the event's bounding box
|
||||||
|
and calculates the scaled distances from this point to the box edges.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEvent
|
||||||
|
The input sound event with a geometry and associated recording.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[Position, Size]
|
||||||
|
A tuple of (peak_position, [l, b, r, t] distances).
|
||||||
|
"""
|
||||||
|
from soundevent import geometry
|
||||||
|
|
||||||
|
geom = sound_event.geometry
|
||||||
|
|
||||||
|
if geom is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot encode the geometry of a sound event without geometry."
|
||||||
|
f" Sound event: {sound_event}"
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||||
|
geom
|
||||||
|
)
|
||||||
|
|
||||||
|
time, freq = get_peak_energy_coordinates(
|
||||||
|
recording=sound_event.recording,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
|
loading_buffer=self.loading_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
size = np.array(
|
||||||
|
[
|
||||||
|
(time - start_time) * self.time_scale,
|
||||||
|
(freq - low_freq) * self.frequency_scale,
|
||||||
|
(end_time - time) * self.time_scale,
|
||||||
|
(high_freq - freq) * self.frequency_scale,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return (time, freq), size
|
||||||
|
|
||||||
|
def decode(self, position: Position, size: Size) -> data.Geometry:
|
||||||
|
"""Recover a BoundingBox from a peak position and edge distances.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
position : Position
|
||||||
|
The reference peak energy position (time, frequency).
|
||||||
|
size : Size
|
||||||
|
NumPy array with scaled distances [left, bottom, right, top].
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.BoundingBox
|
||||||
|
The reconstructed bounding box.
|
||||||
|
"""
|
||||||
|
time, freq = position
|
||||||
|
left, bottom, right, top = size
|
||||||
|
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
time - max(0, float(left)) / self.time_scale,
|
||||||
|
freq - max(0, float(bottom)) / self.frequency_scale,
|
||||||
|
time + max(0, float(right)) / self.time_scale,
|
||||||
|
freq + max(0, float(top)) / self.frequency_scale,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ROIMapperConfig = Annotated[
|
||||||
|
Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
"""A discriminated union of all supported ROI mapper configurations.
|
||||||
|
|
||||||
|
This type allows for selecting and configuring different `ROITargetMapper`
|
||||||
|
implementations by using the `name` field as a discriminator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def build_roi_mapper(
|
||||||
|
config: Optional[ROIMapperConfig] = None,
|
||||||
|
) -> ROITargetMapper:
|
||||||
|
"""Factory function to create an ROITargetMapper from a config object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : ROIMapperConfig
|
||||||
|
A configuration object specifying the mapper type and its parameters.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ROITargetMapper
|
||||||
|
An initialized ROI mapper instance.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
NotImplementedError
|
||||||
|
If the `name` in the config does not correspond to a known mapper.
|
||||||
|
"""
|
||||||
|
config = config or AnchorBBoxMapperConfig()
|
||||||
|
|
||||||
|
if config.name == "anchor_bbox":
|
||||||
|
return AnchorBBoxMapper(
|
||||||
|
anchor=config.anchor,
|
||||||
|
time_scale=config.time_scale,
|
||||||
|
frequency_scale=config.frequency_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == "peak_energy_bbox":
|
||||||
|
preprocessor = build_preprocessor(config.preprocessing)
|
||||||
|
return PeakEnergyBBoxMapper(
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
time_scale=config.time_scale,
|
||||||
|
frequency_scale=config.frequency_scale,
|
||||||
|
loading_buffer=config.loading_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"No ROI mapper of name '{config.name}' is implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
VALID_ANCHORS = [
|
||||||
|
"bottom-left",
|
||||||
|
"bottom-right",
|
||||||
|
"top-left",
|
||||||
|
"top-right",
|
||||||
|
"center-left",
|
||||||
|
"center-right",
|
||||||
|
"top-center",
|
||||||
|
"bottom-center",
|
||||||
|
"center",
|
||||||
|
"centroid",
|
||||||
|
"point_on_surface",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bounding_box(
|
||||||
|
pos: tuple[float, float],
|
||||||
|
duration: float,
|
||||||
|
bandwidth: float,
|
||||||
|
anchor: Anchor = DEFAULT_ANCHOR,
|
||||||
|
) -> data.BoundingBox:
|
||||||
|
"""Construct a BoundingBox from a reference point, size, and position type.
|
||||||
|
|
||||||
|
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
|
||||||
|
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
||||||
|
the input `pos` (time, freq) is located relative to the box (e.g.,
|
||||||
|
center, corner).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pos : Tuple[float, float]
|
||||||
|
Reference position (time, frequency).
|
||||||
|
duration : float
|
||||||
|
The required *unscaled* duration (width) of the bounding box.
|
||||||
|
bandwidth : float
|
||||||
|
The required *unscaled* frequency bandwidth (height) of the bounding
|
||||||
|
box.
|
||||||
|
anchor : Anchor
|
||||||
|
Specifies which part of the bounding box the input `pos` corresponds to.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.BoundingBox
|
||||||
|
The constructed bounding box object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `anchor` is not a recognized value or format.
|
||||||
|
"""
|
||||||
|
time, freq = map(float, pos)
|
||||||
|
duration = max(0, duration)
|
||||||
|
bandwidth = max(0, bandwidth)
|
||||||
|
if anchor in ["center", "centroid", "point_on_surface"]:
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
max(time - duration / 2, 0),
|
||||||
|
max(freq - bandwidth / 2, 0),
|
||||||
|
max(time + duration / 2, 0),
|
||||||
|
max(freq + bandwidth / 2, 0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if anchor not in VALID_ANCHORS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid anchor: {anchor}. Valid options are: {VALID_ANCHORS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
y, x = anchor.split("-")
|
||||||
|
|
||||||
|
start_time = {
|
||||||
|
"left": time,
|
||||||
|
"center": time - duration / 2,
|
||||||
|
"right": time - duration,
|
||||||
|
}[x]
|
||||||
|
|
||||||
|
low_freq = {
|
||||||
|
"bottom": freq,
|
||||||
|
"center": freq - bandwidth / 2,
|
||||||
|
"top": freq - bandwidth,
|
||||||
|
}[y]
|
||||||
|
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
max(0, start_time),
|
||||||
|
max(0, low_freq),
|
||||||
|
max(0, start_time + duration),
|
||||||
|
max(0, low_freq + bandwidth),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_peak_energy_coordinates(
|
||||||
|
recording: data.Recording,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
start_time: float = 0,
|
||||||
|
end_time: Optional[float] = None,
|
||||||
|
low_freq: float = 0,
|
||||||
|
high_freq: Optional[float] = None,
|
||||||
|
loading_buffer: float = 0.05,
|
||||||
|
) -> Position:
|
||||||
|
"""Find the coordinates of the highest energy point in a spectrogram.
|
||||||
|
|
||||||
|
Generates a spectrogram for a specified time-frequency region of a
|
||||||
|
recording and returns the (time, frequency) coordinates of the pixel with
|
||||||
|
the maximum value.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
recording : data.Recording
|
||||||
|
The recording to analyze.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
The processor to convert audio to a spectrogram.
|
||||||
|
start_time : float, default=0
|
||||||
|
The start time of the region of interest.
|
||||||
|
end_time : float, optional
|
||||||
|
The end time of the region of interest. Defaults to recording duration.
|
||||||
|
low_freq : float, default=0
|
||||||
|
The low frequency of the region of interest.
|
||||||
|
high_freq : float, optional
|
||||||
|
The high frequency of the region of interest. Defaults to Nyquist.
|
||||||
|
loading_buffer : float, default=0.05
|
||||||
|
Buffer in seconds to add around the time range when loading the clip
|
||||||
|
to mitigate border effects from transformations like STFT.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Position
|
||||||
|
A (time, frequency) tuple for the peak energy location.
|
||||||
|
"""
|
||||||
|
if end_time is None:
|
||||||
|
end_time = recording.duration
|
||||||
|
end_time = min(end_time, recording.duration)
|
||||||
|
|
||||||
|
if high_freq is None:
|
||||||
|
high_freq = recording.samplerate / 2
|
||||||
|
|
||||||
|
clip_start = max(0, start_time - loading_buffer)
|
||||||
|
clip_end = min(recording.duration, end_time + loading_buffer)
|
||||||
|
|
||||||
|
clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
start_time=clip_start,
|
||||||
|
end_time=clip_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = preprocessor.preprocess_clip(clip)
|
||||||
|
low_freq = max(low_freq, preprocessor.min_freq)
|
||||||
|
high_freq = min(high_freq, preprocessor.max_freq)
|
||||||
|
selection = spec.sel(
|
||||||
|
time=slice(start_time, end_time),
|
||||||
|
frequency=slice(low_freq, high_freq),
|
||||||
|
)
|
||||||
|
|
||||||
|
index = selection.argmax(dim=["time", "frequency"])
|
||||||
|
point = selection.isel(index) # type: ignore
|
||||||
|
peak_time: float = point.time.item()
|
||||||
|
peak_freq: float = point.frequency.item()
|
||||||
|
return peak_time, peak_freq
|
@ -230,11 +230,12 @@ class TermRegistry(Mapping[str, data.Term]):
|
|||||||
del self._terms[key]
|
del self._terms[key]
|
||||||
|
|
||||||
|
|
||||||
term_registry = TermRegistry(
|
default_term_registry = TermRegistry(
|
||||||
terms=dict(
|
terms=dict(
|
||||||
[
|
[
|
||||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||||
("event", call_type),
|
("event", call_type),
|
||||||
|
("species", terms.scientific_name),
|
||||||
("individual", individual),
|
("individual", individual),
|
||||||
("data_source", data_source),
|
("data_source", data_source),
|
||||||
(GENERIC_CLASS_KEY, generic_class),
|
(GENERIC_CLASS_KEY, generic_class),
|
||||||
@ -252,7 +253,7 @@ is explicitly passed.
|
|||||||
|
|
||||||
def get_term_from_key(
|
def get_term_from_key(
|
||||||
key: str,
|
key: str,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: Optional[TermRegistry] = None,
|
||||||
) -> data.Term:
|
) -> data.Term:
|
||||||
"""Convenience function to retrieve a term by key from a registry.
|
"""Convenience function to retrieve a term by key from a registry.
|
||||||
|
|
||||||
@ -277,10 +278,13 @@ def get_term_from_key(
|
|||||||
KeyError
|
KeyError
|
||||||
If the key is not found in the specified registry.
|
If the key is not found in the specified registry.
|
||||||
"""
|
"""
|
||||||
|
term_registry = term_registry or default_term_registry
|
||||||
return term_registry.get_term(key)
|
return term_registry.get_term(key)
|
||||||
|
|
||||||
|
|
||||||
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
|
def get_term_keys(
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
) -> List[str]:
|
||||||
"""Convenience function to get all registered keys from a registry.
|
"""Convenience function to get all registered keys from a registry.
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
Uses the global default registry unless a specific `term_registry`
|
||||||
@ -299,7 +303,9 @@ def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
|
|||||||
return term_registry.get_keys()
|
return term_registry.get_keys()
|
||||||
|
|
||||||
|
|
||||||
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
|
def get_terms(
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
|
) -> List[data.Term]:
|
||||||
"""Convenience function to get all registered terms from a registry.
|
"""Convenience function to get all registered terms from a registry.
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
Uses the global default registry unless a specific `term_registry`
|
||||||
@ -342,7 +348,7 @@ class TagInfo(BaseModel):
|
|||||||
|
|
||||||
def get_tag_from_info(
|
def get_tag_from_info(
|
||||||
tag_info: TagInfo,
|
tag_info: TagInfo,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: Optional[TermRegistry] = None,
|
||||||
) -> data.Tag:
|
) -> data.Tag:
|
||||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||||
|
|
||||||
@ -368,6 +374,7 @@ def get_tag_from_info(
|
|||||||
If the term key specified in `tag_info.key` is not found
|
If the term key specified in `tag_info.key` is not found
|
||||||
in the registry.
|
in the registry.
|
||||||
"""
|
"""
|
||||||
|
term_registry = term_registry or default_term_registry
|
||||||
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
||||||
return data.Tag(term=term, value=tag_info.value)
|
return data.Tag(term=term, value=tag_info.value)
|
||||||
|
|
||||||
@ -439,7 +446,7 @@ class TermConfig(BaseModel):
|
|||||||
def load_terms_from_config(
|
def load_terms_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> Dict[str, data.Term]:
|
) -> Dict[str, data.Term]:
|
||||||
"""Loads term definitions from a configuration file and registers them.
|
"""Loads term definitions from a configuration file and registers them.
|
||||||
|
|
||||||
@ -490,6 +497,6 @@ def load_terms_from_config(
|
|||||||
|
|
||||||
|
|
||||||
def register_term(
|
def register_term(
|
||||||
key: str, term: data.Term, registry: TermRegistry = term_registry
|
key: str, term: data.Term, registry: TermRegistry = default_term_registry
|
||||||
) -> None:
|
) -> None:
|
||||||
registry.add_term(key, term)
|
registry.add_term(key, term)
|
@ -21,9 +21,6 @@ from batdetect2.targets.terms import (
|
|||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
get_term_from_key,
|
get_term_from_key,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import (
|
|
||||||
term_registry as default_term_registry,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DerivationRegistry",
|
"DerivationRegistry",
|
||||||
@ -34,7 +31,7 @@ __all__ = [
|
|||||||
"TransformConfig",
|
"TransformConfig",
|
||||||
"build_transform_from_rule",
|
"build_transform_from_rule",
|
||||||
"build_transformation_from_config",
|
"build_transformation_from_config",
|
||||||
"derivation_registry",
|
"default_derivation_registry",
|
||||||
"get_derivation",
|
"get_derivation",
|
||||||
"load_transformation_config",
|
"load_transformation_config",
|
||||||
"load_transformation_from_config",
|
"load_transformation_from_config",
|
||||||
@ -398,7 +395,7 @@ class DerivationRegistry(Mapping[str, Derivation]):
|
|||||||
return list(self._derivations.values())
|
return list(self._derivations.values())
|
||||||
|
|
||||||
|
|
||||||
derivation_registry = DerivationRegistry()
|
default_derivation_registry = DerivationRegistry()
|
||||||
"""Global instance of the DerivationRegistry.
|
"""Global instance of the DerivationRegistry.
|
||||||
|
|
||||||
Register custom derivation functions here to make them available by key
|
Register custom derivation functions here to make them available by key
|
||||||
@ -409,7 +406,7 @@ in `DeriveTagRule` configuration.
|
|||||||
def get_derivation(
|
def get_derivation(
|
||||||
key: str,
|
key: str,
|
||||||
import_derivation: bool = False,
|
import_derivation: bool = False,
|
||||||
registry: DerivationRegistry = derivation_registry,
|
registry: Optional[DerivationRegistry] = None,
|
||||||
):
|
):
|
||||||
"""Retrieve a derivation function by key, optionally importing it.
|
"""Retrieve a derivation function by key, optionally importing it.
|
||||||
|
|
||||||
@ -443,6 +440,8 @@ def get_derivation(
|
|||||||
AttributeError
|
AttributeError
|
||||||
If dynamic import fails because the function name isn't in the module.
|
If dynamic import fails because the function name isn't in the module.
|
||||||
"""
|
"""
|
||||||
|
registry = registry or default_derivation_registry
|
||||||
|
|
||||||
if not import_derivation or key in registry:
|
if not import_derivation or key in registry:
|
||||||
return registry.get_derivation(key)
|
return registry.get_derivation(key)
|
||||||
|
|
||||||
@ -458,10 +457,16 @@ def get_derivation(
|
|||||||
) from err
|
) from err
|
||||||
|
|
||||||
|
|
||||||
|
TranformationRule = Annotated[
|
||||||
|
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||||
|
Field(discriminator="rule_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_transform_from_rule(
|
def build_transform_from_rule(
|
||||||
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
rule: TranformationRule,
|
||||||
derivation_registry: DerivationRegistry = derivation_registry,
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: Optional[TermRegistry] = None,
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Build a specific SoundEventTransformation function from a rule config.
|
"""Build a specific SoundEventTransformation function from a rule config.
|
||||||
|
|
||||||
@ -559,8 +564,8 @@ def build_transform_from_rule(
|
|||||||
|
|
||||||
def build_transformation_from_config(
|
def build_transformation_from_config(
|
||||||
config: TransformConfig,
|
config: TransformConfig,
|
||||||
derivation_registry: DerivationRegistry = derivation_registry,
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: Optional[TermRegistry] = None,
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Build a composite transformation function from a TransformConfig.
|
"""Build a composite transformation function from a TransformConfig.
|
||||||
|
|
||||||
@ -581,6 +586,7 @@ def build_transformation_from_config(
|
|||||||
SoundEventTransformation
|
SoundEventTransformation
|
||||||
A single function that applies all configured transformations in order.
|
A single function that applies all configured transformations in order.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
transforms = [
|
transforms = [
|
||||||
build_transform_from_rule(
|
build_transform_from_rule(
|
||||||
rule,
|
rule,
|
||||||
@ -590,14 +596,16 @@ def build_transformation_from_config(
|
|||||||
for rule in config.rules
|
for rule in config.rules
|
||||||
]
|
]
|
||||||
|
|
||||||
def transformation(
|
return partial(apply_sequence_of_transforms, transforms=transforms)
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
for transform in transforms:
|
|
||||||
sound_event_annotation = transform(sound_event_annotation)
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
return transformation
|
|
||||||
|
def apply_sequence_of_transforms(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
transforms: list[SoundEventTransformation],
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
for transform in transforms:
|
||||||
|
sound_event_annotation = transform(sound_event_annotation)
|
||||||
|
return sound_event_annotation
|
||||||
|
|
||||||
|
|
||||||
def load_transformation_config(
|
def load_transformation_config(
|
||||||
@ -631,8 +639,8 @@ def load_transformation_config(
|
|||||||
def load_transformation_from_config(
|
def load_transformation_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
derivation_registry: DerivationRegistry = derivation_registry,
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: Optional[TermRegistry] = None,
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Load transformation config from a file and build the final function.
|
"""Load transformation config from a file and build the final function.
|
||||||
|
|
||||||
@ -677,7 +685,7 @@ def load_transformation_from_config(
|
|||||||
def register_derivation(
|
def register_derivation(
|
||||||
key: str,
|
key: str,
|
||||||
derivation: Derivation,
|
derivation: Derivation,
|
||||||
derivation_registry: DerivationRegistry = derivation_registry,
|
derivation_registry: Optional[DerivationRegistry] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a new derivation function in the global registry.
|
"""Register a new derivation function in the global registry.
|
||||||
|
|
||||||
@ -696,4 +704,5 @@ def register_derivation(
|
|||||||
KeyError
|
KeyError
|
||||||
If a derivation function with the same key is already registered.
|
If a derivation function with the same key is already registered.
|
||||||
"""
|
"""
|
||||||
|
derivation_registry = derivation_registry or default_derivation_registry
|
||||||
derivation_registry.register(key, derivation)
|
derivation_registry.register(key, derivation)
|
@ -19,8 +19,16 @@ from soundevent import data
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TargetProtocol",
|
"TargetProtocol",
|
||||||
|
"Position",
|
||||||
|
"Size",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Position = tuple[float, float]
|
||||||
|
"""A tuple representing (time, frequency) coordinates."""
|
||||||
|
|
||||||
|
Size = np.ndarray
|
||||||
|
"""A NumPy array representing the size dimensions of a target."""
|
||||||
|
|
||||||
|
|
||||||
class TargetProtocol(Protocol):
|
class TargetProtocol(Protocol):
|
||||||
"""Protocol defining the interface for the target definition pipeline.
|
"""Protocol defining the interface for the target definition pipeline.
|
||||||
@ -102,7 +110,7 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def encode(
|
def encode_class(
|
||||||
self,
|
self,
|
||||||
sound_event: data.SoundEventAnnotation,
|
sound_event: data.SoundEventAnnotation,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
@ -123,7 +131,7 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def decode(self, class_label: str) -> List[data.Tag]:
|
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||||
"""Decode a predicted class name back into representative tags.
|
"""Decode a predicted class name back into representative tags.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -147,9 +155,9 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_position(
|
def encode_roi(
|
||||||
self, sound_event: data.SoundEventAnnotation
|
self, sound_event: data.SoundEventAnnotation
|
||||||
) -> tuple[float, float]:
|
) -> tuple[Position, Size]:
|
||||||
"""Extract the target reference position from the annotation's geometry.
|
"""Extract the target reference position from the annotation's geometry.
|
||||||
|
|
||||||
Calculates the `(time, frequency)` coordinate representing the primary
|
Calculates the `(time, frequency)` coordinate representing the primary
|
||||||
@ -173,36 +181,12 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
# TODO: Update docstrings
|
||||||
"""Calculate the target size dimensions from the annotation's geometry.
|
def decode_roi(
|
||||||
|
self,
|
||||||
Computes the relevant physical size (e.g., duration/width,
|
position: Position,
|
||||||
bandwidth/height from a bounding box) to produce
|
size: Size,
|
||||||
the numerical target values expected by the model.
|
class_name: Optional[str] = None,
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The annotation containing the geometry (ROI) to process.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
A NumPy array containing the size dimensions, matching the
|
|
||||||
order specified by the `dimension_names` attribute (e.g.,
|
|
||||||
`[width, height]`).
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the annotation lacks geometry or if the size cannot be computed.
|
|
||||||
TypeError
|
|
||||||
If geometry type is unsupported.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def recover_roi(
|
|
||||||
self, pos: tuple[float, float], dims: np.ndarray
|
|
||||||
) -> data.Geometry:
|
) -> data.Geometry:
|
||||||
"""Recover the ROI geometry from a position and dimensions.
|
"""Recover the ROI geometry from a position and dimensions.
|
||||||
|
|
||||||
@ -217,6 +201,8 @@ class TargetProtocol(Protocol):
|
|||||||
dims : np.ndarray
|
dims : np.ndarray
|
||||||
The NumPy array containing the dimensions (e.g., predicted
|
The NumPy array containing the dimensions (e.g., predicted
|
||||||
by the model), corresponding to the order in `dimension_names`.
|
by the model), corresponding to the order in `dimension_names`.
|
||||||
|
class_name: str
|
||||||
|
class
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
@ -97,7 +97,7 @@ def _is_in_subclip(
|
|||||||
start_time: float,
|
start_time: float,
|
||||||
end_time: float,
|
end_time: float,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
time, _ = targets.get_position(sound_event_annotation)
|
(time, _), _ = targets.encode_roi(sound_event_annotation)
|
||||||
return start_time <= time <= end_time
|
return start_time <= time <= end_time
|
||||||
|
|
||||||
|
|
@ -138,7 +138,7 @@ def generate_clip_label(
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
|
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
|
||||||
uuid=clip_annotation.uuid,
|
uuid=clip_annotation.uuid,
|
||||||
num=len(clip_annotation.sound_events)
|
num=len(clip_annotation.sound_events),
|
||||||
)
|
)
|
||||||
|
|
||||||
sound_events = []
|
sound_events = []
|
||||||
@ -260,7 +260,7 @@ def generate_heatmaps(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Get the position of the sound event
|
# Get the position of the sound event
|
||||||
time, frequency = targets.get_position(sound_event_annotation)
|
(time, frequency), size = targets.encode_roi(sound_event_annotation)
|
||||||
|
|
||||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||||
try:
|
try:
|
||||||
@ -280,8 +280,6 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
size = targets.get_size(sound_event_annotation)
|
|
||||||
|
|
||||||
size_heatmap = arrays.set_value_at_pos(
|
size_heatmap = arrays.set_value_at_pos(
|
||||||
size_heatmap,
|
size_heatmap,
|
||||||
size,
|
size,
|
||||||
@ -291,7 +289,7 @@ def generate_heatmaps(
|
|||||||
|
|
||||||
# Get the class name of the sound event
|
# Get the class name of the sound event
|
||||||
try:
|
try:
|
||||||
class_name = targets.encode(sound_event_annotation)
|
class_name = targets.encode_class(sound_event_annotation)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping annotation %s: Unexpected error while encoding "
|
"Skipping annotation %s: Unexpected error while encoding "
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user