mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Update TargetProtocol and related to include rois
This commit is contained in:
parent
9410112e41
commit
6236e78414
@ -1,25 +1,29 @@
|
||||
"""Main entry point for the BatDetect2 Target Definition subsystem.
|
||||
|
||||
This package (`batdetect2.targets`) provides the tools and configurations
|
||||
necessary to define precisely what the BatDetect2 model should learn to detect
|
||||
and classify from audio data. It involves several conceptual steps, managed
|
||||
through configuration files and culminating in executable functions:
|
||||
necessary to define precisely what the BatDetect2 model should learn to detect,
|
||||
classify, and localize from audio data. It involves several conceptual steps,
|
||||
managed through configuration files and culminating in an executable pipeline:
|
||||
|
||||
1. **Terms (`.terms`)**: Defining a controlled vocabulary for annotation tags.
|
||||
1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
|
||||
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
|
||||
3. **Transformation (`.transform`)**: Modifying tags (e.g., standardization,
|
||||
3. **Transformation (`.transform`)**: Modifying tags (standardization,
|
||||
derivation).
|
||||
4. **Class Definition (`.classes`)**: Mapping tags to specific target class
|
||||
names (encoding) and defining how predicted class names map back to tags
|
||||
(decoding).
|
||||
4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
|
||||
target position and size representations, and back.
|
||||
5. **Class Definition (`.classes`)**: Mapping tags to target class names
|
||||
(encoding) and mapping predicted names back to tags (decoding).
|
||||
|
||||
This module exposes the key components for users to configure and utilize this
|
||||
target definition pipeline, primarily through the `TargetConfig` data structure
|
||||
and the `Targets` class, which encapsulates the configured processing steps.
|
||||
and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
|
||||
configured processing steps. The main way to create a functional `Targets`
|
||||
object is via the `build_targets` or `load_targets` functions.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
@ -28,9 +32,9 @@ from batdetect2.targets.classes import (
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
TargetClass,
|
||||
build_decoder_from_config,
|
||||
build_encoder_from_config,
|
||||
build_generic_class_tags_from_config,
|
||||
build_generic_class_tags,
|
||||
build_sound_event_decoder,
|
||||
build_sound_event_encoder,
|
||||
get_class_names_from_config,
|
||||
load_classes_config,
|
||||
load_decoder_from_config,
|
||||
@ -40,10 +44,15 @@ from batdetect2.targets.filtering import (
|
||||
FilterConfig,
|
||||
FilterRule,
|
||||
SoundEventFilter,
|
||||
build_filter_from_config,
|
||||
build_sound_event_filter,
|
||||
load_filter_config,
|
||||
load_filter_from_config,
|
||||
)
|
||||
from batdetect2.targets.rois import (
|
||||
ROIConfig,
|
||||
ROITargetMapper,
|
||||
build_roi_mapper,
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermInfo,
|
||||
@ -69,6 +78,7 @@ from batdetect2.targets.transform import (
|
||||
load_transformation_from_config,
|
||||
register_derivation,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClassesConfig",
|
||||
@ -76,6 +86,8 @@ __all__ = [
|
||||
"FilterConfig",
|
||||
"FilterRule",
|
||||
"MapValueRule",
|
||||
"ROIConfig",
|
||||
"ROITargetMapper",
|
||||
"ReplaceRule",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
@ -84,13 +96,15 @@ __all__ = [
|
||||
"TagInfo",
|
||||
"TargetClass",
|
||||
"TargetConfig",
|
||||
"TargetProtocol",
|
||||
"Targets",
|
||||
"TermInfo",
|
||||
"TransformConfig",
|
||||
"build_decoder_from_config",
|
||||
"build_encoder_from_config",
|
||||
"build_filter_from_config",
|
||||
"build_generic_class_tags_from_config",
|
||||
"build_sound_event_decoder",
|
||||
"build_sound_event_encoder",
|
||||
"build_sound_event_filter",
|
||||
"build_generic_class_tags",
|
||||
"build_roi_mapper",
|
||||
"build_transformation_from_config",
|
||||
"call_type",
|
||||
"get_class_names_from_config",
|
||||
@ -114,29 +128,36 @@ __all__ = [
|
||||
class TargetConfig(BaseConfig):
|
||||
"""Unified configuration for the entire target definition pipeline.
|
||||
|
||||
This model aggregates the configurations for the optional filtering and
|
||||
transformation steps, and the mandatory class definition step. It serves as
|
||||
the primary input for building a complete `Targets` processing object.
|
||||
This model aggregates the configurations for semantic processing (filtering,
|
||||
transformation, class definition) and geometric processing (ROI mapping).
|
||||
It serves as the primary input for building a complete `Targets` object
|
||||
via `build_targets` or `load_targets`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
filtering : FilterConfig, optional
|
||||
Configuration for filtering sound event annotations. If None or
|
||||
omitted, no filtering is applied.
|
||||
Configuration for filtering sound event annotations based on tags.
|
||||
If None or omitted, no filtering is applied.
|
||||
transforms : TransformConfig, optional
|
||||
Configuration for transforming annotation tags. If None or omitted, no
|
||||
transformations are applied.
|
||||
Configuration for transforming annotation tags
|
||||
(mapping, derivation, etc.). If None or omitted, no tag transformations
|
||||
are applied.
|
||||
classes : ClassesConfig
|
||||
Configuration defining the specific target classes, their matching
|
||||
rules, decoding rules (`output_tags`), and the generic class
|
||||
definition. This section is mandatory.
|
||||
Configuration defining the specific target classes, their tag matching
|
||||
rules for encoding, their representative tags for decoding
|
||||
(`output_tags`), and the definition of the generic class tags.
|
||||
This section is mandatory.
|
||||
roi : ROIConfig, optional
|
||||
Configuration defining how geometric ROIs (e.g., bounding boxes) are
|
||||
mapped to target representations (reference point, scaled size).
|
||||
Controls `position`, `time_scale`, `frequency_scale`. If None or
|
||||
omitted, default ROI mapping settings are used.
|
||||
"""
|
||||
|
||||
filtering: Optional[FilterConfig] = None
|
||||
|
||||
transforms: Optional[TransformConfig] = None
|
||||
|
||||
classes: ClassesConfig
|
||||
roi: Optional[ROIConfig] = None
|
||||
|
||||
|
||||
def load_target_config(
|
||||
@ -177,34 +198,40 @@ def load_target_config(
|
||||
return load_config(path=path, schema=TargetConfig, field=field)
|
||||
|
||||
|
||||
class Targets:
|
||||
class Targets(TargetProtocol):
|
||||
"""Encapsulates the complete configured target definition pipeline.
|
||||
|
||||
This class holds the functions for filtering, transforming, encoding, and
|
||||
decoding annotations based on a loaded `TargetConfig`. It provides a
|
||||
high-level interface to apply these steps and access relevant metadata
|
||||
like class names and generic class tags.
|
||||
This class implements the `TargetProtocol`, holding the configured
|
||||
functions for filtering, transforming, encoding (tags to class name),
|
||||
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
||||
and back). It provides a high-level interface to apply these steps and
|
||||
access relevant metadata like class names and dimension names.
|
||||
|
||||
Instances are typically created using the `Targets.from_config` or
|
||||
`Targets.from_file` classmethods.
|
||||
Instances are typically created using the `build_targets` factory function
|
||||
or the `load_targets` convenience loader.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_names : list[str]
|
||||
class_names : List[str]
|
||||
An ordered list of the unique names of the specific target classes
|
||||
defined in the configuration.
|
||||
generic_class_tags : List[data.Tag]
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class (e.g., the default 'Bat' class).
|
||||
generic class category (used when no specific class matches).
|
||||
dimension_names : List[str]
|
||||
The names of the size dimensions handled by the ROI mapper
|
||||
(e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: list[str]
|
||||
class_names: List[str]
|
||||
generic_class_tags: List[data.Tag]
|
||||
dimension_names: List[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encode_fn: SoundEventEncoder,
|
||||
decode_fn: SoundEventDecoder,
|
||||
roi_mapper: ROITargetMapper,
|
||||
class_names: list[str],
|
||||
generic_class_tags: List[data.Tag],
|
||||
filter_fn: Optional[SoundEventFilter] = None,
|
||||
@ -212,26 +239,31 @@ class Targets:
|
||||
):
|
||||
"""Initialize the Targets object.
|
||||
|
||||
Note: This constructor is typically called internally by the
|
||||
`build_targets` factory function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encode_fn : SoundEventEncoder
|
||||
The configured function to encode annotations to class names.
|
||||
Configured function to encode annotations to class names.
|
||||
decode_fn : SoundEventDecoder
|
||||
The configured function to decode class names to tags.
|
||||
Configured function to decode class names to tags.
|
||||
roi_mapper : ROITargetMapper
|
||||
Configured object for mapping geometry to/from position/size.
|
||||
class_names : list[str]
|
||||
The ordered list of specific target class names.
|
||||
Ordered list of specific target class names.
|
||||
generic_class_tags : List[data.Tag]
|
||||
The list of tags representing the generic class.
|
||||
List of tags representing the generic class.
|
||||
filter_fn : SoundEventFilter, optional
|
||||
The configured function to filter annotations. Defaults to None (no
|
||||
filtering).
|
||||
Configured function to filter annotations. Defaults to None.
|
||||
transform_fn : SoundEventTransformation, optional
|
||||
The configured function to transform annotation tags. Defaults to
|
||||
None (no transformation).
|
||||
Configured function to transform annotation tags. Defaults to None.
|
||||
"""
|
||||
self.class_names = class_names
|
||||
self.generic_class_tags = generic_class_tags
|
||||
self.dimension_names = roi_mapper.dimension_names
|
||||
|
||||
self._roi_mapper = roi_mapper
|
||||
self._filter_fn = filter_fn
|
||||
self._encode_fn = encode_fn
|
||||
self._decode_fn = decode_fn
|
||||
@ -316,19 +348,108 @@ class Targets:
|
||||
return self._transform_fn(sound_event)
|
||||
return sound_event
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
def get_position(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[float, float]:
|
||||
"""Extract the target reference position from the annotation's roi.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
geom = sound_event.sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Sound event has no geometry, cannot get its position."
|
||||
)
|
||||
|
||||
return self._roi_mapper.get_roi_position(geom)
|
||||
|
||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||
"""Calculate the target size dimensions from the annotation's geometry.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_size` method, which
|
||||
applies configured scaling factors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
NumPy array containing the size dimensions, matching the
|
||||
order in `self.dimension_names` (e.g., `[width, height]`).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
geom = sound_event.sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Sound event has no geometry, cannot get its size."
|
||||
)
|
||||
|
||||
return self._roi_mapper.get_roi_size(geom)
|
||||
|
||||
def recover_roi(
|
||||
self,
|
||||
pos: tuple[float, float],
|
||||
dims: np.ndarray,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||
un-scales the dimensions and reconstructs the geometry (typically a
|
||||
`BoundingBox`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
dims : np.ndarray
|
||||
NumPy array with size dimensions (e.g., from model prediction),
|
||||
matching the order in `self.dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
return self._roi_mapper.recover_roi(pos, dims)
|
||||
|
||||
|
||||
def build_targets(
|
||||
config: TargetConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
) -> "Targets":
|
||||
) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
This factory method takes the unified configuration object and
|
||||
constructs all the necessary functional components (filter, transform,
|
||||
encoder, decoder) and extracts metadata (class names, generic tags) to
|
||||
create a fully configured `Targets` instance.
|
||||
This factory function takes the unified `TargetConfig` and constructs all
|
||||
necessary functional components (filter, transform, encoder,
|
||||
decoder, ROI mapper) by calling their respective builder functions. It also
|
||||
extracts metadata (class names, generic tags, dimension names) to create
|
||||
and return a fully initialized `Targets` instance, ready to process
|
||||
annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -356,18 +477,18 @@ class Targets:
|
||||
If dynamic import of a derivation function fails (when configured).
|
||||
"""
|
||||
filter_fn = (
|
||||
build_filter_from_config(
|
||||
build_sound_event_filter(
|
||||
config.filtering,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
if config.filtering
|
||||
else None
|
||||
)
|
||||
encode_fn = build_encoder_from_config(
|
||||
encode_fn = build_sound_event_encoder(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
decode_fn = build_decoder_from_config(
|
||||
decode_fn = build_sound_event_decoder(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
@ -380,29 +501,30 @@ class Targets:
|
||||
if config.transforms
|
||||
else None
|
||||
)
|
||||
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
|
||||
class_names = get_class_names_from_config(config.classes)
|
||||
generic_class_tags = build_generic_class_tags_from_config(
|
||||
generic_class_tags = build_generic_class_tags(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
|
||||
return cls(
|
||||
return Targets(
|
||||
filter_fn=filter_fn,
|
||||
encode_fn=encode_fn,
|
||||
decode_fn=decode_fn,
|
||||
class_names=class_names,
|
||||
roi_mapper=roi_mapper,
|
||||
generic_class_tags=generic_class_tags,
|
||||
transform_fn=transform_fn,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls,
|
||||
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
) -> "Targets":
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
|
||||
This convenience factory method loads the `TargetConfig` from the
|
||||
@ -441,7 +563,7 @@ class Targets:
|
||||
config_path,
|
||||
field=field,
|
||||
)
|
||||
return cls.from_config(
|
||||
return build_targets(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
|
@ -22,9 +22,9 @@ __all__ = [
|
||||
"load_classes_config",
|
||||
"load_encoder_from_config",
|
||||
"load_decoder_from_config",
|
||||
"build_encoder_from_config",
|
||||
"build_decoder_from_config",
|
||||
"build_generic_class_tags_from_config",
|
||||
"build_sound_event_encoder",
|
||||
"build_sound_event_decoder",
|
||||
"build_generic_class_tags",
|
||||
"get_class_names_from_config",
|
||||
"DEFAULT_SPECIES_LIST",
|
||||
]
|
||||
@ -314,7 +314,7 @@ def _encode_with_multiple_classifiers(
|
||||
return None
|
||||
|
||||
|
||||
def build_encoder_from_config(
|
||||
def build_sound_event_encoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
@ -408,7 +408,7 @@ def _decode_class(
|
||||
return mapping[name]
|
||||
|
||||
|
||||
def build_decoder_from_config(
|
||||
def build_sound_event_decoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
raise_on_unmapped: bool = False,
|
||||
@ -463,7 +463,7 @@ def build_decoder_from_config(
|
||||
)
|
||||
|
||||
|
||||
def build_generic_class_tags_from_config(
|
||||
def build_generic_class_tags(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> List[data.Tag]:
|
||||
@ -565,7 +565,7 @@ def load_encoder_from_config(
|
||||
provided `term_registry` during the build process.
|
||||
"""
|
||||
config = load_classes_config(path, field=field)
|
||||
return build_encoder_from_config(config, term_registry=term_registry)
|
||||
return build_sound_event_encoder(config, term_registry=term_registry)
|
||||
|
||||
|
||||
def load_decoder_from_config(
|
||||
@ -611,7 +611,7 @@ def load_decoder_from_config(
|
||||
provided `term_registry` during the build process.
|
||||
"""
|
||||
config = load_classes_config(path, field=field)
|
||||
return build_decoder_from_config(
|
||||
return build_sound_event_decoder(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
raise_on_unmapped=raise_on_unmapped,
|
||||
|
@ -17,7 +17,7 @@ __all__ = [
|
||||
"FilterConfig",
|
||||
"FilterRule",
|
||||
"SoundEventFilter",
|
||||
"build_filter_from_config",
|
||||
"build_sound_event_filter",
|
||||
"build_filter_from_rule",
|
||||
"load_filter_config",
|
||||
"load_filter_from_config",
|
||||
@ -241,7 +241,7 @@ class FilterConfig(BaseConfig):
|
||||
rules: List[FilterRule] = Field(default_factory=list)
|
||||
|
||||
|
||||
def build_filter_from_config(
|
||||
def build_sound_event_filter(
|
||||
config: FilterConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> SoundEventFilter:
|
||||
@ -312,4 +312,4 @@ def load_filter_from_config(
|
||||
The final merged filter function ready to be used.
|
||||
"""
|
||||
config = load_filter_config(path=path, field=field)
|
||||
return build_filter_from_config(config, term_registry=term_registry)
|
||||
return build_sound_event_filter(config, term_registry=term_registry)
|
||||
|
@ -1,19 +1,20 @@
|
||||
"""Defines the core interface (Protocol) for the target definition pipeline.
|
||||
|
||||
This module specifies the standard structure and methods expected from an object
|
||||
that encapsulates the configured logic for processing sound event annotations
|
||||
within the `batdetect2.targets` system.
|
||||
This module specifies the standard structure, attributes, and methods expected
|
||||
from an object that encapsulates the complete configured logic for processing
|
||||
sound event annotations within the `batdetect2.targets` system.
|
||||
|
||||
The main component defined here is the `TargetEncoder` protocol. This protocol
|
||||
acts as a contract, ensuring that components responsible for applying
|
||||
filtering, transformations, encoding annotations to class names, and decoding
|
||||
class names back to tags can be interacted with in a consistent manner
|
||||
throughout BatDetect2. It also defines essential metadata attributes expected
|
||||
from implementations.
|
||||
The main component defined here is the `TargetProtocol`. This protocol acts as
|
||||
a contract for the entire target definition process, covering semantic aspects
|
||||
(filtering, tag transformation, class encoding/decoding) as well as geometric
|
||||
aspects (mapping regions of interest to target positions and sizes). It ensures
|
||||
that components responsible for these tasks can be interacted with consistently
|
||||
throughout BatDetect2.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
@ -26,18 +27,30 @@ class TargetProtocol(Protocol):
|
||||
|
||||
This protocol outlines the standard attributes and methods for an object
|
||||
that encapsulates the complete, configured process for handling sound event
|
||||
annotations to determine their target class for model training, and for
|
||||
interpreting model predictions back into annotation tags.
|
||||
annotations (both tags and geometry). It defines how to:
|
||||
- Filter relevant annotations.
|
||||
- Transform annotation tags.
|
||||
- Encode an annotation into a specific target class name.
|
||||
- Decode a class name back into representative tags.
|
||||
- Extract a target reference position from an annotation's geometry (ROI).
|
||||
- Calculate target size dimensions from an annotation's geometry.
|
||||
- Recover an approximate geometry (ROI) from a position and size
|
||||
dimensions.
|
||||
|
||||
Implementations of this protocol bundle all configured logic for these
|
||||
steps.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_names : List[str]
|
||||
An ordered list of the unique names of the specific target classes
|
||||
defined by the configuration represented by this object.
|
||||
defined by the configuration.
|
||||
generic_class_tags : List[data.Tag]
|
||||
A list of `soundevent.data.Tag` objects representing the
|
||||
generic class category (e.g., the default 'Bat' class tags used when
|
||||
no specific class matches).
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class category (e.g., used when no specific class matches).
|
||||
dimension_names : List[str]
|
||||
A list containing the names of the size dimensions returned by
|
||||
`get_size` and expected by `recover_roi` (e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: List[str]
|
||||
@ -46,6 +59,9 @@ class TargetProtocol(Protocol):
|
||||
generic_class_tags: List[data.Tag]
|
||||
"""List of tags representing the generic (unclassified) category."""
|
||||
|
||||
dimension_names: List[str]
|
||||
"""Names of the size dimensions (e.g., ['width', 'height'])."""
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the filter to a sound event annotation.
|
||||
|
||||
@ -100,10 +116,10 @@ class TargetProtocol(Protocol):
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The string name of the matched target class if the annotation matches
|
||||
a specific class definition. Returns None if the annotation does not
|
||||
match any specific class rule (indicating it may belong to a generic
|
||||
category or should be handled differently downstream).
|
||||
The string name of the matched target class if the annotation
|
||||
matches a specific class definition. Returns None if the annotation
|
||||
does not match any specific class rule (indicating it may belong
|
||||
to a generic category or should be handled differently downstream).
|
||||
"""
|
||||
...
|
||||
|
||||
@ -130,3 +146,88 @@ class TargetProtocol(Protocol):
|
||||
found in the configured mapping and error raising is enabled.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_position(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[float, float]:
|
||||
"""Extract the target reference position from the annotation's geometry.
|
||||
|
||||
Calculates the `(time, frequency)` coordinate representing the primary
|
||||
location of the sound event.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI) to process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The calculated reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry or if the position cannot be
|
||||
calculated for the geometry type or configured reference point.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||
"""Calculate the target size dimensions from the annotation's geometry.
|
||||
|
||||
Computes the relevant physical size (e.g., duration/width,
|
||||
bandwidth/height from a bounding box) to produce
|
||||
the numerical target values expected by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI) to process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
A NumPy array containing the size dimensions, matching the
|
||||
order specified by the `dimension_names` attribute (e.g.,
|
||||
`[width, height]`).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry or if the size cannot be computed.
|
||||
TypeError
|
||||
If geometry type is unsupported.
|
||||
"""
|
||||
...
|
||||
|
||||
def recover_roi(
|
||||
self, pos: tuple[float, float], dims: np.ndarray
|
||||
) -> data.Geometry:
|
||||
"""Recover the ROI geometry from a position and dimensions.
|
||||
|
||||
Performs the inverse mapping of `get_position` and `get_size`. It takes
|
||||
a reference position `(time, frequency)` and an array of size
|
||||
dimensions and reconstructs an approximate geometric representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
dims : np.ndarray
|
||||
The NumPy array containing the dimensions (e.g., predicted
|
||||
by the model), corresponding to the order in `dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Geometry
|
||||
The reconstructed geometry.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of provided `dims` does not match `dimension_names`,
|
||||
if dimensions are invalid (e.g., negative after unscaling), or
|
||||
if reconstruction fails based on the configured position type.
|
||||
"""
|
||||
...
|
||||
|
85
docs/source/targets/rois.md
Normal file
85
docs/source/targets/rois.md
Normal file
@ -0,0 +1,85 @@
|
||||
## Defining Target Geometry: Mapping Sound Event Regions
|
||||
|
||||
### Introduction
|
||||
|
||||
In the previous steps of defining targets, we focused on determining _which_ sound events are relevant (`filtering`), _what_ descriptive tags they should have (`transform`), and _which category_ they belong to (`classes`).
|
||||
However, for the model to learn effectively, it also needs to know **where** in the spectrogram each sound event is located and approximately **how large** it is.
|
||||
|
||||
Your annotations typically define the location and extent of a sound event using a **Region of Interest (ROI)**, most commonly a **bounding box** drawn around the call on the spectrogram.
|
||||
This ROI contains detailed spatial information (start/end time, low/high frequency).
|
||||
|
||||
This section explains how BatDetect2 converts the geometric ROI from your annotations into the specific positional and size information used as targets during model training.
|
||||
|
||||
### From ROI to Model Targets: Position & Size
|
||||
|
||||
BatDetect2 does not directly predict a full bounding box.
|
||||
Instead, it is trained to predict:
|
||||
|
||||
1. **A Reference Point:** A single point `(time, frequency)` that represents the primary location of the detected sound event within the spectrogram.
|
||||
2. **Size Dimensions:** Numerical values representing the event's size relative to that reference point, typically its `width` (duration in time) and `height` (bandwidth in frequency).
|
||||
|
||||
This step defines _how_ BatDetect2 calculates this specific reference point and these numerical size values from the original annotation's bounding box.
|
||||
It also handles the reverse process – converting predicted positions and sizes back into bounding boxes for visualization or analysis.
|
||||
|
||||
### Configuring the ROI Mapping
|
||||
|
||||
You can control how this conversion happens through settings in your configuration file (e.g., your main `.yaml` file).
|
||||
These settings are usually placed within the main `targets:` configuration block, under a specific `roi:` key.
|
||||
|
||||
Here are the key settings:
|
||||
|
||||
- **`position`**:
|
||||
|
||||
- **What it does:** Determines which specific point on the annotation's bounding box is used as the single **Reference Point** for training (e.g., `"center"`, `"bottom-left"`).
|
||||
- **Why configure it?** This affects where the peak signal appears in the target heatmaps used for training.
|
||||
Different choices might slightly influence model learning.
|
||||
The default (`"bottom-left"`) is often a good starting point.
|
||||
- **Example Value:** `position: "center"`
|
||||
|
||||
- **`time_scale`**:
|
||||
|
||||
- **What it does:** This is a numerical scaling factor that converts the _actual duration_ (width, measured in seconds) of the bounding box into the numerical 'width' value the model learns to predict (and which is stored in the Size Heatmap).
|
||||
- **Why configure it?** The model predicts raw numbers for size; this scale gives those numbers meaning.
|
||||
For example, setting `time_scale: 1000.0` means the model will be trained to predict the duration in **milliseconds** instead of seconds.
|
||||
- **Important Considerations:**
|
||||
- You can often set this value based on the units you prefer the model to work with internally.
|
||||
However, having target numerical values roughly centered around 1 (e.g., typically between 0.1 and 10) can sometimes improve numerical stability during model training.
|
||||
- The default value in BatDetect2 (e.g., `1000.0`) has been chosen to scale the duration relative to the spectrogram width under default STFT settings.
|
||||
Be aware that if you significantly change STFT parameters (window size or overlap), the relationship between the default scale and the spectrogram dimensions might change.
|
||||
- Crucially, whatever scale you use during training **must** be used when decoding the model's predictions back into real-world time units (seconds).
|
||||
BatDetect2 generally handles this consistency for you automatically when using the full pipeline.
|
||||
- **Example Value:** `time_scale: 1000.0`
|
||||
|
||||
- **`frequency_scale`**:
|
||||
- **What it does:** Similar to `time_scale`, this numerical scaling factor converts the _actual frequency bandwidth_ (height, typically measured in Hz or kHz) of the bounding box into the numerical 'height' value the model learns to predict.
|
||||
- **Why configure it?** It gives physical meaning to the model's raw numerical prediction for bandwidth and allows you to choose the internal units or scale.
|
||||
- **Important Considerations:**
|
||||
- Same as for `time_scale`.
|
||||
- **Example Value:** `frequency_scale: 0.00116`
|
||||
|
||||
**Example YAML Configuration:**
|
||||
|
||||
```yaml
|
||||
# Inside your main configuration file (e.g., training_config.yaml)
|
||||
|
||||
targets: # Top-level key for target definition
|
||||
# ... filtering settings ...
|
||||
# ... transforms settings ...
|
||||
# ... classes settings ...
|
||||
|
||||
# --- ROI Mapping Settings ---
|
||||
roi:
|
||||
position: "bottom-left" # Reference point (e.g., "center", "bottom-left")
|
||||
time_scale: 1000.0 # e.g., Model predicts width in ms
|
||||
frequency_scale: 0.00116 # e.g., Model predicts height relative to ~860Hz (or other model-specific scaling)
|
||||
```
|
||||
|
||||
### Decoding Size Predictions
|
||||
|
||||
These scaling factors (`time_scale`, `frequency_scale`) are also essential for interpreting the model's output correctly.
|
||||
When the model predicts numerical values for width and height, BatDetect2 uses these same scales (in reverse) to convert those numbers back into physically meaningful durations (seconds) and bandwidths (Hz/kHz) when reconstructing bounding boxes from predictions.
|
||||
|
||||
### Outcome
|
||||
|
||||
By configuring the `roi` settings, you ensure that BatDetect2 consistently translates the geometric information from your annotations into the specific reference points and scaled size values required for training the model.
|
||||
Using consistent scales that are appropriate for your data and potentially beneficial for training stability allows the model to effectively learn not just _what_ sound is present, but also _where_ it is located and _how large_ it is, and enables meaningful interpretation of the model's spatial and size predictions.
|
@ -13,9 +13,9 @@ from batdetect2.targets.classes import (
|
||||
_get_default_class_name,
|
||||
_get_default_classes,
|
||||
_is_target_class,
|
||||
build_decoder_from_config,
|
||||
build_encoder_from_config,
|
||||
build_generic_class_tags_from_config,
|
||||
build_sound_event_decoder,
|
||||
build_sound_event_encoder,
|
||||
build_generic_class_tags,
|
||||
get_class_names_from_config,
|
||||
load_classes_config,
|
||||
load_decoder_from_config,
|
||||
@ -231,7 +231,7 @@ def test_build_encoder_from_config(
|
||||
)
|
||||
]
|
||||
)
|
||||
encoder = build_encoder_from_config(
|
||||
encoder = build_sound_event_encoder(
|
||||
config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
@ -239,7 +239,7 @@ def test_build_encoder_from_config(
|
||||
assert result == "pippip"
|
||||
|
||||
config = ClassesConfig(classes=[])
|
||||
encoder = build_encoder_from_config(
|
||||
encoder = build_sound_event_encoder(
|
||||
config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
@ -315,7 +315,7 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||
],
|
||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||
)
|
||||
decoder = build_decoder_from_config(
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
tags = decoder("pippip")
|
||||
@ -335,7 +335,7 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||
],
|
||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||
)
|
||||
decoder = build_decoder_from_config(
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
tags = decoder("pippip")
|
||||
@ -344,14 +344,14 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||
assert tags[0].value == "Pipistrellus pipistrellus"
|
||||
|
||||
# Test raise_on_unmapped=True
|
||||
decoder = build_decoder_from_config(
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry, raise_on_unmapped=True
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
decoder("unknown_class")
|
||||
|
||||
# Test raise_on_unmapped=False
|
||||
decoder = build_decoder_from_config(
|
||||
decoder = build_sound_event_decoder(
|
||||
config, term_registry=sample_term_registry, raise_on_unmapped=False
|
||||
)
|
||||
tags = decoder("unknown_class")
|
||||
@ -402,7 +402,7 @@ def test_build_generic_class_tags_from_config(
|
||||
TagInfo(key="call_type", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
generic_tags = build_generic_class_tags_from_config(
|
||||
generic_tags = build_generic_class_tags(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
assert len(generic_tags) == 2
|
||||
|
@ -7,7 +7,7 @@ from soundevent import data
|
||||
from batdetect2.targets.filtering import (
|
||||
FilterConfig,
|
||||
FilterRule,
|
||||
build_filter_from_config,
|
||||
build_sound_event_filter,
|
||||
build_filter_from_rule,
|
||||
contains_tags,
|
||||
does_not_have_tags,
|
||||
@ -121,7 +121,7 @@ def test_build_filter_from_config(create_annotation):
|
||||
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
|
||||
]
|
||||
)
|
||||
filter_from_config = build_filter_from_config(config)
|
||||
filter_from_config = build_sound_event_filter(config)
|
||||
|
||||
annotation_pass = create_annotation(["tag1", "tag2"])
|
||||
assert filter_from_config(annotation_pass)
|
||||
|
Loading…
Reference in New Issue
Block a user