mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Update TargetProtocol and related to include rois
This commit is contained in:
parent
9410112e41
commit
6236e78414
@ -1,25 +1,29 @@
|
|||||||
"""Main entry point for the BatDetect2 Target Definition subsystem.
|
"""Main entry point for the BatDetect2 Target Definition subsystem.
|
||||||
|
|
||||||
This package (`batdetect2.targets`) provides the tools and configurations
|
This package (`batdetect2.targets`) provides the tools and configurations
|
||||||
necessary to define precisely what the BatDetect2 model should learn to detect
|
necessary to define precisely what the BatDetect2 model should learn to detect,
|
||||||
and classify from audio data. It involves several conceptual steps, managed
|
classify, and localize from audio data. It involves several conceptual steps,
|
||||||
through configuration files and culminating in executable functions:
|
managed through configuration files and culminating in an executable pipeline:
|
||||||
|
|
||||||
1. **Terms (`.terms`)**: Defining a controlled vocabulary for annotation tags.
|
1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
|
||||||
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
|
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
|
||||||
3. **Transformation (`.transform`)**: Modifying tags (e.g., standardization,
|
3. **Transformation (`.transform`)**: Modifying tags (standardization,
|
||||||
derivation).
|
derivation).
|
||||||
4. **Class Definition (`.classes`)**: Mapping tags to specific target class
|
4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
|
||||||
names (encoding) and defining how predicted class names map back to tags
|
target position and size representations, and back.
|
||||||
(decoding).
|
5. **Class Definition (`.classes`)**: Mapping tags to target class names
|
||||||
|
(encoding) and mapping predicted names back to tags (decoding).
|
||||||
|
|
||||||
This module exposes the key components for users to configure and utilize this
|
This module exposes the key components for users to configure and utilize this
|
||||||
target definition pipeline, primarily through the `TargetConfig` data structure
|
target definition pipeline, primarily through the `TargetConfig` data structure
|
||||||
and the `Targets` class, which encapsulates the configured processing steps.
|
and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
|
||||||
|
configured processing steps. The main way to create a functional `Targets`
|
||||||
|
object is via the `build_targets` or `load_targets` functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
@ -28,9 +32,9 @@ from batdetect2.targets.classes import (
|
|||||||
SoundEventDecoder,
|
SoundEventDecoder,
|
||||||
SoundEventEncoder,
|
SoundEventEncoder,
|
||||||
TargetClass,
|
TargetClass,
|
||||||
build_decoder_from_config,
|
build_generic_class_tags,
|
||||||
build_encoder_from_config,
|
build_sound_event_decoder,
|
||||||
build_generic_class_tags_from_config,
|
build_sound_event_encoder,
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
load_classes_config,
|
load_classes_config,
|
||||||
load_decoder_from_config,
|
load_decoder_from_config,
|
||||||
@ -40,10 +44,15 @@ from batdetect2.targets.filtering import (
|
|||||||
FilterConfig,
|
FilterConfig,
|
||||||
FilterRule,
|
FilterRule,
|
||||||
SoundEventFilter,
|
SoundEventFilter,
|
||||||
build_filter_from_config,
|
build_sound_event_filter,
|
||||||
load_filter_config,
|
load_filter_config,
|
||||||
load_filter_from_config,
|
load_filter_from_config,
|
||||||
)
|
)
|
||||||
|
from batdetect2.targets.rois import (
|
||||||
|
ROIConfig,
|
||||||
|
ROITargetMapper,
|
||||||
|
build_roi_mapper,
|
||||||
|
)
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import (
|
||||||
TagInfo,
|
TagInfo,
|
||||||
TermInfo,
|
TermInfo,
|
||||||
@ -69,6 +78,7 @@ from batdetect2.targets.transform import (
|
|||||||
load_transformation_from_config,
|
load_transformation_from_config,
|
||||||
register_derivation,
|
register_derivation,
|
||||||
)
|
)
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClassesConfig",
|
"ClassesConfig",
|
||||||
@ -76,6 +86,8 @@ __all__ = [
|
|||||||
"FilterConfig",
|
"FilterConfig",
|
||||||
"FilterRule",
|
"FilterRule",
|
||||||
"MapValueRule",
|
"MapValueRule",
|
||||||
|
"ROIConfig",
|
||||||
|
"ROITargetMapper",
|
||||||
"ReplaceRule",
|
"ReplaceRule",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
"SoundEventEncoder",
|
"SoundEventEncoder",
|
||||||
@ -84,13 +96,15 @@ __all__ = [
|
|||||||
"TagInfo",
|
"TagInfo",
|
||||||
"TargetClass",
|
"TargetClass",
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
|
"TargetProtocol",
|
||||||
"Targets",
|
"Targets",
|
||||||
"TermInfo",
|
"TermInfo",
|
||||||
"TransformConfig",
|
"TransformConfig",
|
||||||
"build_decoder_from_config",
|
"build_sound_event_decoder",
|
||||||
"build_encoder_from_config",
|
"build_sound_event_encoder",
|
||||||
"build_filter_from_config",
|
"build_sound_event_filter",
|
||||||
"build_generic_class_tags_from_config",
|
"build_generic_class_tags",
|
||||||
|
"build_roi_mapper",
|
||||||
"build_transformation_from_config",
|
"build_transformation_from_config",
|
||||||
"call_type",
|
"call_type",
|
||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
@ -114,29 +128,36 @@ __all__ = [
|
|||||||
class TargetConfig(BaseConfig):
|
class TargetConfig(BaseConfig):
|
||||||
"""Unified configuration for the entire target definition pipeline.
|
"""Unified configuration for the entire target definition pipeline.
|
||||||
|
|
||||||
This model aggregates the configurations for the optional filtering and
|
This model aggregates the configurations for semantic processing (filtering,
|
||||||
transformation steps, and the mandatory class definition step. It serves as
|
transformation, class definition) and geometric processing (ROI mapping).
|
||||||
the primary input for building a complete `Targets` processing object.
|
It serves as the primary input for building a complete `Targets` object
|
||||||
|
via `build_targets` or `load_targets`.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
filtering : FilterConfig, optional
|
filtering : FilterConfig, optional
|
||||||
Configuration for filtering sound event annotations. If None or
|
Configuration for filtering sound event annotations based on tags.
|
||||||
omitted, no filtering is applied.
|
If None or omitted, no filtering is applied.
|
||||||
transforms : TransformConfig, optional
|
transforms : TransformConfig, optional
|
||||||
Configuration for transforming annotation tags. If None or omitted, no
|
Configuration for transforming annotation tags
|
||||||
transformations are applied.
|
(mapping, derivation, etc.). If None or omitted, no tag transformations
|
||||||
|
are applied.
|
||||||
classes : ClassesConfig
|
classes : ClassesConfig
|
||||||
Configuration defining the specific target classes, their matching
|
Configuration defining the specific target classes, their tag matching
|
||||||
rules, decoding rules (`output_tags`), and the generic class
|
rules for encoding, their representative tags for decoding
|
||||||
definition. This section is mandatory.
|
(`output_tags`), and the definition of the generic class tags.
|
||||||
|
This section is mandatory.
|
||||||
|
roi : ROIConfig, optional
|
||||||
|
Configuration defining how geometric ROIs (e.g., bounding boxes) are
|
||||||
|
mapped to target representations (reference point, scaled size).
|
||||||
|
Controls `position`, `time_scale`, `frequency_scale`. If None or
|
||||||
|
omitted, default ROI mapping settings are used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
filtering: Optional[FilterConfig] = None
|
filtering: Optional[FilterConfig] = None
|
||||||
|
|
||||||
transforms: Optional[TransformConfig] = None
|
transforms: Optional[TransformConfig] = None
|
||||||
|
|
||||||
classes: ClassesConfig
|
classes: ClassesConfig
|
||||||
|
roi: Optional[ROIConfig] = None
|
||||||
|
|
||||||
|
|
||||||
def load_target_config(
|
def load_target_config(
|
||||||
@ -177,34 +198,40 @@ def load_target_config(
|
|||||||
return load_config(path=path, schema=TargetConfig, field=field)
|
return load_config(path=path, schema=TargetConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
class Targets:
|
class Targets(TargetProtocol):
|
||||||
"""Encapsulates the complete configured target definition pipeline.
|
"""Encapsulates the complete configured target definition pipeline.
|
||||||
|
|
||||||
This class holds the functions for filtering, transforming, encoding, and
|
This class implements the `TargetProtocol`, holding the configured
|
||||||
decoding annotations based on a loaded `TargetConfig`. It provides a
|
functions for filtering, transforming, encoding (tags to class name),
|
||||||
high-level interface to apply these steps and access relevant metadata
|
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
||||||
like class names and generic class tags.
|
and back). It provides a high-level interface to apply these steps and
|
||||||
|
access relevant metadata like class names and dimension names.
|
||||||
|
|
||||||
Instances are typically created using the `Targets.from_config` or
|
Instances are typically created using the `build_targets` factory function
|
||||||
`Targets.from_file` classmethods.
|
or the `load_targets` convenience loader.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
class_names : list[str]
|
class_names : List[str]
|
||||||
An ordered list of the unique names of the specific target classes
|
An ordered list of the unique names of the specific target classes
|
||||||
defined in the configuration.
|
defined in the configuration.
|
||||||
generic_class_tags : List[data.Tag]
|
generic_class_tags : List[data.Tag]
|
||||||
A list of `soundevent.data.Tag` objects representing the configured
|
A list of `soundevent.data.Tag` objects representing the configured
|
||||||
generic class (e.g., the default 'Bat' class).
|
generic class category (used when no specific class matches).
|
||||||
|
dimension_names : List[str]
|
||||||
|
The names of the size dimensions handled by the ROI mapper
|
||||||
|
(e.g., ['width', 'height']).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class_names: list[str]
|
class_names: List[str]
|
||||||
generic_class_tags: List[data.Tag]
|
generic_class_tags: List[data.Tag]
|
||||||
|
dimension_names: List[str]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encode_fn: SoundEventEncoder,
|
encode_fn: SoundEventEncoder,
|
||||||
decode_fn: SoundEventDecoder,
|
decode_fn: SoundEventDecoder,
|
||||||
|
roi_mapper: ROITargetMapper,
|
||||||
class_names: list[str],
|
class_names: list[str],
|
||||||
generic_class_tags: List[data.Tag],
|
generic_class_tags: List[data.Tag],
|
||||||
filter_fn: Optional[SoundEventFilter] = None,
|
filter_fn: Optional[SoundEventFilter] = None,
|
||||||
@ -212,26 +239,31 @@ class Targets:
|
|||||||
):
|
):
|
||||||
"""Initialize the Targets object.
|
"""Initialize the Targets object.
|
||||||
|
|
||||||
|
Note: This constructor is typically called internally by the
|
||||||
|
`build_targets` factory function.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
encode_fn : SoundEventEncoder
|
encode_fn : SoundEventEncoder
|
||||||
The configured function to encode annotations to class names.
|
Configured function to encode annotations to class names.
|
||||||
decode_fn : SoundEventDecoder
|
decode_fn : SoundEventDecoder
|
||||||
The configured function to decode class names to tags.
|
Configured function to decode class names to tags.
|
||||||
|
roi_mapper : ROITargetMapper
|
||||||
|
Configured object for mapping geometry to/from position/size.
|
||||||
class_names : list[str]
|
class_names : list[str]
|
||||||
The ordered list of specific target class names.
|
Ordered list of specific target class names.
|
||||||
generic_class_tags : List[data.Tag]
|
generic_class_tags : List[data.Tag]
|
||||||
The list of tags representing the generic class.
|
List of tags representing the generic class.
|
||||||
filter_fn : SoundEventFilter, optional
|
filter_fn : SoundEventFilter, optional
|
||||||
The configured function to filter annotations. Defaults to None (no
|
Configured function to filter annotations. Defaults to None.
|
||||||
filtering).
|
|
||||||
transform_fn : SoundEventTransformation, optional
|
transform_fn : SoundEventTransformation, optional
|
||||||
The configured function to transform annotation tags. Defaults to
|
Configured function to transform annotation tags. Defaults to None.
|
||||||
None (no transformation).
|
|
||||||
"""
|
"""
|
||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
self.generic_class_tags = generic_class_tags
|
self.generic_class_tags = generic_class_tags
|
||||||
|
self.dimension_names = roi_mapper.dimension_names
|
||||||
|
|
||||||
|
self._roi_mapper = roi_mapper
|
||||||
self._filter_fn = filter_fn
|
self._filter_fn = filter_fn
|
||||||
self._encode_fn = encode_fn
|
self._encode_fn = encode_fn
|
||||||
self._decode_fn = decode_fn
|
self._decode_fn = decode_fn
|
||||||
@ -316,133 +348,223 @@ class Targets:
|
|||||||
return self._transform_fn(sound_event)
|
return self._transform_fn(sound_event)
|
||||||
return sound_event
|
return sound_event
|
||||||
|
|
||||||
@classmethod
|
def get_position(
|
||||||
def from_config(
|
self, sound_event: data.SoundEventAnnotation
|
||||||
cls,
|
) -> tuple[float, float]:
|
||||||
config: TargetConfig,
|
"""Extract the target reference position from the annotation's roi.
|
||||||
term_registry: TermRegistry = term_registry,
|
|
||||||
derivation_registry: DerivationRegistry = derivation_registry,
|
|
||||||
) -> "Targets":
|
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
|
||||||
|
|
||||||
This factory method takes the unified configuration object and
|
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||||
constructs all the necessary functional components (filter, transform,
|
|
||||||
encoder, decoder) and extracts metadata (class names, generic tags) to
|
|
||||||
create a fully configured `Targets` instance.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
config : TargetConfig
|
sound_event : data.SoundEventAnnotation
|
||||||
The loaded and validated unified target configuration object.
|
The annotation containing the geometry (ROI).
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to use for resolving term keys. Defaults
|
|
||||||
to the global `batdetect2.targets.terms.term_registry`.
|
|
||||||
derivation_registry : DerivationRegistry, optional
|
|
||||||
The DerivationRegistry instance to use for resolving derivation
|
|
||||||
function names. Defaults to the global
|
|
||||||
`batdetect2.targets.transform.derivation_registry`.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Targets
|
Tuple[float, float]
|
||||||
An initialized `Targets` object ready for use.
|
The reference position `(time, frequency)`.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
KeyError
|
ValueError
|
||||||
If term keys or derivation function keys specified in the `config`
|
If the annotation lacks geometry.
|
||||||
are not found in their respective registries.
|
|
||||||
ImportError, AttributeError, TypeError
|
|
||||||
If dynamic import of a derivation function fails (when configured).
|
|
||||||
"""
|
"""
|
||||||
filter_fn = (
|
geom = sound_event.sound_event.geometry
|
||||||
build_filter_from_config(
|
|
||||||
config.filtering,
|
if geom is None:
|
||||||
term_registry=term_registry,
|
raise ValueError(
|
||||||
|
"Sound event has no geometry, cannot get its position."
|
||||||
)
|
)
|
||||||
if config.filtering
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
encode_fn = build_encoder_from_config(
|
|
||||||
config.classes,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
decode_fn = build_decoder_from_config(
|
|
||||||
config.classes,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
transform_fn = (
|
|
||||||
build_transformation_from_config(
|
|
||||||
config.transforms,
|
|
||||||
term_registry=term_registry,
|
|
||||||
derivation_registry=derivation_registry,
|
|
||||||
)
|
|
||||||
if config.transforms
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
class_names = get_class_names_from_config(config.classes)
|
|
||||||
generic_class_tags = build_generic_class_tags_from_config(
|
|
||||||
config.classes,
|
|
||||||
term_registry=term_registry,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
return self._roi_mapper.get_roi_position(geom)
|
||||||
filter_fn=filter_fn,
|
|
||||||
encode_fn=encode_fn,
|
|
||||||
decode_fn=decode_fn,
|
|
||||||
class_names=class_names,
|
|
||||||
generic_class_tags=generic_class_tags,
|
|
||||||
transform_fn=transform_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||||
def from_file(
|
"""Calculate the target size dimensions from the annotation's geometry.
|
||||||
cls,
|
|
||||||
config_path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
term_registry: TermRegistry = term_registry,
|
|
||||||
derivation_registry: DerivationRegistry = derivation_registry,
|
|
||||||
) -> "Targets":
|
|
||||||
"""Load a Targets object directly from a configuration file.
|
|
||||||
|
|
||||||
This convenience factory method loads the `TargetConfig` from the
|
Delegates to the internal ROI mapper's `get_roi_size` method, which
|
||||||
specified file path and then calls `Targets.from_config` to build
|
applies configured scaling factors.
|
||||||
the fully initialized `Targets` object.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
config_path : data.PathLike
|
sound_event : data.SoundEventAnnotation
|
||||||
Path to the configuration file (e.g., YAML).
|
The annotation containing the geometry (ROI).
|
||||||
field : str, optional
|
|
||||||
Dot-separated path to a nested section within the file containing
|
|
||||||
the target configuration. If None, the entire file content is used.
|
|
||||||
term_registry : TermRegistry, optional
|
|
||||||
The TermRegistry instance to use. Defaults to the global default.
|
|
||||||
derivation_registry : DerivationRegistry, optional
|
|
||||||
The DerivationRegistry instance to use. Defaults to the global
|
|
||||||
default.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Targets
|
np.ndarray
|
||||||
An initialized `Targets` object ready for use.
|
NumPy array containing the size dimensions, matching the
|
||||||
|
order in `self.dimension_names` (e.g., `[width, height]`).
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
ValueError
|
||||||
TypeError
|
If the annotation lacks geometry.
|
||||||
Errors raised during file loading, validation, or extraction via
|
|
||||||
`load_target_config`.
|
|
||||||
KeyError, ImportError, AttributeError, TypeError
|
|
||||||
Errors raised during the build process by `Targets.from_config`
|
|
||||||
(e.g., missing keys in registries, failed imports).
|
|
||||||
"""
|
"""
|
||||||
config = load_target_config(
|
geom = sound_event.sound_event.geometry
|
||||||
config_path,
|
|
||||||
field=field,
|
if geom is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Sound event has no geometry, cannot get its size."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._roi_mapper.get_roi_size(geom)
|
||||||
|
|
||||||
|
def recover_roi(
|
||||||
|
self,
|
||||||
|
pos: tuple[float, float],
|
||||||
|
dims: np.ndarray,
|
||||||
|
) -> data.Geometry:
|
||||||
|
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||||
|
|
||||||
|
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||||
|
un-scales the dimensions and reconstructs the geometry (typically a
|
||||||
|
`BoundingBox`).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pos : Tuple[float, float]
|
||||||
|
The reference position `(time, frequency)`.
|
||||||
|
dims : np.ndarray
|
||||||
|
NumPy array with size dimensions (e.g., from model prediction),
|
||||||
|
matching the order in `self.dimension_names`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.Geometry
|
||||||
|
The reconstructed geometry (typically `BoundingBox`).
|
||||||
|
"""
|
||||||
|
return self._roi_mapper.recover_roi(pos, dims)
|
||||||
|
|
||||||
|
|
||||||
|
def build_targets(
|
||||||
|
config: TargetConfig,
|
||||||
|
term_registry: TermRegistry = term_registry,
|
||||||
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
|
) -> Targets:
|
||||||
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
|
|
||||||
|
This factory function takes the unified `TargetConfig` and constructs all
|
||||||
|
necessary functional components (filter, transform, encoder,
|
||||||
|
decoder, ROI mapper) by calling their respective builder functions. It also
|
||||||
|
extracts metadata (class names, generic tags, dimension names) to create
|
||||||
|
and return a fully initialized `Targets` instance, ready to process
|
||||||
|
annotations.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : TargetConfig
|
||||||
|
The loaded and validated unified target configuration object.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance to use for resolving term keys. Defaults
|
||||||
|
to the global `batdetect2.targets.terms.term_registry`.
|
||||||
|
derivation_registry : DerivationRegistry, optional
|
||||||
|
The DerivationRegistry instance to use for resolving derivation
|
||||||
|
function names. Defaults to the global
|
||||||
|
`batdetect2.targets.transform.derivation_registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Targets
|
||||||
|
An initialized `Targets` object ready for use.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If term keys or derivation function keys specified in the `config`
|
||||||
|
are not found in their respective registries.
|
||||||
|
ImportError, AttributeError, TypeError
|
||||||
|
If dynamic import of a derivation function fails (when configured).
|
||||||
|
"""
|
||||||
|
filter_fn = (
|
||||||
|
build_sound_event_filter(
|
||||||
|
config.filtering,
|
||||||
|
term_registry=term_registry,
|
||||||
)
|
)
|
||||||
return cls.from_config(
|
if config.filtering
|
||||||
config,
|
else None
|
||||||
|
)
|
||||||
|
encode_fn = build_sound_event_encoder(
|
||||||
|
config.classes,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
decode_fn = build_sound_event_decoder(
|
||||||
|
config.classes,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
transform_fn = (
|
||||||
|
build_transformation_from_config(
|
||||||
|
config.transforms,
|
||||||
term_registry=term_registry,
|
term_registry=term_registry,
|
||||||
derivation_registry=derivation_registry,
|
derivation_registry=derivation_registry,
|
||||||
)
|
)
|
||||||
|
if config.transforms
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
|
||||||
|
class_names = get_class_names_from_config(config.classes)
|
||||||
|
generic_class_tags = build_generic_class_tags(
|
||||||
|
config.classes,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Targets(
|
||||||
|
filter_fn=filter_fn,
|
||||||
|
encode_fn=encode_fn,
|
||||||
|
decode_fn=decode_fn,
|
||||||
|
class_names=class_names,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
|
generic_class_tags=generic_class_tags,
|
||||||
|
transform_fn=transform_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_targets(
|
||||||
|
config_path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
term_registry: TermRegistry = term_registry,
|
||||||
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
|
) -> Targets:
|
||||||
|
"""Load a Targets object directly from a configuration file.
|
||||||
|
|
||||||
|
This convenience factory method loads the `TargetConfig` from the
|
||||||
|
specified file path and then calls `Targets.from_config` to build
|
||||||
|
the fully initialized `Targets` object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config_path : data.PathLike
|
||||||
|
Path to the configuration file (e.g., YAML).
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing
|
||||||
|
the target configuration. If None, the entire file content is used.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance to use. Defaults to the global default.
|
||||||
|
derivation_registry : DerivationRegistry, optional
|
||||||
|
The DerivationRegistry instance to use. Defaults to the global
|
||||||
|
default.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Targets
|
||||||
|
An initialized `Targets` object ready for use.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||||
|
TypeError
|
||||||
|
Errors raised during file loading, validation, or extraction via
|
||||||
|
`load_target_config`.
|
||||||
|
KeyError, ImportError, AttributeError, TypeError
|
||||||
|
Errors raised during the build process by `Targets.from_config`
|
||||||
|
(e.g., missing keys in registries, failed imports).
|
||||||
|
"""
|
||||||
|
config = load_target_config(
|
||||||
|
config_path,
|
||||||
|
field=field,
|
||||||
|
)
|
||||||
|
return build_targets(
|
||||||
|
config,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
@ -22,9 +22,9 @@ __all__ = [
|
|||||||
"load_classes_config",
|
"load_classes_config",
|
||||||
"load_encoder_from_config",
|
"load_encoder_from_config",
|
||||||
"load_decoder_from_config",
|
"load_decoder_from_config",
|
||||||
"build_encoder_from_config",
|
"build_sound_event_encoder",
|
||||||
"build_decoder_from_config",
|
"build_sound_event_decoder",
|
||||||
"build_generic_class_tags_from_config",
|
"build_generic_class_tags",
|
||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
"DEFAULT_SPECIES_LIST",
|
"DEFAULT_SPECIES_LIST",
|
||||||
]
|
]
|
||||||
@ -314,7 +314,7 @@ def _encode_with_multiple_classifiers(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def build_encoder_from_config(
|
def build_sound_event_encoder(
|
||||||
config: ClassesConfig,
|
config: ClassesConfig,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> SoundEventEncoder:
|
) -> SoundEventEncoder:
|
||||||
@ -408,7 +408,7 @@ def _decode_class(
|
|||||||
return mapping[name]
|
return mapping[name]
|
||||||
|
|
||||||
|
|
||||||
def build_decoder_from_config(
|
def build_sound_event_decoder(
|
||||||
config: ClassesConfig,
|
config: ClassesConfig,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
raise_on_unmapped: bool = False,
|
raise_on_unmapped: bool = False,
|
||||||
@ -463,7 +463,7 @@ def build_decoder_from_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_generic_class_tags_from_config(
|
def build_generic_class_tags(
|
||||||
config: ClassesConfig,
|
config: ClassesConfig,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> List[data.Tag]:
|
) -> List[data.Tag]:
|
||||||
@ -565,7 +565,7 @@ def load_encoder_from_config(
|
|||||||
provided `term_registry` during the build process.
|
provided `term_registry` during the build process.
|
||||||
"""
|
"""
|
||||||
config = load_classes_config(path, field=field)
|
config = load_classes_config(path, field=field)
|
||||||
return build_encoder_from_config(config, term_registry=term_registry)
|
return build_sound_event_encoder(config, term_registry=term_registry)
|
||||||
|
|
||||||
|
|
||||||
def load_decoder_from_config(
|
def load_decoder_from_config(
|
||||||
@ -611,7 +611,7 @@ def load_decoder_from_config(
|
|||||||
provided `term_registry` during the build process.
|
provided `term_registry` during the build process.
|
||||||
"""
|
"""
|
||||||
config = load_classes_config(path, field=field)
|
config = load_classes_config(path, field=field)
|
||||||
return build_decoder_from_config(
|
return build_sound_event_decoder(
|
||||||
config,
|
config,
|
||||||
term_registry=term_registry,
|
term_registry=term_registry,
|
||||||
raise_on_unmapped=raise_on_unmapped,
|
raise_on_unmapped=raise_on_unmapped,
|
||||||
|
@ -17,7 +17,7 @@ __all__ = [
|
|||||||
"FilterConfig",
|
"FilterConfig",
|
||||||
"FilterRule",
|
"FilterRule",
|
||||||
"SoundEventFilter",
|
"SoundEventFilter",
|
||||||
"build_filter_from_config",
|
"build_sound_event_filter",
|
||||||
"build_filter_from_rule",
|
"build_filter_from_rule",
|
||||||
"load_filter_config",
|
"load_filter_config",
|
||||||
"load_filter_from_config",
|
"load_filter_from_config",
|
||||||
@ -241,7 +241,7 @@ class FilterConfig(BaseConfig):
|
|||||||
rules: List[FilterRule] = Field(default_factory=list)
|
rules: List[FilterRule] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def build_filter_from_config(
|
def build_sound_event_filter(
|
||||||
config: FilterConfig,
|
config: FilterConfig,
|
||||||
term_registry: TermRegistry = term_registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> SoundEventFilter:
|
) -> SoundEventFilter:
|
||||||
@ -312,4 +312,4 @@ def load_filter_from_config(
|
|||||||
The final merged filter function ready to be used.
|
The final merged filter function ready to be used.
|
||||||
"""
|
"""
|
||||||
config = load_filter_config(path=path, field=field)
|
config = load_filter_config(path=path, field=field)
|
||||||
return build_filter_from_config(config, term_registry=term_registry)
|
return build_sound_event_filter(config, term_registry=term_registry)
|
||||||
|
@ -1,19 +1,20 @@
|
|||||||
"""Defines the core interface (Protocol) for the target definition pipeline.
|
"""Defines the core interface (Protocol) for the target definition pipeline.
|
||||||
|
|
||||||
This module specifies the standard structure and methods expected from an object
|
This module specifies the standard structure, attributes, and methods expected
|
||||||
that encapsulates the configured logic for processing sound event annotations
|
from an object that encapsulates the complete configured logic for processing
|
||||||
within the `batdetect2.targets` system.
|
sound event annotations within the `batdetect2.targets` system.
|
||||||
|
|
||||||
The main component defined here is the `TargetEncoder` protocol. This protocol
|
The main component defined here is the `TargetProtocol`. This protocol acts as
|
||||||
acts as a contract, ensuring that components responsible for applying
|
a contract for the entire target definition process, covering semantic aspects
|
||||||
filtering, transformations, encoding annotations to class names, and decoding
|
(filtering, tag transformation, class encoding/decoding) as well as geometric
|
||||||
class names back to tags can be interacted with in a consistent manner
|
aspects (mapping regions of interest to target positions and sizes). It ensures
|
||||||
throughout BatDetect2. It also defines essential metadata attributes expected
|
that components responsible for these tasks can be interacted with consistently
|
||||||
from implementations.
|
throughout BatDetect2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Protocol
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -26,18 +27,30 @@ class TargetProtocol(Protocol):
|
|||||||
|
|
||||||
This protocol outlines the standard attributes and methods for an object
|
This protocol outlines the standard attributes and methods for an object
|
||||||
that encapsulates the complete, configured process for handling sound event
|
that encapsulates the complete, configured process for handling sound event
|
||||||
annotations to determine their target class for model training, and for
|
annotations (both tags and geometry). It defines how to:
|
||||||
interpreting model predictions back into annotation tags.
|
- Filter relevant annotations.
|
||||||
|
- Transform annotation tags.
|
||||||
|
- Encode an annotation into a specific target class name.
|
||||||
|
- Decode a class name back into representative tags.
|
||||||
|
- Extract a target reference position from an annotation's geometry (ROI).
|
||||||
|
- Calculate target size dimensions from an annotation's geometry.
|
||||||
|
- Recover an approximate geometry (ROI) from a position and size
|
||||||
|
dimensions.
|
||||||
|
|
||||||
|
Implementations of this protocol bundle all configured logic for these
|
||||||
|
steps.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
class_names : List[str]
|
class_names : List[str]
|
||||||
An ordered list of the unique names of the specific target classes
|
An ordered list of the unique names of the specific target classes
|
||||||
defined by the configuration represented by this object.
|
defined by the configuration.
|
||||||
generic_class_tags : List[data.Tag]
|
generic_class_tags : List[data.Tag]
|
||||||
A list of `soundevent.data.Tag` objects representing the
|
A list of `soundevent.data.Tag` objects representing the configured
|
||||||
generic class category (e.g., the default 'Bat' class tags used when
|
generic class category (e.g., used when no specific class matches).
|
||||||
no specific class matches).
|
dimension_names : List[str]
|
||||||
|
A list containing the names of the size dimensions returned by
|
||||||
|
`get_size` and expected by `recover_roi` (e.g., ['width', 'height']).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class_names: List[str]
|
class_names: List[str]
|
||||||
@ -46,6 +59,9 @@ class TargetProtocol(Protocol):
|
|||||||
generic_class_tags: List[data.Tag]
|
generic_class_tags: List[data.Tag]
|
||||||
"""List of tags representing the generic (unclassified) category."""
|
"""List of tags representing the generic (unclassified) category."""
|
||||||
|
|
||||||
|
dimension_names: List[str]
|
||||||
|
"""Names of the size dimensions (e.g., ['width', 'height'])."""
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||||
"""Apply the filter to a sound event annotation.
|
"""Apply the filter to a sound event annotation.
|
||||||
|
|
||||||
@ -100,10 +116,10 @@ class TargetProtocol(Protocol):
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
str or None
|
str or None
|
||||||
The string name of the matched target class if the annotation matches
|
The string name of the matched target class if the annotation
|
||||||
a specific class definition. Returns None if the annotation does not
|
matches a specific class definition. Returns None if the annotation
|
||||||
match any specific class rule (indicating it may belong to a generic
|
does not match any specific class rule (indicating it may belong
|
||||||
category or should be handled differently downstream).
|
to a generic category or should be handled differently downstream).
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -130,3 +146,88 @@ class TargetProtocol(Protocol):
|
|||||||
found in the configured mapping and error raising is enabled.
|
found in the configured mapping and error raising is enabled.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def get_position(
|
||||||
|
self, sound_event: data.SoundEventAnnotation
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
"""Extract the target reference position from the annotation's geometry.
|
||||||
|
|
||||||
|
Calculates the `(time, frequency)` coordinate representing the primary
|
||||||
|
location of the sound event.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEventAnnotation
|
||||||
|
The annotation containing the geometry (ROI) to process.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[float, float]
|
||||||
|
The calculated reference position `(time, frequency)`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the annotation lacks geometry or if the position cannot be
|
||||||
|
calculated for the geometry type or configured reference point.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||||
|
"""Calculate the target size dimensions from the annotation's geometry.
|
||||||
|
|
||||||
|
Computes the relevant physical size (e.g., duration/width,
|
||||||
|
bandwidth/height from a bounding box) to produce
|
||||||
|
the numerical target values expected by the model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEventAnnotation
|
||||||
|
The annotation containing the geometry (ROI) to process.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
A NumPy array containing the size dimensions, matching the
|
||||||
|
order specified by the `dimension_names` attribute (e.g.,
|
||||||
|
`[width, height]`).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the annotation lacks geometry or if the size cannot be computed.
|
||||||
|
TypeError
|
||||||
|
If geometry type is unsupported.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def recover_roi(
|
||||||
|
self, pos: tuple[float, float], dims: np.ndarray
|
||||||
|
) -> data.Geometry:
|
||||||
|
"""Recover the ROI geometry from a position and dimensions.
|
||||||
|
|
||||||
|
Performs the inverse mapping of `get_position` and `get_size`. It takes
|
||||||
|
a reference position `(time, frequency)` and an array of size
|
||||||
|
dimensions and reconstructs an approximate geometric representation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pos : Tuple[float, float]
|
||||||
|
The reference position `(time, frequency)`.
|
||||||
|
dims : np.ndarray
|
||||||
|
The NumPy array containing the dimensions (e.g., predicted
|
||||||
|
by the model), corresponding to the order in `dimension_names`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
soundevent.data.Geometry
|
||||||
|
The reconstructed geometry.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the number of provided `dims` does not match `dimension_names`,
|
||||||
|
if dimensions are invalid (e.g., negative after unscaling), or
|
||||||
|
if reconstruction fails based on the configured position type.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
85
docs/source/targets/rois.md
Normal file
85
docs/source/targets/rois.md
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
## Defining Target Geometry: Mapping Sound Event Regions
|
||||||
|
|
||||||
|
### Introduction
|
||||||
|
|
||||||
|
In the previous steps of defining targets, we focused on determining _which_ sound events are relevant (`filtering`), _what_ descriptive tags they should have (`transform`), and _which category_ they belong to (`classes`).
|
||||||
|
However, for the model to learn effectively, it also needs to know **where** in the spectrogram each sound event is located and approximately **how large** it is.
|
||||||
|
|
||||||
|
Your annotations typically define the location and extent of a sound event using a **Region of Interest (ROI)**, most commonly a **bounding box** drawn around the call on the spectrogram.
|
||||||
|
This ROI contains detailed spatial information (start/end time, low/high frequency).
|
||||||
|
|
||||||
|
This section explains how BatDetect2 converts the geometric ROI from your annotations into the specific positional and size information used as targets during model training.
|
||||||
|
|
||||||
|
### From ROI to Model Targets: Position & Size
|
||||||
|
|
||||||
|
BatDetect2 does not directly predict a full bounding box.
|
||||||
|
Instead, it is trained to predict:
|
||||||
|
|
||||||
|
1. **A Reference Point:** A single point `(time, frequency)` that represents the primary location of the detected sound event within the spectrogram.
|
||||||
|
2. **Size Dimensions:** Numerical values representing the event's size relative to that reference point, typically its `width` (duration in time) and `height` (bandwidth in frequency).
|
||||||
|
|
||||||
|
This step defines _how_ BatDetect2 calculates this specific reference point and these numerical size values from the original annotation's bounding box.
|
||||||
|
It also handles the reverse process – converting predicted positions and sizes back into bounding boxes for visualization or analysis.
|
||||||
|
|
||||||
|
### Configuring the ROI Mapping
|
||||||
|
|
||||||
|
You can control how this conversion happens through settings in your configuration file (e.g., your main `.yaml` file).
|
||||||
|
These settings are usually placed within the main `targets:` configuration block, under a specific `roi:` key.
|
||||||
|
|
||||||
|
Here are the key settings:
|
||||||
|
|
||||||
|
- **`position`**:
|
||||||
|
|
||||||
|
- **What it does:** Determines which specific point on the annotation's bounding box is used as the single **Reference Point** for training (e.g., `"center"`, `"bottom-left"`).
|
||||||
|
- **Why configure it?** This affects where the peak signal appears in the target heatmaps used for training.
|
||||||
|
Different choices might slightly influence model learning.
|
||||||
|
The default (`"bottom-left"`) is often a good starting point.
|
||||||
|
- **Example Value:** `position: "center"`
|
||||||
|
|
||||||
|
- **`time_scale`**:
|
||||||
|
|
||||||
|
- **What it does:** This is a numerical scaling factor that converts the _actual duration_ (width, measured in seconds) of the bounding box into the numerical 'width' value the model learns to predict (and which is stored in the Size Heatmap).
|
||||||
|
- **Why configure it?** The model predicts raw numbers for size; this scale gives those numbers meaning.
|
||||||
|
For example, setting `time_scale: 1000.0` means the model will be trained to predict the duration in **milliseconds** instead of seconds.
|
||||||
|
- **Important Considerations:**
|
||||||
|
- You can often set this value based on the units you prefer the model to work with internally.
|
||||||
|
However, having target numerical values roughly centered around 1 (e.g., typically between 0.1 and 10) can sometimes improve numerical stability during model training.
|
||||||
|
- The default value in BatDetect2 (e.g., `1000.0`) has been chosen to scale the duration relative to the spectrogram width under default STFT settings.
|
||||||
|
Be aware that if you significantly change STFT parameters (window size or overlap), the relationship between the default scale and the spectrogram dimensions might change.
|
||||||
|
- Crucially, whatever scale you use during training **must** be used when decoding the model's predictions back into real-world time units (seconds).
|
||||||
|
BatDetect2 generally handles this consistency for you automatically when using the full pipeline.
|
||||||
|
- **Example Value:** `time_scale: 1000.0`
|
||||||
|
|
||||||
|
- **`frequency_scale`**:
|
||||||
|
- **What it does:** Similar to `time_scale`, this numerical scaling factor converts the _actual frequency bandwidth_ (height, typically measured in Hz or kHz) of the bounding box into the numerical 'height' value the model learns to predict.
|
||||||
|
- **Why configure it?** It gives physical meaning to the model's raw numerical prediction for bandwidth and allows you to choose the internal units or scale.
|
||||||
|
- **Important Considerations:**
|
||||||
|
- Same as for `time_scale`.
|
||||||
|
- **Example Value:** `frequency_scale: 0.00116`
|
||||||
|
|
||||||
|
**Example YAML Configuration:**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Inside your main configuration file (e.g., training_config.yaml)
|
||||||
|
|
||||||
|
targets: # Top-level key for target definition
|
||||||
|
# ... filtering settings ...
|
||||||
|
# ... transforms settings ...
|
||||||
|
# ... classes settings ...
|
||||||
|
|
||||||
|
# --- ROI Mapping Settings ---
|
||||||
|
roi:
|
||||||
|
position: "bottom-left" # Reference point (e.g., "center", "bottom-left")
|
||||||
|
time_scale: 1000.0 # e.g., Model predicts width in ms
|
||||||
|
frequency_scale: 0.00116 # e.g., Model predicts height relative to ~860Hz (or other model-specific scaling)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Decoding Size Predictions
|
||||||
|
|
||||||
|
These scaling factors (`time_scale`, `frequency_scale`) are also essential for interpreting the model's output correctly.
|
||||||
|
When the model predicts numerical values for width and height, BatDetect2 uses these same scales (in reverse) to convert those numbers back into physically meaningful durations (seconds) and bandwidths (Hz/kHz) when reconstructing bounding boxes from predictions.
|
||||||
|
|
||||||
|
### Outcome
|
||||||
|
|
||||||
|
By configuring the `roi` settings, you ensure that BatDetect2 consistently translates the geometric information from your annotations into the specific reference points and scaled size values required for training the model.
|
||||||
|
Using consistent scales that are appropriate for your data and potentially beneficial for training stability allows the model to effectively learn not just _what_ sound is present, but also _where_ it is located and _how large_ it is, and enables meaningful interpretation of the model's spatial and size predictions.
|
@ -13,9 +13,9 @@ from batdetect2.targets.classes import (
|
|||||||
_get_default_class_name,
|
_get_default_class_name,
|
||||||
_get_default_classes,
|
_get_default_classes,
|
||||||
_is_target_class,
|
_is_target_class,
|
||||||
build_decoder_from_config,
|
build_sound_event_decoder,
|
||||||
build_encoder_from_config,
|
build_sound_event_encoder,
|
||||||
build_generic_class_tags_from_config,
|
build_generic_class_tags,
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
load_classes_config,
|
load_classes_config,
|
||||||
load_decoder_from_config,
|
load_decoder_from_config,
|
||||||
@ -231,7 +231,7 @@ def test_build_encoder_from_config(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
encoder = build_encoder_from_config(
|
encoder = build_sound_event_encoder(
|
||||||
config,
|
config,
|
||||||
term_registry=sample_term_registry,
|
term_registry=sample_term_registry,
|
||||||
)
|
)
|
||||||
@ -239,7 +239,7 @@ def test_build_encoder_from_config(
|
|||||||
assert result == "pippip"
|
assert result == "pippip"
|
||||||
|
|
||||||
config = ClassesConfig(classes=[])
|
config = ClassesConfig(classes=[])
|
||||||
encoder = build_encoder_from_config(
|
encoder = build_sound_event_encoder(
|
||||||
config,
|
config,
|
||||||
term_registry=sample_term_registry,
|
term_registry=sample_term_registry,
|
||||||
)
|
)
|
||||||
@ -315,7 +315,7 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
|||||||
],
|
],
|
||||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||||
)
|
)
|
||||||
decoder = build_decoder_from_config(
|
decoder = build_sound_event_decoder(
|
||||||
config, term_registry=sample_term_registry
|
config, term_registry=sample_term_registry
|
||||||
)
|
)
|
||||||
tags = decoder("pippip")
|
tags = decoder("pippip")
|
||||||
@ -335,7 +335,7 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
|||||||
],
|
],
|
||||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||||
)
|
)
|
||||||
decoder = build_decoder_from_config(
|
decoder = build_sound_event_decoder(
|
||||||
config, term_registry=sample_term_registry
|
config, term_registry=sample_term_registry
|
||||||
)
|
)
|
||||||
tags = decoder("pippip")
|
tags = decoder("pippip")
|
||||||
@ -344,14 +344,14 @@ def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
|||||||
assert tags[0].value == "Pipistrellus pipistrellus"
|
assert tags[0].value == "Pipistrellus pipistrellus"
|
||||||
|
|
||||||
# Test raise_on_unmapped=True
|
# Test raise_on_unmapped=True
|
||||||
decoder = build_decoder_from_config(
|
decoder = build_sound_event_decoder(
|
||||||
config, term_registry=sample_term_registry, raise_on_unmapped=True
|
config, term_registry=sample_term_registry, raise_on_unmapped=True
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
decoder("unknown_class")
|
decoder("unknown_class")
|
||||||
|
|
||||||
# Test raise_on_unmapped=False
|
# Test raise_on_unmapped=False
|
||||||
decoder = build_decoder_from_config(
|
decoder = build_sound_event_decoder(
|
||||||
config, term_registry=sample_term_registry, raise_on_unmapped=False
|
config, term_registry=sample_term_registry, raise_on_unmapped=False
|
||||||
)
|
)
|
||||||
tags = decoder("unknown_class")
|
tags = decoder("unknown_class")
|
||||||
@ -402,7 +402,7 @@ def test_build_generic_class_tags_from_config(
|
|||||||
TagInfo(key="call_type", value="Echolocation"),
|
TagInfo(key="call_type", value="Echolocation"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
generic_tags = build_generic_class_tags_from_config(
|
generic_tags = build_generic_class_tags(
|
||||||
config, term_registry=sample_term_registry
|
config, term_registry=sample_term_registry
|
||||||
)
|
)
|
||||||
assert len(generic_tags) == 2
|
assert len(generic_tags) == 2
|
||||||
|
@ -7,7 +7,7 @@ from soundevent import data
|
|||||||
from batdetect2.targets.filtering import (
|
from batdetect2.targets.filtering import (
|
||||||
FilterConfig,
|
FilterConfig,
|
||||||
FilterRule,
|
FilterRule,
|
||||||
build_filter_from_config,
|
build_sound_event_filter,
|
||||||
build_filter_from_rule,
|
build_filter_from_rule,
|
||||||
contains_tags,
|
contains_tags,
|
||||||
does_not_have_tags,
|
does_not_have_tags,
|
||||||
@ -121,7 +121,7 @@ def test_build_filter_from_config(create_annotation):
|
|||||||
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
|
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
filter_from_config = build_filter_from_config(config)
|
filter_from_config = build_sound_event_filter(config)
|
||||||
|
|
||||||
annotation_pass = create_annotation(["tag1", "tag2"])
|
annotation_pass = create_annotation(["tag1", "tag2"])
|
||||||
assert filter_from_config(annotation_pass)
|
assert filter_from_config(annotation_pass)
|
||||||
|
Loading…
Reference in New Issue
Block a user