Update TargetProtocol and related to include rois

This commit is contained in:
mbsantiago 2025-04-19 20:26:18 +01:00
parent 9410112e41
commit 6236e78414
7 changed files with 498 additions and 190 deletions

View File

@ -1,25 +1,29 @@
"""Main entry point for the BatDetect2 Target Definition subsystem. """Main entry point for the BatDetect2 Target Definition subsystem.
This package (`batdetect2.targets`) provides the tools and configurations This package (`batdetect2.targets`) provides the tools and configurations
necessary to define precisely what the BatDetect2 model should learn to detect necessary to define precisely what the BatDetect2 model should learn to detect,
and classify from audio data. It involves several conceptual steps, managed classify, and localize from audio data. It involves several conceptual steps,
through configuration files and culminating in executable functions: managed through configuration files and culminating in an executable pipeline:
1. **Terms (`.terms`)**: Defining a controlled vocabulary for annotation tags. 1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations. 2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
3. **Transformation (`.transform`)**: Modifying tags (e.g., standardization, 3. **Transformation (`.transform`)**: Modifying tags (standardization,
derivation). derivation).
4. **Class Definition (`.classes`)**: Mapping tags to specific target class 4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
names (encoding) and defining how predicted class names map back to tags target position and size representations, and back.
(decoding). 5. **Class Definition (`.classes`)**: Mapping tags to target class names
(encoding) and mapping predicted names back to tags (decoding).
This module exposes the key components for users to configure and utilize this This module exposes the key components for users to configure and utilize this
target definition pipeline, primarily through the `TargetConfig` data structure target definition pipeline, primarily through the `TargetConfig` data structure
and the `Targets` class, which encapsulates the configured processing steps. and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
configured processing steps. The main way to create a functional `Targets`
object is via the `build_targets` or `load_targets` functions.
""" """
from typing import List, Optional from typing import List, Optional
import numpy as np
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
@ -28,9 +32,9 @@ from batdetect2.targets.classes import (
SoundEventDecoder, SoundEventDecoder,
SoundEventEncoder, SoundEventEncoder,
TargetClass, TargetClass,
build_decoder_from_config, build_generic_class_tags,
build_encoder_from_config, build_sound_event_decoder,
build_generic_class_tags_from_config, build_sound_event_encoder,
get_class_names_from_config, get_class_names_from_config,
load_classes_config, load_classes_config,
load_decoder_from_config, load_decoder_from_config,
@ -40,10 +44,15 @@ from batdetect2.targets.filtering import (
FilterConfig, FilterConfig,
FilterRule, FilterRule,
SoundEventFilter, SoundEventFilter,
build_filter_from_config, build_sound_event_filter,
load_filter_config, load_filter_config,
load_filter_from_config, load_filter_from_config,
) )
from batdetect2.targets.rois import (
ROIConfig,
ROITargetMapper,
build_roi_mapper,
)
from batdetect2.targets.terms import ( from batdetect2.targets.terms import (
TagInfo, TagInfo,
TermInfo, TermInfo,
@ -69,6 +78,7 @@ from batdetect2.targets.transform import (
load_transformation_from_config, load_transformation_from_config,
register_derivation, register_derivation,
) )
from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"ClassesConfig", "ClassesConfig",
@ -76,6 +86,8 @@ __all__ = [
"FilterConfig", "FilterConfig",
"FilterRule", "FilterRule",
"MapValueRule", "MapValueRule",
"ROIConfig",
"ROITargetMapper",
"ReplaceRule", "ReplaceRule",
"SoundEventDecoder", "SoundEventDecoder",
"SoundEventEncoder", "SoundEventEncoder",
@ -84,13 +96,15 @@ __all__ = [
"TagInfo", "TagInfo",
"TargetClass", "TargetClass",
"TargetConfig", "TargetConfig",
"TargetProtocol",
"Targets", "Targets",
"TermInfo", "TermInfo",
"TransformConfig", "TransformConfig",
"build_decoder_from_config", "build_sound_event_decoder",
"build_encoder_from_config", "build_sound_event_encoder",
"build_filter_from_config", "build_sound_event_filter",
"build_generic_class_tags_from_config", "build_generic_class_tags",
"build_roi_mapper",
"build_transformation_from_config", "build_transformation_from_config",
"call_type", "call_type",
"get_class_names_from_config", "get_class_names_from_config",
@ -114,29 +128,36 @@ __all__ = [
class TargetConfig(BaseConfig): class TargetConfig(BaseConfig):
"""Unified configuration for the entire target definition pipeline. """Unified configuration for the entire target definition pipeline.
This model aggregates the configurations for the optional filtering and This model aggregates the configurations for semantic processing (filtering,
transformation steps, and the mandatory class definition step. It serves as transformation, class definition) and geometric processing (ROI mapping).
the primary input for building a complete `Targets` processing object. It serves as the primary input for building a complete `Targets` object
via `build_targets` or `load_targets`.
Attributes Attributes
---------- ----------
filtering : FilterConfig, optional filtering : FilterConfig, optional
Configuration for filtering sound event annotations. If None or Configuration for filtering sound event annotations based on tags.
omitted, no filtering is applied. If None or omitted, no filtering is applied.
transforms : TransformConfig, optional transforms : TransformConfig, optional
Configuration for transforming annotation tags. If None or omitted, no Configuration for transforming annotation tags
transformations are applied. (mapping, derivation, etc.). If None or omitted, no tag transformations
are applied.
classes : ClassesConfig classes : ClassesConfig
Configuration defining the specific target classes, their matching Configuration defining the specific target classes, their tag matching
rules, decoding rules (`output_tags`), and the generic class rules for encoding, their representative tags for decoding
definition. This section is mandatory. (`output_tags`), and the definition of the generic class tags.
This section is mandatory.
roi : ROIConfig, optional
Configuration defining how geometric ROIs (e.g., bounding boxes) are
mapped to target representations (reference point, scaled size).
Controls `position`, `time_scale`, `frequency_scale`. If None or
omitted, default ROI mapping settings are used.
""" """
filtering: Optional[FilterConfig] = None filtering: Optional[FilterConfig] = None
transforms: Optional[TransformConfig] = None transforms: Optional[TransformConfig] = None
classes: ClassesConfig classes: ClassesConfig
roi: Optional[ROIConfig] = None
def load_target_config( def load_target_config(
@ -177,34 +198,40 @@ def load_target_config(
return load_config(path=path, schema=TargetConfig, field=field) return load_config(path=path, schema=TargetConfig, field=field)
class Targets: class Targets(TargetProtocol):
"""Encapsulates the complete configured target definition pipeline. """Encapsulates the complete configured target definition pipeline.
This class holds the functions for filtering, transforming, encoding, and This class implements the `TargetProtocol`, holding the configured
decoding annotations based on a loaded `TargetConfig`. It provides a functions for filtering, transforming, encoding (tags to class name),
high-level interface to apply these steps and access relevant metadata decoding (class name to tags), and mapping ROIs (geometry to position/size
like class names and generic class tags. and back). It provides a high-level interface to apply these steps and
access relevant metadata like class names and dimension names.
Instances are typically created using the `Targets.from_config` or Instances are typically created using the `build_targets` factory function
`Targets.from_file` classmethods. or the `load_targets` convenience loader.
Attributes Attributes
---------- ----------
class_names : list[str] class_names : List[str]
An ordered list of the unique names of the specific target classes An ordered list of the unique names of the specific target classes
defined in the configuration. defined in the configuration.
generic_class_tags : List[data.Tag] generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the configured A list of `soundevent.data.Tag` objects representing the configured
generic class (e.g., the default 'Bat' class). generic class category (used when no specific class matches).
dimension_names : List[str]
The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']).
""" """
class_names: list[str] class_names: List[str]
generic_class_tags: List[data.Tag] generic_class_tags: List[data.Tag]
dimension_names: List[str]
def __init__( def __init__(
self, self,
encode_fn: SoundEventEncoder, encode_fn: SoundEventEncoder,
decode_fn: SoundEventDecoder, decode_fn: SoundEventDecoder,
roi_mapper: ROITargetMapper,
class_names: list[str], class_names: list[str],
generic_class_tags: List[data.Tag], generic_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventFilter] = None, filter_fn: Optional[SoundEventFilter] = None,
@ -212,26 +239,31 @@ class Targets:
): ):
"""Initialize the Targets object. """Initialize the Targets object.
Note: This constructor is typically called internally by the
`build_targets` factory function.
Parameters Parameters
---------- ----------
encode_fn : SoundEventEncoder encode_fn : SoundEventEncoder
The configured function to encode annotations to class names. Configured function to encode annotations to class names.
decode_fn : SoundEventDecoder decode_fn : SoundEventDecoder
The configured function to decode class names to tags. Configured function to decode class names to tags.
roi_mapper : ROITargetMapper
Configured object for mapping geometry to/from position/size.
class_names : list[str] class_names : list[str]
The ordered list of specific target class names. Ordered list of specific target class names.
generic_class_tags : List[data.Tag] generic_class_tags : List[data.Tag]
The list of tags representing the generic class. List of tags representing the generic class.
filter_fn : SoundEventFilter, optional filter_fn : SoundEventFilter, optional
The configured function to filter annotations. Defaults to None (no Configured function to filter annotations. Defaults to None.
filtering).
transform_fn : SoundEventTransformation, optional transform_fn : SoundEventTransformation, optional
The configured function to transform annotation tags. Defaults to Configured function to transform annotation tags. Defaults to None.
None (no transformation).
""" """
self.class_names = class_names self.class_names = class_names
self.generic_class_tags = generic_class_tags self.generic_class_tags = generic_class_tags
self.dimension_names = roi_mapper.dimension_names
self._roi_mapper = roi_mapper
self._filter_fn = filter_fn self._filter_fn = filter_fn
self._encode_fn = encode_fn self._encode_fn = encode_fn
self._decode_fn = decode_fn self._decode_fn = decode_fn
@ -316,133 +348,223 @@ class Targets:
return self._transform_fn(sound_event) return self._transform_fn(sound_event)
return sound_event return sound_event
@classmethod def get_position(
def from_config( self, sound_event: data.SoundEventAnnotation
cls, ) -> tuple[float, float]:
config: TargetConfig, """Extract the target reference position from the annotation's roi.
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> "Targets":
"""Build a Targets object from a loaded TargetConfig.
This factory method takes the unified configuration object and Delegates to the internal ROI mapper's `get_roi_position` method.
constructs all the necessary functional components (filter, transform,
encoder, decoder) and extracts metadata (class names, generic tags) to
create a fully configured `Targets` instance.
Parameters Parameters
---------- ----------
config : TargetConfig sound_event : data.SoundEventAnnotation
The loaded and validated unified target configuration object. The annotation containing the geometry (ROI).
term_registry : TermRegistry, optional
The TermRegistry instance to use for resolving term keys. Defaults
to the global `batdetect2.targets.terms.term_registry`.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use for resolving derivation
function names. Defaults to the global
`batdetect2.targets.transform.derivation_registry`.
Returns Returns
------- -------
Targets Tuple[float, float]
An initialized `Targets` object ready for use. The reference position `(time, frequency)`.
Raises Raises
------ ------
KeyError ValueError
If term keys or derivation function keys specified in the `config` If the annotation lacks geometry.
are not found in their respective registries.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails (when configured).
""" """
filter_fn = ( geom = sound_event.sound_event.geometry
build_filter_from_config(
config.filtering, if geom is None:
term_registry=term_registry, raise ValueError(
"Sound event has no geometry, cannot get its position."
) )
if config.filtering
else None
)
encode_fn = build_encoder_from_config(
config.classes,
term_registry=term_registry,
)
decode_fn = build_decoder_from_config(
config.classes,
term_registry=term_registry,
)
transform_fn = (
build_transformation_from_config(
config.transforms,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
if config.transforms
else None
)
class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags_from_config(
config.classes,
term_registry=term_registry,
)
return cls( return self._roi_mapper.get_roi_position(geom)
filter_fn=filter_fn,
encode_fn=encode_fn,
decode_fn=decode_fn,
class_names=class_names,
generic_class_tags=generic_class_tags,
transform_fn=transform_fn,
)
@classmethod def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
def from_file( """Calculate the target size dimensions from the annotation's geometry.
cls,
config_path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> "Targets":
"""Load a Targets object directly from a configuration file.
This convenience factory method loads the `TargetConfig` from the Delegates to the internal ROI mapper's `get_roi_size` method, which
specified file path and then calls `Targets.from_config` to build applies configured scaling factors.
the fully initialized `Targets` object.
Parameters Parameters
---------- ----------
config_path : data.PathLike sound_event : data.SoundEventAnnotation
Path to the configuration file (e.g., YAML). The annotation containing the geometry (ROI).
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
term_registry : TermRegistry, optional
The TermRegistry instance to use. Defaults to the global default.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use. Defaults to the global
default.
Returns Returns
------- -------
Targets np.ndarray
An initialized `Targets` object ready for use. NumPy array containing the size dimensions, matching the
order in `self.dimension_names` (e.g., `[width, height]`).
Raises Raises
------ ------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError, ValueError
TypeError If the annotation lacks geometry.
Errors raised during file loading, validation, or extraction via
`load_target_config`.
KeyError, ImportError, AttributeError, TypeError
Errors raised during the build process by `Targets.from_config`
(e.g., missing keys in registries, failed imports).
""" """
config = load_target_config( geom = sound_event.sound_event.geometry
config_path,
field=field, if geom is None:
raise ValueError(
"Sound event has no geometry, cannot get its size."
)
return self._roi_mapper.get_roi_size(geom)
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which
un-scales the dimensions and reconstructs the geometry (typically a
`BoundingBox`).
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`.
Returns
-------
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
return self._roi_mapper.recover_roi(pos, dims)
def build_targets(
config: TargetConfig,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
This factory function takes the unified `TargetConfig` and constructs all
necessary functional components (filter, transform, encoder,
decoder, ROI mapper) by calling their respective builder functions. It also
extracts metadata (class names, generic tags, dimension names) to create
and return a fully initialized `Targets` instance, ready to process
annotations.
Parameters
----------
config : TargetConfig
The loaded and validated unified target configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance to use for resolving term keys. Defaults
to the global `batdetect2.targets.terms.term_registry`.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use for resolving derivation
function names. Defaults to the global
`batdetect2.targets.transform.derivation_registry`.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
KeyError
If term keys or derivation function keys specified in the `config`
are not found in their respective registries.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails (when configured).
"""
filter_fn = (
build_sound_event_filter(
config.filtering,
term_registry=term_registry,
) )
return cls.from_config( if config.filtering
config, else None
)
encode_fn = build_sound_event_encoder(
config.classes,
term_registry=term_registry,
)
decode_fn = build_sound_event_decoder(
config.classes,
term_registry=term_registry,
)
transform_fn = (
build_transformation_from_config(
config.transforms,
term_registry=term_registry, term_registry=term_registry,
derivation_registry=derivation_registry, derivation_registry=derivation_registry,
) )
if config.transforms
else None
)
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags(
config.classes,
term_registry=term_registry,
)
return Targets(
filter_fn=filter_fn,
encode_fn=encode_fn,
decode_fn=decode_fn,
class_names=class_names,
roi_mapper=roi_mapper,
generic_class_tags=generic_class_tags,
transform_fn=transform_fn,
)
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> Targets:
"""Load a Targets object directly from a configuration file.
This convenience factory method loads the `TargetConfig` from the
specified file path and then calls `Targets.from_config` to build
the fully initialized `Targets` object.
Parameters
----------
config_path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
term_registry : TermRegistry, optional
The TermRegistry instance to use. Defaults to the global default.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use. Defaults to the global
default.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
Errors raised during file loading, validation, or extraction via
`load_target_config`.
KeyError, ImportError, AttributeError, TypeError
Errors raised during the build process by `Targets.from_config`
(e.g., missing keys in registries, failed imports).
"""
config = load_target_config(
config_path,
field=field,
)
return build_targets(
config,
term_registry=term_registry,
derivation_registry=derivation_registry,
)

View File

@ -22,9 +22,9 @@ __all__ = [
"load_classes_config", "load_classes_config",
"load_encoder_from_config", "load_encoder_from_config",
"load_decoder_from_config", "load_decoder_from_config",
"build_encoder_from_config", "build_sound_event_encoder",
"build_decoder_from_config", "build_sound_event_decoder",
"build_generic_class_tags_from_config", "build_generic_class_tags",
"get_class_names_from_config", "get_class_names_from_config",
"DEFAULT_SPECIES_LIST", "DEFAULT_SPECIES_LIST",
] ]
@ -314,7 +314,7 @@ def _encode_with_multiple_classifiers(
return None return None
def build_encoder_from_config( def build_sound_event_encoder(
config: ClassesConfig, config: ClassesConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = term_registry,
) -> SoundEventEncoder: ) -> SoundEventEncoder:
@ -408,7 +408,7 @@ def _decode_class(
return mapping[name] return mapping[name]
def build_decoder_from_config( def build_sound_event_decoder(
config: ClassesConfig, config: ClassesConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = term_registry,
raise_on_unmapped: bool = False, raise_on_unmapped: bool = False,
@ -463,7 +463,7 @@ def build_decoder_from_config(
) )
def build_generic_class_tags_from_config( def build_generic_class_tags(
config: ClassesConfig, config: ClassesConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = term_registry,
) -> List[data.Tag]: ) -> List[data.Tag]:
@ -565,7 +565,7 @@ def load_encoder_from_config(
provided `term_registry` during the build process. provided `term_registry` during the build process.
""" """
config = load_classes_config(path, field=field) config = load_classes_config(path, field=field)
return build_encoder_from_config(config, term_registry=term_registry) return build_sound_event_encoder(config, term_registry=term_registry)
def load_decoder_from_config( def load_decoder_from_config(
@ -611,7 +611,7 @@ def load_decoder_from_config(
provided `term_registry` during the build process. provided `term_registry` during the build process.
""" """
config = load_classes_config(path, field=field) config = load_classes_config(path, field=field)
return build_decoder_from_config( return build_sound_event_decoder(
config, config,
term_registry=term_registry, term_registry=term_registry,
raise_on_unmapped=raise_on_unmapped, raise_on_unmapped=raise_on_unmapped,

View File

@ -17,7 +17,7 @@ __all__ = [
"FilterConfig", "FilterConfig",
"FilterRule", "FilterRule",
"SoundEventFilter", "SoundEventFilter",
"build_filter_from_config", "build_sound_event_filter",
"build_filter_from_rule", "build_filter_from_rule",
"load_filter_config", "load_filter_config",
"load_filter_from_config", "load_filter_from_config",
@ -241,7 +241,7 @@ class FilterConfig(BaseConfig):
rules: List[FilterRule] = Field(default_factory=list) rules: List[FilterRule] = Field(default_factory=list)
def build_filter_from_config( def build_sound_event_filter(
config: FilterConfig, config: FilterConfig,
term_registry: TermRegistry = term_registry, term_registry: TermRegistry = term_registry,
) -> SoundEventFilter: ) -> SoundEventFilter:
@ -312,4 +312,4 @@ def load_filter_from_config(
The final merged filter function ready to be used. The final merged filter function ready to be used.
""" """
config = load_filter_config(path=path, field=field) config = load_filter_config(path=path, field=field)
return build_filter_from_config(config, term_registry=term_registry) return build_sound_event_filter(config, term_registry=term_registry)

View File

@ -1,19 +1,20 @@
"""Defines the core interface (Protocol) for the target definition pipeline. """Defines the core interface (Protocol) for the target definition pipeline.
This module specifies the standard structure and methods expected from an object This module specifies the standard structure, attributes, and methods expected
that encapsulates the configured logic for processing sound event annotations from an object that encapsulates the complete configured logic for processing
within the `batdetect2.targets` system. sound event annotations within the `batdetect2.targets` system.
The main component defined here is the `TargetEncoder` protocol. This protocol The main component defined here is the `TargetProtocol`. This protocol acts as
acts as a contract, ensuring that components responsible for applying a contract for the entire target definition process, covering semantic aspects
filtering, transformations, encoding annotations to class names, and decoding (filtering, tag transformation, class encoding/decoding) as well as geometric
class names back to tags can be interacted with in a consistent manner aspects (mapping regions of interest to target positions and sizes). It ensures
throughout BatDetect2. It also defines essential metadata attributes expected that components responsible for these tasks can be interacted with consistently
from implementations. throughout BatDetect2.
""" """
from typing import List, Optional, Protocol from typing import List, Optional, Protocol
import numpy as np
from soundevent import data from soundevent import data
__all__ = [ __all__ = [
@ -26,18 +27,30 @@ class TargetProtocol(Protocol):
This protocol outlines the standard attributes and methods for an object This protocol outlines the standard attributes and methods for an object
that encapsulates the complete, configured process for handling sound event that encapsulates the complete, configured process for handling sound event
annotations to determine their target class for model training, and for annotations (both tags and geometry). It defines how to:
interpreting model predictions back into annotation tags. - Filter relevant annotations.
- Transform annotation tags.
- Encode an annotation into a specific target class name.
- Decode a class name back into representative tags.
- Extract a target reference position from an annotation's geometry (ROI).
- Calculate target size dimensions from an annotation's geometry.
- Recover an approximate geometry (ROI) from a position and size
dimensions.
Implementations of this protocol bundle all configured logic for these
steps.
Attributes Attributes
---------- ----------
class_names : List[str] class_names : List[str]
An ordered list of the unique names of the specific target classes An ordered list of the unique names of the specific target classes
defined by the configuration represented by this object. defined by the configuration.
generic_class_tags : List[data.Tag] generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the A list of `soundevent.data.Tag` objects representing the configured
generic class category (e.g., the default 'Bat' class tags used when generic class category (e.g., used when no specific class matches).
no specific class matches). dimension_names : List[str]
A list containing the names of the size dimensions returned by
`get_size` and expected by `recover_roi` (e.g., ['width', 'height']).
""" """
class_names: List[str] class_names: List[str]
@ -46,6 +59,9 @@ class TargetProtocol(Protocol):
generic_class_tags: List[data.Tag] generic_class_tags: List[data.Tag]
"""List of tags representing the generic (unclassified) category.""" """List of tags representing the generic (unclassified) category."""
dimension_names: List[str]
"""Names of the size dimensions (e.g., ['width', 'height'])."""
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the filter to a sound event annotation. """Apply the filter to a sound event annotation.
@ -100,10 +116,10 @@ class TargetProtocol(Protocol):
Returns Returns
------- -------
str or None str or None
The string name of the matched target class if the annotation matches The string name of the matched target class if the annotation
a specific class definition. Returns None if the annotation does not matches a specific class definition. Returns None if the annotation
match any specific class rule (indicating it may belong to a generic does not match any specific class rule (indicating it may belong
category or should be handled differently downstream). to a generic category or should be handled differently downstream).
""" """
... ...
@ -130,3 +146,88 @@ class TargetProtocol(Protocol):
found in the configured mapping and error raising is enabled. found in the configured mapping and error raising is enabled.
""" """
... ...
def get_position(
self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]:
"""Extract the target reference position from the annotation's geometry.
Calculates the `(time, frequency)` coordinate representing the primary
location of the sound event.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
Tuple[float, float]
The calculated reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry or if the position cannot be
calculated for the geometry type or configured reference point.
"""
...
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
"""Calculate the target size dimensions from the annotation's geometry.
Computes the relevant physical size (e.g., duration/width,
bandwidth/height from a bounding box) to produce
the numerical target values expected by the model.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
np.ndarray
A NumPy array containing the size dimensions, matching the
order specified by the `dimension_names` attribute (e.g.,
`[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry or if the size cannot be computed.
TypeError
If geometry type is unsupported.
"""
...
def recover_roi(
self, pos: tuple[float, float], dims: np.ndarray
) -> data.Geometry:
"""Recover the ROI geometry from a position and dimensions.
Performs the inverse mapping of `get_position` and `get_size`. It takes
a reference position `(time, frequency)` and an array of size
dimensions and reconstructs an approximate geometric representation.
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
The NumPy array containing the dimensions (e.g., predicted
by the model), corresponding to the order in `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry.
Raises
------
ValueError
If the number of provided `dims` does not match `dimension_names`,
if dimensions are invalid (e.g., negative after unscaling), or
if reconstruction fails based on the configured position type.
"""
...

View File

@ -0,0 +1,85 @@
## Defining Target Geometry: Mapping Sound Event Regions
### Introduction
In the previous steps of defining targets, we focused on determining _which_ sound events are relevant (`filtering`), _what_ descriptive tags they should have (`transform`), and _which category_ they belong to (`classes`).
However, for the model to learn effectively, it also needs to know **where** in the spectrogram each sound event is located and approximately **how large** it is.
Your annotations typically define the location and extent of a sound event using a **Region of Interest (ROI)**, most commonly a **bounding box** drawn around the call on the spectrogram.
This ROI contains detailed spatial information (start/end time, low/high frequency).
This section explains how BatDetect2 converts the geometric ROI from your annotations into the specific positional and size information used as targets during model training.
### From ROI to Model Targets: Position & Size
BatDetect2 does not directly predict a full bounding box.
Instead, it is trained to predict:
1. **A Reference Point:** A single point `(time, frequency)` that represents the primary location of the detected sound event within the spectrogram.
2. **Size Dimensions:** Numerical values representing the event's size relative to that reference point, typically its `width` (duration in time) and `height` (bandwidth in frequency).
This step defines _how_ BatDetect2 calculates this specific reference point and these numerical size values from the original annotation's bounding box.
It also handles the reverse process converting predicted positions and sizes back into bounding boxes for visualization or analysis.
### Configuring the ROI Mapping
You can control how this conversion happens through settings in your configuration file (e.g., your main `.yaml` file).
These settings are usually placed within the main `targets:` configuration block, under a specific `roi:` key.
Here are the key settings:
- **`position`**:
- **What it does:** Determines which specific point on the annotation's bounding box is used as the single **Reference Point** for training (e.g., `"center"`, `"bottom-left"`).
- **Why configure it?** This affects where the peak signal appears in the target heatmaps used for training.
Different choices might slightly influence model learning.
The default (`"bottom-left"`) is often a good starting point.
- **Example Value:** `position: "center"`
- **`time_scale`**:
- **What it does:** This is a numerical scaling factor that converts the _actual duration_ (width, measured in seconds) of the bounding box into the numerical 'width' value the model learns to predict (and which is stored in the Size Heatmap).
- **Why configure it?** The model predicts raw numbers for size; this scale gives those numbers meaning.
For example, setting `time_scale: 1000.0` means the model will be trained to predict the duration in **milliseconds** instead of seconds.
- **Important Considerations:**
- You can often set this value based on the units you prefer the model to work with internally.
However, having target numerical values roughly centered around 1 (e.g., typically between 0.1 and 10) can sometimes improve numerical stability during model training.
- The default value in BatDetect2 (e.g., `1000.0`) has been chosen to scale the duration relative to the spectrogram width under default STFT settings.
Be aware that if you significantly change STFT parameters (window size or overlap), the relationship between the default scale and the spectrogram dimensions might change.
- Crucially, whatever scale you use during training **must** be used when decoding the model's predictions back into real-world time units (seconds).
BatDetect2 generally handles this consistency for you automatically when using the full pipeline.
- **Example Value:** `time_scale: 1000.0`
- **`frequency_scale`**:
- **What it does:** Similar to `time_scale`, this numerical scaling factor converts the _actual frequency bandwidth_ (height, typically measured in Hz or kHz) of the bounding box into the numerical 'height' value the model learns to predict.
- **Why configure it?** It gives physical meaning to the model's raw numerical prediction for bandwidth and allows you to choose the internal units or scale.
- **Important Considerations:**
- Same as for `time_scale`.
- **Example Value:** `frequency_scale: 0.00116`
**Example YAML Configuration:**
```yaml
# Inside your main configuration file (e.g., training_config.yaml)
targets: # Top-level key for target definition
# ... filtering settings ...
# ... transforms settings ...
# ... classes settings ...
# --- ROI Mapping Settings ---
roi:
position: "bottom-left" # Reference point (e.g., "center", "bottom-left")
time_scale: 1000.0 # e.g., Model predicts width in ms
frequency_scale: 0.00116 # e.g., Model predicts height relative to ~860Hz (or other model-specific scaling)
```
### Decoding Size Predictions
These scaling factors (`time_scale`, `frequency_scale`) are also essential for interpreting the model's output correctly.
When the model predicts numerical values for width and height, BatDetect2 uses these same scales (in reverse) to convert those numbers back into physically meaningful durations (seconds) and bandwidths (Hz/kHz) when reconstructing bounding boxes from predictions.
### Outcome
By configuring the `roi` settings, you ensure that BatDetect2 consistently translates the geometric information from your annotations into the specific reference points and scaled size values required for training the model.
Using consistent scales that are appropriate for your data and potentially beneficial for training stability allows the model to effectively learn not just _what_ sound is present, but also _where_ it is located and _how large_ it is, and enables meaningful interpretation of the model's spatial and size predictions.

View File

@ -13,9 +13,9 @@ from batdetect2.targets.classes import (
_get_default_class_name, _get_default_class_name,
_get_default_classes, _get_default_classes,
_is_target_class, _is_target_class,
build_decoder_from_config, build_sound_event_decoder,
build_encoder_from_config, build_sound_event_encoder,
build_generic_class_tags_from_config, build_generic_class_tags,
get_class_names_from_config, get_class_names_from_config,
load_classes_config, load_classes_config,
load_decoder_from_config, load_decoder_from_config,
@ -231,7 +231,7 @@ def test_build_encoder_from_config(
) )
] ]
) )
encoder = build_encoder_from_config( encoder = build_sound_event_encoder(
config, config,
term_registry=sample_term_registry, term_registry=sample_term_registry,
) )
@ -239,7 +239,7 @@ def test_build_encoder_from_config(
assert result == "pippip" assert result == "pippip"
config = ClassesConfig(classes=[]) config = ClassesConfig(classes=[])
encoder = build_encoder_from_config( encoder = build_sound_event_encoder(
config, config,
term_registry=sample_term_registry, term_registry=sample_term_registry,
) )
@ -315,7 +315,7 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
], ],
generic_class=[TagInfo(key="order", value="Chiroptera")], generic_class=[TagInfo(key="order", value="Chiroptera")],
) )
decoder = build_decoder_from_config( decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry config, term_registry=sample_term_registry
) )
tags = decoder("pippip") tags = decoder("pippip")
@ -335,7 +335,7 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
], ],
generic_class=[TagInfo(key="order", value="Chiroptera")], generic_class=[TagInfo(key="order", value="Chiroptera")],
) )
decoder = build_decoder_from_config( decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry config, term_registry=sample_term_registry
) )
tags = decoder("pippip") tags = decoder("pippip")
@ -344,14 +344,14 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
assert tags[0].value == "Pipistrellus pipistrellus" assert tags[0].value == "Pipistrellus pipistrellus"
# Test raise_on_unmapped=True # Test raise_on_unmapped=True
decoder = build_decoder_from_config( decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry, raise_on_unmapped=True config, term_registry=sample_term_registry, raise_on_unmapped=True
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
decoder("unknown_class") decoder("unknown_class")
# Test raise_on_unmapped=False # Test raise_on_unmapped=False
decoder = build_decoder_from_config( decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry, raise_on_unmapped=False config, term_registry=sample_term_registry, raise_on_unmapped=False
) )
tags = decoder("unknown_class") tags = decoder("unknown_class")
@ -402,7 +402,7 @@ def test_build_generic_class_tags_from_config(
TagInfo(key="call_type", value="Echolocation"), TagInfo(key="call_type", value="Echolocation"),
], ],
) )
generic_tags = build_generic_class_tags_from_config( generic_tags = build_generic_class_tags(
config, term_registry=sample_term_registry config, term_registry=sample_term_registry
) )
assert len(generic_tags) == 2 assert len(generic_tags) == 2

View File

@ -7,7 +7,7 @@ from soundevent import data
from batdetect2.targets.filtering import ( from batdetect2.targets.filtering import (
FilterConfig, FilterConfig,
FilterRule, FilterRule,
build_filter_from_config, build_sound_event_filter,
build_filter_from_rule, build_filter_from_rule,
contains_tags, contains_tags,
does_not_have_tags, does_not_have_tags,
@ -121,7 +121,7 @@ def test_build_filter_from_config(create_annotation):
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]), FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
] ]
) )
filter_from_config = build_filter_from_config(config) filter_from_config = build_sound_event_filter(config)
annotation_pass = create_annotation(["tag1", "tag2"]) annotation_pass = create_annotation(["tag1", "tag2"])
assert filter_from_config(annotation_pass) assert filter_from_config(annotation_pass)