mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
709 lines
22 KiB
Python
709 lines
22 KiB
Python
import importlib
|
|
from functools import partial
|
|
from typing import (
|
|
Annotated,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Mapping,
|
|
Optional,
|
|
Union,
|
|
)
|
|
|
|
from pydantic import Field
|
|
from soundevent import data
|
|
|
|
from batdetect2.configs import BaseConfig, load_config
|
|
from batdetect2.targets.terms import (
|
|
TagInfo,
|
|
TermRegistry,
|
|
get_tag_from_info,
|
|
get_term_from_key,
|
|
)
|
|
|
|
__all__ = [
|
|
"DerivationRegistry",
|
|
"DeriveTagRule",
|
|
"MapValueRule",
|
|
"ReplaceRule",
|
|
"SoundEventTransformation",
|
|
"TransformConfig",
|
|
"build_transform_from_rule",
|
|
"build_transformation_from_config",
|
|
"default_derivation_registry",
|
|
"get_derivation",
|
|
"load_transformation_config",
|
|
"load_transformation_from_config",
|
|
"register_derivation",
|
|
]
|
|
|
|
SoundEventTransformation = Callable[
|
|
[data.SoundEventAnnotation], data.SoundEventAnnotation
|
|
]
|
|
"""Type alias for a sound event transformation function.
|
|
|
|
A function that accepts a sound event annotation object and returns a
|
|
(potentially) modified sound event annotation object. Transformations
|
|
should generally return a copy of the annotation rather than modifying
|
|
it in place.
|
|
"""
|
|
|
|
|
|
Derivation = Callable[[str], str]
|
|
"""Type alias for a derivation function.
|
|
|
|
A function that accepts a single string (typically a tag value) and returns
|
|
a new string (the derived value).
|
|
"""
|
|
|
|
|
|
class MapValueRule(BaseConfig):
|
|
"""Configuration for mapping specific values of a source term.
|
|
|
|
This rule replaces tags matching a specific term and one of the
|
|
original values with a new tag (potentially having a different term)
|
|
containing the corresponding replacement value. Useful for standardizing
|
|
or grouping tag values.
|
|
|
|
Attributes
|
|
----------
|
|
rule_type : Literal["map_value"]
|
|
Discriminator field identifying this rule type.
|
|
source_term_key : str
|
|
The key (registered in `TermRegistry`) of the term whose tags' values
|
|
should be checked against the `value_mapping`.
|
|
value_mapping : Dict[str, str]
|
|
A dictionary mapping original string values to replacement string
|
|
values. Only tags whose value is a key in this dictionary will be
|
|
affected.
|
|
target_term_key : str, optional
|
|
The key (registered in `TermRegistry`) for the term of the *output*
|
|
tag. If None (default), the output tag uses the same term as the
|
|
source (`source_term_key`). If provided, the term of the affected
|
|
tag is changed to this target term upon replacement.
|
|
"""
|
|
|
|
rule_type: Literal["map_value"] = "map_value"
|
|
source_term_key: str
|
|
value_mapping: Dict[str, str]
|
|
target_term_key: Optional[str] = None
|
|
|
|
|
|
def map_value_transform(
|
|
sound_event_annotation: data.SoundEventAnnotation,
|
|
source_term: data.Term,
|
|
target_term: data.Term,
|
|
mapping: Dict[str, str],
|
|
) -> data.SoundEventAnnotation:
|
|
"""Apply a value mapping transformation to an annotation's tags.
|
|
|
|
Iterates through the annotation's tags. If a tag matches the `source_term`
|
|
and its value is found in the `mapping`, it is replaced by a new tag with
|
|
the `target_term` and the mapped value. Other tags are kept unchanged.
|
|
|
|
Parameters
|
|
----------
|
|
sound_event_annotation : data.SoundEventAnnotation
|
|
The annotation to transform.
|
|
source_term : data.Term
|
|
The term of tags whose values should be mapped.
|
|
target_term : data.Term
|
|
The term to use for the newly created tags after mapping.
|
|
mapping : Dict[str, str]
|
|
The dictionary mapping original values to new values.
|
|
|
|
Returns
|
|
-------
|
|
data.SoundEventAnnotation
|
|
A new annotation object with the transformed tags.
|
|
"""
|
|
tags = []
|
|
|
|
for tag in sound_event_annotation.tags:
|
|
if tag.term != source_term or tag.value not in mapping:
|
|
tags.append(tag)
|
|
continue
|
|
|
|
new_value = mapping[tag.value]
|
|
tags.append(data.Tag(term=target_term, value=new_value))
|
|
|
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
|
|
|
|
|
class DeriveTagRule(BaseConfig):
|
|
"""Configuration for deriving a new tag from an existing tag's value.
|
|
|
|
This rule applies a specified function (`derivation_function`) to the
|
|
value of tags matching the `source_term_key`. It then adds a new tag
|
|
with the `target_term_key` and the derived value.
|
|
|
|
Attributes
|
|
----------
|
|
rule_type : Literal["derive_tag"]
|
|
Discriminator field identifying this rule type.
|
|
source_term_key : str
|
|
The key (registered in `TermRegistry`) of the term whose tag values
|
|
will be used as input to the derivation function.
|
|
derivation_function : str
|
|
The name/key identifying the derivation function to use. This can be
|
|
a key registered in the `DerivationRegistry` or, if
|
|
`import_derivation` is True, a full Python path like
|
|
`'my_module.my_submodule.my_function'`.
|
|
target_term_key : str, optional
|
|
The key (registered in `TermRegistry`) for the term of the new tag
|
|
that will be created with the derived value. If None (default), the
|
|
derived tag uses the same term as the source (`source_term_key`),
|
|
effectively performing an in-place value transformation.
|
|
import_derivation : bool, default=False
|
|
If True, treat `derivation_function` as a Python import path and
|
|
attempt to dynamically import it if not found in the registry.
|
|
Requires the function to be accessible in the Python environment.
|
|
keep_source : bool, default=True
|
|
If True, the original source tag (whose value was used for derivation)
|
|
is kept in the annotation's tag list alongside the newly derived tag.
|
|
If False, the original source tag is removed.
|
|
"""
|
|
|
|
rule_type: Literal["derive_tag"] = "derive_tag"
|
|
source_term_key: str
|
|
derivation_function: str
|
|
target_term_key: Optional[str] = None
|
|
import_derivation: bool = False
|
|
keep_source: bool = True
|
|
|
|
|
|
def derivation_tag_transform(
|
|
sound_event_annotation: data.SoundEventAnnotation,
|
|
source_term: data.Term,
|
|
target_term: data.Term,
|
|
derivation: Derivation,
|
|
keep_source: bool = True,
|
|
) -> data.SoundEventAnnotation:
|
|
"""Apply a derivation transformation to an annotation's tags.
|
|
|
|
Iterates through the annotation's tags. For each tag matching the
|
|
`source_term`, its value is passed to the `derivation` function.
|
|
A new tag is created with the `target_term` and the derived value,
|
|
and added to the output tag list. The original source tag is kept
|
|
or discarded based on `keep_source`. Other tags are kept unchanged.
|
|
|
|
Parameters
|
|
----------
|
|
sound_event_annotation : data.SoundEventAnnotation
|
|
The annotation to transform.
|
|
source_term : data.Term
|
|
The term of tags whose values serve as input for the derivation.
|
|
target_term : data.Term
|
|
The term to use for the newly created derived tags.
|
|
derivation : Derivation
|
|
The function to apply to the source tag's value.
|
|
keep_source : bool, default=True
|
|
Whether to keep the original source tag in the output.
|
|
|
|
Returns
|
|
-------
|
|
data.SoundEventAnnotation
|
|
A new annotation object with the transformed tags (including derived
|
|
ones).
|
|
"""
|
|
tags = []
|
|
|
|
for tag in sound_event_annotation.tags:
|
|
if tag.term != source_term:
|
|
tags.append(tag)
|
|
continue
|
|
|
|
if keep_source:
|
|
tags.append(tag)
|
|
|
|
new_value = derivation(tag.value)
|
|
tags.append(data.Tag(term=target_term, value=new_value))
|
|
|
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
|
|
|
|
|
class ReplaceRule(BaseConfig):
|
|
"""Configuration for exactly replacing one specific tag with another.
|
|
|
|
This rule looks for an exact match of the `original` tag (both term and
|
|
value) and replaces it with the specified `replacement` tag.
|
|
|
|
Attributes
|
|
----------
|
|
rule_type : Literal["replace"]
|
|
Discriminator field identifying this rule type.
|
|
original : TagInfo
|
|
The exact tag to search for, defined using its value and term key.
|
|
replacement : TagInfo
|
|
The tag to substitute in place of the original tag, defined using
|
|
its value and term key.
|
|
"""
|
|
|
|
rule_type: Literal["replace"] = "replace"
|
|
original: TagInfo
|
|
replacement: TagInfo
|
|
|
|
|
|
def replace_tag_transform(
|
|
sound_event_annotation: data.SoundEventAnnotation,
|
|
source: data.Tag,
|
|
target: data.Tag,
|
|
) -> data.SoundEventAnnotation:
|
|
"""Apply an exact tag replacement transformation.
|
|
|
|
Iterates through the annotation's tags. If a tag exactly matches the
|
|
`source` tag, it is replaced by the `target` tag. Other tags are kept
|
|
unchanged.
|
|
|
|
Parameters
|
|
----------
|
|
sound_event_annotation : data.SoundEventAnnotation
|
|
The annotation to transform.
|
|
source : data.Tag
|
|
The exact tag to find and replace.
|
|
target : data.Tag
|
|
The tag to replace the source tag with.
|
|
|
|
Returns
|
|
-------
|
|
data.SoundEventAnnotation
|
|
A new annotation object with the replaced tag (if found).
|
|
"""
|
|
tags = []
|
|
|
|
for tag in sound_event_annotation.tags:
|
|
if tag == source:
|
|
tags.append(target)
|
|
else:
|
|
tags.append(tag)
|
|
|
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
|
|
|
|
|
class TransformConfig(BaseConfig):
|
|
"""Configuration model for defining a sequence of transformation rules.
|
|
|
|
Attributes
|
|
----------
|
|
rules : List[Union[ReplaceRule, MapValueRule, DeriveTagRule]]
|
|
A list of transformation rules to apply. The rules are applied
|
|
sequentially in the order they appear in the list. The output of
|
|
one rule becomes the input for the next. The `rule_type` field
|
|
discriminates between the different rule models.
|
|
"""
|
|
|
|
rules: List[
|
|
Annotated[
|
|
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
|
Field(discriminator="rule_type"),
|
|
]
|
|
] = Field(
|
|
default_factory=list,
|
|
)
|
|
|
|
|
|
class DerivationRegistry(Mapping[str, Derivation]):
|
|
"""A registry for managing named derivation functions.
|
|
|
|
Derivation functions are callables that take a string value and return
|
|
a transformed string value, used by `DeriveTagRule`. This registry
|
|
allows functions to be registered with a key and retrieved later.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize an empty DerivationRegistry."""
|
|
self._derivations: Dict[str, Derivation] = {}
|
|
|
|
def __getitem__(self, key: str) -> Derivation:
|
|
"""Retrieve a derivation function by key."""
|
|
return self._derivations[key]
|
|
|
|
def __len__(self) -> int:
|
|
"""Return the number of registered derivation functions."""
|
|
return len(self._derivations)
|
|
|
|
def __iter__(self):
|
|
"""Return an iterator over the keys of registered functions."""
|
|
return iter(self._derivations)
|
|
|
|
def register(self, key: str, derivation: Derivation) -> None:
|
|
"""Register a derivation function with a unique key.
|
|
|
|
Parameters
|
|
----------
|
|
key : str
|
|
The unique key to associate with the derivation function.
|
|
derivation : Derivation
|
|
The callable derivation function (takes str, returns str).
|
|
|
|
Raises
|
|
------
|
|
KeyError
|
|
If a derivation function with the same key is already registered.
|
|
"""
|
|
if key in self._derivations:
|
|
raise KeyError(
|
|
f"A derivation with the provided key {key} already exists"
|
|
)
|
|
|
|
self._derivations[key] = derivation
|
|
|
|
def get_derivation(self, key: str) -> Derivation:
|
|
"""Retrieve a derivation function by its registered key.
|
|
|
|
Parameters
|
|
----------
|
|
key : str
|
|
The key of the derivation function to retrieve.
|
|
|
|
Returns
|
|
-------
|
|
Derivation
|
|
The requested derivation function.
|
|
|
|
Raises
|
|
------
|
|
KeyError
|
|
If no derivation function with the specified key is registered.
|
|
"""
|
|
try:
|
|
return self._derivations[key]
|
|
except KeyError as err:
|
|
raise KeyError(
|
|
f"No derivation with key {key} is registered."
|
|
) from err
|
|
|
|
def get_keys(self) -> List[str]:
|
|
"""Get a list of all registered derivation function keys.
|
|
|
|
Returns
|
|
-------
|
|
List[str]
|
|
The keys of all registered functions.
|
|
"""
|
|
return list(self._derivations.keys())
|
|
|
|
def get_derivations(self) -> List[Derivation]:
|
|
"""Get a list of all registered derivation functions.
|
|
|
|
Returns
|
|
-------
|
|
List[Derivation]
|
|
The registered derivation function objects.
|
|
"""
|
|
return list(self._derivations.values())
|
|
|
|
|
|
default_derivation_registry = DerivationRegistry()
|
|
"""Global instance of the DerivationRegistry.
|
|
|
|
Register custom derivation functions here to make them available by key
|
|
in `DeriveTagRule` configuration.
|
|
"""
|
|
|
|
|
|
def get_derivation(
|
|
key: str,
|
|
import_derivation: bool = False,
|
|
registry: Optional[DerivationRegistry] = None,
|
|
):
|
|
"""Retrieve a derivation function by key, optionally importing it.
|
|
|
|
First attempts to find the function in the provided `registry`.
|
|
If not found and `import_derivation` is True, attempts to dynamically
|
|
import the function using the `key` as a full Python path
|
|
(e.g., 'my_module.submodule.my_func').
|
|
|
|
Parameters
|
|
----------
|
|
key : str
|
|
The key or Python path of the derivation function.
|
|
import_derivation : bool, default=False
|
|
If True, attempt dynamic import if key is not in the registry.
|
|
registry : DerivationRegistry, optional
|
|
The registry instance to check first. Defaults to the global
|
|
`derivation_registry`.
|
|
|
|
Returns
|
|
-------
|
|
Derivation
|
|
The requested derivation function.
|
|
|
|
Raises
|
|
------
|
|
KeyError
|
|
If the key is not found in the registry and either
|
|
`import_derivation` is False or the dynamic import fails.
|
|
ImportError
|
|
If dynamic import fails specifically due to module not found.
|
|
AttributeError
|
|
If dynamic import fails because the function name isn't in the module.
|
|
"""
|
|
registry = registry or default_derivation_registry
|
|
|
|
if not import_derivation or key in registry:
|
|
return registry.get_derivation(key)
|
|
|
|
try:
|
|
module_path, func_name = key.rsplit(".", 1)
|
|
module = importlib.import_module(module_path)
|
|
func = getattr(module, func_name)
|
|
return func
|
|
except ImportError as err:
|
|
raise KeyError(
|
|
f"Unable to load derivation '{key}'. Check the path and ensure "
|
|
"it points to a valid callable function in an importable module."
|
|
) from err
|
|
|
|
|
|
TranformationRule = Annotated[
|
|
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
|
Field(discriminator="rule_type"),
|
|
]
|
|
|
|
|
|
def build_transform_from_rule(
|
|
rule: TranformationRule,
|
|
derivation_registry: Optional[DerivationRegistry] = None,
|
|
term_registry: Optional[TermRegistry] = None,
|
|
) -> SoundEventTransformation:
|
|
"""Build a specific SoundEventTransformation function from a rule config.
|
|
|
|
Selects the appropriate transformation logic based on the rule's
|
|
`rule_type`, fetches necessary terms and derivation functions, and
|
|
returns a partially applied function ready to transform an annotation.
|
|
|
|
Parameters
|
|
----------
|
|
rule : Union[ReplaceRule, MapValueRule, DeriveTagRule]
|
|
The configuration object for a single transformation rule.
|
|
registry : DerivationRegistry, optional
|
|
The derivation registry to use for `DeriveTagRule`. Defaults to the
|
|
global `derivation_registry`.
|
|
|
|
Returns
|
|
-------
|
|
SoundEventTransformation
|
|
A callable that applies the specified rule to a SoundEventAnnotation.
|
|
|
|
Raises
|
|
------
|
|
KeyError
|
|
If required term keys or derivation keys are not found.
|
|
ValueError
|
|
If the rule has an unknown `rule_type`.
|
|
ImportError, AttributeError, TypeError
|
|
If dynamic import of a derivation function fails.
|
|
"""
|
|
if rule.rule_type == "replace":
|
|
source = get_tag_from_info(
|
|
rule.original,
|
|
term_registry=term_registry,
|
|
)
|
|
target = get_tag_from_info(
|
|
rule.replacement,
|
|
term_registry=term_registry,
|
|
)
|
|
return partial(replace_tag_transform, source=source, target=target)
|
|
|
|
if rule.rule_type == "derive_tag":
|
|
source_term = get_term_from_key(
|
|
rule.source_term_key,
|
|
term_registry=term_registry,
|
|
)
|
|
target_term = (
|
|
get_term_from_key(
|
|
rule.target_term_key,
|
|
term_registry=term_registry,
|
|
)
|
|
if rule.target_term_key
|
|
else source_term
|
|
)
|
|
derivation = get_derivation(
|
|
key=rule.derivation_function,
|
|
import_derivation=rule.import_derivation,
|
|
registry=derivation_registry,
|
|
)
|
|
return partial(
|
|
derivation_tag_transform,
|
|
source_term=source_term,
|
|
target_term=target_term,
|
|
derivation=derivation,
|
|
keep_source=rule.keep_source,
|
|
)
|
|
|
|
if rule.rule_type == "map_value":
|
|
source_term = get_term_from_key(
|
|
rule.source_term_key,
|
|
term_registry=term_registry,
|
|
)
|
|
target_term = (
|
|
get_term_from_key(
|
|
rule.target_term_key,
|
|
term_registry=term_registry,
|
|
)
|
|
if rule.target_term_key
|
|
else source_term
|
|
)
|
|
return partial(
|
|
map_value_transform,
|
|
source_term=source_term,
|
|
target_term=target_term,
|
|
mapping=rule.value_mapping,
|
|
)
|
|
|
|
# Handle unknown rule type
|
|
valid_options = ["replace", "derive_tag", "map_value"]
|
|
# Should be caught by Pydantic validation, but good practice
|
|
raise ValueError(
|
|
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
|
|
f"Valid options are: {valid_options}"
|
|
)
|
|
|
|
|
|
def build_transformation_from_config(
|
|
config: TransformConfig,
|
|
derivation_registry: Optional[DerivationRegistry] = None,
|
|
term_registry: Optional[TermRegistry] = None,
|
|
) -> SoundEventTransformation:
|
|
"""Build a composite transformation function from a TransformConfig.
|
|
|
|
Creates a sequence of individual transformation functions based on the
|
|
rules defined in the configuration. Returns a single function that
|
|
applies these transformations sequentially to an annotation.
|
|
|
|
Parameters
|
|
----------
|
|
config : TransformConfig
|
|
The configuration object containing the list of transformation rules.
|
|
derivation_reg : DerivationRegistry, optional
|
|
The derivation registry to use when building `DeriveTagRule`
|
|
transformations. Defaults to the global `derivation_registry`.
|
|
|
|
Returns
|
|
-------
|
|
SoundEventTransformation
|
|
A single function that applies all configured transformations in order.
|
|
"""
|
|
|
|
transforms = [
|
|
build_transform_from_rule(
|
|
rule,
|
|
derivation_registry=derivation_registry,
|
|
term_registry=term_registry,
|
|
)
|
|
for rule in config.rules
|
|
]
|
|
|
|
return partial(apply_sequence_of_transforms, transforms=transforms)
|
|
|
|
|
|
def apply_sequence_of_transforms(
|
|
sound_event_annotation: data.SoundEventAnnotation,
|
|
transforms: list[SoundEventTransformation],
|
|
) -> data.SoundEventAnnotation:
|
|
for transform in transforms:
|
|
sound_event_annotation = transform(sound_event_annotation)
|
|
return sound_event_annotation
|
|
|
|
|
|
def load_transformation_config(
|
|
path: data.PathLike, field: Optional[str] = None
|
|
) -> TransformConfig:
|
|
"""Load the transformation configuration from a file.
|
|
|
|
Parameters
|
|
----------
|
|
path : data.PathLike
|
|
Path to the configuration file (YAML).
|
|
field : str, optional
|
|
If the transformation configuration is nested under a specific key
|
|
in the file, specify the key here. Defaults to None.
|
|
|
|
Returns
|
|
-------
|
|
TransformConfig
|
|
The loaded and validated transformation configuration object.
|
|
|
|
Raises
|
|
------
|
|
FileNotFoundError
|
|
If the config file path does not exist.
|
|
pydantic.ValidationError
|
|
If the config file structure does not match the TransformConfig schema.
|
|
"""
|
|
return load_config(path=path, schema=TransformConfig, field=field)
|
|
|
|
|
|
def load_transformation_from_config(
|
|
path: data.PathLike,
|
|
field: Optional[str] = None,
|
|
derivation_registry: Optional[DerivationRegistry] = None,
|
|
term_registry: Optional[TermRegistry] = None,
|
|
) -> SoundEventTransformation:
|
|
"""Load transformation config from a file and build the final function.
|
|
|
|
This is a convenience function that combines loading the configuration
|
|
and building the final callable transformation function that applies
|
|
all rules sequentially.
|
|
|
|
Parameters
|
|
----------
|
|
path : data.PathLike
|
|
Path to the configuration file (YAML).
|
|
field : str, optional
|
|
If the transformation configuration is nested under a specific key
|
|
in the file, specify the key here. Defaults to None.
|
|
|
|
Returns
|
|
-------
|
|
SoundEventTransformation
|
|
The final composite transformation function ready to be used.
|
|
|
|
Raises
|
|
------
|
|
FileNotFoundError
|
|
If the config file path does not exist.
|
|
pydantic.ValidationError
|
|
If the config file structure does not match the TransformConfig schema.
|
|
KeyError
|
|
If required term keys or derivation keys specified in the config
|
|
are not found during the build process.
|
|
ImportError, AttributeError, TypeError
|
|
If dynamic import of a derivation function specified in the config
|
|
fails.
|
|
"""
|
|
config = load_transformation_config(path=path, field=field)
|
|
return build_transformation_from_config(
|
|
config,
|
|
derivation_registry=derivation_registry,
|
|
term_registry=term_registry,
|
|
)
|
|
|
|
|
|
def register_derivation(
|
|
key: str,
|
|
derivation: Derivation,
|
|
derivation_registry: Optional[DerivationRegistry] = None,
|
|
) -> None:
|
|
"""Register a new derivation function in the global registry.
|
|
|
|
Parameters
|
|
----------
|
|
key : str
|
|
The unique key to associate with the derivation function.
|
|
derivation : Derivation
|
|
The callable derivation function (takes str, returns str).
|
|
derivation_registry : DerivationRegistry, optional
|
|
The registry instance to register the derivation function with.
|
|
Defaults to the global `derivation_registry`.
|
|
|
|
Raises
|
|
------
|
|
KeyError
|
|
If a derivation function with the same key is already registered.
|
|
"""
|
|
derivation_registry = derivation_registry or default_derivation_registry
|
|
derivation_registry.register(key, derivation)
|