diff --git a/batdetect2/targets/__init__.py b/batdetect2/targets/__init__.py index 0530aab..bb32aab 100644 --- a/batdetect2/targets/__init__.py +++ b/batdetect2/targets/__init__.py @@ -1,25 +1,29 @@ """Main entry point for the BatDetect2 Target Definition subsystem. This package (`batdetect2.targets`) provides the tools and configurations -necessary to define precisely what the BatDetect2 model should learn to detect -and classify from audio data. It involves several conceptual steps, managed -through configuration files and culminating in executable functions: +necessary to define precisely what the BatDetect2 model should learn to detect, +classify, and localize from audio data. It involves several conceptual steps, +managed through configuration files and culminating in an executable pipeline: -1. **Terms (`.terms`)**: Defining a controlled vocabulary for annotation tags. +1. **Terms (`.terms`)**: Defining vocabulary for annotation tags. 2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations. -3. **Transformation (`.transform`)**: Modifying tags (e.g., standardization, +3. **Transformation (`.transform`)**: Modifying tags (standardization, derivation). -4. **Class Definition (`.classes`)**: Mapping tags to specific target class - names (encoding) and defining how predicted class names map back to tags - (decoding). +4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to + target position and size representations, and back. +5. **Class Definition (`.classes`)**: Mapping tags to target class names + (encoding) and mapping predicted names back to tags (decoding). This module exposes the key components for users to configure and utilize this target definition pipeline, primarily through the `TargetConfig` data structure -and the `Targets` class, which encapsulates the configured processing steps. +and the `Targets` class (implementing `TargetProtocol`), which encapsulates the +configured processing steps. The main way to create a functional `Targets` +object is via the `build_targets` or `load_targets` functions. """ from typing import List, Optional +import numpy as np from soundevent import data from batdetect2.configs import BaseConfig, load_config @@ -28,9 +32,9 @@ from batdetect2.targets.classes import ( SoundEventDecoder, SoundEventEncoder, TargetClass, - build_decoder_from_config, - build_encoder_from_config, - build_generic_class_tags_from_config, + build_generic_class_tags, + build_sound_event_decoder, + build_sound_event_encoder, get_class_names_from_config, load_classes_config, load_decoder_from_config, @@ -40,10 +44,15 @@ from batdetect2.targets.filtering import ( FilterConfig, FilterRule, SoundEventFilter, - build_filter_from_config, + build_sound_event_filter, load_filter_config, load_filter_from_config, ) +from batdetect2.targets.rois import ( + ROIConfig, + ROITargetMapper, + build_roi_mapper, +) from batdetect2.targets.terms import ( TagInfo, TermInfo, @@ -69,6 +78,7 @@ from batdetect2.targets.transform import ( load_transformation_from_config, register_derivation, ) +from batdetect2.targets.types import TargetProtocol __all__ = [ "ClassesConfig", @@ -76,6 +86,8 @@ __all__ = [ "FilterConfig", "FilterRule", "MapValueRule", + "ROIConfig", + "ROITargetMapper", "ReplaceRule", "SoundEventDecoder", "SoundEventEncoder", @@ -84,13 +96,15 @@ __all__ = [ "TagInfo", "TargetClass", "TargetConfig", + "TargetProtocol", "Targets", "TermInfo", "TransformConfig", - "build_decoder_from_config", - "build_encoder_from_config", - "build_filter_from_config", - "build_generic_class_tags_from_config", + "build_sound_event_decoder", + "build_sound_event_encoder", + "build_sound_event_filter", + "build_generic_class_tags", + "build_roi_mapper", "build_transformation_from_config", "call_type", "get_class_names_from_config", @@ -114,29 +128,36 @@ __all__ = [ class TargetConfig(BaseConfig): """Unified configuration for the entire target definition pipeline. - This model aggregates the configurations for the optional filtering and - transformation steps, and the mandatory class definition step. It serves as - the primary input for building a complete `Targets` processing object. + This model aggregates the configurations for semantic processing (filtering, + transformation, class definition) and geometric processing (ROI mapping). + It serves as the primary input for building a complete `Targets` object + via `build_targets` or `load_targets`. Attributes ---------- filtering : FilterConfig, optional - Configuration for filtering sound event annotations. If None or - omitted, no filtering is applied. + Configuration for filtering sound event annotations based on tags. + If None or omitted, no filtering is applied. transforms : TransformConfig, optional - Configuration for transforming annotation tags. If None or omitted, no - transformations are applied. + Configuration for transforming annotation tags + (mapping, derivation, etc.). If None or omitted, no tag transformations + are applied. classes : ClassesConfig - Configuration defining the specific target classes, their matching - rules, decoding rules (`output_tags`), and the generic class - definition. This section is mandatory. + Configuration defining the specific target classes, their tag matching + rules for encoding, their representative tags for decoding + (`output_tags`), and the definition of the generic class tags. + This section is mandatory. + roi : ROIConfig, optional + Configuration defining how geometric ROIs (e.g., bounding boxes) are + mapped to target representations (reference point, scaled size). + Controls `position`, `time_scale`, `frequency_scale`. If None or + omitted, default ROI mapping settings are used. """ filtering: Optional[FilterConfig] = None - transforms: Optional[TransformConfig] = None - classes: ClassesConfig + roi: Optional[ROIConfig] = None def load_target_config( @@ -177,34 +198,40 @@ def load_target_config( return load_config(path=path, schema=TargetConfig, field=field) -class Targets: +class Targets(TargetProtocol): """Encapsulates the complete configured target definition pipeline. - This class holds the functions for filtering, transforming, encoding, and - decoding annotations based on a loaded `TargetConfig`. It provides a - high-level interface to apply these steps and access relevant metadata - like class names and generic class tags. + This class implements the `TargetProtocol`, holding the configured + functions for filtering, transforming, encoding (tags to class name), + decoding (class name to tags), and mapping ROIs (geometry to position/size + and back). It provides a high-level interface to apply these steps and + access relevant metadata like class names and dimension names. - Instances are typically created using the `Targets.from_config` or - `Targets.from_file` classmethods. + Instances are typically created using the `build_targets` factory function + or the `load_targets` convenience loader. Attributes ---------- - class_names : list[str] + class_names : List[str] An ordered list of the unique names of the specific target classes defined in the configuration. generic_class_tags : List[data.Tag] A list of `soundevent.data.Tag` objects representing the configured - generic class (e.g., the default 'Bat' class). + generic class category (used when no specific class matches). + dimension_names : List[str] + The names of the size dimensions handled by the ROI mapper + (e.g., ['width', 'height']). """ - class_names: list[str] + class_names: List[str] generic_class_tags: List[data.Tag] + dimension_names: List[str] def __init__( self, encode_fn: SoundEventEncoder, decode_fn: SoundEventDecoder, + roi_mapper: ROITargetMapper, class_names: list[str], generic_class_tags: List[data.Tag], filter_fn: Optional[SoundEventFilter] = None, @@ -212,26 +239,31 @@ class Targets: ): """Initialize the Targets object. + Note: This constructor is typically called internally by the + `build_targets` factory function. + Parameters ---------- encode_fn : SoundEventEncoder - The configured function to encode annotations to class names. + Configured function to encode annotations to class names. decode_fn : SoundEventDecoder - The configured function to decode class names to tags. + Configured function to decode class names to tags. + roi_mapper : ROITargetMapper + Configured object for mapping geometry to/from position/size. class_names : list[str] - The ordered list of specific target class names. + Ordered list of specific target class names. generic_class_tags : List[data.Tag] - The list of tags representing the generic class. + List of tags representing the generic class. filter_fn : SoundEventFilter, optional - The configured function to filter annotations. Defaults to None (no - filtering). + Configured function to filter annotations. Defaults to None. transform_fn : SoundEventTransformation, optional - The configured function to transform annotation tags. Defaults to - None (no transformation). + Configured function to transform annotation tags. Defaults to None. """ self.class_names = class_names self.generic_class_tags = generic_class_tags + self.dimension_names = roi_mapper.dimension_names + self._roi_mapper = roi_mapper self._filter_fn = filter_fn self._encode_fn = encode_fn self._decode_fn = decode_fn @@ -316,133 +348,223 @@ class Targets: return self._transform_fn(sound_event) return sound_event - @classmethod - def from_config( - cls, - config: TargetConfig, - term_registry: TermRegistry = term_registry, - derivation_registry: DerivationRegistry = derivation_registry, - ) -> "Targets": - """Build a Targets object from a loaded TargetConfig. + def get_position( + self, sound_event: data.SoundEventAnnotation + ) -> tuple[float, float]: + """Extract the target reference position from the annotation's roi. - This factory method takes the unified configuration object and - constructs all the necessary functional components (filter, transform, - encoder, decoder) and extracts metadata (class names, generic tags) to - create a fully configured `Targets` instance. + Delegates to the internal ROI mapper's `get_roi_position` method. Parameters ---------- - config : TargetConfig - The loaded and validated unified target configuration object. - term_registry : TermRegistry, optional - The TermRegistry instance to use for resolving term keys. Defaults - to the global `batdetect2.targets.terms.term_registry`. - derivation_registry : DerivationRegistry, optional - The DerivationRegistry instance to use for resolving derivation - function names. Defaults to the global - `batdetect2.targets.transform.derivation_registry`. + sound_event : data.SoundEventAnnotation + The annotation containing the geometry (ROI). Returns ------- - Targets - An initialized `Targets` object ready for use. + Tuple[float, float] + The reference position `(time, frequency)`. Raises ------ - KeyError - If term keys or derivation function keys specified in the `config` - are not found in their respective registries. - ImportError, AttributeError, TypeError - If dynamic import of a derivation function fails (when configured). + ValueError + If the annotation lacks geometry. """ - filter_fn = ( - build_filter_from_config( - config.filtering, - term_registry=term_registry, + geom = sound_event.sound_event.geometry + + if geom is None: + raise ValueError( + "Sound event has no geometry, cannot get its position." ) - if config.filtering - else None - ) - encode_fn = build_encoder_from_config( - config.classes, - term_registry=term_registry, - ) - decode_fn = build_decoder_from_config( - config.classes, - term_registry=term_registry, - ) - transform_fn = ( - build_transformation_from_config( - config.transforms, - term_registry=term_registry, - derivation_registry=derivation_registry, - ) - if config.transforms - else None - ) - class_names = get_class_names_from_config(config.classes) - generic_class_tags = build_generic_class_tags_from_config( - config.classes, - term_registry=term_registry, - ) - return cls( - filter_fn=filter_fn, - encode_fn=encode_fn, - decode_fn=decode_fn, - class_names=class_names, - generic_class_tags=generic_class_tags, - transform_fn=transform_fn, - ) + return self._roi_mapper.get_roi_position(geom) - @classmethod - def from_file( - cls, - config_path: data.PathLike, - field: Optional[str] = None, - term_registry: TermRegistry = term_registry, - derivation_registry: DerivationRegistry = derivation_registry, - ) -> "Targets": - """Load a Targets object directly from a configuration file. + def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray: + """Calculate the target size dimensions from the annotation's geometry. - This convenience factory method loads the `TargetConfig` from the - specified file path and then calls `Targets.from_config` to build - the fully initialized `Targets` object. + Delegates to the internal ROI mapper's `get_roi_size` method, which + applies configured scaling factors. Parameters ---------- - config_path : data.PathLike - Path to the configuration file (e.g., YAML). - field : str, optional - Dot-separated path to a nested section within the file containing - the target configuration. If None, the entire file content is used. - term_registry : TermRegistry, optional - The TermRegistry instance to use. Defaults to the global default. - derivation_registry : DerivationRegistry, optional - The DerivationRegistry instance to use. Defaults to the global - default. + sound_event : data.SoundEventAnnotation + The annotation containing the geometry (ROI). Returns ------- - Targets - An initialized `Targets` object ready for use. + np.ndarray + NumPy array containing the size dimensions, matching the + order in `self.dimension_names` (e.g., `[width, height]`). Raises ------ - FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError, - TypeError - Errors raised during file loading, validation, or extraction via - `load_target_config`. - KeyError, ImportError, AttributeError, TypeError - Errors raised during the build process by `Targets.from_config` - (e.g., missing keys in registries, failed imports). + ValueError + If the annotation lacks geometry. """ - config = load_target_config( - config_path, - field=field, + geom = sound_event.sound_event.geometry + + if geom is None: + raise ValueError( + "Sound event has no geometry, cannot get its size." + ) + + return self._roi_mapper.get_roi_size(geom) + + def recover_roi( + self, + pos: tuple[float, float], + dims: np.ndarray, + ) -> data.Geometry: + """Recover an approximate geometric ROI from a position and dimensions. + + Delegates to the internal ROI mapper's `recover_roi` method, which + un-scales the dimensions and reconstructs the geometry (typically a + `BoundingBox`). + + Parameters + ---------- + pos : Tuple[float, float] + The reference position `(time, frequency)`. + dims : np.ndarray + NumPy array with size dimensions (e.g., from model prediction), + matching the order in `self.dimension_names`. + + Returns + ------- + data.Geometry + The reconstructed geometry (typically `BoundingBox`). + """ + return self._roi_mapper.recover_roi(pos, dims) + + +def build_targets( + config: TargetConfig, + term_registry: TermRegistry = term_registry, + derivation_registry: DerivationRegistry = derivation_registry, +) -> Targets: + """Build a Targets object from a loaded TargetConfig. + + This factory function takes the unified `TargetConfig` and constructs all + necessary functional components (filter, transform, encoder, + decoder, ROI mapper) by calling their respective builder functions. It also + extracts metadata (class names, generic tags, dimension names) to create + and return a fully initialized `Targets` instance, ready to process + annotations. + + Parameters + ---------- + config : TargetConfig + The loaded and validated unified target configuration object. + term_registry : TermRegistry, optional + The TermRegistry instance to use for resolving term keys. Defaults + to the global `batdetect2.targets.terms.term_registry`. + derivation_registry : DerivationRegistry, optional + The DerivationRegistry instance to use for resolving derivation + function names. Defaults to the global + `batdetect2.targets.transform.derivation_registry`. + + Returns + ------- + Targets + An initialized `Targets` object ready for use. + + Raises + ------ + KeyError + If term keys or derivation function keys specified in the `config` + are not found in their respective registries. + ImportError, AttributeError, TypeError + If dynamic import of a derivation function fails (when configured). + """ + filter_fn = ( + build_sound_event_filter( + config.filtering, + term_registry=term_registry, ) - return cls.from_config( - config, + if config.filtering + else None + ) + encode_fn = build_sound_event_encoder( + config.classes, + term_registry=term_registry, + ) + decode_fn = build_sound_event_decoder( + config.classes, + term_registry=term_registry, + ) + transform_fn = ( + build_transformation_from_config( + config.transforms, term_registry=term_registry, derivation_registry=derivation_registry, ) + if config.transforms + else None + ) + roi_mapper = build_roi_mapper(config.roi or ROIConfig()) + class_names = get_class_names_from_config(config.classes) + generic_class_tags = build_generic_class_tags( + config.classes, + term_registry=term_registry, + ) + + return Targets( + filter_fn=filter_fn, + encode_fn=encode_fn, + decode_fn=decode_fn, + class_names=class_names, + roi_mapper=roi_mapper, + generic_class_tags=generic_class_tags, + transform_fn=transform_fn, + ) + + +def load_targets( + config_path: data.PathLike, + field: Optional[str] = None, + term_registry: TermRegistry = term_registry, + derivation_registry: DerivationRegistry = derivation_registry, +) -> Targets: + """Load a Targets object directly from a configuration file. + + This convenience factory method loads the `TargetConfig` from the + specified file path and then calls `Targets.from_config` to build + the fully initialized `Targets` object. + + Parameters + ---------- + config_path : data.PathLike + Path to the configuration file (e.g., YAML). + field : str, optional + Dot-separated path to a nested section within the file containing + the target configuration. If None, the entire file content is used. + term_registry : TermRegistry, optional + The TermRegistry instance to use. Defaults to the global default. + derivation_registry : DerivationRegistry, optional + The DerivationRegistry instance to use. Defaults to the global + default. + + Returns + ------- + Targets + An initialized `Targets` object ready for use. + + Raises + ------ + FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError, + TypeError + Errors raised during file loading, validation, or extraction via + `load_target_config`. + KeyError, ImportError, AttributeError, TypeError + Errors raised during the build process by `Targets.from_config` + (e.g., missing keys in registries, failed imports). + """ + config = load_target_config( + config_path, + field=field, + ) + return build_targets( + config, + term_registry=term_registry, + derivation_registry=derivation_registry, + ) diff --git a/batdetect2/targets/classes.py b/batdetect2/targets/classes.py index 79b5563..54d0d81 100644 --- a/batdetect2/targets/classes.py +++ b/batdetect2/targets/classes.py @@ -22,9 +22,9 @@ __all__ = [ "load_classes_config", "load_encoder_from_config", "load_decoder_from_config", - "build_encoder_from_config", - "build_decoder_from_config", - "build_generic_class_tags_from_config", + "build_sound_event_encoder", + "build_sound_event_decoder", + "build_generic_class_tags", "get_class_names_from_config", "DEFAULT_SPECIES_LIST", ] @@ -314,7 +314,7 @@ def _encode_with_multiple_classifiers( return None -def build_encoder_from_config( +def build_sound_event_encoder( config: ClassesConfig, term_registry: TermRegistry = term_registry, ) -> SoundEventEncoder: @@ -408,7 +408,7 @@ def _decode_class( return mapping[name] -def build_decoder_from_config( +def build_sound_event_decoder( config: ClassesConfig, term_registry: TermRegistry = term_registry, raise_on_unmapped: bool = False, @@ -463,7 +463,7 @@ def build_decoder_from_config( ) -def build_generic_class_tags_from_config( +def build_generic_class_tags( config: ClassesConfig, term_registry: TermRegistry = term_registry, ) -> List[data.Tag]: @@ -565,7 +565,7 @@ def load_encoder_from_config( provided `term_registry` during the build process. """ config = load_classes_config(path, field=field) - return build_encoder_from_config(config, term_registry=term_registry) + return build_sound_event_encoder(config, term_registry=term_registry) def load_decoder_from_config( @@ -611,7 +611,7 @@ def load_decoder_from_config( provided `term_registry` during the build process. """ config = load_classes_config(path, field=field) - return build_decoder_from_config( + return build_sound_event_decoder( config, term_registry=term_registry, raise_on_unmapped=raise_on_unmapped, diff --git a/batdetect2/targets/filtering.py b/batdetect2/targets/filtering.py index ff9f53a..8869050 100644 --- a/batdetect2/targets/filtering.py +++ b/batdetect2/targets/filtering.py @@ -17,7 +17,7 @@ __all__ = [ "FilterConfig", "FilterRule", "SoundEventFilter", - "build_filter_from_config", + "build_sound_event_filter", "build_filter_from_rule", "load_filter_config", "load_filter_from_config", @@ -241,7 +241,7 @@ class FilterConfig(BaseConfig): rules: List[FilterRule] = Field(default_factory=list) -def build_filter_from_config( +def build_sound_event_filter( config: FilterConfig, term_registry: TermRegistry = term_registry, ) -> SoundEventFilter: @@ -312,4 +312,4 @@ def load_filter_from_config( The final merged filter function ready to be used. """ config = load_filter_config(path=path, field=field) - return build_filter_from_config(config, term_registry=term_registry) + return build_sound_event_filter(config, term_registry=term_registry) diff --git a/batdetect2/targets/types.py b/batdetect2/targets/types.py index 1a884bb..5bb260e 100644 --- a/batdetect2/targets/types.py +++ b/batdetect2/targets/types.py @@ -1,19 +1,20 @@ """Defines the core interface (Protocol) for the target definition pipeline. -This module specifies the standard structure and methods expected from an object -that encapsulates the configured logic for processing sound event annotations -within the `batdetect2.targets` system. +This module specifies the standard structure, attributes, and methods expected +from an object that encapsulates the complete configured logic for processing +sound event annotations within the `batdetect2.targets` system. -The main component defined here is the `TargetEncoder` protocol. This protocol -acts as a contract, ensuring that components responsible for applying -filtering, transformations, encoding annotations to class names, and decoding -class names back to tags can be interacted with in a consistent manner -throughout BatDetect2. It also defines essential metadata attributes expected -from implementations. +The main component defined here is the `TargetProtocol`. This protocol acts as +a contract for the entire target definition process, covering semantic aspects +(filtering, tag transformation, class encoding/decoding) as well as geometric +aspects (mapping regions of interest to target positions and sizes). It ensures +that components responsible for these tasks can be interacted with consistently +throughout BatDetect2. """ from typing import List, Optional, Protocol +import numpy as np from soundevent import data __all__ = [ @@ -26,18 +27,30 @@ class TargetProtocol(Protocol): This protocol outlines the standard attributes and methods for an object that encapsulates the complete, configured process for handling sound event - annotations to determine their target class for model training, and for - interpreting model predictions back into annotation tags. + annotations (both tags and geometry). It defines how to: + - Filter relevant annotations. + - Transform annotation tags. + - Encode an annotation into a specific target class name. + - Decode a class name back into representative tags. + - Extract a target reference position from an annotation's geometry (ROI). + - Calculate target size dimensions from an annotation's geometry. + - Recover an approximate geometry (ROI) from a position and size + dimensions. + + Implementations of this protocol bundle all configured logic for these + steps. Attributes ---------- class_names : List[str] An ordered list of the unique names of the specific target classes - defined by the configuration represented by this object. + defined by the configuration. generic_class_tags : List[data.Tag] - A list of `soundevent.data.Tag` objects representing the - generic class category (e.g., the default 'Bat' class tags used when - no specific class matches). + A list of `soundevent.data.Tag` objects representing the configured + generic class category (e.g., used when no specific class matches). + dimension_names : List[str] + A list containing the names of the size dimensions returned by + `get_size` and expected by `recover_roi` (e.g., ['width', 'height']). """ class_names: List[str] @@ -46,6 +59,9 @@ class TargetProtocol(Protocol): generic_class_tags: List[data.Tag] """List of tags representing the generic (unclassified) category.""" + dimension_names: List[str] + """Names of the size dimensions (e.g., ['width', 'height']).""" + def filter(self, sound_event: data.SoundEventAnnotation) -> bool: """Apply the filter to a sound event annotation. @@ -100,10 +116,10 @@ class TargetProtocol(Protocol): Returns ------- str or None - The string name of the matched target class if the annotation matches - a specific class definition. Returns None if the annotation does not - match any specific class rule (indicating it may belong to a generic - category or should be handled differently downstream). + The string name of the matched target class if the annotation + matches a specific class definition. Returns None if the annotation + does not match any specific class rule (indicating it may belong + to a generic category or should be handled differently downstream). """ ... @@ -130,3 +146,88 @@ class TargetProtocol(Protocol): found in the configured mapping and error raising is enabled. """ ... + + def get_position( + self, sound_event: data.SoundEventAnnotation + ) -> tuple[float, float]: + """Extract the target reference position from the annotation's geometry. + + Calculates the `(time, frequency)` coordinate representing the primary + location of the sound event. + + Parameters + ---------- + sound_event : data.SoundEventAnnotation + The annotation containing the geometry (ROI) to process. + + Returns + ------- + Tuple[float, float] + The calculated reference position `(time, frequency)`. + + Raises + ------ + ValueError + If the annotation lacks geometry or if the position cannot be + calculated for the geometry type or configured reference point. + """ + ... + + def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray: + """Calculate the target size dimensions from the annotation's geometry. + + Computes the relevant physical size (e.g., duration/width, + bandwidth/height from a bounding box) to produce + the numerical target values expected by the model. + + Parameters + ---------- + sound_event : data.SoundEventAnnotation + The annotation containing the geometry (ROI) to process. + + Returns + ------- + np.ndarray + A NumPy array containing the size dimensions, matching the + order specified by the `dimension_names` attribute (e.g., + `[width, height]`). + + Raises + ------ + ValueError + If the annotation lacks geometry or if the size cannot be computed. + TypeError + If geometry type is unsupported. + """ + ... + + def recover_roi( + self, pos: tuple[float, float], dims: np.ndarray + ) -> data.Geometry: + """Recover the ROI geometry from a position and dimensions. + + Performs the inverse mapping of `get_position` and `get_size`. It takes + a reference position `(time, frequency)` and an array of size + dimensions and reconstructs an approximate geometric representation. + + Parameters + ---------- + pos : Tuple[float, float] + The reference position `(time, frequency)`. + dims : np.ndarray + The NumPy array containing the dimensions (e.g., predicted + by the model), corresponding to the order in `dimension_names`. + + Returns + ------- + soundevent.data.Geometry + The reconstructed geometry. + + Raises + ------ + ValueError + If the number of provided `dims` does not match `dimension_names`, + if dimensions are invalid (e.g., negative after unscaling), or + if reconstruction fails based on the configured position type. + """ + ... diff --git a/docs/source/targets/rois.md b/docs/source/targets/rois.md new file mode 100644 index 0000000..bbf406e --- /dev/null +++ b/docs/source/targets/rois.md @@ -0,0 +1,85 @@ +## Defining Target Geometry: Mapping Sound Event Regions + +### Introduction + +In the previous steps of defining targets, we focused on determining _which_ sound events are relevant (`filtering`), _what_ descriptive tags they should have (`transform`), and _which category_ they belong to (`classes`). +However, for the model to learn effectively, it also needs to know **where** in the spectrogram each sound event is located and approximately **how large** it is. + +Your annotations typically define the location and extent of a sound event using a **Region of Interest (ROI)**, most commonly a **bounding box** drawn around the call on the spectrogram. +This ROI contains detailed spatial information (start/end time, low/high frequency). + +This section explains how BatDetect2 converts the geometric ROI from your annotations into the specific positional and size information used as targets during model training. + +### From ROI to Model Targets: Position & Size + +BatDetect2 does not directly predict a full bounding box. +Instead, it is trained to predict: + +1. **A Reference Point:** A single point `(time, frequency)` that represents the primary location of the detected sound event within the spectrogram. +2. **Size Dimensions:** Numerical values representing the event's size relative to that reference point, typically its `width` (duration in time) and `height` (bandwidth in frequency). + +This step defines _how_ BatDetect2 calculates this specific reference point and these numerical size values from the original annotation's bounding box. +It also handles the reverse process – converting predicted positions and sizes back into bounding boxes for visualization or analysis. + +### Configuring the ROI Mapping + +You can control how this conversion happens through settings in your configuration file (e.g., your main `.yaml` file). +These settings are usually placed within the main `targets:` configuration block, under a specific `roi:` key. + +Here are the key settings: + +- **`position`**: + + - **What it does:** Determines which specific point on the annotation's bounding box is used as the single **Reference Point** for training (e.g., `"center"`, `"bottom-left"`). + - **Why configure it?** This affects where the peak signal appears in the target heatmaps used for training. + Different choices might slightly influence model learning. + The default (`"bottom-left"`) is often a good starting point. + - **Example Value:** `position: "center"` + +- **`time_scale`**: + + - **What it does:** This is a numerical scaling factor that converts the _actual duration_ (width, measured in seconds) of the bounding box into the numerical 'width' value the model learns to predict (and which is stored in the Size Heatmap). + - **Why configure it?** The model predicts raw numbers for size; this scale gives those numbers meaning. + For example, setting `time_scale: 1000.0` means the model will be trained to predict the duration in **milliseconds** instead of seconds. + - **Important Considerations:** + - You can often set this value based on the units you prefer the model to work with internally. + However, having target numerical values roughly centered around 1 (e.g., typically between 0.1 and 10) can sometimes improve numerical stability during model training. + - The default value in BatDetect2 (e.g., `1000.0`) has been chosen to scale the duration relative to the spectrogram width under default STFT settings. + Be aware that if you significantly change STFT parameters (window size or overlap), the relationship between the default scale and the spectrogram dimensions might change. + - Crucially, whatever scale you use during training **must** be used when decoding the model's predictions back into real-world time units (seconds). + BatDetect2 generally handles this consistency for you automatically when using the full pipeline. + - **Example Value:** `time_scale: 1000.0` + +- **`frequency_scale`**: + - **What it does:** Similar to `time_scale`, this numerical scaling factor converts the _actual frequency bandwidth_ (height, typically measured in Hz or kHz) of the bounding box into the numerical 'height' value the model learns to predict. + - **Why configure it?** It gives physical meaning to the model's raw numerical prediction for bandwidth and allows you to choose the internal units or scale. + - **Important Considerations:** + - Same as for `time_scale`. + - **Example Value:** `frequency_scale: 0.00116` + +**Example YAML Configuration:** + +```yaml +# Inside your main configuration file (e.g., training_config.yaml) + +targets: # Top-level key for target definition + # ... filtering settings ... + # ... transforms settings ... + # ... classes settings ... + + # --- ROI Mapping Settings --- + roi: + position: "bottom-left" # Reference point (e.g., "center", "bottom-left") + time_scale: 1000.0 # e.g., Model predicts width in ms + frequency_scale: 0.00116 # e.g., Model predicts height relative to ~860Hz (or other model-specific scaling) +``` + +### Decoding Size Predictions + +These scaling factors (`time_scale`, `frequency_scale`) are also essential for interpreting the model's output correctly. +When the model predicts numerical values for width and height, BatDetect2 uses these same scales (in reverse) to convert those numbers back into physically meaningful durations (seconds) and bandwidths (Hz/kHz) when reconstructing bounding boxes from predictions. + +### Outcome + +By configuring the `roi` settings, you ensure that BatDetect2 consistently translates the geometric information from your annotations into the specific reference points and scaled size values required for training the model. +Using consistent scales that are appropriate for your data and potentially beneficial for training stability allows the model to effectively learn not just _what_ sound is present, but also _where_ it is located and _how large_ it is, and enables meaningful interpretation of the model's spatial and size predictions. diff --git a/tests/test_targets/test_classes.py b/tests/test_targets/test_classes.py index 69c30b5..c75c4b2 100644 --- a/tests/test_targets/test_classes.py +++ b/tests/test_targets/test_classes.py @@ -13,9 +13,9 @@ from batdetect2.targets.classes import ( _get_default_class_name, _get_default_classes, _is_target_class, - build_decoder_from_config, - build_encoder_from_config, - build_generic_class_tags_from_config, + build_sound_event_decoder, + build_sound_event_encoder, + build_generic_class_tags, get_class_names_from_config, load_classes_config, load_decoder_from_config, @@ -231,7 +231,7 @@ def test_build_encoder_from_config( ) ] ) - encoder = build_encoder_from_config( + encoder = build_sound_event_encoder( config, term_registry=sample_term_registry, ) @@ -239,7 +239,7 @@ def test_build_encoder_from_config( assert result == "pippip" config = ClassesConfig(classes=[]) - encoder = build_encoder_from_config( + encoder = build_sound_event_encoder( config, term_registry=sample_term_registry, ) @@ -315,7 +315,7 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry): ], generic_class=[TagInfo(key="order", value="Chiroptera")], ) - decoder = build_decoder_from_config( + decoder = build_sound_event_decoder( config, term_registry=sample_term_registry ) tags = decoder("pippip") @@ -335,7 +335,7 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry): ], generic_class=[TagInfo(key="order", value="Chiroptera")], ) - decoder = build_decoder_from_config( + decoder = build_sound_event_decoder( config, term_registry=sample_term_registry ) tags = decoder("pippip") @@ -344,14 +344,14 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry): assert tags[0].value == "Pipistrellus pipistrellus" # Test raise_on_unmapped=True - decoder = build_decoder_from_config( + decoder = build_sound_event_decoder( config, term_registry=sample_term_registry, raise_on_unmapped=True ) with pytest.raises(ValueError): decoder("unknown_class") # Test raise_on_unmapped=False - decoder = build_decoder_from_config( + decoder = build_sound_event_decoder( config, term_registry=sample_term_registry, raise_on_unmapped=False ) tags = decoder("unknown_class") @@ -402,7 +402,7 @@ def test_build_generic_class_tags_from_config( TagInfo(key="call_type", value="Echolocation"), ], ) - generic_tags = build_generic_class_tags_from_config( + generic_tags = build_generic_class_tags( config, term_registry=sample_term_registry ) assert len(generic_tags) == 2 diff --git a/tests/test_targets/test_filtering.py b/tests/test_targets/test_filtering.py index 7bfb397..426266c 100644 --- a/tests/test_targets/test_filtering.py +++ b/tests/test_targets/test_filtering.py @@ -7,7 +7,7 @@ from soundevent import data from batdetect2.targets.filtering import ( FilterConfig, FilterRule, - build_filter_from_config, + build_sound_event_filter, build_filter_from_rule, contains_tags, does_not_have_tags, @@ -121,7 +121,7 @@ def test_build_filter_from_config(create_annotation): FilterRule(match_type="any", tags=[TagInfo(value="tag2")]), ] ) - filter_from_config = build_filter_from_config(config) + filter_from_config = build_sound_event_filter(config) annotation_pass = create_annotation(["tag1", "tag2"]) assert filter_from_config(annotation_pass)