batdetect2/src/batdetect2/targets/transform.py
2025-06-21 13:48:40 +01:00

709 lines
22 KiB
Python

import importlib
from functools import partial
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
get_term_from_key,
)
__all__ = [
"DerivationRegistry",
"DeriveTagRule",
"MapValueRule",
"ReplaceRule",
"SoundEventTransformation",
"TransformConfig",
"build_transform_from_rule",
"build_transformation_from_config",
"default_derivation_registry",
"get_derivation",
"load_transformation_config",
"load_transformation_from_config",
"register_derivation",
]
SoundEventTransformation = Callable[
[data.SoundEventAnnotation], data.SoundEventAnnotation
]
"""Type alias for a sound event transformation function.
A function that accepts a sound event annotation object and returns a
(potentially) modified sound event annotation object. Transformations
should generally return a copy of the annotation rather than modifying
it in place.
"""
Derivation = Callable[[str], str]
"""Type alias for a derivation function.
A function that accepts a single string (typically a tag value) and returns
a new string (the derived value).
"""
class MapValueRule(BaseConfig):
"""Configuration for mapping specific values of a source term.
This rule replaces tags matching a specific term and one of the
original values with a new tag (potentially having a different term)
containing the corresponding replacement value. Useful for standardizing
or grouping tag values.
Attributes
----------
rule_type : Literal["map_value"]
Discriminator field identifying this rule type.
source_term_key : str
The key (registered in `TermRegistry`) of the term whose tags' values
should be checked against the `value_mapping`.
value_mapping : Dict[str, str]
A dictionary mapping original string values to replacement string
values. Only tags whose value is a key in this dictionary will be
affected.
target_term_key : str, optional
The key (registered in `TermRegistry`) for the term of the *output*
tag. If None (default), the output tag uses the same term as the
source (`source_term_key`). If provided, the term of the affected
tag is changed to this target term upon replacement.
"""
rule_type: Literal["map_value"] = "map_value"
source_term_key: str
value_mapping: Dict[str, str]
target_term_key: Optional[str] = None
def map_value_transform(
sound_event_annotation: data.SoundEventAnnotation,
source_term: data.Term,
target_term: data.Term,
mapping: Dict[str, str],
) -> data.SoundEventAnnotation:
"""Apply a value mapping transformation to an annotation's tags.
Iterates through the annotation's tags. If a tag matches the `source_term`
and its value is found in the `mapping`, it is replaced by a new tag with
the `target_term` and the mapped value. Other tags are kept unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source_term : data.Term
The term of tags whose values should be mapped.
target_term : data.Term
The term to use for the newly created tags after mapping.
mapping : Dict[str, str]
The dictionary mapping original values to new values.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags.
"""
tags = []
for tag in sound_event_annotation.tags:
if tag.term != source_term or tag.value not in mapping:
tags.append(tag)
continue
new_value = mapping[tag.value]
tags.append(data.Tag(term=target_term, value=new_value))
return sound_event_annotation.model_copy(update=dict(tags=tags))
class DeriveTagRule(BaseConfig):
"""Configuration for deriving a new tag from an existing tag's value.
This rule applies a specified function (`derivation_function`) to the
value of tags matching the `source_term_key`. It then adds a new tag
with the `target_term_key` and the derived value.
Attributes
----------
rule_type : Literal["derive_tag"]
Discriminator field identifying this rule type.
source_term_key : str
The key (registered in `TermRegistry`) of the term whose tag values
will be used as input to the derivation function.
derivation_function : str
The name/key identifying the derivation function to use. This can be
a key registered in the `DerivationRegistry` or, if
`import_derivation` is True, a full Python path like
`'my_module.my_submodule.my_function'`.
target_term_key : str, optional
The key (registered in `TermRegistry`) for the term of the new tag
that will be created with the derived value. If None (default), the
derived tag uses the same term as the source (`source_term_key`),
effectively performing an in-place value transformation.
import_derivation : bool, default=False
If True, treat `derivation_function` as a Python import path and
attempt to dynamically import it if not found in the registry.
Requires the function to be accessible in the Python environment.
keep_source : bool, default=True
If True, the original source tag (whose value was used for derivation)
is kept in the annotation's tag list alongside the newly derived tag.
If False, the original source tag is removed.
"""
rule_type: Literal["derive_tag"] = "derive_tag"
source_term_key: str
derivation_function: str
target_term_key: Optional[str] = None
import_derivation: bool = False
keep_source: bool = True
def derivation_tag_transform(
sound_event_annotation: data.SoundEventAnnotation,
source_term: data.Term,
target_term: data.Term,
derivation: Derivation,
keep_source: bool = True,
) -> data.SoundEventAnnotation:
"""Apply a derivation transformation to an annotation's tags.
Iterates through the annotation's tags. For each tag matching the
`source_term`, its value is passed to the `derivation` function.
A new tag is created with the `target_term` and the derived value,
and added to the output tag list. The original source tag is kept
or discarded based on `keep_source`. Other tags are kept unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source_term : data.Term
The term of tags whose values serve as input for the derivation.
target_term : data.Term
The term to use for the newly created derived tags.
derivation : Derivation
The function to apply to the source tag's value.
keep_source : bool, default=True
Whether to keep the original source tag in the output.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags (including derived
ones).
"""
tags = []
for tag in sound_event_annotation.tags:
if tag.term != source_term:
tags.append(tag)
continue
if keep_source:
tags.append(tag)
new_value = derivation(tag.value)
tags.append(data.Tag(term=target_term, value=new_value))
return sound_event_annotation.model_copy(update=dict(tags=tags))
class ReplaceRule(BaseConfig):
"""Configuration for exactly replacing one specific tag with another.
This rule looks for an exact match of the `original` tag (both term and
value) and replaces it with the specified `replacement` tag.
Attributes
----------
rule_type : Literal["replace"]
Discriminator field identifying this rule type.
original : TagInfo
The exact tag to search for, defined using its value and term key.
replacement : TagInfo
The tag to substitute in place of the original tag, defined using
its value and term key.
"""
rule_type: Literal["replace"] = "replace"
original: TagInfo
replacement: TagInfo
def replace_tag_transform(
sound_event_annotation: data.SoundEventAnnotation,
source: data.Tag,
target: data.Tag,
) -> data.SoundEventAnnotation:
"""Apply an exact tag replacement transformation.
Iterates through the annotation's tags. If a tag exactly matches the
`source` tag, it is replaced by the `target` tag. Other tags are kept
unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source : data.Tag
The exact tag to find and replace.
target : data.Tag
The tag to replace the source tag with.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the replaced tag (if found).
"""
tags = []
for tag in sound_event_annotation.tags:
if tag == source:
tags.append(target)
else:
tags.append(tag)
return sound_event_annotation.model_copy(update=dict(tags=tags))
class TransformConfig(BaseConfig):
"""Configuration model for defining a sequence of transformation rules.
Attributes
----------
rules : List[Union[ReplaceRule, MapValueRule, DeriveTagRule]]
A list of transformation rules to apply. The rules are applied
sequentially in the order they appear in the list. The output of
one rule becomes the input for the next. The `rule_type` field
discriminates between the different rule models.
"""
rules: List[
Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
] = Field(
default_factory=list,
)
class DerivationRegistry(Mapping[str, Derivation]):
"""A registry for managing named derivation functions.
Derivation functions are callables that take a string value and return
a transformed string value, used by `DeriveTagRule`. This registry
allows functions to be registered with a key and retrieved later.
"""
def __init__(self):
"""Initialize an empty DerivationRegistry."""
self._derivations: Dict[str, Derivation] = {}
def __getitem__(self, key: str) -> Derivation:
"""Retrieve a derivation function by key."""
return self._derivations[key]
def __len__(self) -> int:
"""Return the number of registered derivation functions."""
return len(self._derivations)
def __iter__(self):
"""Return an iterator over the keys of registered functions."""
return iter(self._derivations)
def register(self, key: str, derivation: Derivation) -> None:
"""Register a derivation function with a unique key.
Parameters
----------
key : str
The unique key to associate with the derivation function.
derivation : Derivation
The callable derivation function (takes str, returns str).
Raises
------
KeyError
If a derivation function with the same key is already registered.
"""
if key in self._derivations:
raise KeyError(
f"A derivation with the provided key {key} already exists"
)
self._derivations[key] = derivation
def get_derivation(self, key: str) -> Derivation:
"""Retrieve a derivation function by its registered key.
Parameters
----------
key : str
The key of the derivation function to retrieve.
Returns
-------
Derivation
The requested derivation function.
Raises
------
KeyError
If no derivation function with the specified key is registered.
"""
try:
return self._derivations[key]
except KeyError as err:
raise KeyError(
f"No derivation with key {key} is registered."
) from err
def get_keys(self) -> List[str]:
"""Get a list of all registered derivation function keys.
Returns
-------
List[str]
The keys of all registered functions.
"""
return list(self._derivations.keys())
def get_derivations(self) -> List[Derivation]:
"""Get a list of all registered derivation functions.
Returns
-------
List[Derivation]
The registered derivation function objects.
"""
return list(self._derivations.values())
default_derivation_registry = DerivationRegistry()
"""Global instance of the DerivationRegistry.
Register custom derivation functions here to make them available by key
in `DeriveTagRule` configuration.
"""
def get_derivation(
key: str,
import_derivation: bool = False,
registry: Optional[DerivationRegistry] = None,
):
"""Retrieve a derivation function by key, optionally importing it.
First attempts to find the function in the provided `registry`.
If not found and `import_derivation` is True, attempts to dynamically
import the function using the `key` as a full Python path
(e.g., 'my_module.submodule.my_func').
Parameters
----------
key : str
The key or Python path of the derivation function.
import_derivation : bool, default=False
If True, attempt dynamic import if key is not in the registry.
registry : DerivationRegistry, optional
The registry instance to check first. Defaults to the global
`derivation_registry`.
Returns
-------
Derivation
The requested derivation function.
Raises
------
KeyError
If the key is not found in the registry and either
`import_derivation` is False or the dynamic import fails.
ImportError
If dynamic import fails specifically due to module not found.
AttributeError
If dynamic import fails because the function name isn't in the module.
"""
registry = registry or default_derivation_registry
if not import_derivation or key in registry:
return registry.get_derivation(key)
try:
module_path, func_name = key.rsplit(".", 1)
module = importlib.import_module(module_path)
func = getattr(module, func_name)
return func
except ImportError as err:
raise KeyError(
f"Unable to load derivation '{key}'. Check the path and ensure "
"it points to a valid callable function in an importable module."
) from err
TranformationRule = Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
def build_transform_from_rule(
rule: TranformationRule,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Build a specific SoundEventTransformation function from a rule config.
Selects the appropriate transformation logic based on the rule's
`rule_type`, fetches necessary terms and derivation functions, and
returns a partially applied function ready to transform an annotation.
Parameters
----------
rule : Union[ReplaceRule, MapValueRule, DeriveTagRule]
The configuration object for a single transformation rule.
registry : DerivationRegistry, optional
The derivation registry to use for `DeriveTagRule`. Defaults to the
global `derivation_registry`.
Returns
-------
SoundEventTransformation
A callable that applies the specified rule to a SoundEventAnnotation.
Raises
------
KeyError
If required term keys or derivation keys are not found.
ValueError
If the rule has an unknown `rule_type`.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails.
"""
if rule.rule_type == "replace":
source = get_tag_from_info(
rule.original,
term_registry=term_registry,
)
target = get_tag_from_info(
rule.replacement,
term_registry=term_registry,
)
return partial(replace_tag_transform, source=source, target=target)
if rule.rule_type == "derive_tag":
source_term = get_term_from_key(
rule.source_term_key,
term_registry=term_registry,
)
target_term = (
get_term_from_key(
rule.target_term_key,
term_registry=term_registry,
)
if rule.target_term_key
else source_term
)
derivation = get_derivation(
key=rule.derivation_function,
import_derivation=rule.import_derivation,
registry=derivation_registry,
)
return partial(
derivation_tag_transform,
source_term=source_term,
target_term=target_term,
derivation=derivation,
keep_source=rule.keep_source,
)
if rule.rule_type == "map_value":
source_term = get_term_from_key(
rule.source_term_key,
term_registry=term_registry,
)
target_term = (
get_term_from_key(
rule.target_term_key,
term_registry=term_registry,
)
if rule.target_term_key
else source_term
)
return partial(
map_value_transform,
source_term=source_term,
target_term=target_term,
mapping=rule.value_mapping,
)
# Handle unknown rule type
valid_options = ["replace", "derive_tag", "map_value"]
# Should be caught by Pydantic validation, but good practice
raise ValueError(
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
f"Valid options are: {valid_options}"
)
def build_transformation_from_config(
config: TransformConfig,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Build a composite transformation function from a TransformConfig.
Creates a sequence of individual transformation functions based on the
rules defined in the configuration. Returns a single function that
applies these transformations sequentially to an annotation.
Parameters
----------
config : TransformConfig
The configuration object containing the list of transformation rules.
derivation_reg : DerivationRegistry, optional
The derivation registry to use when building `DeriveTagRule`
transformations. Defaults to the global `derivation_registry`.
Returns
-------
SoundEventTransformation
A single function that applies all configured transformations in order.
"""
transforms = [
build_transform_from_rule(
rule,
derivation_registry=derivation_registry,
term_registry=term_registry,
)
for rule in config.rules
]
return partial(apply_sequence_of_transforms, transforms=transforms)
def apply_sequence_of_transforms(
sound_event_annotation: data.SoundEventAnnotation,
transforms: list[SoundEventTransformation],
) -> data.SoundEventAnnotation:
for transform in transforms:
sound_event_annotation = transform(sound_event_annotation)
return sound_event_annotation
def load_transformation_config(
path: data.PathLike, field: Optional[str] = None
) -> TransformConfig:
"""Load the transformation configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the transformation configuration is nested under a specific key
in the file, specify the key here. Defaults to None.
Returns
-------
TransformConfig
The loaded and validated transformation configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TransformConfig schema.
"""
return load_config(path=path, schema=TransformConfig, field=field)
def load_transformation_from_config(
path: data.PathLike,
field: Optional[str] = None,
derivation_registry: Optional[DerivationRegistry] = None,
term_registry: Optional[TermRegistry] = None,
) -> SoundEventTransformation:
"""Load transformation config from a file and build the final function.
This is a convenience function that combines loading the configuration
and building the final callable transformation function that applies
all rules sequentially.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the transformation configuration is nested under a specific key
in the file, specify the key here. Defaults to None.
Returns
-------
SoundEventTransformation
The final composite transformation function ready to be used.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TransformConfig schema.
KeyError
If required term keys or derivation keys specified in the config
are not found during the build process.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function specified in the config
fails.
"""
config = load_transformation_config(path=path, field=field)
return build_transformation_from_config(
config,
derivation_registry=derivation_registry,
term_registry=term_registry,
)
def register_derivation(
key: str,
derivation: Derivation,
derivation_registry: Optional[DerivationRegistry] = None,
) -> None:
"""Register a new derivation function in the global registry.
Parameters
----------
key : str
The unique key to associate with the derivation function.
derivation : Derivation
The callable derivation function (takes str, returns str).
derivation_registry : DerivationRegistry, optional
The registry instance to register the derivation function with.
Defaults to the global `derivation_registry`.
Raises
------
KeyError
If a derivation function with the same key is already registered.
"""
derivation_registry = derivation_registry or default_derivation_registry
derivation_registry.register(key, derivation)