mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 15:12:06 +02:00
Compare commits
No commits in common. "a462beaeb8619256a1f995f5594c503d4d1bc24b" and "ebad489cb196929532e52091f77504127e7e3e87" have entirely different histories.
a462beaeb8
...
ebad489cb1
2
.gitignore
vendored
2
.gitignore
vendored
@ -107,7 +107,7 @@ experiments/*
|
|||||||
|
|
||||||
# DO Include
|
# DO Include
|
||||||
!batdetect2_notebook.ipynb
|
!batdetect2_notebook.ipynb
|
||||||
!src/batdetect2/models/checkpoints/*.pth.tar
|
!batdetect2/models/checkpoints/*.pth.tar
|
||||||
!tests/data/*.wav
|
!tests/data/*.wav
|
||||||
!notebooks/*.ipynb
|
!notebooks/*.ipynb
|
||||||
!tests/data/**/*.wav
|
!tests/data/**/*.wav
|
||||||
|
4
Makefile
4
Makefile
@ -1,7 +1,7 @@
|
|||||||
# Variables
|
# Variables
|
||||||
SOURCE_DIR = src
|
SOURCE_DIR = batdetect2
|
||||||
TESTS_DIR = tests
|
TESTS_DIR = tests
|
||||||
PYTHON_DIRS = src tests
|
PYTHON_DIRS = batdetect2 tests
|
||||||
DOCS_SOURCE = docs/source
|
DOCS_SOURCE = docs/source
|
||||||
DOCS_BUILD = docs/build
|
DOCS_BUILD = docs/build
|
||||||
HTML_COVERAGE_DIR = htmlcov
|
HTML_COVERAGE_DIR = htmlcov
|
||||||
|
@ -189,7 +189,8 @@ def train_command(
|
|||||||
config=postprocess_config_loaded,
|
config=postprocess_config_loaded,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded postprocessor from file {path}", path=postprocess_config
|
"Loaded postprocessor from file {path}",
|
||||||
|
path=train_config,
|
||||||
)
|
)
|
||||||
except IOError:
|
except IOError:
|
||||||
logger.debug(
|
logger.debug(
|
@ -157,4 +157,4 @@ def load_config(
|
|||||||
if field:
|
if field:
|
||||||
config = get_object_field(config, field)
|
config = get_object_field(config, field)
|
||||||
|
|
||||||
return schema.model_validate(config or {})
|
return schema.model_validate(config)
|
@ -8,7 +8,6 @@ from batdetect2.data.annotations import (
|
|||||||
from batdetect2.data.datasets import (
|
from batdetect2.data.datasets import (
|
||||||
DatasetConfig,
|
DatasetConfig,
|
||||||
load_dataset,
|
load_dataset,
|
||||||
load_dataset_config,
|
|
||||||
load_dataset_from_config,
|
load_dataset_from_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,6 +19,5 @@ __all__ = [
|
|||||||
"DatasetConfig",
|
"DatasetConfig",
|
||||||
"load_annotated_dataset",
|
"load_annotated_dataset",
|
||||||
"load_dataset",
|
"load_dataset",
|
||||||
"load_dataset_config",
|
|
||||||
"load_dataset_from_config",
|
"load_dataset_from_config",
|
||||||
]
|
]
|
@ -161,11 +161,6 @@ def insert_source_tag(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: add documentation
|
|
||||||
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
|
||||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_from_config(
|
def load_dataset_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
@ -72,7 +72,7 @@ def iterate_over_sound_events(
|
|||||||
sound_event_annotation
|
sound_event_annotation
|
||||||
)
|
)
|
||||||
|
|
||||||
class_name = targets.encode_class(sound_event_annotation)
|
class_name = targets.encode(sound_event_annotation)
|
||||||
if class_name is None and exclude_generic:
|
if class_name is None and exclude_generic:
|
||||||
continue
|
continue
|
||||||
|
|
@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.metrics import auc, roc_curve
|
from sklearn.metrics import auc, roc_curve
|
||||||
|
|
@ -40,7 +40,7 @@ def match_sound_events_and_raw_predictions(
|
|||||||
|
|
||||||
gt_uuid = target.uuid if target is not None else None
|
gt_uuid = target.uuid if target is not None else None
|
||||||
gt_det = target is not None
|
gt_det = target is not None
|
||||||
gt_class = targets.encode_class(target) if target is not None else None
|
gt_class = targets.encode(target) if target is not None else None
|
||||||
|
|
||||||
pred_score = float(prediction.detection_score) if prediction else 0
|
pred_score = float(prediction.detection_score) if prediction else 0
|
||||||
|
|
@ -7,8 +7,8 @@ containing detected sound events with associated class tags and geometry.
|
|||||||
|
|
||||||
The pipeline involves several configurable steps, implemented in submodules:
|
The pipeline involves several configurable steps, implemented in submodules:
|
||||||
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||||
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
|
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
|
||||||
model output arrays.
|
coordinates to raw model output arrays.
|
||||||
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||||
(location and score) based on thresholds and score ranking (top-k).
|
(location and score) based on thresholds and score ranking (top-k).
|
||||||
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||||
@ -526,7 +526,7 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
return [
|
return [
|
||||||
convert_xr_dataset_to_raw_prediction(
|
convert_xr_dataset_to_raw_prediction(
|
||||||
dataset,
|
dataset,
|
||||||
self.targets.decode_roi,
|
self.targets.recover_roi,
|
||||||
)
|
)
|
||||||
for dataset in detection_datasets
|
for dataset in detection_datasets
|
||||||
]
|
]
|
||||||
@ -558,7 +558,7 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
convert_raw_predictions_to_clip_prediction(
|
convert_raw_predictions_to_clip_prediction(
|
||||||
prediction,
|
prediction,
|
||||||
clip,
|
clip,
|
||||||
sound_event_decoder=self.targets.decode_class,
|
sound_event_decoder=self.targets.decode,
|
||||||
generic_class_tags=self.targets.generic_class_tags,
|
generic_class_tags=self.targets.generic_class_tags,
|
||||||
classification_threshold=self.config.classification_threshold,
|
classification_threshold=self.config.classification_threshold,
|
||||||
)
|
)
|
@ -4,7 +4,8 @@ This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
|||||||
It takes the structured detection data extracted by the `extraction` module
|
It takes the structured detection data extracted by the `extraction` module
|
||||||
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||||
class probabilities, and features for each detection point) and converts it
|
class probabilities, and features for each detection point) and converts it
|
||||||
into standardized prediction objects based on the `soundevent` data model.
|
into meaningful, standardized prediction objects based on the `soundevent` data
|
||||||
|
model.
|
||||||
|
|
||||||
The process involves:
|
The process involves:
|
||||||
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||||
@ -32,7 +33,7 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
||||||
from batdetect2.targets.classes import SoundEventDecoder
|
from batdetect2.targets.classes import SoundEventDecoder
|
||||||
from batdetect2.utils.arrays import iterate_over_array
|
from batdetect2.utils.arrays import iterate_over_array
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ decoding.
|
|||||||
|
|
||||||
def convert_xr_dataset_to_raw_prediction(
|
def convert_xr_dataset_to_raw_prediction(
|
||||||
detection_dataset: xr.Dataset,
|
detection_dataset: xr.Dataset,
|
||||||
geometry_decoder: GeometryDecoder,
|
geometry_builder: GeometryBuilder,
|
||||||
) -> List[RawPrediction]:
|
) -> List[RawPrediction]:
|
||||||
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
output by `extract_detection_xr_dataset`. Expected variables include
|
output by `extract_detection_xr_dataset`. Expected variables include
|
||||||
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||||
Must have a 'detection' dimension.
|
Must have a 'detection' dimension.
|
||||||
geometry_decoder : GeometryDecoder
|
geometry_builder : GeometryBuilder
|
||||||
A function that takes a position tuple `(time, freq)` and a NumPy array
|
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||||
of dimensions, and returns the corresponding reconstructed
|
of dimensions, and returns the corresponding reconstructed
|
||||||
`soundevent.data.Geometry`.
|
`soundevent.data.Geometry`.
|
||||||
@ -95,20 +96,14 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
for det_num in range(detection_dataset.sizes["detection"]):
|
for det_num in range(detection_dataset.sizes["detection"]):
|
||||||
det_info = detection_dataset.sel(detection=det_num)
|
det_info = detection_dataset.sel(detection=det_num)
|
||||||
|
|
||||||
# TODO: Maybe clean this up
|
geom = geometry_builder(
|
||||||
highest_scoring_class = det_info.coords["category"][
|
|
||||||
det_info["classes"].argmax()
|
|
||||||
].item()
|
|
||||||
|
|
||||||
geom = geometry_decoder(
|
|
||||||
(det_info.time, det_info.frequency),
|
(det_info.time, det_info.frequency),
|
||||||
det_info.dimensions,
|
det_info.dimensions,
|
||||||
class_name=highest_scoring_class,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
detections.append(
|
detections.append(
|
||||||
RawPrediction(
|
RawPrediction(
|
||||||
detection_score=det_info.scores,
|
detection_score=det_info.score,
|
||||||
geometry=geom,
|
geometry=geom,
|
||||||
class_scores=det_info.classes,
|
class_scores=det_info.classes,
|
||||||
features=det_info.features,
|
features=det_info.features,
|
@ -1,6 +1,6 @@
|
|||||||
"""Extracts candidate detection points from a model output heatmap.
|
"""Extracts candidate detection points from a model output heatmap.
|
||||||
|
|
||||||
This module implements Step 3 within the BatDetect2 postprocessing
|
This module implements a specific step within the BatDetect2 postprocessing
|
||||||
pipeline. Its primary function is to identify potential sound event locations
|
pipeline. Its primary function is to identify potential sound event locations
|
||||||
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||||
produced by the neural network (usually after Non-Maximum Suppression and
|
produced by the neural network (usually after Non-Maximum Suppression and
|
@ -1,9 +1,9 @@
|
|||||||
"""Extracts associated data for detected points from model output arrays.
|
"""Extracts associated data for detected points from model output arrays.
|
||||||
|
|
||||||
This module implements a Step 4 in the BatDetect2 postprocessing pipeline.
|
This module implements a key step (Step 4) in the BatDetect2 postprocessing
|
||||||
After candidate detection points (time, frequency, score) have been identified,
|
pipeline. After candidate detection points (time, frequency, score) have been
|
||||||
this module extracts the corresponding values from other raw model output
|
identified, this module extracts the corresponding values from other raw model
|
||||||
arrays, such as:
|
output arrays, such as:
|
||||||
|
|
||||||
- Predicted bounding box sizes (width, height).
|
- Predicted bounding box sizes (width, height).
|
||||||
- Class probability scores for each defined target class.
|
- Class probability scores for each defined target class.
|
@ -11,37 +11,30 @@ modularity and consistent interaction between different parts of the BatDetect2
|
|||||||
system that deal with model predictions.
|
system that deal with model predictions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, NamedTuple, Optional, Protocol
|
from typing import Callable, List, NamedTuple, Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.targets.types import Position, Size
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RawPrediction",
|
"RawPrediction",
|
||||||
"PostprocessorProtocol",
|
"PostprocessorProtocol",
|
||||||
"GeometryDecoder",
|
"GeometryBuilder",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# TODO: update the docstring
|
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
|
||||||
class GeometryDecoder(Protocol):
|
"""Type alias for a function that recovers geometry from position and size.
|
||||||
"""Type alias for a function that recovers geometry from position and size.
|
|
||||||
|
|
||||||
This callable takes:
|
This callable takes:
|
||||||
1. A position tuple `(time, frequency)`.
|
1. A position tuple `(time, frequency)`.
|
||||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||||
3. Optionally a class name of the highest scoring class. This is to accomodate
|
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||||
different ways of decoding geometry that depend on the predicted class.
|
`BoundingBox`).
|
||||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
"""
|
||||||
`BoundingBox`).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, position: Position, size: Size, class_name: Optional[str] = None
|
|
||||||
) -> data.Geometry: ...
|
|
||||||
|
|
||||||
|
|
||||||
class RawPrediction(NamedTuple):
|
class RawPrediction(NamedTuple):
|
@ -23,7 +23,7 @@ object is via the `build_targets` or `load_targets` functions.
|
|||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from loguru import logger
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -50,8 +50,7 @@ from batdetect2.targets.filtering import (
|
|||||||
load_filter_from_config,
|
load_filter_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
AnchorBBoxMapperConfig,
|
ROIConfig,
|
||||||
ROIMapperConfig,
|
|
||||||
ROITargetMapper,
|
ROITargetMapper,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
)
|
)
|
||||||
@ -60,11 +59,11 @@ from batdetect2.targets.terms import (
|
|||||||
TermInfo,
|
TermInfo,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
call_type,
|
call_type,
|
||||||
default_term_registry,
|
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
get_term_from_key,
|
get_term_from_key,
|
||||||
individual,
|
individual,
|
||||||
register_term,
|
register_term,
|
||||||
|
term_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.transform import (
|
from batdetect2.targets.transform import (
|
||||||
DerivationRegistry,
|
DerivationRegistry,
|
||||||
@ -74,13 +73,13 @@ from batdetect2.targets.transform import (
|
|||||||
SoundEventTransformation,
|
SoundEventTransformation,
|
||||||
TransformConfig,
|
TransformConfig,
|
||||||
build_transformation_from_config,
|
build_transformation_from_config,
|
||||||
default_derivation_registry,
|
derivation_registry,
|
||||||
get_derivation,
|
get_derivation,
|
||||||
load_transformation_config,
|
load_transformation_config,
|
||||||
load_transformation_from_config,
|
load_transformation_from_config,
|
||||||
register_derivation,
|
register_derivation,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.types import Position, Size, TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClassesConfig",
|
"ClassesConfig",
|
||||||
@ -89,7 +88,7 @@ __all__ = [
|
|||||||
"FilterConfig",
|
"FilterConfig",
|
||||||
"FilterRule",
|
"FilterRule",
|
||||||
"MapValueRule",
|
"MapValueRule",
|
||||||
"AnchorBBoxMapperConfig",
|
"ROIConfig",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"ReplaceRule",
|
"ReplaceRule",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
@ -157,12 +156,12 @@ class TargetConfig(BaseConfig):
|
|||||||
omitted, default ROI mapping settings are used.
|
omitted, default ROI mapping settings are used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
filtering: FilterConfig = Field(default_factory=FilterConfig)
|
filtering: Optional[FilterConfig] = None
|
||||||
transforms: TransformConfig = Field(default_factory=TransformConfig)
|
transforms: Optional[TransformConfig] = None
|
||||||
classes: ClassesConfig = Field(
|
classes: ClassesConfig = Field(
|
||||||
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
||||||
)
|
)
|
||||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
roi: Optional[ROIConfig] = None
|
||||||
|
|
||||||
|
|
||||||
def load_target_config(
|
def load_target_config(
|
||||||
@ -241,7 +240,6 @@ class Targets(TargetProtocol):
|
|||||||
generic_class_tags: List[data.Tag],
|
generic_class_tags: List[data.Tag],
|
||||||
filter_fn: Optional[SoundEventFilter] = None,
|
filter_fn: Optional[SoundEventFilter] = None,
|
||||||
transform_fn: Optional[SoundEventTransformation] = None,
|
transform_fn: Optional[SoundEventTransformation] = None,
|
||||||
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
|
||||||
):
|
):
|
||||||
"""Initialize the Targets object.
|
"""Initialize the Targets object.
|
||||||
|
|
||||||
@ -274,16 +272,6 @@ class Targets(TargetProtocol):
|
|||||||
self._encode_fn = encode_fn
|
self._encode_fn = encode_fn
|
||||||
self._decode_fn = decode_fn
|
self._decode_fn = decode_fn
|
||||||
self._transform_fn = transform_fn
|
self._transform_fn = transform_fn
|
||||||
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
|
||||||
|
|
||||||
for class_name in self._roi_mapper_overrides:
|
|
||||||
if class_name not in self.class_names:
|
|
||||||
# TODO: improve this warning
|
|
||||||
logger.warning(
|
|
||||||
"The ROI mapper overrides contains a class ({class_name}) "
|
|
||||||
"not present in the class names.",
|
|
||||||
class_name=class_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||||
"""Apply the configured filter to a sound event annotation.
|
"""Apply the configured filter to a sound event annotation.
|
||||||
@ -303,9 +291,7 @@ class Targets(TargetProtocol):
|
|||||||
return True
|
return True
|
||||||
return self._filter_fn(sound_event)
|
return self._filter_fn(sound_event)
|
||||||
|
|
||||||
def encode_class(
|
def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]:
|
||||||
self, sound_event: data.SoundEventAnnotation
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Encode a sound event annotation to its target class name.
|
"""Encode a sound event annotation to its target class name.
|
||||||
|
|
||||||
Applies the configured class definition rules (including priority)
|
Applies the configured class definition rules (including priority)
|
||||||
@ -326,7 +312,7 @@ class Targets(TargetProtocol):
|
|||||||
"""
|
"""
|
||||||
return self._encode_fn(sound_event)
|
return self._encode_fn(sound_event)
|
||||||
|
|
||||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
def decode(self, class_label: str) -> List[data.Tag]:
|
||||||
"""Decode a predicted class name back into representative tags.
|
"""Decode a predicted class name back into representative tags.
|
||||||
|
|
||||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||||
@ -366,9 +352,9 @@ class Targets(TargetProtocol):
|
|||||||
return self._transform_fn(sound_event)
|
return self._transform_fn(sound_event)
|
||||||
return sound_event
|
return sound_event
|
||||||
|
|
||||||
def encode_roi(
|
def get_position(
|
||||||
self, sound_event: data.SoundEventAnnotation
|
self, sound_event: data.SoundEventAnnotation
|
||||||
) -> tuple[Position, Size]:
|
) -> tuple[float, float]:
|
||||||
"""Extract the target reference position from the annotation's roi.
|
"""Extract the target reference position from the annotation's roi.
|
||||||
|
|
||||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||||
@ -388,20 +374,50 @@ class Targets(TargetProtocol):
|
|||||||
ValueError
|
ValueError
|
||||||
If the annotation lacks geometry.
|
If the annotation lacks geometry.
|
||||||
"""
|
"""
|
||||||
class_name = self.encode_class(sound_event)
|
geom = sound_event.sound_event.geometry
|
||||||
|
|
||||||
if class_name in self._roi_mapper_overrides:
|
if geom is None:
|
||||||
return self._roi_mapper_overrides[class_name].encode(
|
raise ValueError(
|
||||||
sound_event.sound_event
|
"Sound event has no geometry, cannot get its position."
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._roi_mapper.encode(sound_event.sound_event)
|
return self._roi_mapper.get_roi_position(geom)
|
||||||
|
|
||||||
def decode_roi(
|
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||||
|
"""Calculate the target size dimensions from the annotation's geometry.
|
||||||
|
|
||||||
|
Delegates to the internal ROI mapper's `get_roi_size` method, which
|
||||||
|
applies configured scaling factors.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEventAnnotation
|
||||||
|
The annotation containing the geometry (ROI).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
NumPy array containing the size dimensions, matching the
|
||||||
|
order in `self.dimension_names` (e.g., `[width, height]`).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the annotation lacks geometry.
|
||||||
|
"""
|
||||||
|
geom = sound_event.sound_event.geometry
|
||||||
|
|
||||||
|
if geom is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Sound event has no geometry, cannot get its size."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._roi_mapper.get_roi_size(geom)
|
||||||
|
|
||||||
|
def recover_roi(
|
||||||
self,
|
self,
|
||||||
position: Position,
|
pos: tuple[float, float],
|
||||||
size: Size,
|
dims: np.ndarray,
|
||||||
class_name: Optional[str] = None,
|
|
||||||
) -> data.Geometry:
|
) -> data.Geometry:
|
||||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||||
|
|
||||||
@ -422,13 +438,7 @@ class Targets(TargetProtocol):
|
|||||||
data.Geometry
|
data.Geometry
|
||||||
The reconstructed geometry (typically `BoundingBox`).
|
The reconstructed geometry (typically `BoundingBox`).
|
||||||
"""
|
"""
|
||||||
if class_name in self._roi_mapper_overrides:
|
return self._roi_mapper.recover_roi(pos, dims)
|
||||||
return self._roi_mapper_overrides[class_name].decode(
|
|
||||||
position,
|
|
||||||
size,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._roi_mapper.decode(position, size)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CLASSES = [
|
DEFAULT_CLASSES = [
|
||||||
@ -483,12 +493,10 @@ DEFAULT_CLASSES = [
|
|||||||
TargetClass(
|
TargetClass(
|
||||||
tags=[TagInfo(value="Nyctalus leisleri")],
|
tags=[TagInfo(value="Nyctalus leisleri")],
|
||||||
name="nyclei",
|
name="nyclei",
|
||||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
|
||||||
),
|
),
|
||||||
TargetClass(
|
TargetClass(
|
||||||
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
||||||
name="rhifer",
|
name="rhifer",
|
||||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
|
||||||
),
|
),
|
||||||
TargetClass(
|
TargetClass(
|
||||||
tags=[TagInfo(value="Plecotus auritus")],
|
tags=[TagInfo(value="Plecotus auritus")],
|
||||||
@ -529,14 +537,13 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
classes=DEFAULT_CLASSES_CONFIG,
|
classes=DEFAULT_CLASSES_CONFIG,
|
||||||
roi=AnchorBBoxMapperConfig(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_targets(
|
def build_targets(
|
||||||
config: Optional[TargetConfig] = None,
|
config: Optional[TargetConfig] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
) -> Targets:
|
) -> Targets:
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
|
|
||||||
@ -599,17 +606,12 @@ def build_targets(
|
|||||||
if config.transforms
|
if config.transforms
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
roi_mapper = build_roi_mapper(config.roi)
|
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
|
||||||
class_names = get_class_names_from_config(config.classes)
|
class_names = get_class_names_from_config(config.classes)
|
||||||
generic_class_tags = build_generic_class_tags(
|
generic_class_tags = build_generic_class_tags(
|
||||||
config.classes,
|
config.classes,
|
||||||
term_registry=term_registry,
|
term_registry=term_registry,
|
||||||
)
|
)
|
||||||
roi_overrides = {
|
|
||||||
class_config.name: build_roi_mapper(class_config.roi)
|
|
||||||
for class_config in config.classes.classes
|
|
||||||
if class_config.roi is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
return Targets(
|
return Targets(
|
||||||
filter_fn=filter_fn,
|
filter_fn=filter_fn,
|
||||||
@ -619,15 +621,14 @@ def build_targets(
|
|||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
generic_class_tags=generic_class_tags,
|
generic_class_tags=generic_class_tags,
|
||||||
transform_fn=transform_fn,
|
transform_fn=transform_fn,
|
||||||
roi_mapper_overrides=roi_overrides,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_targets(
|
def load_targets(
|
||||||
config_path: data.PathLike,
|
config_path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
derivation_registry: DerivationRegistry = default_derivation_registry,
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
) -> Targets:
|
) -> Targets:
|
||||||
"""Load a Targets object directly from a configuration file.
|
"""Load a Targets object directly from a configuration file.
|
||||||
|
|
@ -6,27 +6,29 @@ from pydantic import Field, field_validator
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.targets.rois import ROIMapperConfig
|
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import (
|
||||||
GENERIC_CLASS_KEY,
|
GENERIC_CLASS_KEY,
|
||||||
TagInfo,
|
TagInfo,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
default_term_registry,
|
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
|
term_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_SPECIES_LIST",
|
"SoundEventEncoder",
|
||||||
"build_generic_class_tags",
|
"SoundEventDecoder",
|
||||||
"build_sound_event_decoder",
|
"TargetClass",
|
||||||
"build_sound_event_encoder",
|
"ClassesConfig",
|
||||||
"get_class_names_from_config",
|
|
||||||
"load_classes_config",
|
"load_classes_config",
|
||||||
"load_decoder_from_config",
|
|
||||||
"load_encoder_from_config",
|
"load_encoder_from_config",
|
||||||
|
"load_decoder_from_config",
|
||||||
|
"build_sound_event_encoder",
|
||||||
|
"build_sound_event_decoder",
|
||||||
|
"build_generic_class_tags",
|
||||||
|
"get_class_names_from_config",
|
||||||
|
"DEFAULT_SPECIES_LIST",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
"""Type alias for a sound event class encoder function.
|
"""Type alias for a sound event class encoder function.
|
||||||
|
|
||||||
@ -111,7 +113,6 @@ class TargetClass(BaseConfig):
|
|||||||
tags: List[TagInfo] = Field(min_length=1)
|
tags: List[TagInfo] = Field(min_length=1)
|
||||||
match_type: Literal["all", "any"] = Field(default="all")
|
match_type: Literal["all", "any"] = Field(default="all")
|
||||||
output_tags: Optional[List[TagInfo]] = None
|
output_tags: Optional[List[TagInfo]] = None
|
||||||
roi: Optional[ROIMapperConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_default_classes() -> List[TargetClass]:
|
def _get_default_classes() -> List[TargetClass]:
|
||||||
@ -234,7 +235,7 @@ class ClassesConfig(BaseConfig):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def is_target_class(
|
def _is_target_class(
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
tags: Set[data.Tag],
|
tags: Set[data.Tag],
|
||||||
match_all: bool = True,
|
match_all: bool = True,
|
||||||
@ -315,7 +316,7 @@ def _encode_with_multiple_classifiers(
|
|||||||
|
|
||||||
def build_sound_event_encoder(
|
def build_sound_event_encoder(
|
||||||
config: ClassesConfig,
|
config: ClassesConfig,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> SoundEventEncoder:
|
) -> SoundEventEncoder:
|
||||||
"""Build a sound event encoder function from the classes configuration.
|
"""Build a sound event encoder function from the classes configuration.
|
||||||
|
|
||||||
@ -349,7 +350,7 @@ def build_sound_event_encoder(
|
|||||||
(
|
(
|
||||||
class_info.name,
|
class_info.name,
|
||||||
partial(
|
partial(
|
||||||
is_target_class,
|
_is_target_class,
|
||||||
tags={
|
tags={
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||||
for tag_info in class_info.tags
|
for tag_info in class_info.tags
|
||||||
@ -409,7 +410,7 @@ def _decode_class(
|
|||||||
|
|
||||||
def build_sound_event_decoder(
|
def build_sound_event_decoder(
|
||||||
config: ClassesConfig,
|
config: ClassesConfig,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
raise_on_unmapped: bool = False,
|
raise_on_unmapped: bool = False,
|
||||||
) -> SoundEventDecoder:
|
) -> SoundEventDecoder:
|
||||||
"""Build a sound event decoder function from the classes configuration.
|
"""Build a sound event decoder function from the classes configuration.
|
||||||
@ -464,7 +465,7 @@ def build_sound_event_decoder(
|
|||||||
|
|
||||||
def build_generic_class_tags(
|
def build_generic_class_tags(
|
||||||
config: ClassesConfig,
|
config: ClassesConfig,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> List[data.Tag]:
|
) -> List[data.Tag]:
|
||||||
"""Extract and build the list of tags for the generic class from config.
|
"""Extract and build the list of tags for the generic class from config.
|
||||||
|
|
||||||
@ -529,7 +530,7 @@ def load_classes_config(
|
|||||||
def load_encoder_from_config(
|
def load_encoder_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> SoundEventEncoder:
|
) -> SoundEventEncoder:
|
||||||
"""Load a class encoder function directly from a configuration file.
|
"""Load a class encoder function directly from a configuration file.
|
||||||
|
|
||||||
@ -570,7 +571,7 @@ def load_encoder_from_config(
|
|||||||
def load_decoder_from_config(
|
def load_decoder_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
raise_on_unmapped: bool = False,
|
raise_on_unmapped: bool = False,
|
||||||
) -> SoundEventDecoder:
|
) -> SoundEventDecoder:
|
||||||
"""Load a class decoder function directly from a configuration file.
|
"""Load a class decoder function directly from a configuration file.
|
@ -10,7 +10,7 @@ from batdetect2.targets.terms import (
|
|||||||
TagInfo,
|
TagInfo,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
default_term_registry,
|
term_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -156,7 +156,7 @@ def equal_tags(
|
|||||||
|
|
||||||
def build_filter_from_rule(
|
def build_filter_from_rule(
|
||||||
rule: FilterRule,
|
rule: FilterRule,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> SoundEventFilter:
|
) -> SoundEventFilter:
|
||||||
"""Creates a callable filter function from a single FilterRule.
|
"""Creates a callable filter function from a single FilterRule.
|
||||||
|
|
||||||
@ -243,7 +243,7 @@ class FilterConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_sound_event_filter(
|
def build_sound_event_filter(
|
||||||
config: FilterConfig,
|
config: FilterConfig,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> SoundEventFilter:
|
) -> SoundEventFilter:
|
||||||
"""Builds a merged filter function from a FilterConfig object.
|
"""Builds a merged filter function from a FilterConfig object.
|
||||||
|
|
||||||
@ -291,7 +291,7 @@ def load_filter_config(
|
|||||||
def load_filter_from_config(
|
def load_filter_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> SoundEventFilter:
|
) -> SoundEventFilter:
|
||||||
"""Loads filter configuration from a file and builds the filter function.
|
"""Loads filter configuration from a file and builds the filter function.
|
||||||
|
|
511
batdetect2/targets/rois.py
Normal file
511
batdetect2/targets/rois.py
Normal file
@ -0,0 +1,511 @@
|
|||||||
|
"""Handles mapping between geometric ROIs and target representations.
|
||||||
|
|
||||||
|
This module defines the interface and provides implementation for converting
|
||||||
|
a sound event's Region of Interest (ROI), typically represented by a
|
||||||
|
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
|
||||||
|
suitable for use as a machine learning target. This usually involves:
|
||||||
|
|
||||||
|
1. Extracting a single reference point (time, frequency) from the geometry.
|
||||||
|
2. Calculating relevant size dimensions (e.g., duration/width,
|
||||||
|
bandwidth/height) and applying scaling factors.
|
||||||
|
|
||||||
|
It also provides the inverse operation: recovering an approximate geometric ROI
|
||||||
|
(like a `BoundingBox`) from a predicted reference point and predicted size
|
||||||
|
dimensions.
|
||||||
|
|
||||||
|
This logic is encapsulated within components adhering to the `ROITargetMapper`
|
||||||
|
protocol. Configuration for this mapping (e.g., which reference point to use,
|
||||||
|
scaling factors) is managed by the `ROIConfig`. This module separates the
|
||||||
|
*geometric* aspect of target definition from the *semantic* classification
|
||||||
|
handled in `batdetect2.targets.classes`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Literal, Optional, Protocol, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
|
||||||
|
Positions = Literal[
|
||||||
|
"bottom-left",
|
||||||
|
"bottom-right",
|
||||||
|
"top-left",
|
||||||
|
"top-right",
|
||||||
|
"center-left",
|
||||||
|
"center-right",
|
||||||
|
"top-center",
|
||||||
|
"bottom-center",
|
||||||
|
"center",
|
||||||
|
"centroid",
|
||||||
|
"point_on_surface",
|
||||||
|
]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ROITargetMapper",
|
||||||
|
"ROIConfig",
|
||||||
|
"BBoxEncoder",
|
||||||
|
"build_roi_mapper",
|
||||||
|
"load_roi_mapper",
|
||||||
|
"DEFAULT_POSITION",
|
||||||
|
"SIZE_WIDTH",
|
||||||
|
"SIZE_HEIGHT",
|
||||||
|
"SIZE_ORDER",
|
||||||
|
"DEFAULT_TIME_SCALE",
|
||||||
|
"DEFAULT_FREQUENCY_SCALE",
|
||||||
|
]
|
||||||
|
|
||||||
|
SIZE_WIDTH = "width"
|
||||||
|
"""Standard name for the width/time dimension component ('width')."""
|
||||||
|
|
||||||
|
SIZE_HEIGHT = "height"
|
||||||
|
"""Standard name for the height/frequency dimension component ('height')."""
|
||||||
|
|
||||||
|
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
|
||||||
|
"""Standard order of dimensions for size arrays ([width, height])."""
|
||||||
|
|
||||||
|
DEFAULT_TIME_SCALE = 1000.0
|
||||||
|
"""Default scaling factor for time duration."""
|
||||||
|
|
||||||
|
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
|
||||||
|
"""Default scaling factor for frequency bandwidth."""
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_POSITION = "bottom-left"
|
||||||
|
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||||
|
|
||||||
|
|
||||||
|
class ROITargetMapper(Protocol):
|
||||||
|
"""Protocol defining the interface for ROI-to-target mapping.
|
||||||
|
|
||||||
|
Specifies the methods required for converting a geometric region of interest
|
||||||
|
(`soundevent.data.Geometry`) into a target representation (reference point
|
||||||
|
and scaled dimensions) and for recovering an approximate ROI from that
|
||||||
|
representation.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
dimension_names : List[str]
|
||||||
|
A list containing the names of the dimensions returned by
|
||||||
|
`get_roi_size` and expected by `recover_roi`
|
||||||
|
(e.g., ['width', 'height']).
|
||||||
|
"""
|
||||||
|
|
||||||
|
dimension_names: List[str]
|
||||||
|
|
||||||
|
def get_roi_position(
|
||||||
|
self,
|
||||||
|
geom: data.Geometry,
|
||||||
|
position: Optional[Positions] = None,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
"""Extract the reference position from a geometry.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
geom : soundevent.data.Geometry
|
||||||
|
The input geometry (e.g., BoundingBox, Polygon).
|
||||||
|
position : Positions, optional
|
||||||
|
Overrides the default `position` configured for the mapper.
|
||||||
|
If provided, this position will be used instead of the mapper's
|
||||||
|
internal default.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[float, float]
|
||||||
|
The calculated reference position as (time, frequency) coordinates,
|
||||||
|
based on the implementing class's configuration (e.g., "center",
|
||||||
|
"bottom-left").
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the position cannot be calculated for the given geometry type
|
||||||
|
or configured reference point.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||||
|
"""Calculate the scaled target dimensions from a geometry.
|
||||||
|
|
||||||
|
Computes the relevant size measures.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
geom : soundevent.data.Geometry
|
||||||
|
The input geometry.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
A NumPy array containing the scaled dimensions corresponding to
|
||||||
|
`dimension_names`. For bounding boxes, typically contains
|
||||||
|
`[scaled_width, scaled_height]`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
TypeError, ValueError
|
||||||
|
If the size cannot be computed for the given geometry type.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def recover_roi(
|
||||||
|
self,
|
||||||
|
pos: tuple[float, float],
|
||||||
|
dims: np.ndarray,
|
||||||
|
position: Optional[Positions] = None,
|
||||||
|
) -> data.Geometry:
|
||||||
|
"""Recover an approximate ROI from a position and target dimensions.
|
||||||
|
|
||||||
|
Performs the inverse mapping: takes a reference position and the
|
||||||
|
predicted dimensions and reconstructs a geometric representation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pos : Tuple[float, float]
|
||||||
|
The reference position (time, frequency).
|
||||||
|
dims : np.ndarray
|
||||||
|
NumPy array containing the dimensions, matching the order
|
||||||
|
specified by `dimension_names`.
|
||||||
|
position : Positions, optional
|
||||||
|
Overrides the default `position` configured for the mapper.
|
||||||
|
If provided, this position will be used instead of the mapper's
|
||||||
|
internal default when reconstructing the roi geometry.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.Geometry
|
||||||
|
The reconstructed geometry.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the number of provided dimensions `dims` does not match
|
||||||
|
`dimension_names` or if reconstruction fails.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ROIConfig(BaseConfig):
|
||||||
|
"""Configuration for mapping Regions of Interest (ROIs).
|
||||||
|
|
||||||
|
Defines parameters controlling how geometric ROIs are converted into
|
||||||
|
target representations (reference points and scaled sizes).
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
position : Positions, default="bottom-left"
|
||||||
|
Specifies the reference point within the geometry (e.g., bounding box)
|
||||||
|
to use as the target location (e.g., "center", "bottom-left").
|
||||||
|
See `soundevent.geometry.operations.Positions`.
|
||||||
|
time_scale : float, default=1000.0
|
||||||
|
Scaling factor applied to the time duration (width) of the ROI
|
||||||
|
when calculating the target size representation. Must match model
|
||||||
|
expectations.
|
||||||
|
frequency_scale : float, default=1/859.375
|
||||||
|
Scaling factor applied to the frequency bandwidth (height) of the ROI
|
||||||
|
when calculating the target size representation. Must match model
|
||||||
|
expectations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
position: Positions = DEFAULT_POSITION
|
||||||
|
time_scale: float = DEFAULT_TIME_SCALE
|
||||||
|
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
||||||
|
|
||||||
|
|
||||||
|
class BBoxEncoder(ROITargetMapper):
|
||||||
|
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
|
||||||
|
|
||||||
|
This class implements the ROI mapping protocol primarily for
|
||||||
|
`soundevent.data.BoundingBox` geometry. It extracts reference points,
|
||||||
|
calculates scaled width/height, and recovers bounding boxes based on
|
||||||
|
configured position and scaling factors.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
dimension_names : List[str]
|
||||||
|
Specifies the output dimension names as ['width', 'height'].
|
||||||
|
position : Positions
|
||||||
|
The configured reference point type (e.g., "center", "bottom-left").
|
||||||
|
time_scale : float
|
||||||
|
The configured scaling factor for the time dimension (width).
|
||||||
|
frequency_scale : float
|
||||||
|
The configured scaling factor for the frequency dimension (height).
|
||||||
|
"""
|
||||||
|
|
||||||
|
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
position: Positions = DEFAULT_POSITION,
|
||||||
|
time_scale: float = DEFAULT_TIME_SCALE,
|
||||||
|
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||||
|
):
|
||||||
|
"""Initialize the BBoxEncoder.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
position : Positions, default="bottom-left"
|
||||||
|
Reference point type within the bounding box.
|
||||||
|
time_scale : float, default=1000.0
|
||||||
|
Scaling factor for time duration (width).
|
||||||
|
frequency_scale : float, default=1/859.375
|
||||||
|
Scaling factor for frequency bandwidth (height).
|
||||||
|
"""
|
||||||
|
self.position: Positions = position
|
||||||
|
self.time_scale = time_scale
|
||||||
|
self.frequency_scale = frequency_scale
|
||||||
|
|
||||||
|
def get_roi_position(
|
||||||
|
self,
|
||||||
|
geom: data.Geometry,
|
||||||
|
position: Optional[Positions] = None,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""Extract the configured reference position from the geometry.
|
||||||
|
|
||||||
|
Uses `soundevent.geometry.get_geometry_point`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
geom : soundevent.data.Geometry
|
||||||
|
Input geometry (e.g., BoundingBox).
|
||||||
|
position : Positions, optional
|
||||||
|
Overrides the default `position` configured for the encoder.
|
||||||
|
If provided, this position will be used instead of `self.position`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[float, float]
|
||||||
|
Reference position (time, frequency).
|
||||||
|
"""
|
||||||
|
from soundevent import geometry
|
||||||
|
|
||||||
|
position = position or self.position
|
||||||
|
return geometry.get_geometry_point(geom, position=position)
|
||||||
|
|
||||||
|
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||||
|
"""Calculate the scaled [width, height] from the geometry's bounds.
|
||||||
|
|
||||||
|
Computes the bounding box, extracts duration and bandwidth, and applies
|
||||||
|
the configured `time_scale` and `frequency_scale`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
geom : soundevent.data.Geometry
|
||||||
|
Input geometry.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
A 1D NumPy array: `[scaled_width, scaled_height]`.
|
||||||
|
"""
|
||||||
|
from soundevent import geometry
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||||
|
geom
|
||||||
|
)
|
||||||
|
return np.array(
|
||||||
|
[
|
||||||
|
(end_time - start_time) * self.time_scale,
|
||||||
|
(high_freq - low_freq) * self.frequency_scale,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def recover_roi(
|
||||||
|
self,
|
||||||
|
pos: tuple[float, float],
|
||||||
|
dims: np.ndarray,
|
||||||
|
position: Optional[Positions] = None,
|
||||||
|
) -> data.Geometry:
|
||||||
|
"""Recover a BoundingBox from a position and scaled dimensions.
|
||||||
|
|
||||||
|
Un-scales the input dimensions using the configured factors and
|
||||||
|
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
|
||||||
|
the given reference `pos` according to the configured `position` type.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pos : Tuple[float, float]
|
||||||
|
Reference position (time, frequency).
|
||||||
|
dims : np.ndarray
|
||||||
|
NumPy array containing the *scaled* dimensions, expected order is
|
||||||
|
[scaled_width, scaled_height].
|
||||||
|
position : Positions, optional
|
||||||
|
Overrides the default `position` configured for the encoder.
|
||||||
|
If provided, this position will be used instead of `self.position`
|
||||||
|
when reconstructing the bounding box.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.BoundingBox
|
||||||
|
The reconstructed bounding box.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `dims` does not have the expected shape (length 2).
|
||||||
|
"""
|
||||||
|
position = position or self.position
|
||||||
|
|
||||||
|
if dims.ndim != 1 or dims.shape[0] != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Dimension array does not have the expected shape. "
|
||||||
|
f"({dims.shape = }) != ([2])"
|
||||||
|
)
|
||||||
|
|
||||||
|
width, height = dims
|
||||||
|
return _build_bounding_box(
|
||||||
|
pos,
|
||||||
|
duration=float(width) / self.time_scale,
|
||||||
|
bandwidth=float(height) / self.frequency_scale,
|
||||||
|
position=self.position,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
|
||||||
|
"""Factory function to create an ROITargetMapper from configuration.
|
||||||
|
|
||||||
|
Currently creates a `BBoxEncoder` instance based on the provided
|
||||||
|
`ROIConfig`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : ROIConfig
|
||||||
|
Configuration object specifying ROI mapping parameters.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ROITargetMapper
|
||||||
|
An initialized `BBoxEncoder` instance configured with the settings
|
||||||
|
from `config`.
|
||||||
|
"""
|
||||||
|
return BBoxEncoder(
|
||||||
|
position=config.position,
|
||||||
|
time_scale=config.time_scale,
|
||||||
|
frequency_scale=config.frequency_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_roi_mapper(
|
||||||
|
path: data.PathLike, field: Optional[str] = None
|
||||||
|
) -> ROITargetMapper:
|
||||||
|
"""Load ROI mapping configuration from a file and build the mapper.
|
||||||
|
|
||||||
|
Convenience function that loads an `ROIConfig` from the specified file
|
||||||
|
(and optional field) and then uses `build_roi_mapper` to create the
|
||||||
|
corresponding `ROITargetMapper` instance.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file (e.g., YAML).
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
ROI configuration. If None, the entire file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ROITargetMapper
|
||||||
|
An initialized ROI mapper instance based on the configuration file.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||||
|
TypeError
|
||||||
|
If the configuration file cannot be found, parsed, validated, or if
|
||||||
|
the specified `field` is invalid.
|
||||||
|
"""
|
||||||
|
config = load_config(path=path, schema=ROIConfig, field=field)
|
||||||
|
return build_roi_mapper(config)
|
||||||
|
|
||||||
|
|
||||||
|
VALID_POSITIONS = [
|
||||||
|
"bottom-left",
|
||||||
|
"bottom-right",
|
||||||
|
"top-left",
|
||||||
|
"top-right",
|
||||||
|
"center-left",
|
||||||
|
"center-right",
|
||||||
|
"top-center",
|
||||||
|
"bottom-center",
|
||||||
|
"center",
|
||||||
|
"centroid",
|
||||||
|
"point_on_surface",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bounding_box(
|
||||||
|
pos: tuple[float, float],
|
||||||
|
duration: float,
|
||||||
|
bandwidth: float,
|
||||||
|
position: Positions = DEFAULT_POSITION,
|
||||||
|
) -> data.BoundingBox:
|
||||||
|
"""Construct a BoundingBox from a reference point, size, and position type.
|
||||||
|
|
||||||
|
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
|
||||||
|
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
||||||
|
the input `pos` (time, freq) is located relative to the box (e.g.,
|
||||||
|
center, corner).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pos : Tuple[float, float]
|
||||||
|
Reference position (time, frequency).
|
||||||
|
duration : float
|
||||||
|
The required *unscaled* duration (width) of the bounding box.
|
||||||
|
bandwidth : float
|
||||||
|
The required *unscaled* frequency bandwidth (height) of the bounding
|
||||||
|
box.
|
||||||
|
position : Positions, default="bottom-left"
|
||||||
|
Specifies which part of the bounding box the input `pos` corresponds to.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.BoundingBox
|
||||||
|
The constructed bounding box object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `position` is not a recognized value or format.
|
||||||
|
"""
|
||||||
|
time, freq = map(float, pos)
|
||||||
|
duration = max(0, duration)
|
||||||
|
bandwidth = max(0, bandwidth)
|
||||||
|
if position in ["center", "centroid", "point_on_surface"]:
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
max(time - duration / 2, 0),
|
||||||
|
max(freq - bandwidth / 2, 0),
|
||||||
|
max(time + duration / 2, 0),
|
||||||
|
max(freq + bandwidth / 2, 0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if position not in VALID_POSITIONS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid position: {position}. "
|
||||||
|
f"Valid options are: {VALID_POSITIONS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
y, x = position.split("-")
|
||||||
|
|
||||||
|
start_time = {
|
||||||
|
"left": time,
|
||||||
|
"center": time - duration / 2,
|
||||||
|
"right": time - duration,
|
||||||
|
}[x]
|
||||||
|
|
||||||
|
low_freq = {
|
||||||
|
"bottom": freq,
|
||||||
|
"center": freq - bandwidth / 2,
|
||||||
|
"top": freq - bandwidth,
|
||||||
|
}[y]
|
||||||
|
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
max(0, start_time),
|
||||||
|
max(0, low_freq),
|
||||||
|
max(0, start_time + duration),
|
||||||
|
max(0, low_freq + bandwidth),
|
||||||
|
]
|
||||||
|
)
|
@ -230,12 +230,11 @@ class TermRegistry(Mapping[str, data.Term]):
|
|||||||
del self._terms[key]
|
del self._terms[key]
|
||||||
|
|
||||||
|
|
||||||
default_term_registry = TermRegistry(
|
term_registry = TermRegistry(
|
||||||
terms=dict(
|
terms=dict(
|
||||||
[
|
[
|
||||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||||
("event", call_type),
|
("event", call_type),
|
||||||
("species", terms.scientific_name),
|
|
||||||
("individual", individual),
|
("individual", individual),
|
||||||
("data_source", data_source),
|
("data_source", data_source),
|
||||||
(GENERIC_CLASS_KEY, generic_class),
|
(GENERIC_CLASS_KEY, generic_class),
|
||||||
@ -253,7 +252,7 @@ is explicitly passed.
|
|||||||
|
|
||||||
def get_term_from_key(
|
def get_term_from_key(
|
||||||
key: str,
|
key: str,
|
||||||
term_registry: Optional[TermRegistry] = None,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> data.Term:
|
) -> data.Term:
|
||||||
"""Convenience function to retrieve a term by key from a registry.
|
"""Convenience function to retrieve a term by key from a registry.
|
||||||
|
|
||||||
@ -278,13 +277,10 @@ def get_term_from_key(
|
|||||||
KeyError
|
KeyError
|
||||||
If the key is not found in the specified registry.
|
If the key is not found in the specified registry.
|
||||||
"""
|
"""
|
||||||
term_registry = term_registry or default_term_registry
|
|
||||||
return term_registry.get_term(key)
|
return term_registry.get_term(key)
|
||||||
|
|
||||||
|
|
||||||
def get_term_keys(
|
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> List[str]:
|
|
||||||
"""Convenience function to get all registered keys from a registry.
|
"""Convenience function to get all registered keys from a registry.
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
Uses the global default registry unless a specific `term_registry`
|
||||||
@ -303,9 +299,7 @@ def get_term_keys(
|
|||||||
return term_registry.get_keys()
|
return term_registry.get_keys()
|
||||||
|
|
||||||
|
|
||||||
def get_terms(
|
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
|
||||||
term_registry: TermRegistry = default_term_registry,
|
|
||||||
) -> List[data.Term]:
|
|
||||||
"""Convenience function to get all registered terms from a registry.
|
"""Convenience function to get all registered terms from a registry.
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
Uses the global default registry unless a specific `term_registry`
|
||||||
@ -348,7 +342,7 @@ class TagInfo(BaseModel):
|
|||||||
|
|
||||||
def get_tag_from_info(
|
def get_tag_from_info(
|
||||||
tag_info: TagInfo,
|
tag_info: TagInfo,
|
||||||
term_registry: Optional[TermRegistry] = None,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> data.Tag:
|
) -> data.Tag:
|
||||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||||
|
|
||||||
@ -374,7 +368,6 @@ def get_tag_from_info(
|
|||||||
If the term key specified in `tag_info.key` is not found
|
If the term key specified in `tag_info.key` is not found
|
||||||
in the registry.
|
in the registry.
|
||||||
"""
|
"""
|
||||||
term_registry = term_registry or default_term_registry
|
|
||||||
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
||||||
return data.Tag(term=term, value=tag_info.value)
|
return data.Tag(term=term, value=tag_info.value)
|
||||||
|
|
||||||
@ -446,7 +439,7 @@ class TermConfig(BaseModel):
|
|||||||
def load_terms_from_config(
|
def load_terms_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = default_term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> Dict[str, data.Term]:
|
) -> Dict[str, data.Term]:
|
||||||
"""Loads term definitions from a configuration file and registers them.
|
"""Loads term definitions from a configuration file and registers them.
|
||||||
|
|
||||||
@ -497,6 +490,6 @@ def load_terms_from_config(
|
|||||||
|
|
||||||
|
|
||||||
def register_term(
|
def register_term(
|
||||||
key: str, term: data.Term, registry: TermRegistry = default_term_registry
|
key: str, term: data.Term, registry: TermRegistry = term_registry
|
||||||
) -> None:
|
) -> None:
|
||||||
registry.add_term(key, term)
|
registry.add_term(key, term)
|
@ -21,6 +21,9 @@ from batdetect2.targets.terms import (
|
|||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
get_term_from_key,
|
get_term_from_key,
|
||||||
)
|
)
|
||||||
|
from batdetect2.targets.terms import (
|
||||||
|
term_registry as default_term_registry,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DerivationRegistry",
|
"DerivationRegistry",
|
||||||
@ -31,7 +34,7 @@ __all__ = [
|
|||||||
"TransformConfig",
|
"TransformConfig",
|
||||||
"build_transform_from_rule",
|
"build_transform_from_rule",
|
||||||
"build_transformation_from_config",
|
"build_transformation_from_config",
|
||||||
"default_derivation_registry",
|
"derivation_registry",
|
||||||
"get_derivation",
|
"get_derivation",
|
||||||
"load_transformation_config",
|
"load_transformation_config",
|
||||||
"load_transformation_from_config",
|
"load_transformation_from_config",
|
||||||
@ -395,7 +398,7 @@ class DerivationRegistry(Mapping[str, Derivation]):
|
|||||||
return list(self._derivations.values())
|
return list(self._derivations.values())
|
||||||
|
|
||||||
|
|
||||||
default_derivation_registry = DerivationRegistry()
|
derivation_registry = DerivationRegistry()
|
||||||
"""Global instance of the DerivationRegistry.
|
"""Global instance of the DerivationRegistry.
|
||||||
|
|
||||||
Register custom derivation functions here to make them available by key
|
Register custom derivation functions here to make them available by key
|
||||||
@ -406,7 +409,7 @@ in `DeriveTagRule` configuration.
|
|||||||
def get_derivation(
|
def get_derivation(
|
||||||
key: str,
|
key: str,
|
||||||
import_derivation: bool = False,
|
import_derivation: bool = False,
|
||||||
registry: Optional[DerivationRegistry] = None,
|
registry: DerivationRegistry = derivation_registry,
|
||||||
):
|
):
|
||||||
"""Retrieve a derivation function by key, optionally importing it.
|
"""Retrieve a derivation function by key, optionally importing it.
|
||||||
|
|
||||||
@ -440,8 +443,6 @@ def get_derivation(
|
|||||||
AttributeError
|
AttributeError
|
||||||
If dynamic import fails because the function name isn't in the module.
|
If dynamic import fails because the function name isn't in the module.
|
||||||
"""
|
"""
|
||||||
registry = registry or default_derivation_registry
|
|
||||||
|
|
||||||
if not import_derivation or key in registry:
|
if not import_derivation or key in registry:
|
||||||
return registry.get_derivation(key)
|
return registry.get_derivation(key)
|
||||||
|
|
||||||
@ -457,16 +458,10 @@ def get_derivation(
|
|||||||
) from err
|
) from err
|
||||||
|
|
||||||
|
|
||||||
TranformationRule = Annotated[
|
|
||||||
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
|
||||||
Field(discriminator="rule_type"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_transform_from_rule(
|
def build_transform_from_rule(
|
||||||
rule: TranformationRule,
|
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
term_registry: Optional[TermRegistry] = None,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Build a specific SoundEventTransformation function from a rule config.
|
"""Build a specific SoundEventTransformation function from a rule config.
|
||||||
|
|
||||||
@ -564,8 +559,8 @@ def build_transform_from_rule(
|
|||||||
|
|
||||||
def build_transformation_from_config(
|
def build_transformation_from_config(
|
||||||
config: TransformConfig,
|
config: TransformConfig,
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
term_registry: Optional[TermRegistry] = None,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Build a composite transformation function from a TransformConfig.
|
"""Build a composite transformation function from a TransformConfig.
|
||||||
|
|
||||||
@ -586,7 +581,6 @@ def build_transformation_from_config(
|
|||||||
SoundEventTransformation
|
SoundEventTransformation
|
||||||
A single function that applies all configured transformations in order.
|
A single function that applies all configured transformations in order.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
transforms = [
|
transforms = [
|
||||||
build_transform_from_rule(
|
build_transform_from_rule(
|
||||||
rule,
|
rule,
|
||||||
@ -596,16 +590,14 @@ def build_transformation_from_config(
|
|||||||
for rule in config.rules
|
for rule in config.rules
|
||||||
]
|
]
|
||||||
|
|
||||||
return partial(apply_sequence_of_transforms, transforms=transforms)
|
def transformation(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
for transform in transforms:
|
||||||
|
sound_event_annotation = transform(sound_event_annotation)
|
||||||
|
return sound_event_annotation
|
||||||
|
|
||||||
|
return transformation
|
||||||
def apply_sequence_of_transforms(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
transforms: list[SoundEventTransformation],
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
for transform in transforms:
|
|
||||||
sound_event_annotation = transform(sound_event_annotation)
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
|
|
||||||
def load_transformation_config(
|
def load_transformation_config(
|
||||||
@ -639,8 +631,8 @@ def load_transformation_config(
|
|||||||
def load_transformation_from_config(
|
def load_transformation_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
term_registry: Optional[TermRegistry] = None,
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Load transformation config from a file and build the final function.
|
"""Load transformation config from a file and build the final function.
|
||||||
|
|
||||||
@ -685,7 +677,7 @@ def load_transformation_from_config(
|
|||||||
def register_derivation(
|
def register_derivation(
|
||||||
key: str,
|
key: str,
|
||||||
derivation: Derivation,
|
derivation: Derivation,
|
||||||
derivation_registry: Optional[DerivationRegistry] = None,
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a new derivation function in the global registry.
|
"""Register a new derivation function in the global registry.
|
||||||
|
|
||||||
@ -704,5 +696,4 @@ def register_derivation(
|
|||||||
KeyError
|
KeyError
|
||||||
If a derivation function with the same key is already registered.
|
If a derivation function with the same key is already registered.
|
||||||
"""
|
"""
|
||||||
derivation_registry = derivation_registry or default_derivation_registry
|
|
||||||
derivation_registry.register(key, derivation)
|
derivation_registry.register(key, derivation)
|
@ -19,16 +19,8 @@ from soundevent import data
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TargetProtocol",
|
"TargetProtocol",
|
||||||
"Position",
|
|
||||||
"Size",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
Position = tuple[float, float]
|
|
||||||
"""A tuple representing (time, frequency) coordinates."""
|
|
||||||
|
|
||||||
Size = np.ndarray
|
|
||||||
"""A NumPy array representing the size dimensions of a target."""
|
|
||||||
|
|
||||||
|
|
||||||
class TargetProtocol(Protocol):
|
class TargetProtocol(Protocol):
|
||||||
"""Protocol defining the interface for the target definition pipeline.
|
"""Protocol defining the interface for the target definition pipeline.
|
||||||
@ -110,7 +102,7 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def encode_class(
|
def encode(
|
||||||
self,
|
self,
|
||||||
sound_event: data.SoundEventAnnotation,
|
sound_event: data.SoundEventAnnotation,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
@ -131,7 +123,7 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
def decode(self, class_label: str) -> List[data.Tag]:
|
||||||
"""Decode a predicted class name back into representative tags.
|
"""Decode a predicted class name back into representative tags.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -155,9 +147,9 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def encode_roi(
|
def get_position(
|
||||||
self, sound_event: data.SoundEventAnnotation
|
self, sound_event: data.SoundEventAnnotation
|
||||||
) -> tuple[Position, Size]:
|
) -> tuple[float, float]:
|
||||||
"""Extract the target reference position from the annotation's geometry.
|
"""Extract the target reference position from the annotation's geometry.
|
||||||
|
|
||||||
Calculates the `(time, frequency)` coordinate representing the primary
|
Calculates the `(time, frequency)` coordinate representing the primary
|
||||||
@ -181,12 +173,36 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
# TODO: Update docstrings
|
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||||
def decode_roi(
|
"""Calculate the target size dimensions from the annotation's geometry.
|
||||||
self,
|
|
||||||
position: Position,
|
Computes the relevant physical size (e.g., duration/width,
|
||||||
size: Size,
|
bandwidth/height from a bounding box) to produce
|
||||||
class_name: Optional[str] = None,
|
the numerical target values expected by the model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEventAnnotation
|
||||||
|
The annotation containing the geometry (ROI) to process.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
A NumPy array containing the size dimensions, matching the
|
||||||
|
order specified by the `dimension_names` attribute (e.g.,
|
||||||
|
`[width, height]`).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the annotation lacks geometry or if the size cannot be computed.
|
||||||
|
TypeError
|
||||||
|
If geometry type is unsupported.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def recover_roi(
|
||||||
|
self, pos: tuple[float, float], dims: np.ndarray
|
||||||
) -> data.Geometry:
|
) -> data.Geometry:
|
||||||
"""Recover the ROI geometry from a position and dimensions.
|
"""Recover the ROI geometry from a position and dimensions.
|
||||||
|
|
||||||
@ -201,8 +217,6 @@ class TargetProtocol(Protocol):
|
|||||||
dims : np.ndarray
|
dims : np.ndarray
|
||||||
The NumPy array containing the dimensions (e.g., predicted
|
The NumPy array containing the dimensions (e.g., predicted
|
||||||
by the model), corresponding to the order in `dimension_names`.
|
by the model), corresponding to the order in `dimension_names`.
|
||||||
class_name: str
|
|
||||||
class
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
@ -97,7 +97,7 @@ def _is_in_subclip(
|
|||||||
start_time: float,
|
start_time: float,
|
||||||
end_time: float,
|
end_time: float,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
(time, _), _ = targets.encode_roi(sound_event_annotation)
|
time, _ = targets.get_position(sound_event_annotation)
|
||||||
return start_time <= time <= end_time
|
return start_time <= time <= end_time
|
||||||
|
|
||||||
|
|
@ -138,7 +138,7 @@ def generate_clip_label(
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
|
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events",
|
||||||
uuid=clip_annotation.uuid,
|
uuid=clip_annotation.uuid,
|
||||||
num=len(clip_annotation.sound_events),
|
num=len(clip_annotation.sound_events)
|
||||||
)
|
)
|
||||||
|
|
||||||
sound_events = []
|
sound_events = []
|
||||||
@ -260,7 +260,7 @@ def generate_heatmaps(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Get the position of the sound event
|
# Get the position of the sound event
|
||||||
(time, frequency), size = targets.encode_roi(sound_event_annotation)
|
time, frequency = targets.get_position(sound_event_annotation)
|
||||||
|
|
||||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||||
try:
|
try:
|
||||||
@ -280,6 +280,8 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
size = targets.get_size(sound_event_annotation)
|
||||||
|
|
||||||
size_heatmap = arrays.set_value_at_pos(
|
size_heatmap = arrays.set_value_at_pos(
|
||||||
size_heatmap,
|
size_heatmap,
|
||||||
size,
|
size,
|
||||||
@ -289,7 +291,7 @@ def generate_heatmaps(
|
|||||||
|
|
||||||
# Get the class name of the sound event
|
# Get the class name of the sound event
|
||||||
try:
|
try:
|
||||||
class_name = targets.encode_class(sound_event_annotation)
|
class_name = targets.encode(sound_event_annotation)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping annotation %s: Unexpected error while encoding "
|
"Skipping annotation %s: Unexpected error while encoding "
|
@ -1,19 +0,0 @@
|
|||||||
import marimo
|
|
||||||
|
|
||||||
__generated_with = "0.13.15"
|
|
||||||
app = marimo.App(width="medium")
|
|
||||||
|
|
||||||
|
|
||||||
@app.cell
|
|
||||||
def _():
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@app.cell
|
|
||||||
def _():
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app.run()
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user