Update TargetProtocol and related to include rois

This commit is contained in:
mbsantiago 2025-04-19 20:26:18 +01:00
parent 9410112e41
commit 6236e78414
7 changed files with 498 additions and 190 deletions

View File

@ -1,25 +1,29 @@
"""Main entry point for the BatDetect2 Target Definition subsystem.
This package (`batdetect2.targets`) provides the tools and configurations
necessary to define precisely what the BatDetect2 model should learn to detect
and classify from audio data. It involves several conceptual steps, managed
through configuration files and culminating in executable functions:
necessary to define precisely what the BatDetect2 model should learn to detect,
classify, and localize from audio data. It involves several conceptual steps,
managed through configuration files and culminating in an executable pipeline:
1. **Terms (`.terms`)**: Defining a controlled vocabulary for annotation tags.
1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
3. **Transformation (`.transform`)**: Modifying tags (e.g., standardization,
3. **Transformation (`.transform`)**: Modifying tags (standardization,
derivation).
4. **Class Definition (`.classes`)**: Mapping tags to specific target class
names (encoding) and defining how predicted class names map back to tags
(decoding).
4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
target position and size representations, and back.
5. **Class Definition (`.classes`)**: Mapping tags to target class names
(encoding) and mapping predicted names back to tags (decoding).
This module exposes the key components for users to configure and utilize this
target definition pipeline, primarily through the `TargetConfig` data structure
and the `Targets` class, which encapsulates the configured processing steps.
and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
configured processing steps. The main way to create a functional `Targets`
object is via the `build_targets` or `load_targets` functions.
"""
from typing import List, Optional
import numpy as np
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
@ -28,9 +32,9 @@ from batdetect2.targets.classes import (
SoundEventDecoder,
SoundEventEncoder,
TargetClass,
build_decoder_from_config,
build_encoder_from_config,
build_generic_class_tags_from_config,
build_generic_class_tags,
build_sound_event_decoder,
build_sound_event_encoder,
get_class_names_from_config,
load_classes_config,
load_decoder_from_config,
@ -40,10 +44,15 @@ from batdetect2.targets.filtering import (
FilterConfig,
FilterRule,
SoundEventFilter,
build_filter_from_config,
build_sound_event_filter,
load_filter_config,
load_filter_from_config,
)
from batdetect2.targets.rois import (
ROIConfig,
ROITargetMapper,
build_roi_mapper,
)
from batdetect2.targets.terms import (
TagInfo,
TermInfo,
@ -69,6 +78,7 @@ from batdetect2.targets.transform import (
load_transformation_from_config,
register_derivation,
)
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClassesConfig",
@ -76,6 +86,8 @@ __all__ = [
"FilterConfig",
"FilterRule",
"MapValueRule",
"ROIConfig",
"ROITargetMapper",
"ReplaceRule",
"SoundEventDecoder",
"SoundEventEncoder",
@ -84,13 +96,15 @@ __all__ = [
"TagInfo",
"TargetClass",
"TargetConfig",
"TargetProtocol",
"Targets",
"TermInfo",
"TransformConfig",
"build_decoder_from_config",
"build_encoder_from_config",
"build_filter_from_config",
"build_generic_class_tags_from_config",
"build_sound_event_decoder",
"build_sound_event_encoder",
"build_sound_event_filter",
"build_generic_class_tags",
"build_roi_mapper",
"build_transformation_from_config",
"call_type",
"get_class_names_from_config",
@ -114,29 +128,36 @@ __all__ = [
class TargetConfig(BaseConfig):
"""Unified configuration for the entire target definition pipeline.
This model aggregates the configurations for the optional filtering and
transformation steps, and the mandatory class definition step. It serves as
the primary input for building a complete `Targets` processing object.
This model aggregates the configurations for semantic processing (filtering,
transformation, class definition) and geometric processing (ROI mapping).
It serves as the primary input for building a complete `Targets` object
via `build_targets` or `load_targets`.
Attributes
----------
filtering : FilterConfig, optional
Configuration for filtering sound event annotations. If None or
omitted, no filtering is applied.
Configuration for filtering sound event annotations based on tags.
If None or omitted, no filtering is applied.
transforms : TransformConfig, optional
Configuration for transforming annotation tags. If None or omitted, no
transformations are applied.
Configuration for transforming annotation tags
(mapping, derivation, etc.). If None or omitted, no tag transformations
are applied.
classes : ClassesConfig
Configuration defining the specific target classes, their matching
rules, decoding rules (`output_tags`), and the generic class
definition. This section is mandatory.
Configuration defining the specific target classes, their tag matching
rules for encoding, their representative tags for decoding
(`output_tags`), and the definition of the generic class tags.
This section is mandatory.
roi : ROIConfig, optional
Configuration defining how geometric ROIs (e.g., bounding boxes) are
mapped to target representations (reference point, scaled size).
Controls `position`, `time_scale`, `frequency_scale`. If None or
omitted, default ROI mapping settings are used.
"""
filtering: Optional[FilterConfig] = None
transforms: Optional[TransformConfig] = None
classes: ClassesConfig
roi: Optional[ROIConfig] = None
def load_target_config(
@ -177,34 +198,40 @@ def load_target_config(
return load_config(path=path, schema=TargetConfig, field=field)
class Targets:
class Targets(TargetProtocol):
"""Encapsulates the complete configured target definition pipeline.
This class holds the functions for filtering, transforming, encoding, and
decoding annotations based on a loaded `TargetConfig`. It provides a
high-level interface to apply these steps and access relevant metadata
like class names and generic class tags.
This class implements the `TargetProtocol`, holding the configured
functions for filtering, transforming, encoding (tags to class name),
decoding (class name to tags), and mapping ROIs (geometry to position/size
and back). It provides a high-level interface to apply these steps and
access relevant metadata like class names and dimension names.
Instances are typically created using the `Targets.from_config` or
`Targets.from_file` classmethods.
Instances are typically created using the `build_targets` factory function
or the `load_targets` convenience loader.
Attributes
----------
class_names : list[str]
class_names : List[str]
An ordered list of the unique names of the specific target classes
defined in the configuration.
generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the configured
generic class (e.g., the default 'Bat' class).
generic class category (used when no specific class matches).
dimension_names : List[str]
The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']).
"""
class_names: list[str]
class_names: List[str]
generic_class_tags: List[data.Tag]
dimension_names: List[str]
def __init__(
self,
encode_fn: SoundEventEncoder,
decode_fn: SoundEventDecoder,
roi_mapper: ROITargetMapper,
class_names: list[str],
generic_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventFilter] = None,
@ -212,26 +239,31 @@ class Targets:
):
"""Initialize the Targets object.
Note: This constructor is typically called internally by the
`build_targets` factory function.
Parameters
----------
encode_fn : SoundEventEncoder
The configured function to encode annotations to class names.
Configured function to encode annotations to class names.
decode_fn : SoundEventDecoder
The configured function to decode class names to tags.
Configured function to decode class names to tags.
roi_mapper : ROITargetMapper
Configured object for mapping geometry to/from position/size.
class_names : list[str]
The ordered list of specific target class names.
Ordered list of specific target class names.
generic_class_tags : List[data.Tag]
The list of tags representing the generic class.
List of tags representing the generic class.
filter_fn : SoundEventFilter, optional
The configured function to filter annotations. Defaults to None (no
filtering).
Configured function to filter annotations. Defaults to None.
transform_fn : SoundEventTransformation, optional
The configured function to transform annotation tags. Defaults to
None (no transformation).
Configured function to transform annotation tags. Defaults to None.
"""
self.class_names = class_names
self.generic_class_tags = generic_class_tags
self.dimension_names = roi_mapper.dimension_names
self._roi_mapper = roi_mapper
self._filter_fn = filter_fn
self._encode_fn = encode_fn
self._decode_fn = decode_fn
@ -316,133 +348,223 @@ class Targets:
return self._transform_fn(sound_event)
return sound_event
@classmethod
def from_config(
cls,
config: TargetConfig,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> "Targets":
"""Build a Targets object from a loaded TargetConfig.
def get_position(
self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]:
"""Extract the target reference position from the annotation's roi.
This factory method takes the unified configuration object and
constructs all the necessary functional components (filter, transform,
encoder, decoder) and extracts metadata (class names, generic tags) to
create a fully configured `Targets` instance.
Delegates to the internal ROI mapper's `get_roi_position` method.
Parameters
----------
config : TargetConfig
The loaded and validated unified target configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance to use for resolving term keys. Defaults
to the global `batdetect2.targets.terms.term_registry`.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use for resolving derivation
function names. Defaults to the global
`batdetect2.targets.transform.derivation_registry`.
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
Targets
An initialized `Targets` object ready for use.
Tuple[float, float]
The reference position `(time, frequency)`.
Raises
------
KeyError
If term keys or derivation function keys specified in the `config`
are not found in their respective registries.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails (when configured).
ValueError
If the annotation lacks geometry.
"""
filter_fn = (
build_filter_from_config(
config.filtering,
term_registry=term_registry,
geom = sound_event.sound_event.geometry
if geom is None:
raise ValueError(
"Sound event has no geometry, cannot get its position."
)
if config.filtering
else None
)
encode_fn = build_encoder_from_config(
config.classes,
term_registry=term_registry,
)
decode_fn = build_decoder_from_config(
config.classes,
term_registry=term_registry,
)
transform_fn = (
build_transformation_from_config(
config.transforms,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
if config.transforms
else None
)
class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags_from_config(
config.classes,
term_registry=term_registry,
)
return cls(
filter_fn=filter_fn,
encode_fn=encode_fn,
decode_fn=decode_fn,
class_names=class_names,
generic_class_tags=generic_class_tags,
transform_fn=transform_fn,
)
return self._roi_mapper.get_roi_position(geom)
@classmethod
def from_file(
cls,
config_path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> "Targets":
"""Load a Targets object directly from a configuration file.
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
"""Calculate the target size dimensions from the annotation's geometry.
This convenience factory method loads the `TargetConfig` from the
specified file path and then calls `Targets.from_config` to build
the fully initialized `Targets` object.
Delegates to the internal ROI mapper's `get_roi_size` method, which
applies configured scaling factors.
Parameters
----------
config_path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
term_registry : TermRegistry, optional
The TermRegistry instance to use. Defaults to the global default.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use. Defaults to the global
default.
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
Targets
An initialized `Targets` object ready for use.
np.ndarray
NumPy array containing the size dimensions, matching the
order in `self.dimension_names` (e.g., `[width, height]`).
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
Errors raised during file loading, validation, or extraction via
`load_target_config`.
KeyError, ImportError, AttributeError, TypeError
Errors raised during the build process by `Targets.from_config`
(e.g., missing keys in registries, failed imports).
ValueError
If the annotation lacks geometry.
"""
config = load_target_config(
config_path,
field=field,
geom = sound_event.sound_event.geometry
if geom is None:
raise ValueError(
"Sound event has no geometry, cannot get its size."
)
return self._roi_mapper.get_roi_size(geom)
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which
un-scales the dimensions and reconstructs the geometry (typically a
`BoundingBox`).
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`.
Returns
-------
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
return self._roi_mapper.recover_roi(pos, dims)
def build_targets(
config: TargetConfig,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
This factory function takes the unified `TargetConfig` and constructs all
necessary functional components (filter, transform, encoder,
decoder, ROI mapper) by calling their respective builder functions. It also
extracts metadata (class names, generic tags, dimension names) to create
and return a fully initialized `Targets` instance, ready to process
annotations.
Parameters
----------
config : TargetConfig
The loaded and validated unified target configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance to use for resolving term keys. Defaults
to the global `batdetect2.targets.terms.term_registry`.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use for resolving derivation
function names. Defaults to the global
`batdetect2.targets.transform.derivation_registry`.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
KeyError
If term keys or derivation function keys specified in the `config`
are not found in their respective registries.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails (when configured).
"""
filter_fn = (
build_sound_event_filter(
config.filtering,
term_registry=term_registry,
)
return cls.from_config(
config,
if config.filtering
else None
)
encode_fn = build_sound_event_encoder(
config.classes,
term_registry=term_registry,
)
decode_fn = build_sound_event_decoder(
config.classes,
term_registry=term_registry,
)
transform_fn = (
build_transformation_from_config(
config.transforms,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
if config.transforms
else None
)
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags(
config.classes,
term_registry=term_registry,
)
return Targets(
filter_fn=filter_fn,
encode_fn=encode_fn,
decode_fn=decode_fn,
class_names=class_names,
roi_mapper=roi_mapper,
generic_class_tags=generic_class_tags,
transform_fn=transform_fn,
)
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> Targets:
"""Load a Targets object directly from a configuration file.
This convenience factory method loads the `TargetConfig` from the
specified file path and then calls `Targets.from_config` to build
the fully initialized `Targets` object.
Parameters
----------
config_path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
term_registry : TermRegistry, optional
The TermRegistry instance to use. Defaults to the global default.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use. Defaults to the global
default.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
Errors raised during file loading, validation, or extraction via
`load_target_config`.
KeyError, ImportError, AttributeError, TypeError
Errors raised during the build process by `Targets.from_config`
(e.g., missing keys in registries, failed imports).
"""
config = load_target_config(
config_path,
field=field,
)
return build_targets(
config,
term_registry=term_registry,
derivation_registry=derivation_registry,
)

View File

@ -22,9 +22,9 @@ __all__ = [
"load_classes_config",
"load_encoder_from_config",
"load_decoder_from_config",
"build_encoder_from_config",
"build_decoder_from_config",
"build_generic_class_tags_from_config",
"build_sound_event_encoder",
"build_sound_event_decoder",
"build_generic_class_tags",
"get_class_names_from_config",
"DEFAULT_SPECIES_LIST",
]
@ -314,7 +314,7 @@ def _encode_with_multiple_classifiers(
return None
def build_encoder_from_config(
def build_sound_event_encoder(
config: ClassesConfig,
term_registry: TermRegistry = term_registry,
) -> SoundEventEncoder:
@ -408,7 +408,7 @@ def _decode_class(
return mapping[name]
def build_decoder_from_config(
def build_sound_event_decoder(
config: ClassesConfig,
term_registry: TermRegistry = term_registry,
raise_on_unmapped: bool = False,
@ -463,7 +463,7 @@ def build_decoder_from_config(
)
def build_generic_class_tags_from_config(
def build_generic_class_tags(
config: ClassesConfig,
term_registry: TermRegistry = term_registry,
) -> List[data.Tag]:
@ -565,7 +565,7 @@ def load_encoder_from_config(
provided `term_registry` during the build process.
"""
config = load_classes_config(path, field=field)
return build_encoder_from_config(config, term_registry=term_registry)
return build_sound_event_encoder(config, term_registry=term_registry)
def load_decoder_from_config(
@ -611,7 +611,7 @@ def load_decoder_from_config(
provided `term_registry` during the build process.
"""
config = load_classes_config(path, field=field)
return build_decoder_from_config(
return build_sound_event_decoder(
config,
term_registry=term_registry,
raise_on_unmapped=raise_on_unmapped,

View File

@ -17,7 +17,7 @@ __all__ = [
"FilterConfig",
"FilterRule",
"SoundEventFilter",
"build_filter_from_config",
"build_sound_event_filter",
"build_filter_from_rule",
"load_filter_config",
"load_filter_from_config",
@ -241,7 +241,7 @@ class FilterConfig(BaseConfig):
rules: List[FilterRule] = Field(default_factory=list)
def build_filter_from_config(
def build_sound_event_filter(
config: FilterConfig,
term_registry: TermRegistry = term_registry,
) -> SoundEventFilter:
@ -312,4 +312,4 @@ def load_filter_from_config(
The final merged filter function ready to be used.
"""
config = load_filter_config(path=path, field=field)
return build_filter_from_config(config, term_registry=term_registry)
return build_sound_event_filter(config, term_registry=term_registry)

View File

@ -1,19 +1,20 @@
"""Defines the core interface (Protocol) for the target definition pipeline.
This module specifies the standard structure and methods expected from an object
that encapsulates the configured logic for processing sound event annotations
within the `batdetect2.targets` system.
This module specifies the standard structure, attributes, and methods expected
from an object that encapsulates the complete configured logic for processing
sound event annotations within the `batdetect2.targets` system.
The main component defined here is the `TargetEncoder` protocol. This protocol
acts as a contract, ensuring that components responsible for applying
filtering, transformations, encoding annotations to class names, and decoding
class names back to tags can be interacted with in a consistent manner
throughout BatDetect2. It also defines essential metadata attributes expected
from implementations.
The main component defined here is the `TargetProtocol`. This protocol acts as
a contract for the entire target definition process, covering semantic aspects
(filtering, tag transformation, class encoding/decoding) as well as geometric
aspects (mapping regions of interest to target positions and sizes). It ensures
that components responsible for these tasks can be interacted with consistently
throughout BatDetect2.
"""
from typing import List, Optional, Protocol
import numpy as np
from soundevent import data
__all__ = [
@ -26,18 +27,30 @@ class TargetProtocol(Protocol):
This protocol outlines the standard attributes and methods for an object
that encapsulates the complete, configured process for handling sound event
annotations to determine their target class for model training, and for
interpreting model predictions back into annotation tags.
annotations (both tags and geometry). It defines how to:
- Filter relevant annotations.
- Transform annotation tags.
- Encode an annotation into a specific target class name.
- Decode a class name back into representative tags.
- Extract a target reference position from an annotation's geometry (ROI).
- Calculate target size dimensions from an annotation's geometry.
- Recover an approximate geometry (ROI) from a position and size
dimensions.
Implementations of this protocol bundle all configured logic for these
steps.
Attributes
----------
class_names : List[str]
An ordered list of the unique names of the specific target classes
defined by the configuration represented by this object.
defined by the configuration.
generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the
generic class category (e.g., the default 'Bat' class tags used when
no specific class matches).
A list of `soundevent.data.Tag` objects representing the configured
generic class category (e.g., used when no specific class matches).
dimension_names : List[str]
A list containing the names of the size dimensions returned by
`get_size` and expected by `recover_roi` (e.g., ['width', 'height']).
"""
class_names: List[str]
@ -46,6 +59,9 @@ class TargetProtocol(Protocol):
generic_class_tags: List[data.Tag]
"""List of tags representing the generic (unclassified) category."""
dimension_names: List[str]
"""Names of the size dimensions (e.g., ['width', 'height'])."""
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the filter to a sound event annotation.
@ -100,10 +116,10 @@ class TargetProtocol(Protocol):
Returns
-------
str or None
The string name of the matched target class if the annotation matches
a specific class definition. Returns None if the annotation does not
match any specific class rule (indicating it may belong to a generic
category or should be handled differently downstream).
The string name of the matched target class if the annotation
matches a specific class definition. Returns None if the annotation
does not match any specific class rule (indicating it may belong
to a generic category or should be handled differently downstream).
"""
...
@ -130,3 +146,88 @@ class TargetProtocol(Protocol):
found in the configured mapping and error raising is enabled.
"""
...
def get_position(
self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]:
"""Extract the target reference position from the annotation's geometry.
Calculates the `(time, frequency)` coordinate representing the primary
location of the sound event.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
Tuple[float, float]
The calculated reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry or if the position cannot be
calculated for the geometry type or configured reference point.
"""
...
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
"""Calculate the target size dimensions from the annotation's geometry.
Computes the relevant physical size (e.g., duration/width,
bandwidth/height from a bounding box) to produce
the numerical target values expected by the model.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
np.ndarray
A NumPy array containing the size dimensions, matching the
order specified by the `dimension_names` attribute (e.g.,
`[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry or if the size cannot be computed.
TypeError
If geometry type is unsupported.
"""
...
def recover_roi(
self, pos: tuple[float, float], dims: np.ndarray
) -> data.Geometry:
"""Recover the ROI geometry from a position and dimensions.
Performs the inverse mapping of `get_position` and `get_size`. It takes
a reference position `(time, frequency)` and an array of size
dimensions and reconstructs an approximate geometric representation.
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
The NumPy array containing the dimensions (e.g., predicted
by the model), corresponding to the order in `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry.
Raises
------
ValueError
If the number of provided `dims` does not match `dimension_names`,
if dimensions are invalid (e.g., negative after unscaling), or
if reconstruction fails based on the configured position type.
"""
...

View File

@ -0,0 +1,85 @@
## Defining Target Geometry: Mapping Sound Event Regions
### Introduction
In the previous steps of defining targets, we focused on determining _which_ sound events are relevant (`filtering`), _what_ descriptive tags they should have (`transform`), and _which category_ they belong to (`classes`).
However, for the model to learn effectively, it also needs to know **where** in the spectrogram each sound event is located and approximately **how large** it is.
Your annotations typically define the location and extent of a sound event using a **Region of Interest (ROI)**, most commonly a **bounding box** drawn around the call on the spectrogram.
This ROI contains detailed spatial information (start/end time, low/high frequency).
This section explains how BatDetect2 converts the geometric ROI from your annotations into the specific positional and size information used as targets during model training.
### From ROI to Model Targets: Position & Size
BatDetect2 does not directly predict a full bounding box.
Instead, it is trained to predict:
1. **A Reference Point:** A single point `(time, frequency)` that represents the primary location of the detected sound event within the spectrogram.
2. **Size Dimensions:** Numerical values representing the event's size relative to that reference point, typically its `width` (duration in time) and `height` (bandwidth in frequency).
This step defines _how_ BatDetect2 calculates this specific reference point and these numerical size values from the original annotation's bounding box.
It also handles the reverse process converting predicted positions and sizes back into bounding boxes for visualization or analysis.
### Configuring the ROI Mapping
You can control how this conversion happens through settings in your configuration file (e.g., your main `.yaml` file).
These settings are usually placed within the main `targets:` configuration block, under a specific `roi:` key.
Here are the key settings:
- **`position`**:
- **What it does:** Determines which specific point on the annotation's bounding box is used as the single **Reference Point** for training (e.g., `"center"`, `"bottom-left"`).
- **Why configure it?** This affects where the peak signal appears in the target heatmaps used for training.
Different choices might slightly influence model learning.
The default (`"bottom-left"`) is often a good starting point.
- **Example Value:** `position: "center"`
- **`time_scale`**:
- **What it does:** This is a numerical scaling factor that converts the _actual duration_ (width, measured in seconds) of the bounding box into the numerical 'width' value the model learns to predict (and which is stored in the Size Heatmap).
- **Why configure it?** The model predicts raw numbers for size; this scale gives those numbers meaning.
For example, setting `time_scale: 1000.0` means the model will be trained to predict the duration in **milliseconds** instead of seconds.
- **Important Considerations:**
- You can often set this value based on the units you prefer the model to work with internally.
However, having target numerical values roughly centered around 1 (e.g., typically between 0.1 and 10) can sometimes improve numerical stability during model training.
- The default value in BatDetect2 (e.g., `1000.0`) has been chosen to scale the duration relative to the spectrogram width under default STFT settings.
Be aware that if you significantly change STFT parameters (window size or overlap), the relationship between the default scale and the spectrogram dimensions might change.
- Crucially, whatever scale you use during training **must** be used when decoding the model's predictions back into real-world time units (seconds).
BatDetect2 generally handles this consistency for you automatically when using the full pipeline.
- **Example Value:** `time_scale: 1000.0`
- **`frequency_scale`**:
- **What it does:** Similar to `time_scale`, this numerical scaling factor converts the _actual frequency bandwidth_ (height, typically measured in Hz or kHz) of the bounding box into the numerical 'height' value the model learns to predict.
- **Why configure it?** It gives physical meaning to the model's raw numerical prediction for bandwidth and allows you to choose the internal units or scale.
- **Important Considerations:**
- Same as for `time_scale`.
- **Example Value:** `frequency_scale: 0.00116`
**Example YAML Configuration:**
```yaml
# Inside your main configuration file (e.g., training_config.yaml)
targets: # Top-level key for target definition
# ... filtering settings ...
# ... transforms settings ...
# ... classes settings ...
# --- ROI Mapping Settings ---
roi:
position: "bottom-left" # Reference point (e.g., "center", "bottom-left")
time_scale: 1000.0 # e.g., Model predicts width in ms
frequency_scale: 0.00116 # e.g., Model predicts height relative to ~860Hz (or other model-specific scaling)
```
### Decoding Size Predictions
These scaling factors (`time_scale`, `frequency_scale`) are also essential for interpreting the model's output correctly.
When the model predicts numerical values for width and height, BatDetect2 uses these same scales (in reverse) to convert those numbers back into physically meaningful durations (seconds) and bandwidths (Hz/kHz) when reconstructing bounding boxes from predictions.
### Outcome
By configuring the `roi` settings, you ensure that BatDetect2 consistently translates the geometric information from your annotations into the specific reference points and scaled size values required for training the model.
Using consistent scales that are appropriate for your data and potentially beneficial for training stability allows the model to effectively learn not just _what_ sound is present, but also _where_ it is located and _how large_ it is, and enables meaningful interpretation of the model's spatial and size predictions.

View File

@ -13,9 +13,9 @@ from batdetect2.targets.classes import (
_get_default_class_name,
_get_default_classes,
_is_target_class,
build_decoder_from_config,
build_encoder_from_config,
build_generic_class_tags_from_config,
build_sound_event_decoder,
build_sound_event_encoder,
build_generic_class_tags,
get_class_names_from_config,
load_classes_config,
load_decoder_from_config,
@ -231,7 +231,7 @@ def test_build_encoder_from_config(
)
]
)
encoder = build_encoder_from_config(
encoder = build_sound_event_encoder(
config,
term_registry=sample_term_registry,
)
@ -239,7 +239,7 @@ def test_build_encoder_from_config(
assert result == "pippip"
config = ClassesConfig(classes=[])
encoder = build_encoder_from_config(
encoder = build_sound_event_encoder(
config,
term_registry=sample_term_registry,
)
@ -315,7 +315,7 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_decoder_from_config(
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry
)
tags = decoder("pippip")
@ -335,7 +335,7 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_decoder_from_config(
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry
)
tags = decoder("pippip")
@ -344,14 +344,14 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
assert tags[0].value == "Pipistrellus pipistrellus"
# Test raise_on_unmapped=True
decoder = build_decoder_from_config(
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry, raise_on_unmapped=True
)
with pytest.raises(ValueError):
decoder("unknown_class")
# Test raise_on_unmapped=False
decoder = build_decoder_from_config(
decoder = build_sound_event_decoder(
config, term_registry=sample_term_registry, raise_on_unmapped=False
)
tags = decoder("unknown_class")
@ -402,7 +402,7 @@ def test_build_generic_class_tags_from_config(
TagInfo(key="call_type", value="Echolocation"),
],
)
generic_tags = build_generic_class_tags_from_config(
generic_tags = build_generic_class_tags(
config, term_registry=sample_term_registry
)
assert len(generic_tags) == 2

View File

@ -7,7 +7,7 @@ from soundevent import data
from batdetect2.targets.filtering import (
FilterConfig,
FilterRule,
build_filter_from_config,
build_sound_event_filter,
build_filter_from_rule,
contains_tags,
does_not_have_tags,
@ -121,7 +121,7 @@ def test_build_filter_from_config(create_annotation):
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
]
)
filter_from_config = build_filter_from_config(config)
filter_from_config = build_sound_event_filter(config)
annotation_pass = create_annotation(["tag1", "tag2"])
assert filter_from_config(annotation_pass)