Add target classes module

This commit is contained in:
mbsantiago 2025-04-15 07:32:58 +01:00
parent 02d4779207
commit af48c33307
6 changed files with 922 additions and 21 deletions

View File

@ -24,22 +24,53 @@ from batdetect2.targets.terms import (
TermInfo,
call_type,
get_tag_from_info,
get_term_from_key,
individual,
register_term,
term_registry,
)
from batdetect2.targets.transform import (
DerivationRegistry,
DeriveTagRule,
MapValueRule,
ReplaceRule,
SoundEventTransformation,
TransformConfig,
build_transform_from_rule,
build_transformation_from_config,
derivation_registry,
get_derivation,
load_transformation_config,
load_transformation_from_config,
)
__all__ = [
"DerivationRegistry",
"DeriveTagRule",
"HeatmapsConfig",
"LabelConfig",
"MapValueRule",
"ReplaceRule",
"SoundEventTransformation",
"TagInfo",
"TargetConfig",
"TermInfo",
"TransformConfig",
"build_decoder",
"build_target_encoder",
"build_transform_from_rule",
"build_transformation_from_config",
"call_type",
"derivation_registry",
"generate_heatmaps",
"get_class_names",
"get_derivation",
"get_tag_from_info",
"individual",
"load_label_config",
"load_target_config",
"load_transformation_config",
"load_transformation_from_config",
"register_term",
"term_registry",
]

View File

@ -0,0 +1,300 @@
from collections import Counter
from functools import partial
from typing import Callable, List, Literal, Optional, Set
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
term_registry,
)
__all__ = [
"SoundEventEncoder",
"TargetClass",
"ClassesConfig",
"load_classes_config",
"build_encoder_from_config",
"load_encoder_from_config",
"get_class_names_from_config",
]
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Type alias for a sound event class encoder function.
An encoder function takes a sound event annotation and returns the string name
of the target class it belongs to, based on a predefined set of rules.
If the annotation does not match any defined target class according to the
rules, the function returns None.
"""
class TargetClass(BaseConfig):
"""Defines the criteria for assigning an annotation to a specific class.
Each instance represents one potential output class for the classification
model. It specifies the class name and the tag conditions an annotation
must meet to be assigned this class label.
Attributes
----------
name : str
The unique name assigned to this target class (e.g., 'pippip',
'myodau', 'noise'). This name will be used as the label during model
training and output. Should be unique across all TargetClass
definitions in a configuration.
tag : List[TagInfo]
A list of one or more tags (defined using `TagInfo`) that an annotation
must possess to potentially match this class.
match_type : Literal["all", "any"], default="all"
Determines how the `tag` list is evaluated:
- "all": The annotation must have *all* the tags listed in the `tag`
field to match this class definition.
- "any": The annotation must have *at least one* of the tags listed
in the `tag` field to match this class definition.
"""
name: str
tags: List[TagInfo] = Field(default_factory=list, min_length=1)
match_type: Literal["all", "any"] = Field(default="all")
class ClassesConfig(BaseConfig):
"""Configuration model holding the list of target class definitions.
The order of `TargetClass` objects in the `classes` list defines the
priority for classification. When encoding an annotation, the system checks
against the class definitions in this sequence and assigns the name of the
*first* matching class.
Attributes
----------
classes : List[TargetClass]
An ordered list of target class definitions. The order determines
matching priority (first match wins).
Raises
------
ValueError
If validation fails (e.g., non-unique class names).
Notes
-----
It is crucial that the `name` attribute of each `TargetClass` in the
`classes` list is unique. This configuration includes a validator to
enforce this uniqueness.
"""
classes: List[TargetClass] = Field(default_factory=list)
@field_validator("classes")
def check_unique_class_names(cls, v: List[TargetClass]):
"""Ensure all defined class names are unique."""
names = [c.name for c in v]
if len(names) != len(set(names)):
name_counts = Counter(names)
duplicates = [
name for name, count in name_counts.items() if count > 1
]
raise ValueError(
"Class names must be unique. Found duplicates: "
f"{', '.join(duplicates)}"
)
return v
def load_classes_config(
path: data.PathLike,
field: Optional[str] = None,
) -> ClassesConfig:
"""Load the target classes configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
ClassesConfig
The loaded and validated classes configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
"""
return load_config(path, schema=ClassesConfig, field=field)
def is_target_class(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
match_all: bool = True,
) -> bool:
"""Check if a sound event annotation matches a set of required tags.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
required_tags : Set[data.Tag]
A set of `soundevent.data.Tag` objects that define the class criteria.
match_all : bool, default=True
If True, checks if *all* `required_tags` are present in the
annotation's tags (subset check). If False, checks if *at least one*
of the `required_tags` is present (intersection check).
Returns
-------
bool
True if the annotation meets the tag criteria, False otherwise.
"""
annotation_tags = set(sound_event_annotation.tags)
if match_all:
return tags <= annotation_tags
return bool(tags & annotation_tags)
def get_class_names_from_config(config: ClassesConfig) -> List[str]:
"""Extract the list of class names from a ClassesConfig object.
Parameters
----------
config : ClassesConfig
The loaded classes configuration object.
Returns
-------
List[str]
An ordered list of unique class names defined in the configuration.
"""
return [class_info.name for class_info in config.classes]
def build_encoder_from_config(
config: ClassesConfig,
term_registry: TermRegistry = term_registry,
) -> SoundEventEncoder:
"""Build a sound event encoder function from the classes configuration.
The returned encoder function iterates through the class definitions in the
order specified in the config. It assigns an annotation the name of the
first class definition it matches.
Parameters
----------
config : ClassesConfig
The loaded and validated classes configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance used to look up term keys specified in the
`TagInfo` objects within the configuration. Defaults to the global
`batdetect2.targets.terms.registry`.
Returns
-------
SoundEventEncoder
A callable function that takes a `SoundEventAnnotation` and returns
an optional string representing the matched class name, or None if no
class matches.
Raises
------
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry`.
"""
binary_classifiers = [
(
class_info.name,
partial(
is_target_class,
tags={
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in class_info.tags
},
match_all=class_info.match_type == "all",
),
)
for class_info in config.classes
]
def encoder(
sound_event_annotation: data.SoundEventAnnotation,
) -> Optional[str]:
"""Assign a class name to an annotation based on configured rules.
Iterates through pre-compiled classifiers in priority order. Returns
the name of the first matching class, or None if no match is found.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to encode.
Returns
-------
str or None
The name of the matched class, or None.
"""
for class_name, classifier in binary_classifiers:
if classifier(sound_event_annotation):
return class_name
return None
return encoder
def load_encoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
) -> SoundEventEncoder:
"""Load a class encoder function directly from a configuration file.
This is a convenience function that combines loading the `ClassesConfig`
from a file and building the final `SoundEventEncoder` function.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
term_registry : TermRegistry, optional
The TermRegistry instance used for term lookups. Defaults to the
global `batdetect2.targets.terms.registry`.
Returns
-------
SoundEventEncoder
The final encoder function ready to classify annotations.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry` during the build process.
"""
config = load_classes_config(path, field=field)
return build_encoder_from_config(config, term_registry=term_registry)

View File

@ -143,7 +143,7 @@ class TermRegistry(Mapping[str, data.Term]):
raise KeyError(
"No term found for key "
f"'{key}'. Ensure it is registered or loaded. "
"Use `get_term_keys()` to list available terms."
f"Available keys: {', '.join(self.get_keys())}"
) from err
def add_custom_term(
@ -215,8 +215,11 @@ class TermRegistry(Mapping[str, data.Term]):
"""
return list(self._terms.values())
def remove_key(self, key: str) -> None:
del self._terms[key]
registry = TermRegistry(
term_registry = TermRegistry(
terms=dict(
[
*getmembers(terms, lambda x: isinstance(x, data.Term)),
@ -237,7 +240,7 @@ is explicitly passed.
def get_term_from_key(
key: str,
term_registry: TermRegistry = registry,
term_registry: TermRegistry = term_registry,
) -> data.Term:
"""Convenience function to retrieve a term by key from a registry.
@ -265,7 +268,7 @@ def get_term_from_key(
return term_registry.get_term(key)
def get_term_keys(term_registry: TermRegistry = registry) -> List[str]:
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
"""Convenience function to get all registered keys from a registry.
Uses the global default registry unless a specific `term_registry`
@ -284,7 +287,7 @@ def get_term_keys(term_registry: TermRegistry = registry) -> List[str]:
return term_registry.get_keys()
def get_terms(term_registry: TermRegistry = registry) -> List[data.Term]:
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
"""Convenience function to get all registered terms from a registry.
Uses the global default registry unless a specific `term_registry`
@ -327,7 +330,7 @@ class TagInfo(BaseModel):
def get_tag_from_info(
tag_info: TagInfo,
term_registry: TermRegistry = registry,
term_registry: TermRegistry = term_registry,
) -> data.Tag:
"""Creates a soundevent.data.Tag object from TagInfo data.
@ -424,7 +427,7 @@ class TermConfig(BaseModel):
def load_terms_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = registry,
term_registry: TermRegistry = term_registry,
) -> Dict[str, data.Term]:
"""Loads term definitions from a configuration file and registers them.
@ -472,3 +475,9 @@ def load_terms_from_config(
)
for info in data.terms
}
def register_term(
key: str, term: data.Term, registry: TermRegistry = term_registry
) -> None:
registry.add_term(key, term)

View File

@ -1,6 +1,15 @@
import importlib
from functools import partial
from typing import Callable, Dict, List, Literal, Mapping, Optional, Union
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)
from pydantic import Field
from soundevent import data
@ -8,9 +17,13 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
get_term_from_key,
)
from batdetect2.targets.terms import (
term_registry as default_term_registry,
)
__all__ = [
"SoundEventTransformation",
@ -282,9 +295,13 @@ class TransformConfig(BaseConfig):
discriminates between the different rule models.
"""
rules: List[Union[ReplaceRule, MapValueRule, DeriveTagRule]] = Field(
rules: List[
Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
] = Field(
default_factory=list,
discriminator="rule_type",
)
@ -442,7 +459,8 @@ def get_derivation(
def build_transform_from_rule(
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
registry: DerivationRegistry = derivation_registry,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventTransformation:
"""Build a specific SoundEventTransformation function from a rule config.
@ -473,21 +491,33 @@ def build_transform_from_rule(
If dynamic import of a derivation function fails.
"""
if rule.rule_type == "replace":
source = get_tag_from_info(rule.original)
target = get_tag_from_info(rule.replacement)
source = get_tag_from_info(
rule.original,
term_registry=term_registry,
)
target = get_tag_from_info(
rule.replacement,
term_registry=term_registry,
)
return partial(replace_tag_transform, source=source, target=target)
if rule.rule_type == "derive_tag":
source_term = get_term_from_key(rule.source_term_key)
source_term = get_term_from_key(
rule.source_term_key,
term_registry=term_registry,
)
target_term = (
get_term_from_key(rule.target_term_key)
get_term_from_key(
rule.target_term_key,
term_registry=term_registry,
)
if rule.target_term_key
else source_term
)
derivation = get_derivation(
key=rule.derivation_function,
import_derivation=rule.import_derivation,
registry=registry,
registry=derivation_registry,
)
return partial(
derivation_tag_transform,
@ -498,9 +528,15 @@ def build_transform_from_rule(
)
if rule.rule_type == "map_value":
source_term = get_term_from_key(rule.source_term_key)
source_term = get_term_from_key(
rule.source_term_key,
term_registry=term_registry,
)
target_term = (
get_term_from_key(rule.target_term_key)
get_term_from_key(
rule.target_term_key,
term_registry=term_registry,
)
if rule.target_term_key
else source_term
)
@ -522,6 +558,8 @@ def build_transform_from_rule(
def build_transformation_from_config(
config: TransformConfig,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventTransformation:
"""Build a composite transformation function from a TransformConfig.
@ -542,7 +580,14 @@ def build_transformation_from_config(
SoundEventTransformation
A single function that applies all configured transformations in order.
"""
transforms = [build_transform_from_rule(rule) for rule in config.rules]
transforms = [
build_transform_from_rule(
rule,
derivation_registry=derivation_registry,
term_registry=term_registry,
)
for rule in config.rules
]
def transformation(
sound_event_annotation: data.SoundEventAnnotation,
@ -583,7 +628,10 @@ def load_transformation_config(
def load_transformation_from_config(
path: data.PathLike, field: Optional[str] = None
path: data.PathLike,
field: Optional[str] = None,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventTransformation:
"""Load transformation config from a file and build the final function.
@ -618,4 +666,8 @@ def load_transformation_from_config(
fails.
"""
config = load_transformation_config(path=path, field=field)
return build_transformation_from_config(config)
return build_transformation_from_config(
config,
derivation_registry=derivation_registry,
term_registry=term_registry,
)

148
docs/targets/classes.md Normal file
View File

@ -0,0 +1,148 @@
## Step 4: Defining Target Classes for Training
### Purpose and Context
You've prepared your data by defining your annotation vocabulary (Step 1: Terms), removing irrelevant sounds (Step 2: Filtering), and potentially cleaning up or modifying tags (Step 3: Transforming Tags).
Now, it's time to tell `batdetect2` **exactly what categories (classes) your model should learn to identify**.
This step involves defining rules that map the final tags on your sound event annotations to specific **class names** (like `pippip`, `myodau`, or `noise`).
These class names are the labels the machine learning model will be trained to predict.
Getting this definition right is essential for successful model training.
### How it Works: Defining Classes with Rules
You define your target classes in your main configuration file (e.g., your `.yaml` training config), typically under a section named `classes`.
This section contains a **list** of class definitions.
Each item in the list defines one specific class your model should learn.
### Defining a Single Class
Each class definition rule requires a few key pieces of information:
1. `name`: **(Required)** This is the unique, simple name you want to give this class (e.g., `pipistrellus_pipistrellus`, `myotis_daubentonii`, `echolocation_noise`).
This is the label the model will actually use.
Choose names that are clear and distinct.
**Each class name must be unique.**
2. `tags`: **(Required)** This is a list containing one or more specific tags that identify annotations belonging to this class.
Remember, each tag is specified using its term `key` (like `species` or `sound_type`, defaulting to `class` if omitted) and its specific `value` (like `Pipistrellus pipistrellus` or `Echolocation`).
3. `match_type`: **(Optional, defaults to `"all"`)** This tells the system how to use the list of tags you provided in the `tag` field:
- `"all"`: An annotation must have **ALL** of the tags listed in the `tags` section to be considered part of this class.
(This is the default if you don't specify `match_type`).
- `"any"`: An annotation only needs to have **AT LEAST ONE** of the tags listed in the `tags` section to be considered part of this class.
**Example: Defining two specific bat species classes**
```yaml
# In your main configuration file
classes:
# Definition for the first class
- name: pippip # Simple name for Pipistrellus pipistrellus
tags:
- key: species # Term key (could also default to 'class')
value: Pipistrellus pipistrellus # Specific tag value
# match_type defaults to "all" (which is fine for a single tag)
# Definition for the second class
- name: myodau # Simple name for Myotis daubentonii
tags:
- key: species
value: Myotis daubentonii
```
**Example: Defining a class requiring multiple conditions (`match_type: "all"`)**
```yaml
classes:
- name: high_quality_pippip # Name for high-quality P. pip calls
match_type: all # Annotation must match BOTH tags below
tags:
- key: species
value: Pipistrellus pipistrellus
- key: quality # Assumes 'quality' term key exists
value: Good
```
**Example: Defining a class matching multiple alternative tags (`match_type: "any"`)**
```yaml
classes:
- name: pipistrelle # Name for any Pipistrellus species in this list
match_type: any # Annotation must match AT LEAST ONE tag below
tags:
- key: species
value: Pipistrellus pipistrellus
- key: species
value: Pipistrellus pygmaeus
- key: species
value: Pipistrellus nathusii
```
### Handling Overlap: Priority Order Matters!
Sometimes, an annotation might have tags that match the rules for _more than one_ class definition.
For example, an annotation tagged `species: Pipistrellus pipistrellus` would match both a specific `'pippip'` class rule and a broader `'pipistrelle'` genus rule (like the examples above) if both were defined.
How does `batdetect2` decide which class name to assign? It uses the **order of the class definitions in your configuration list**.
- The system checks an annotation against your class rules one by one, starting from the **top** of the `classes` list and moving down.
- As soon as it finds a rule that the annotation matches, it assigns that rule's `name` to the annotation and **stops checking** further rules for that annotation.
- **The first match wins!**
Therefore, you should generally place your **most specific rules before more general rules** if you want the specific category to take precedence.
**Example: Prioritizing Species over Noise**
```yaml
classes:
# --- Specific Species Rules (Checked First) ---
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myodau
tags:
- key: species
value: Myotis daubentonii
# --- General Noise Rule (Checked Last) ---
- name: noise # Catch-all for anything tagged as Noise
match_type: any # Match if any noise tag is present
tags:
- key: sound_type # Assume 'sound_type' term key exists
value: Noise
- key: quality # Assume 'quality' term key exists
value: Low # Maybe low quality is also considered noise for training
```
In this example, an annotation tagged with `species: Myotis daubentonii` _and_ `quality: Low` would be assigned the class name `myodau` because that rule comes first in the list.
It would not be assigned `noise`, even though it also matches the second condition of the noise rule.
Okay, that's a very important clarification about how BatDetect2 handles sounds that don't match specific class definitions.
Let's refine that section to accurately reflect this behavior.
### What if No Class Matches? (The Generic "Bat" Class)
It's important to understand what happens if a sound event annotation passes through the filtering (Step 2) and transformation (Step 3) steps, but its final set of tags doesn't match _any_ of the specific class definitions you've listed in this section.
These annotations are **not ignored** during training.
Instead, they are typically assigned to a **generic "relevant sound" class**.
Think of this as a category for sounds that you considered important enough to keep after filtering, but which don't fit into one of your specific target classes for detailed classification (like a particular species).
This generic class is distinct from background noise.
In BatDetect2, this default generic class is often referred to as the **"Bat"** class.
The goal is generally that all relevant bat echolocation calls that pass the initial filtering should fall into _either_ one of your specific defined classes (like `pippip` or `myodau`) _or_ this generic "Bat" class.
**In summary:**
- Sounds passing **filtering** are considered relevant.
- If a relevant sound matches one of your **specific class rules** (in priority order), it gets that specific class label.
- If a relevant sound does **not** match any specific class rule, it gets the **generic "Bat" class** label.
**Crucially:** If you want certain types of sounds (even if they are bat calls) to be **completely excluded** from the training process altogether (not even included in the generic "Bat" class), you **must remove them using rules in the Filtering step (Step 2)**.
Any sound annotation that makes it past filtering _will_ be used in training, either under one of your specific classes or the generic one.
### Outcome
By defining this list of prioritized class rules, you provide `batdetect2` with a clear procedure to assign a specific target label (your class `name`) to each relevant sound event annotation based on its tags.
This labelled data is exactly what the model needs for training (Step 5).

View File

@ -0,0 +1,361 @@
from pathlib import Path
import pytest
from soundevent import data
from batdetect2.targets import (
DeriveTagRule,
MapValueRule,
ReplaceRule,
TagInfo,
TransformConfig,
build_transform_from_rule,
build_transformation_from_config,
)
from batdetect2.targets.terms import TermRegistry
from batdetect2.targets.transform import DerivationRegistry
@pytest.fixture
def term_registry():
return TermRegistry()
@pytest.fixture
def derivation_registry():
return DerivationRegistry()
@pytest.fixture
def term1(term_registry: TermRegistry) -> data.Term:
return term_registry.add_custom_term(key="term1")
@pytest.fixture
def term2(term_registry: TermRegistry) -> data.Term:
return term_registry.add_custom_term(key="term2")
@pytest.fixture
def annotation(
sound_event: data.SoundEvent,
term1: data.Term,
) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
)
def test_map_value_rule(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].value == "value2"
def test_map_value_rule_no_match(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"other_value": "value2"},
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].value == "value1"
def test_replace_rule(
annotation: data.SoundEventAnnotation,
term2: data.Term,
term_registry: TermRegistry,
):
rule = ReplaceRule(
rule_type="replace",
original=TagInfo(key="term1", value="value1"),
replacement=TagInfo(key="term2", value="value2"),
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term2
assert transformed_annotation.tags[0].value == "value2"
def test_replace_rule_no_match(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term2: data.Term,
):
rule = ReplaceRule(
rule_type="replace",
original=TagInfo(key="term1", value="wrong_value"),
replacement=TagInfo(key="term2", value="value2"),
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].key == "term1"
assert transformed_annotation.tags[0].term != term2
assert transformed_annotation.tags[0].value == "value1"
def test_build_transformation_from_config(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
):
config = TransformConfig(
rules=[
MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
),
ReplaceRule(
rule_type="replace",
original=TagInfo(key="term2", value="value2"),
replacement=TagInfo(key="term3", value="value3"),
),
]
)
term_registry.add_custom_term("term2")
term_registry.add_custom_term("term3")
transform = build_transformation_from_config(
config,
term_registry=term_registry,
)
transformed_annotation = transform(annotation)
assert transformed_annotation.tags[0].key == "term1"
assert transformed_annotation.tags[0].value == "value2"
def test_derive_tag_rule(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term1
assert transformed_annotation.tags[1].value == "value1_derived"
def test_derive_tag_rule_keep_source_false(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
keep_source=False,
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 1
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1_derived"
def test_derive_tag_rule_target_term(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
term2: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
target_term_key="term2",
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term2
assert transformed_annotation.tags[1].value == "value1_derived"
def test_derive_tag_rule_import_derivation(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term1: data.Term,
tmp_path: Path,
):
# Create a dummy derivation function in a temporary file
derivation_module_path = (
tmp_path / "temp_derivation.py"
) # Changed to /tmp since /home/santiago is not writable
derivation_module_path.write_text(
"""
def my_imported_derivation(x: str) -> str:
return x + "_imported"
"""
)
# Ensure the temporary file is importable by adding its directory to sys.path
import sys
sys.path.insert(0, str(tmp_path))
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="temp_derivation.my_imported_derivation",
import_derivation=True,
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term1
assert transformed_annotation.tags[1].value == "value1_imported"
# Clean up the temporary file and sys.path
sys.path.remove(str(tmp_path))
def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry):
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="nonexistent_derivation",
)
with pytest.raises(KeyError):
build_transform_from_rule(rule, term_registry=term_registry)
def test_build_transform_from_rule_invalid_rule_type():
class InvalidRule:
rule_type = "invalid"
rule = InvalidRule() # type: ignore
with pytest.raises(ValueError):
build_transform_from_rule(rule) # type: ignore
def test_map_value_rule_target_term(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term2: data.Term,
):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
target_term_key="term2",
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term2
assert transformed_annotation.tags[0].value == "value2"
def test_map_value_rule_target_term_none(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
term1: data.Term,
):
rule = MapValueRule(
rule_type="map_value",
source_term_key="term1",
value_mapping={"value1": "value2"},
target_term_key=None,
)
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
transformed_annotation = transform_fn(annotation)
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value2"
def test_derive_tag_rule_target_term_none(
annotation: data.SoundEventAnnotation,
term_registry: TermRegistry,
derivation_registry: DerivationRegistry,
term1: data.Term,
):
def derivation_func(x: str) -> str:
return x + "_derived"
derivation_registry.register("my_derivation", derivation_func)
rule = DeriveTagRule(
rule_type="derive_tag",
source_term_key="term1",
derivation_function="my_derivation",
target_term_key=None,
)
transform_fn = build_transform_from_rule(
rule,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
transformed_annotation = transform_fn(annotation)
assert len(transformed_annotation.tags) == 2
assert transformed_annotation.tags[0].term == term1
assert transformed_annotation.tags[0].value == "value1"
assert transformed_annotation.tags[1].term == term1
assert transformed_annotation.tags[1].value == "value1_derived"
def test_build_transformation_from_config_empty(
annotation: data.SoundEventAnnotation,
):
config = TransformConfig(rules=[])
transform = build_transformation_from_config(config)
transformed_annotation = transform(annotation)
assert transformed_annotation == annotation