mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Add target classes module
This commit is contained in:
parent
02d4779207
commit
af48c33307
@ -24,22 +24,53 @@ from batdetect2.targets.terms import (
|
|||||||
TermInfo,
|
TermInfo,
|
||||||
call_type,
|
call_type,
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
|
get_term_from_key,
|
||||||
individual,
|
individual,
|
||||||
|
register_term,
|
||||||
|
term_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.transform import (
|
||||||
|
DerivationRegistry,
|
||||||
|
DeriveTagRule,
|
||||||
|
MapValueRule,
|
||||||
|
ReplaceRule,
|
||||||
|
SoundEventTransformation,
|
||||||
|
TransformConfig,
|
||||||
|
build_transform_from_rule,
|
||||||
|
build_transformation_from_config,
|
||||||
|
derivation_registry,
|
||||||
|
get_derivation,
|
||||||
|
load_transformation_config,
|
||||||
|
load_transformation_from_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"DerivationRegistry",
|
||||||
|
"DeriveTagRule",
|
||||||
"HeatmapsConfig",
|
"HeatmapsConfig",
|
||||||
"LabelConfig",
|
"LabelConfig",
|
||||||
|
"MapValueRule",
|
||||||
|
"ReplaceRule",
|
||||||
|
"SoundEventTransformation",
|
||||||
"TagInfo",
|
"TagInfo",
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"TermInfo",
|
"TermInfo",
|
||||||
|
"TransformConfig",
|
||||||
"build_decoder",
|
"build_decoder",
|
||||||
"build_target_encoder",
|
"build_target_encoder",
|
||||||
|
"build_transform_from_rule",
|
||||||
|
"build_transformation_from_config",
|
||||||
"call_type",
|
"call_type",
|
||||||
|
"derivation_registry",
|
||||||
"generate_heatmaps",
|
"generate_heatmaps",
|
||||||
"get_class_names",
|
"get_class_names",
|
||||||
|
"get_derivation",
|
||||||
"get_tag_from_info",
|
"get_tag_from_info",
|
||||||
"individual",
|
"individual",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_target_config",
|
"load_target_config",
|
||||||
|
"load_transformation_config",
|
||||||
|
"load_transformation_from_config",
|
||||||
|
"register_term",
|
||||||
|
"term_registry",
|
||||||
]
|
]
|
||||||
|
300
batdetect2/targets/classes.py
Normal file
300
batdetect2/targets/classes.py
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
from collections import Counter
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, List, Literal, Optional, Set
|
||||||
|
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.targets.terms import (
|
||||||
|
TagInfo,
|
||||||
|
TermRegistry,
|
||||||
|
get_tag_from_info,
|
||||||
|
term_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SoundEventEncoder",
|
||||||
|
"TargetClass",
|
||||||
|
"ClassesConfig",
|
||||||
|
"load_classes_config",
|
||||||
|
"build_encoder_from_config",
|
||||||
|
"load_encoder_from_config",
|
||||||
|
"get_class_names_from_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
|
"""Type alias for a sound event class encoder function.
|
||||||
|
|
||||||
|
An encoder function takes a sound event annotation and returns the string name
|
||||||
|
of the target class it belongs to, based on a predefined set of rules.
|
||||||
|
If the annotation does not match any defined target class according to the
|
||||||
|
rules, the function returns None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TargetClass(BaseConfig):
|
||||||
|
"""Defines the criteria for assigning an annotation to a specific class.
|
||||||
|
|
||||||
|
Each instance represents one potential output class for the classification
|
||||||
|
model. It specifies the class name and the tag conditions an annotation
|
||||||
|
must meet to be assigned this class label.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The unique name assigned to this target class (e.g., 'pippip',
|
||||||
|
'myodau', 'noise'). This name will be used as the label during model
|
||||||
|
training and output. Should be unique across all TargetClass
|
||||||
|
definitions in a configuration.
|
||||||
|
tag : List[TagInfo]
|
||||||
|
A list of one or more tags (defined using `TagInfo`) that an annotation
|
||||||
|
must possess to potentially match this class.
|
||||||
|
match_type : Literal["all", "any"], default="all"
|
||||||
|
Determines how the `tag` list is evaluated:
|
||||||
|
- "all": The annotation must have *all* the tags listed in the `tag`
|
||||||
|
field to match this class definition.
|
||||||
|
- "any": The annotation must have *at least one* of the tags listed
|
||||||
|
in the `tag` field to match this class definition.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
tags: List[TagInfo] = Field(default_factory=list, min_length=1)
|
||||||
|
match_type: Literal["all", "any"] = Field(default="all")
|
||||||
|
|
||||||
|
|
||||||
|
class ClassesConfig(BaseConfig):
|
||||||
|
"""Configuration model holding the list of target class definitions.
|
||||||
|
|
||||||
|
The order of `TargetClass` objects in the `classes` list defines the
|
||||||
|
priority for classification. When encoding an annotation, the system checks
|
||||||
|
against the class definitions in this sequence and assigns the name of the
|
||||||
|
*first* matching class.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
classes : List[TargetClass]
|
||||||
|
An ordered list of target class definitions. The order determines
|
||||||
|
matching priority (first match wins).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If validation fails (e.g., non-unique class names).
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
It is crucial that the `name` attribute of each `TargetClass` in the
|
||||||
|
`classes` list is unique. This configuration includes a validator to
|
||||||
|
enforce this uniqueness.
|
||||||
|
"""
|
||||||
|
|
||||||
|
classes: List[TargetClass] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("classes")
|
||||||
|
def check_unique_class_names(cls, v: List[TargetClass]):
|
||||||
|
"""Ensure all defined class names are unique."""
|
||||||
|
names = [c.name for c in v]
|
||||||
|
|
||||||
|
if len(names) != len(set(names)):
|
||||||
|
name_counts = Counter(names)
|
||||||
|
duplicates = [
|
||||||
|
name for name, count in name_counts.items() if count > 1
|
||||||
|
]
|
||||||
|
raise ValueError(
|
||||||
|
"Class names must be unique. Found duplicates: "
|
||||||
|
f"{', '.join(duplicates)}"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def load_classes_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> ClassesConfig:
|
||||||
|
"""Load the target classes configuration from a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (YAML).
|
||||||
|
field : str, optional
|
||||||
|
If the classes configuration is nested under a specific key in the
|
||||||
|
file, specify the key here. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ClassesConfig
|
||||||
|
The loaded and validated classes configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the config file structure does not match the ClassesConfig schema
|
||||||
|
or if class names are not unique.
|
||||||
|
"""
|
||||||
|
return load_config(path, schema=ClassesConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def is_target_class(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
tags: Set[data.Tag],
|
||||||
|
match_all: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a sound event annotation matches a set of required tags.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to check.
|
||||||
|
required_tags : Set[data.Tag]
|
||||||
|
A set of `soundevent.data.Tag` objects that define the class criteria.
|
||||||
|
match_all : bool, default=True
|
||||||
|
If True, checks if *all* `required_tags` are present in the
|
||||||
|
annotation's tags (subset check). If False, checks if *at least one*
|
||||||
|
of the `required_tags` is present (intersection check).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation meets the tag criteria, False otherwise.
|
||||||
|
"""
|
||||||
|
annotation_tags = set(sound_event_annotation.tags)
|
||||||
|
|
||||||
|
if match_all:
|
||||||
|
return tags <= annotation_tags
|
||||||
|
|
||||||
|
return bool(tags & annotation_tags)
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_names_from_config(config: ClassesConfig) -> List[str]:
|
||||||
|
"""Extract the list of class names from a ClassesConfig object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : ClassesConfig
|
||||||
|
The loaded classes configuration object.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[str]
|
||||||
|
An ordered list of unique class names defined in the configuration.
|
||||||
|
"""
|
||||||
|
return [class_info.name for class_info in config.classes]
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder_from_config(
|
||||||
|
config: ClassesConfig,
|
||||||
|
term_registry: TermRegistry = term_registry,
|
||||||
|
) -> SoundEventEncoder:
|
||||||
|
"""Build a sound event encoder function from the classes configuration.
|
||||||
|
|
||||||
|
The returned encoder function iterates through the class definitions in the
|
||||||
|
order specified in the config. It assigns an annotation the name of the
|
||||||
|
first class definition it matches.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : ClassesConfig
|
||||||
|
The loaded and validated classes configuration object.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance used to look up term keys specified in the
|
||||||
|
`TagInfo` objects within the configuration. Defaults to the global
|
||||||
|
`batdetect2.targets.terms.registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventEncoder
|
||||||
|
A callable function that takes a `SoundEventAnnotation` and returns
|
||||||
|
an optional string representing the matched class name, or None if no
|
||||||
|
class matches.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If a term key specified in the configuration is not found in the
|
||||||
|
provided `term_registry`.
|
||||||
|
"""
|
||||||
|
binary_classifiers = [
|
||||||
|
(
|
||||||
|
class_info.name,
|
||||||
|
partial(
|
||||||
|
is_target_class,
|
||||||
|
tags={
|
||||||
|
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||||
|
for tag_info in class_info.tags
|
||||||
|
},
|
||||||
|
match_all=class_info.match_type == "all",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for class_info in config.classes
|
||||||
|
]
|
||||||
|
|
||||||
|
def encoder(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Assign a class name to an annotation based on configured rules.
|
||||||
|
|
||||||
|
Iterates through pre-compiled classifiers in priority order. Returns
|
||||||
|
the name of the first matching class, or None if no match is found.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to encode.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str or None
|
||||||
|
The name of the matched class, or None.
|
||||||
|
"""
|
||||||
|
for class_name, classifier in binary_classifiers:
|
||||||
|
if classifier(sound_event_annotation):
|
||||||
|
return class_name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def load_encoder_from_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
term_registry: TermRegistry = term_registry,
|
||||||
|
) -> SoundEventEncoder:
|
||||||
|
"""Load a class encoder function directly from a configuration file.
|
||||||
|
|
||||||
|
This is a convenience function that combines loading the `ClassesConfig`
|
||||||
|
from a file and building the final `SoundEventEncoder` function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (e.g., YAML).
|
||||||
|
field : str, optional
|
||||||
|
If the classes configuration is nested under a specific key in the
|
||||||
|
file, specify the key here. Defaults to None.
|
||||||
|
term_registry : TermRegistry, optional
|
||||||
|
The TermRegistry instance used for term lookups. Defaults to the
|
||||||
|
global `batdetect2.targets.terms.registry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventEncoder
|
||||||
|
The final encoder function ready to classify annotations.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the config file structure does not match the ClassesConfig schema
|
||||||
|
or if class names are not unique.
|
||||||
|
KeyError
|
||||||
|
If a term key specified in the configuration is not found in the
|
||||||
|
provided `term_registry` during the build process.
|
||||||
|
"""
|
||||||
|
config = load_classes_config(path, field=field)
|
||||||
|
return build_encoder_from_config(config, term_registry=term_registry)
|
@ -143,7 +143,7 @@ class TermRegistry(Mapping[str, data.Term]):
|
|||||||
raise KeyError(
|
raise KeyError(
|
||||||
"No term found for key "
|
"No term found for key "
|
||||||
f"'{key}'. Ensure it is registered or loaded. "
|
f"'{key}'. Ensure it is registered or loaded. "
|
||||||
"Use `get_term_keys()` to list available terms."
|
f"Available keys: {', '.join(self.get_keys())}"
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
def add_custom_term(
|
def add_custom_term(
|
||||||
@ -215,8 +215,11 @@ class TermRegistry(Mapping[str, data.Term]):
|
|||||||
"""
|
"""
|
||||||
return list(self._terms.values())
|
return list(self._terms.values())
|
||||||
|
|
||||||
|
def remove_key(self, key: str) -> None:
|
||||||
|
del self._terms[key]
|
||||||
|
|
||||||
registry = TermRegistry(
|
|
||||||
|
term_registry = TermRegistry(
|
||||||
terms=dict(
|
terms=dict(
|
||||||
[
|
[
|
||||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||||
@ -237,7 +240,7 @@ is explicitly passed.
|
|||||||
|
|
||||||
def get_term_from_key(
|
def get_term_from_key(
|
||||||
key: str,
|
key: str,
|
||||||
term_registry: TermRegistry = registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> data.Term:
|
) -> data.Term:
|
||||||
"""Convenience function to retrieve a term by key from a registry.
|
"""Convenience function to retrieve a term by key from a registry.
|
||||||
|
|
||||||
@ -265,7 +268,7 @@ def get_term_from_key(
|
|||||||
return term_registry.get_term(key)
|
return term_registry.get_term(key)
|
||||||
|
|
||||||
|
|
||||||
def get_term_keys(term_registry: TermRegistry = registry) -> List[str]:
|
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
|
||||||
"""Convenience function to get all registered keys from a registry.
|
"""Convenience function to get all registered keys from a registry.
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
Uses the global default registry unless a specific `term_registry`
|
||||||
@ -284,7 +287,7 @@ def get_term_keys(term_registry: TermRegistry = registry) -> List[str]:
|
|||||||
return term_registry.get_keys()
|
return term_registry.get_keys()
|
||||||
|
|
||||||
|
|
||||||
def get_terms(term_registry: TermRegistry = registry) -> List[data.Term]:
|
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
|
||||||
"""Convenience function to get all registered terms from a registry.
|
"""Convenience function to get all registered terms from a registry.
|
||||||
|
|
||||||
Uses the global default registry unless a specific `term_registry`
|
Uses the global default registry unless a specific `term_registry`
|
||||||
@ -327,7 +330,7 @@ class TagInfo(BaseModel):
|
|||||||
|
|
||||||
def get_tag_from_info(
|
def get_tag_from_info(
|
||||||
tag_info: TagInfo,
|
tag_info: TagInfo,
|
||||||
term_registry: TermRegistry = registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> data.Tag:
|
) -> data.Tag:
|
||||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||||
|
|
||||||
@ -424,7 +427,7 @@ class TermConfig(BaseModel):
|
|||||||
def load_terms_from_config(
|
def load_terms_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
term_registry: TermRegistry = registry,
|
term_registry: TermRegistry = term_registry,
|
||||||
) -> Dict[str, data.Term]:
|
) -> Dict[str, data.Term]:
|
||||||
"""Loads term definitions from a configuration file and registers them.
|
"""Loads term definitions from a configuration file and registers them.
|
||||||
|
|
||||||
@ -472,3 +475,9 @@ def load_terms_from_config(
|
|||||||
)
|
)
|
||||||
for info in data.terms
|
for info in data.terms
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_term(
|
||||||
|
key: str, term: data.Term, registry: TermRegistry = term_registry
|
||||||
|
) -> None:
|
||||||
|
registry.add_term(key, term)
|
||||||
|
@ -1,6 +1,15 @@
|
|||||||
import importlib
|
import importlib
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Literal, Mapping, Optional, Union
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -8,9 +17,13 @@ from soundevent import data
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import (
|
||||||
TagInfo,
|
TagInfo,
|
||||||
|
TermRegistry,
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
get_term_from_key,
|
get_term_from_key,
|
||||||
)
|
)
|
||||||
|
from batdetect2.targets.terms import (
|
||||||
|
term_registry as default_term_registry,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SoundEventTransformation",
|
"SoundEventTransformation",
|
||||||
@ -282,9 +295,13 @@ class TransformConfig(BaseConfig):
|
|||||||
discriminates between the different rule models.
|
discriminates between the different rule models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rules: List[Union[ReplaceRule, MapValueRule, DeriveTagRule]] = Field(
|
rules: List[
|
||||||
|
Annotated[
|
||||||
|
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||||
|
Field(discriminator="rule_type"),
|
||||||
|
]
|
||||||
|
] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
discriminator="rule_type",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -442,7 +459,8 @@ def get_derivation(
|
|||||||
|
|
||||||
def build_transform_from_rule(
|
def build_transform_from_rule(
|
||||||
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||||
registry: DerivationRegistry = derivation_registry,
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Build a specific SoundEventTransformation function from a rule config.
|
"""Build a specific SoundEventTransformation function from a rule config.
|
||||||
|
|
||||||
@ -473,21 +491,33 @@ def build_transform_from_rule(
|
|||||||
If dynamic import of a derivation function fails.
|
If dynamic import of a derivation function fails.
|
||||||
"""
|
"""
|
||||||
if rule.rule_type == "replace":
|
if rule.rule_type == "replace":
|
||||||
source = get_tag_from_info(rule.original)
|
source = get_tag_from_info(
|
||||||
target = get_tag_from_info(rule.replacement)
|
rule.original,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
target = get_tag_from_info(
|
||||||
|
rule.replacement,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
return partial(replace_tag_transform, source=source, target=target)
|
return partial(replace_tag_transform, source=source, target=target)
|
||||||
|
|
||||||
if rule.rule_type == "derive_tag":
|
if rule.rule_type == "derive_tag":
|
||||||
source_term = get_term_from_key(rule.source_term_key)
|
source_term = get_term_from_key(
|
||||||
|
rule.source_term_key,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
target_term = (
|
target_term = (
|
||||||
get_term_from_key(rule.target_term_key)
|
get_term_from_key(
|
||||||
|
rule.target_term_key,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
if rule.target_term_key
|
if rule.target_term_key
|
||||||
else source_term
|
else source_term
|
||||||
)
|
)
|
||||||
derivation = get_derivation(
|
derivation = get_derivation(
|
||||||
key=rule.derivation_function,
|
key=rule.derivation_function,
|
||||||
import_derivation=rule.import_derivation,
|
import_derivation=rule.import_derivation,
|
||||||
registry=registry,
|
registry=derivation_registry,
|
||||||
)
|
)
|
||||||
return partial(
|
return partial(
|
||||||
derivation_tag_transform,
|
derivation_tag_transform,
|
||||||
@ -498,9 +528,15 @@ def build_transform_from_rule(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if rule.rule_type == "map_value":
|
if rule.rule_type == "map_value":
|
||||||
source_term = get_term_from_key(rule.source_term_key)
|
source_term = get_term_from_key(
|
||||||
|
rule.source_term_key,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
target_term = (
|
target_term = (
|
||||||
get_term_from_key(rule.target_term_key)
|
get_term_from_key(
|
||||||
|
rule.target_term_key,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
if rule.target_term_key
|
if rule.target_term_key
|
||||||
else source_term
|
else source_term
|
||||||
)
|
)
|
||||||
@ -522,6 +558,8 @@ def build_transform_from_rule(
|
|||||||
|
|
||||||
def build_transformation_from_config(
|
def build_transformation_from_config(
|
||||||
config: TransformConfig,
|
config: TransformConfig,
|
||||||
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Build a composite transformation function from a TransformConfig.
|
"""Build a composite transformation function from a TransformConfig.
|
||||||
|
|
||||||
@ -542,7 +580,14 @@ def build_transformation_from_config(
|
|||||||
SoundEventTransformation
|
SoundEventTransformation
|
||||||
A single function that applies all configured transformations in order.
|
A single function that applies all configured transformations in order.
|
||||||
"""
|
"""
|
||||||
transforms = [build_transform_from_rule(rule) for rule in config.rules]
|
transforms = [
|
||||||
|
build_transform_from_rule(
|
||||||
|
rule,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
for rule in config.rules
|
||||||
|
]
|
||||||
|
|
||||||
def transformation(
|
def transformation(
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
@ -583,7 +628,10 @@ def load_transformation_config(
|
|||||||
|
|
||||||
|
|
||||||
def load_transformation_from_config(
|
def load_transformation_from_config(
|
||||||
path: data.PathLike, field: Optional[str] = None
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
derivation_registry: DerivationRegistry = derivation_registry,
|
||||||
|
term_registry: TermRegistry = default_term_registry,
|
||||||
) -> SoundEventTransformation:
|
) -> SoundEventTransformation:
|
||||||
"""Load transformation config from a file and build the final function.
|
"""Load transformation config from a file and build the final function.
|
||||||
|
|
||||||
@ -618,4 +666,8 @@ def load_transformation_from_config(
|
|||||||
fails.
|
fails.
|
||||||
"""
|
"""
|
||||||
config = load_transformation_config(path=path, field=field)
|
config = load_transformation_config(path=path, field=field)
|
||||||
return build_transformation_from_config(config)
|
return build_transformation_from_config(
|
||||||
|
config,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
148
docs/targets/classes.md
Normal file
148
docs/targets/classes.md
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
## Step 4: Defining Target Classes for Training
|
||||||
|
|
||||||
|
### Purpose and Context
|
||||||
|
|
||||||
|
You've prepared your data by defining your annotation vocabulary (Step 1: Terms), removing irrelevant sounds (Step 2: Filtering), and potentially cleaning up or modifying tags (Step 3: Transforming Tags).
|
||||||
|
Now, it's time to tell `batdetect2` **exactly what categories (classes) your model should learn to identify**.
|
||||||
|
|
||||||
|
This step involves defining rules that map the final tags on your sound event annotations to specific **class names** (like `pippip`, `myodau`, or `noise`).
|
||||||
|
These class names are the labels the machine learning model will be trained to predict.
|
||||||
|
Getting this definition right is essential for successful model training.
|
||||||
|
|
||||||
|
### How it Works: Defining Classes with Rules
|
||||||
|
|
||||||
|
You define your target classes in your main configuration file (e.g., your `.yaml` training config), typically under a section named `classes`.
|
||||||
|
This section contains a **list** of class definitions.
|
||||||
|
Each item in the list defines one specific class your model should learn.
|
||||||
|
|
||||||
|
### Defining a Single Class
|
||||||
|
|
||||||
|
Each class definition rule requires a few key pieces of information:
|
||||||
|
|
||||||
|
1. `name`: **(Required)** This is the unique, simple name you want to give this class (e.g., `pipistrellus_pipistrellus`, `myotis_daubentonii`, `echolocation_noise`).
|
||||||
|
This is the label the model will actually use.
|
||||||
|
Choose names that are clear and distinct.
|
||||||
|
**Each class name must be unique.**
|
||||||
|
2. `tags`: **(Required)** This is a list containing one or more specific tags that identify annotations belonging to this class.
|
||||||
|
Remember, each tag is specified using its term `key` (like `species` or `sound_type`, defaulting to `class` if omitted) and its specific `value` (like `Pipistrellus pipistrellus` or `Echolocation`).
|
||||||
|
3. `match_type`: **(Optional, defaults to `"all"`)** This tells the system how to use the list of tags you provided in the `tag` field:
|
||||||
|
- `"all"`: An annotation must have **ALL** of the tags listed in the `tags` section to be considered part of this class.
|
||||||
|
(This is the default if you don't specify `match_type`).
|
||||||
|
- `"any"`: An annotation only needs to have **AT LEAST ONE** of the tags listed in the `tags` section to be considered part of this class.
|
||||||
|
|
||||||
|
**Example: Defining two specific bat species classes**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# In your main configuration file
|
||||||
|
classes:
|
||||||
|
# Definition for the first class
|
||||||
|
- name: pippip # Simple name for Pipistrellus pipistrellus
|
||||||
|
tags:
|
||||||
|
- key: species # Term key (could also default to 'class')
|
||||||
|
value: Pipistrellus pipistrellus # Specific tag value
|
||||||
|
# match_type defaults to "all" (which is fine for a single tag)
|
||||||
|
|
||||||
|
# Definition for the second class
|
||||||
|
- name: myodau # Simple name for Myotis daubentonii
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis daubentonii
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example: Defining a class requiring multiple conditions (`match_type: "all"`)**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
classes:
|
||||||
|
- name: high_quality_pippip # Name for high-quality P. pip calls
|
||||||
|
match_type: all # Annotation must match BOTH tags below
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
- key: quality # Assumes 'quality' term key exists
|
||||||
|
value: Good
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example: Defining a class matching multiple alternative tags (`match_type: "any"`)**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
classes:
|
||||||
|
- name: pipistrelle # Name for any Pipistrellus species in this list
|
||||||
|
match_type: any # Annotation must match AT LEAST ONE tag below
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pygmaeus
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus nathusii
|
||||||
|
```
|
||||||
|
|
||||||
|
### Handling Overlap: Priority Order Matters!
|
||||||
|
|
||||||
|
Sometimes, an annotation might have tags that match the rules for _more than one_ class definition.
|
||||||
|
For example, an annotation tagged `species: Pipistrellus pipistrellus` would match both a specific `'pippip'` class rule and a broader `'pipistrelle'` genus rule (like the examples above) if both were defined.
|
||||||
|
|
||||||
|
How does `batdetect2` decide which class name to assign? It uses the **order of the class definitions in your configuration list**.
|
||||||
|
|
||||||
|
- The system checks an annotation against your class rules one by one, starting from the **top** of the `classes` list and moving down.
|
||||||
|
- As soon as it finds a rule that the annotation matches, it assigns that rule's `name` to the annotation and **stops checking** further rules for that annotation.
|
||||||
|
- **The first match wins!**
|
||||||
|
|
||||||
|
Therefore, you should generally place your **most specific rules before more general rules** if you want the specific category to take precedence.
|
||||||
|
|
||||||
|
**Example: Prioritizing Species over Noise**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
classes:
|
||||||
|
# --- Specific Species Rules (Checked First) ---
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
|
||||||
|
- name: myodau
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis daubentonii
|
||||||
|
|
||||||
|
# --- General Noise Rule (Checked Last) ---
|
||||||
|
- name: noise # Catch-all for anything tagged as Noise
|
||||||
|
match_type: any # Match if any noise tag is present
|
||||||
|
tags:
|
||||||
|
- key: sound_type # Assume 'sound_type' term key exists
|
||||||
|
value: Noise
|
||||||
|
- key: quality # Assume 'quality' term key exists
|
||||||
|
value: Low # Maybe low quality is also considered noise for training
|
||||||
|
```
|
||||||
|
|
||||||
|
In this example, an annotation tagged with `species: Myotis daubentonii` _and_ `quality: Low` would be assigned the class name `myodau` because that rule comes first in the list.
|
||||||
|
It would not be assigned `noise`, even though it also matches the second condition of the noise rule.
|
||||||
|
|
||||||
|
Okay, that's a very important clarification about how BatDetect2 handles sounds that don't match specific class definitions.
|
||||||
|
Let's refine that section to accurately reflect this behavior.
|
||||||
|
|
||||||
|
### What if No Class Matches? (The Generic "Bat" Class)
|
||||||
|
|
||||||
|
It's important to understand what happens if a sound event annotation passes through the filtering (Step 2) and transformation (Step 3) steps, but its final set of tags doesn't match _any_ of the specific class definitions you've listed in this section.
|
||||||
|
|
||||||
|
These annotations are **not ignored** during training.
|
||||||
|
Instead, they are typically assigned to a **generic "relevant sound" class**.
|
||||||
|
Think of this as a category for sounds that you considered important enough to keep after filtering, but which don't fit into one of your specific target classes for detailed classification (like a particular species).
|
||||||
|
This generic class is distinct from background noise.
|
||||||
|
|
||||||
|
In BatDetect2, this default generic class is often referred to as the **"Bat"** class.
|
||||||
|
The goal is generally that all relevant bat echolocation calls that pass the initial filtering should fall into _either_ one of your specific defined classes (like `pippip` or `myodau`) _or_ this generic "Bat" class.
|
||||||
|
|
||||||
|
**In summary:**
|
||||||
|
|
||||||
|
- Sounds passing **filtering** are considered relevant.
|
||||||
|
- If a relevant sound matches one of your **specific class rules** (in priority order), it gets that specific class label.
|
||||||
|
- If a relevant sound does **not** match any specific class rule, it gets the **generic "Bat" class** label.
|
||||||
|
|
||||||
|
**Crucially:** If you want certain types of sounds (even if they are bat calls) to be **completely excluded** from the training process altogether (not even included in the generic "Bat" class), you **must remove them using rules in the Filtering step (Step 2)**.
|
||||||
|
Any sound annotation that makes it past filtering _will_ be used in training, either under one of your specific classes or the generic one.
|
||||||
|
|
||||||
|
### Outcome
|
||||||
|
|
||||||
|
By defining this list of prioritized class rules, you provide `batdetect2` with a clear procedure to assign a specific target label (your class `name`) to each relevant sound event annotation based on its tags.
|
||||||
|
This labelled data is exactly what the model needs for training (Step 5).
|
361
tests/test_targets/test_transform.py
Normal file
361
tests/test_targets/test_transform.py
Normal file
@ -0,0 +1,361 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import (
|
||||||
|
DeriveTagRule,
|
||||||
|
MapValueRule,
|
||||||
|
ReplaceRule,
|
||||||
|
TagInfo,
|
||||||
|
TransformConfig,
|
||||||
|
build_transform_from_rule,
|
||||||
|
build_transformation_from_config,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.terms import TermRegistry
|
||||||
|
from batdetect2.targets.transform import DerivationRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def term_registry():
|
||||||
|
return TermRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def derivation_registry():
|
||||||
|
return DerivationRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def term1(term_registry: TermRegistry) -> data.Term:
|
||||||
|
return term_registry.add_custom_term(key="term1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def term2(term_registry: TermRegistry) -> data.Term:
|
||||||
|
return term_registry.add_custom_term(key="term2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def annotation(
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
term1: data.Term,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_value_rule(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
):
|
||||||
|
rule = MapValueRule(
|
||||||
|
rule_type="map_value",
|
||||||
|
source_term_key="term1",
|
||||||
|
value_mapping={"value1": "value2"},
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_value_rule_no_match(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
):
|
||||||
|
rule = MapValueRule(
|
||||||
|
rule_type="map_value",
|
||||||
|
source_term_key="term1",
|
||||||
|
value_mapping={"other_value": "value2"},
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_rule(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term2: data.Term,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
):
|
||||||
|
rule = ReplaceRule(
|
||||||
|
rule_type="replace",
|
||||||
|
original=TagInfo(key="term1", value="value1"),
|
||||||
|
replacement=TagInfo(key="term2", value="value2"),
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].term == term2
|
||||||
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_rule_no_match(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
term2: data.Term,
|
||||||
|
):
|
||||||
|
rule = ReplaceRule(
|
||||||
|
rule_type="replace",
|
||||||
|
original=TagInfo(key="term1", value="wrong_value"),
|
||||||
|
replacement=TagInfo(key="term2", value="value2"),
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].key == "term1"
|
||||||
|
assert transformed_annotation.tags[0].term != term2
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_transformation_from_config(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
):
|
||||||
|
config = TransformConfig(
|
||||||
|
rules=[
|
||||||
|
MapValueRule(
|
||||||
|
rule_type="map_value",
|
||||||
|
source_term_key="term1",
|
||||||
|
value_mapping={"value1": "value2"},
|
||||||
|
),
|
||||||
|
ReplaceRule(
|
||||||
|
rule_type="replace",
|
||||||
|
original=TagInfo(key="term2", value="value2"),
|
||||||
|
replacement=TagInfo(key="term3", value="value3"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
term_registry.add_custom_term("term2")
|
||||||
|
term_registry.add_custom_term("term3")
|
||||||
|
transform = build_transformation_from_config(
|
||||||
|
config,
|
||||||
|
term_registry=term_registry,
|
||||||
|
)
|
||||||
|
transformed_annotation = transform(annotation)
|
||||||
|
assert transformed_annotation.tags[0].key == "term1"
|
||||||
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
derivation_registry: DerivationRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
):
|
||||||
|
def derivation_func(x: str) -> str:
|
||||||
|
return x + "_derived"
|
||||||
|
|
||||||
|
derivation_registry.register("my_derivation", derivation_func)
|
||||||
|
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="my_derivation",
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(
|
||||||
|
rule,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
|
assert len(transformed_annotation.tags) == 2
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
assert transformed_annotation.tags[1].term == term1
|
||||||
|
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule_keep_source_false(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
derivation_registry: DerivationRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
):
|
||||||
|
def derivation_func(x: str) -> str:
|
||||||
|
return x + "_derived"
|
||||||
|
|
||||||
|
derivation_registry.register("my_derivation", derivation_func)
|
||||||
|
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="my_derivation",
|
||||||
|
keep_source=False,
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(
|
||||||
|
rule,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
|
assert len(transformed_annotation.tags) == 1
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value1_derived"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule_target_term(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
derivation_registry: DerivationRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
term2: data.Term,
|
||||||
|
):
|
||||||
|
def derivation_func(x: str) -> str:
|
||||||
|
return x + "_derived"
|
||||||
|
|
||||||
|
derivation_registry.register("my_derivation", derivation_func)
|
||||||
|
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="my_derivation",
|
||||||
|
target_term_key="term2",
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(
|
||||||
|
rule,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
|
assert len(transformed_annotation.tags) == 2
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
assert transformed_annotation.tags[1].term == term2
|
||||||
|
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule_import_derivation(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
# Create a dummy derivation function in a temporary file
|
||||||
|
derivation_module_path = (
|
||||||
|
tmp_path / "temp_derivation.py"
|
||||||
|
) # Changed to /tmp since /home/santiago is not writable
|
||||||
|
derivation_module_path.write_text(
|
||||||
|
"""
|
||||||
|
def my_imported_derivation(x: str) -> str:
|
||||||
|
return x + "_imported"
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# Ensure the temporary file is importable by adding its directory to sys.path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, str(tmp_path))
|
||||||
|
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="temp_derivation.my_imported_derivation",
|
||||||
|
import_derivation=True,
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
|
assert len(transformed_annotation.tags) == 2
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
assert transformed_annotation.tags[1].term == term1
|
||||||
|
assert transformed_annotation.tags[1].value == "value1_imported"
|
||||||
|
|
||||||
|
# Clean up the temporary file and sys.path
|
||||||
|
sys.path.remove(str(tmp_path))
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry):
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="nonexistent_derivation",
|
||||||
|
)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_transform_from_rule_invalid_rule_type():
|
||||||
|
class InvalidRule:
|
||||||
|
rule_type = "invalid"
|
||||||
|
|
||||||
|
rule = InvalidRule() # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
build_transform_from_rule(rule) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_value_rule_target_term(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
term2: data.Term,
|
||||||
|
):
|
||||||
|
rule = MapValueRule(
|
||||||
|
rule_type="map_value",
|
||||||
|
source_term_key="term1",
|
||||||
|
value_mapping={"value1": "value2"},
|
||||||
|
target_term_key="term2",
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].term == term2
|
||||||
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_map_value_rule_target_term_none(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
):
|
||||||
|
rule = MapValueRule(
|
||||||
|
rule_type="map_value",
|
||||||
|
source_term_key="term1",
|
||||||
|
value_mapping={"value1": "value2"},
|
||||||
|
target_term_key=None,
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_derive_tag_rule_target_term_none(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
term_registry: TermRegistry,
|
||||||
|
derivation_registry: DerivationRegistry,
|
||||||
|
term1: data.Term,
|
||||||
|
):
|
||||||
|
def derivation_func(x: str) -> str:
|
||||||
|
return x + "_derived"
|
||||||
|
|
||||||
|
derivation_registry.register("my_derivation", derivation_func)
|
||||||
|
|
||||||
|
rule = DeriveTagRule(
|
||||||
|
rule_type="derive_tag",
|
||||||
|
source_term_key="term1",
|
||||||
|
derivation_function="my_derivation",
|
||||||
|
target_term_key=None,
|
||||||
|
)
|
||||||
|
transform_fn = build_transform_from_rule(
|
||||||
|
rule,
|
||||||
|
term_registry=term_registry,
|
||||||
|
derivation_registry=derivation_registry,
|
||||||
|
)
|
||||||
|
transformed_annotation = transform_fn(annotation)
|
||||||
|
|
||||||
|
assert len(transformed_annotation.tags) == 2
|
||||||
|
assert transformed_annotation.tags[0].term == term1
|
||||||
|
assert transformed_annotation.tags[0].value == "value1"
|
||||||
|
assert transformed_annotation.tags[1].term == term1
|
||||||
|
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_transformation_from_config_empty(
|
||||||
|
annotation: data.SoundEventAnnotation,
|
||||||
|
):
|
||||||
|
config = TransformConfig(rules=[])
|
||||||
|
transform = build_transformation_from_config(config)
|
||||||
|
transformed_annotation = transform(annotation)
|
||||||
|
assert transformed_annotation == annotation
|
Loading…
Reference in New Issue
Block a user