From af48c333077a2e6e804333705b0d2cae59bf71ae Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 15 Apr 2025 07:32:58 +0100 Subject: [PATCH] Add target classes module --- batdetect2/targets/__init__.py | 31 +++ batdetect2/targets/classes.py | 300 ++++++++++++++++++++++ batdetect2/targets/terms.py | 23 +- batdetect2/targets/transform.py | 80 ++++-- docs/targets/classes.md | 148 +++++++++++ tests/test_targets/test_transform.py | 361 +++++++++++++++++++++++++++ 6 files changed, 922 insertions(+), 21 deletions(-) create mode 100644 batdetect2/targets/classes.py create mode 100644 docs/targets/classes.md create mode 100644 tests/test_targets/test_transform.py diff --git a/batdetect2/targets/__init__.py b/batdetect2/targets/__init__.py index a04d685..03a8f8a 100644 --- a/batdetect2/targets/__init__.py +++ b/batdetect2/targets/__init__.py @@ -24,22 +24,53 @@ from batdetect2.targets.terms import ( TermInfo, call_type, get_tag_from_info, + get_term_from_key, individual, + register_term, + term_registry, +) +from batdetect2.targets.transform import ( + DerivationRegistry, + DeriveTagRule, + MapValueRule, + ReplaceRule, + SoundEventTransformation, + TransformConfig, + build_transform_from_rule, + build_transformation_from_config, + derivation_registry, + get_derivation, + load_transformation_config, + load_transformation_from_config, ) __all__ = [ + "DerivationRegistry", + "DeriveTagRule", "HeatmapsConfig", "LabelConfig", + "MapValueRule", + "ReplaceRule", + "SoundEventTransformation", "TagInfo", "TargetConfig", "TermInfo", + "TransformConfig", "build_decoder", "build_target_encoder", + "build_transform_from_rule", + "build_transformation_from_config", "call_type", + "derivation_registry", "generate_heatmaps", "get_class_names", + "get_derivation", "get_tag_from_info", "individual", "load_label_config", "load_target_config", + "load_transformation_config", + "load_transformation_from_config", + "register_term", + "term_registry", ] diff --git a/batdetect2/targets/classes.py b/batdetect2/targets/classes.py new file mode 100644 index 0000000..07df2f7 --- /dev/null +++ b/batdetect2/targets/classes.py @@ -0,0 +1,300 @@ +from collections import Counter +from functools import partial +from typing import Callable, List, Literal, Optional, Set + +from pydantic import Field, field_validator +from soundevent import data + +from batdetect2.configs import BaseConfig, load_config +from batdetect2.targets.terms import ( + TagInfo, + TermRegistry, + get_tag_from_info, + term_registry, +) + +__all__ = [ + "SoundEventEncoder", + "TargetClass", + "ClassesConfig", + "load_classes_config", + "build_encoder_from_config", + "load_encoder_from_config", + "get_class_names_from_config", +] + +SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]] +"""Type alias for a sound event class encoder function. + +An encoder function takes a sound event annotation and returns the string name +of the target class it belongs to, based on a predefined set of rules. +If the annotation does not match any defined target class according to the +rules, the function returns None. +""" + + +class TargetClass(BaseConfig): + """Defines the criteria for assigning an annotation to a specific class. + + Each instance represents one potential output class for the classification + model. It specifies the class name and the tag conditions an annotation + must meet to be assigned this class label. + + Attributes + ---------- + name : str + The unique name assigned to this target class (e.g., 'pippip', + 'myodau', 'noise'). This name will be used as the label during model + training and output. Should be unique across all TargetClass + definitions in a configuration. + tag : List[TagInfo] + A list of one or more tags (defined using `TagInfo`) that an annotation + must possess to potentially match this class. + match_type : Literal["all", "any"], default="all" + Determines how the `tag` list is evaluated: + - "all": The annotation must have *all* the tags listed in the `tag` + field to match this class definition. + - "any": The annotation must have *at least one* of the tags listed + in the `tag` field to match this class definition. + """ + + name: str + tags: List[TagInfo] = Field(default_factory=list, min_length=1) + match_type: Literal["all", "any"] = Field(default="all") + + +class ClassesConfig(BaseConfig): + """Configuration model holding the list of target class definitions. + + The order of `TargetClass` objects in the `classes` list defines the + priority for classification. When encoding an annotation, the system checks + against the class definitions in this sequence and assigns the name of the + *first* matching class. + + Attributes + ---------- + classes : List[TargetClass] + An ordered list of target class definitions. The order determines + matching priority (first match wins). + + Raises + ------ + ValueError + If validation fails (e.g., non-unique class names). + + Notes + ----- + It is crucial that the `name` attribute of each `TargetClass` in the + `classes` list is unique. This configuration includes a validator to + enforce this uniqueness. + """ + + classes: List[TargetClass] = Field(default_factory=list) + + @field_validator("classes") + def check_unique_class_names(cls, v: List[TargetClass]): + """Ensure all defined class names are unique.""" + names = [c.name for c in v] + + if len(names) != len(set(names)): + name_counts = Counter(names) + duplicates = [ + name for name, count in name_counts.items() if count > 1 + ] + raise ValueError( + "Class names must be unique. Found duplicates: " + f"{', '.join(duplicates)}" + ) + return v + + +def load_classes_config( + path: data.PathLike, + field: Optional[str] = None, +) -> ClassesConfig: + """Load the target classes configuration from a file. + + Parameters + ---------- + path : data.PathLike + Path to the configuration file (YAML). + field : str, optional + If the classes configuration is nested under a specific key in the + file, specify the key here. Defaults to None. + + Returns + ------- + ClassesConfig + The loaded and validated classes configuration object. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + pydantic.ValidationError + If the config file structure does not match the ClassesConfig schema + or if class names are not unique. + """ + return load_config(path, schema=ClassesConfig, field=field) + + +def is_target_class( + sound_event_annotation: data.SoundEventAnnotation, + tags: Set[data.Tag], + match_all: bool = True, +) -> bool: + """Check if a sound event annotation matches a set of required tags. + + Parameters + ---------- + sound_event_annotation : data.SoundEventAnnotation + The annotation to check. + required_tags : Set[data.Tag] + A set of `soundevent.data.Tag` objects that define the class criteria. + match_all : bool, default=True + If True, checks if *all* `required_tags` are present in the + annotation's tags (subset check). If False, checks if *at least one* + of the `required_tags` is present (intersection check). + + Returns + ------- + bool + True if the annotation meets the tag criteria, False otherwise. + """ + annotation_tags = set(sound_event_annotation.tags) + + if match_all: + return tags <= annotation_tags + + return bool(tags & annotation_tags) + + +def get_class_names_from_config(config: ClassesConfig) -> List[str]: + """Extract the list of class names from a ClassesConfig object. + + Parameters + ---------- + config : ClassesConfig + The loaded classes configuration object. + + Returns + ------- + List[str] + An ordered list of unique class names defined in the configuration. + """ + return [class_info.name for class_info in config.classes] + + +def build_encoder_from_config( + config: ClassesConfig, + term_registry: TermRegistry = term_registry, +) -> SoundEventEncoder: + """Build a sound event encoder function from the classes configuration. + + The returned encoder function iterates through the class definitions in the + order specified in the config. It assigns an annotation the name of the + first class definition it matches. + + Parameters + ---------- + config : ClassesConfig + The loaded and validated classes configuration object. + term_registry : TermRegistry, optional + The TermRegistry instance used to look up term keys specified in the + `TagInfo` objects within the configuration. Defaults to the global + `batdetect2.targets.terms.registry`. + + Returns + ------- + SoundEventEncoder + A callable function that takes a `SoundEventAnnotation` and returns + an optional string representing the matched class name, or None if no + class matches. + + Raises + ------ + KeyError + If a term key specified in the configuration is not found in the + provided `term_registry`. + """ + binary_classifiers = [ + ( + class_info.name, + partial( + is_target_class, + tags={ + get_tag_from_info(tag_info, term_registry=term_registry) + for tag_info in class_info.tags + }, + match_all=class_info.match_type == "all", + ), + ) + for class_info in config.classes + ] + + def encoder( + sound_event_annotation: data.SoundEventAnnotation, + ) -> Optional[str]: + """Assign a class name to an annotation based on configured rules. + + Iterates through pre-compiled classifiers in priority order. Returns + the name of the first matching class, or None if no match is found. + + Parameters + ---------- + sound_event_annotation : data.SoundEventAnnotation + The annotation to encode. + + Returns + ------- + str or None + The name of the matched class, or None. + """ + for class_name, classifier in binary_classifiers: + if classifier(sound_event_annotation): + return class_name + + return None + + return encoder + + +def load_encoder_from_config( + path: data.PathLike, + field: Optional[str] = None, + term_registry: TermRegistry = term_registry, +) -> SoundEventEncoder: + """Load a class encoder function directly from a configuration file. + + This is a convenience function that combines loading the `ClassesConfig` + from a file and building the final `SoundEventEncoder` function. + + Parameters + ---------- + path : data.PathLike + Path to the configuration file (e.g., YAML). + field : str, optional + If the classes configuration is nested under a specific key in the + file, specify the key here. Defaults to None. + term_registry : TermRegistry, optional + The TermRegistry instance used for term lookups. Defaults to the + global `batdetect2.targets.terms.registry`. + + Returns + ------- + SoundEventEncoder + The final encoder function ready to classify annotations. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + pydantic.ValidationError + If the config file structure does not match the ClassesConfig schema + or if class names are not unique. + KeyError + If a term key specified in the configuration is not found in the + provided `term_registry` during the build process. + """ + config = load_classes_config(path, field=field) + return build_encoder_from_config(config, term_registry=term_registry) diff --git a/batdetect2/targets/terms.py b/batdetect2/targets/terms.py index 308b5e4..07f2a2e 100644 --- a/batdetect2/targets/terms.py +++ b/batdetect2/targets/terms.py @@ -143,7 +143,7 @@ class TermRegistry(Mapping[str, data.Term]): raise KeyError( "No term found for key " f"'{key}'. Ensure it is registered or loaded. " - "Use `get_term_keys()` to list available terms." + f"Available keys: {', '.join(self.get_keys())}" ) from err def add_custom_term( @@ -215,8 +215,11 @@ class TermRegistry(Mapping[str, data.Term]): """ return list(self._terms.values()) + def remove_key(self, key: str) -> None: + del self._terms[key] -registry = TermRegistry( + +term_registry = TermRegistry( terms=dict( [ *getmembers(terms, lambda x: isinstance(x, data.Term)), @@ -237,7 +240,7 @@ is explicitly passed. def get_term_from_key( key: str, - term_registry: TermRegistry = registry, + term_registry: TermRegistry = term_registry, ) -> data.Term: """Convenience function to retrieve a term by key from a registry. @@ -265,7 +268,7 @@ def get_term_from_key( return term_registry.get_term(key) -def get_term_keys(term_registry: TermRegistry = registry) -> List[str]: +def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]: """Convenience function to get all registered keys from a registry. Uses the global default registry unless a specific `term_registry` @@ -284,7 +287,7 @@ def get_term_keys(term_registry: TermRegistry = registry) -> List[str]: return term_registry.get_keys() -def get_terms(term_registry: TermRegistry = registry) -> List[data.Term]: +def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]: """Convenience function to get all registered terms from a registry. Uses the global default registry unless a specific `term_registry` @@ -327,7 +330,7 @@ class TagInfo(BaseModel): def get_tag_from_info( tag_info: TagInfo, - term_registry: TermRegistry = registry, + term_registry: TermRegistry = term_registry, ) -> data.Tag: """Creates a soundevent.data.Tag object from TagInfo data. @@ -424,7 +427,7 @@ class TermConfig(BaseModel): def load_terms_from_config( path: data.PathLike, field: Optional[str] = None, - term_registry: TermRegistry = registry, + term_registry: TermRegistry = term_registry, ) -> Dict[str, data.Term]: """Loads term definitions from a configuration file and registers them. @@ -472,3 +475,9 @@ def load_terms_from_config( ) for info in data.terms } + + +def register_term( + key: str, term: data.Term, registry: TermRegistry = term_registry +) -> None: + registry.add_term(key, term) diff --git a/batdetect2/targets/transform.py b/batdetect2/targets/transform.py index 93f57e4..5d04c37 100644 --- a/batdetect2/targets/transform.py +++ b/batdetect2/targets/transform.py @@ -1,6 +1,15 @@ import importlib from functools import partial -from typing import Callable, Dict, List, Literal, Mapping, Optional, Union +from typing import ( + Annotated, + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Union, +) from pydantic import Field from soundevent import data @@ -8,9 +17,13 @@ from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.targets.terms import ( TagInfo, + TermRegistry, get_tag_from_info, get_term_from_key, ) +from batdetect2.targets.terms import ( + term_registry as default_term_registry, +) __all__ = [ "SoundEventTransformation", @@ -282,9 +295,13 @@ class TransformConfig(BaseConfig): discriminates between the different rule models. """ - rules: List[Union[ReplaceRule, MapValueRule, DeriveTagRule]] = Field( + rules: List[ + Annotated[ + Union[ReplaceRule, MapValueRule, DeriveTagRule], + Field(discriminator="rule_type"), + ] + ] = Field( default_factory=list, - discriminator="rule_type", ) @@ -442,7 +459,8 @@ def get_derivation( def build_transform_from_rule( rule: Union[ReplaceRule, MapValueRule, DeriveTagRule], - registry: DerivationRegistry = derivation_registry, + derivation_registry: DerivationRegistry = derivation_registry, + term_registry: TermRegistry = default_term_registry, ) -> SoundEventTransformation: """Build a specific SoundEventTransformation function from a rule config. @@ -473,21 +491,33 @@ def build_transform_from_rule( If dynamic import of a derivation function fails. """ if rule.rule_type == "replace": - source = get_tag_from_info(rule.original) - target = get_tag_from_info(rule.replacement) + source = get_tag_from_info( + rule.original, + term_registry=term_registry, + ) + target = get_tag_from_info( + rule.replacement, + term_registry=term_registry, + ) return partial(replace_tag_transform, source=source, target=target) if rule.rule_type == "derive_tag": - source_term = get_term_from_key(rule.source_term_key) + source_term = get_term_from_key( + rule.source_term_key, + term_registry=term_registry, + ) target_term = ( - get_term_from_key(rule.target_term_key) + get_term_from_key( + rule.target_term_key, + term_registry=term_registry, + ) if rule.target_term_key else source_term ) derivation = get_derivation( key=rule.derivation_function, import_derivation=rule.import_derivation, - registry=registry, + registry=derivation_registry, ) return partial( derivation_tag_transform, @@ -498,9 +528,15 @@ def build_transform_from_rule( ) if rule.rule_type == "map_value": - source_term = get_term_from_key(rule.source_term_key) + source_term = get_term_from_key( + rule.source_term_key, + term_registry=term_registry, + ) target_term = ( - get_term_from_key(rule.target_term_key) + get_term_from_key( + rule.target_term_key, + term_registry=term_registry, + ) if rule.target_term_key else source_term ) @@ -522,6 +558,8 @@ def build_transform_from_rule( def build_transformation_from_config( config: TransformConfig, + derivation_registry: DerivationRegistry = derivation_registry, + term_registry: TermRegistry = default_term_registry, ) -> SoundEventTransformation: """Build a composite transformation function from a TransformConfig. @@ -542,7 +580,14 @@ def build_transformation_from_config( SoundEventTransformation A single function that applies all configured transformations in order. """ - transforms = [build_transform_from_rule(rule) for rule in config.rules] + transforms = [ + build_transform_from_rule( + rule, + derivation_registry=derivation_registry, + term_registry=term_registry, + ) + for rule in config.rules + ] def transformation( sound_event_annotation: data.SoundEventAnnotation, @@ -583,7 +628,10 @@ def load_transformation_config( def load_transformation_from_config( - path: data.PathLike, field: Optional[str] = None + path: data.PathLike, + field: Optional[str] = None, + derivation_registry: DerivationRegistry = derivation_registry, + term_registry: TermRegistry = default_term_registry, ) -> SoundEventTransformation: """Load transformation config from a file and build the final function. @@ -618,4 +666,8 @@ def load_transformation_from_config( fails. """ config = load_transformation_config(path=path, field=field) - return build_transformation_from_config(config) + return build_transformation_from_config( + config, + derivation_registry=derivation_registry, + term_registry=term_registry, + ) diff --git a/docs/targets/classes.md b/docs/targets/classes.md new file mode 100644 index 0000000..f5dc3cd --- /dev/null +++ b/docs/targets/classes.md @@ -0,0 +1,148 @@ +## Step 4: Defining Target Classes for Training + +### Purpose and Context + +You've prepared your data by defining your annotation vocabulary (Step 1: Terms), removing irrelevant sounds (Step 2: Filtering), and potentially cleaning up or modifying tags (Step 3: Transforming Tags). +Now, it's time to tell `batdetect2` **exactly what categories (classes) your model should learn to identify**. + +This step involves defining rules that map the final tags on your sound event annotations to specific **class names** (like `pippip`, `myodau`, or `noise`). +These class names are the labels the machine learning model will be trained to predict. +Getting this definition right is essential for successful model training. + +### How it Works: Defining Classes with Rules + +You define your target classes in your main configuration file (e.g., your `.yaml` training config), typically under a section named `classes`. +This section contains a **list** of class definitions. +Each item in the list defines one specific class your model should learn. + +### Defining a Single Class + +Each class definition rule requires a few key pieces of information: + +1. `name`: **(Required)** This is the unique, simple name you want to give this class (e.g., `pipistrellus_pipistrellus`, `myotis_daubentonii`, `echolocation_noise`). + This is the label the model will actually use. + Choose names that are clear and distinct. + **Each class name must be unique.** +2. `tags`: **(Required)** This is a list containing one or more specific tags that identify annotations belonging to this class. + Remember, each tag is specified using its term `key` (like `species` or `sound_type`, defaulting to `class` if omitted) and its specific `value` (like `Pipistrellus pipistrellus` or `Echolocation`). +3. `match_type`: **(Optional, defaults to `"all"`)** This tells the system how to use the list of tags you provided in the `tag` field: + - `"all"`: An annotation must have **ALL** of the tags listed in the `tags` section to be considered part of this class. + (This is the default if you don't specify `match_type`). + - `"any"`: An annotation only needs to have **AT LEAST ONE** of the tags listed in the `tags` section to be considered part of this class. + +**Example: Defining two specific bat species classes** + +```yaml +# In your main configuration file +classes: + # Definition for the first class + - name: pippip # Simple name for Pipistrellus pipistrellus + tags: + - key: species # Term key (could also default to 'class') + value: Pipistrellus pipistrellus # Specific tag value + # match_type defaults to "all" (which is fine for a single tag) + + # Definition for the second class + - name: myodau # Simple name for Myotis daubentonii + tags: + - key: species + value: Myotis daubentonii +``` + +**Example: Defining a class requiring multiple conditions (`match_type: "all"`)** + +```yaml +classes: + - name: high_quality_pippip # Name for high-quality P. pip calls + match_type: all # Annotation must match BOTH tags below + tags: + - key: species + value: Pipistrellus pipistrellus + - key: quality # Assumes 'quality' term key exists + value: Good +``` + +**Example: Defining a class matching multiple alternative tags (`match_type: "any"`)** + +```yaml +classes: + - name: pipistrelle # Name for any Pipistrellus species in this list + match_type: any # Annotation must match AT LEAST ONE tag below + tags: + - key: species + value: Pipistrellus pipistrellus + - key: species + value: Pipistrellus pygmaeus + - key: species + value: Pipistrellus nathusii +``` + +### Handling Overlap: Priority Order Matters! + +Sometimes, an annotation might have tags that match the rules for _more than one_ class definition. +For example, an annotation tagged `species: Pipistrellus pipistrellus` would match both a specific `'pippip'` class rule and a broader `'pipistrelle'` genus rule (like the examples above) if both were defined. + +How does `batdetect2` decide which class name to assign? It uses the **order of the class definitions in your configuration list**. + +- The system checks an annotation against your class rules one by one, starting from the **top** of the `classes` list and moving down. +- As soon as it finds a rule that the annotation matches, it assigns that rule's `name` to the annotation and **stops checking** further rules for that annotation. +- **The first match wins!** + +Therefore, you should generally place your **most specific rules before more general rules** if you want the specific category to take precedence. + +**Example: Prioritizing Species over Noise** + +```yaml +classes: + # --- Specific Species Rules (Checked First) --- + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + + - name: myodau + tags: + - key: species + value: Myotis daubentonii + + # --- General Noise Rule (Checked Last) --- + - name: noise # Catch-all for anything tagged as Noise + match_type: any # Match if any noise tag is present + tags: + - key: sound_type # Assume 'sound_type' term key exists + value: Noise + - key: quality # Assume 'quality' term key exists + value: Low # Maybe low quality is also considered noise for training +``` + +In this example, an annotation tagged with `species: Myotis daubentonii` _and_ `quality: Low` would be assigned the class name `myodau` because that rule comes first in the list. +It would not be assigned `noise`, even though it also matches the second condition of the noise rule. + +Okay, that's a very important clarification about how BatDetect2 handles sounds that don't match specific class definitions. +Let's refine that section to accurately reflect this behavior. + +### What if No Class Matches? (The Generic "Bat" Class) + +It's important to understand what happens if a sound event annotation passes through the filtering (Step 2) and transformation (Step 3) steps, but its final set of tags doesn't match _any_ of the specific class definitions you've listed in this section. + +These annotations are **not ignored** during training. +Instead, they are typically assigned to a **generic "relevant sound" class**. +Think of this as a category for sounds that you considered important enough to keep after filtering, but which don't fit into one of your specific target classes for detailed classification (like a particular species). +This generic class is distinct from background noise. + +In BatDetect2, this default generic class is often referred to as the **"Bat"** class. +The goal is generally that all relevant bat echolocation calls that pass the initial filtering should fall into _either_ one of your specific defined classes (like `pippip` or `myodau`) _or_ this generic "Bat" class. + +**In summary:** + +- Sounds passing **filtering** are considered relevant. +- If a relevant sound matches one of your **specific class rules** (in priority order), it gets that specific class label. +- If a relevant sound does **not** match any specific class rule, it gets the **generic "Bat" class** label. + +**Crucially:** If you want certain types of sounds (even if they are bat calls) to be **completely excluded** from the training process altogether (not even included in the generic "Bat" class), you **must remove them using rules in the Filtering step (Step 2)**. +Any sound annotation that makes it past filtering _will_ be used in training, either under one of your specific classes or the generic one. + +### Outcome + +By defining this list of prioritized class rules, you provide `batdetect2` with a clear procedure to assign a specific target label (your class `name`) to each relevant sound event annotation based on its tags. +This labelled data is exactly what the model needs for training (Step 5). diff --git a/tests/test_targets/test_transform.py b/tests/test_targets/test_transform.py new file mode 100644 index 0000000..c55ec07 --- /dev/null +++ b/tests/test_targets/test_transform.py @@ -0,0 +1,361 @@ +from pathlib import Path + +import pytest +from soundevent import data + +from batdetect2.targets import ( + DeriveTagRule, + MapValueRule, + ReplaceRule, + TagInfo, + TransformConfig, + build_transform_from_rule, + build_transformation_from_config, +) +from batdetect2.targets.terms import TermRegistry +from batdetect2.targets.transform import DerivationRegistry + + +@pytest.fixture +def term_registry(): + return TermRegistry() + + +@pytest.fixture +def derivation_registry(): + return DerivationRegistry() + + +@pytest.fixture +def term1(term_registry: TermRegistry) -> data.Term: + return term_registry.add_custom_term(key="term1") + + +@pytest.fixture +def term2(term_registry: TermRegistry) -> data.Term: + return term_registry.add_custom_term(key="term2") + + +@pytest.fixture +def annotation( + sound_event: data.SoundEvent, + term1: data.Term, +) -> data.SoundEventAnnotation: + return data.SoundEventAnnotation( + sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")] + ) + + +def test_map_value_rule( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, +): + rule = MapValueRule( + rule_type="map_value", + source_term_key="term1", + value_mapping={"value1": "value2"}, + ) + transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transformed_annotation = transform_fn(annotation) + assert transformed_annotation.tags[0].value == "value2" + + +def test_map_value_rule_no_match( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, +): + rule = MapValueRule( + rule_type="map_value", + source_term_key="term1", + value_mapping={"other_value": "value2"}, + ) + transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transformed_annotation = transform_fn(annotation) + assert transformed_annotation.tags[0].value == "value1" + + +def test_replace_rule( + annotation: data.SoundEventAnnotation, + term2: data.Term, + term_registry: TermRegistry, +): + rule = ReplaceRule( + rule_type="replace", + original=TagInfo(key="term1", value="value1"), + replacement=TagInfo(key="term2", value="value2"), + ) + transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transformed_annotation = transform_fn(annotation) + assert transformed_annotation.tags[0].term == term2 + assert transformed_annotation.tags[0].value == "value2" + + +def test_replace_rule_no_match( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, + term2: data.Term, +): + rule = ReplaceRule( + rule_type="replace", + original=TagInfo(key="term1", value="wrong_value"), + replacement=TagInfo(key="term2", value="value2"), + ) + transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transformed_annotation = transform_fn(annotation) + assert transformed_annotation.tags[0].key == "term1" + assert transformed_annotation.tags[0].term != term2 + assert transformed_annotation.tags[0].value == "value1" + + +def test_build_transformation_from_config( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, +): + config = TransformConfig( + rules=[ + MapValueRule( + rule_type="map_value", + source_term_key="term1", + value_mapping={"value1": "value2"}, + ), + ReplaceRule( + rule_type="replace", + original=TagInfo(key="term2", value="value2"), + replacement=TagInfo(key="term3", value="value3"), + ), + ] + ) + term_registry.add_custom_term("term2") + term_registry.add_custom_term("term3") + transform = build_transformation_from_config( + config, + term_registry=term_registry, + ) + transformed_annotation = transform(annotation) + assert transformed_annotation.tags[0].key == "term1" + assert transformed_annotation.tags[0].value == "value2" + + +def test_derive_tag_rule( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, + derivation_registry: DerivationRegistry, + term1: data.Term, +): + def derivation_func(x: str) -> str: + return x + "_derived" + + derivation_registry.register("my_derivation", derivation_func) + + rule = DeriveTagRule( + rule_type="derive_tag", + source_term_key="term1", + derivation_function="my_derivation", + ) + transform_fn = build_transform_from_rule( + rule, + term_registry=term_registry, + derivation_registry=derivation_registry, + ) + transformed_annotation = transform_fn(annotation) + + assert len(transformed_annotation.tags) == 2 + assert transformed_annotation.tags[0].term == term1 + assert transformed_annotation.tags[0].value == "value1" + assert transformed_annotation.tags[1].term == term1 + assert transformed_annotation.tags[1].value == "value1_derived" + + +def test_derive_tag_rule_keep_source_false( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, + derivation_registry: DerivationRegistry, + term1: data.Term, +): + def derivation_func(x: str) -> str: + return x + "_derived" + + derivation_registry.register("my_derivation", derivation_func) + + rule = DeriveTagRule( + rule_type="derive_tag", + source_term_key="term1", + derivation_function="my_derivation", + keep_source=False, + ) + transform_fn = build_transform_from_rule( + rule, + term_registry=term_registry, + derivation_registry=derivation_registry, + ) + transformed_annotation = transform_fn(annotation) + + assert len(transformed_annotation.tags) == 1 + assert transformed_annotation.tags[0].term == term1 + assert transformed_annotation.tags[0].value == "value1_derived" + + +def test_derive_tag_rule_target_term( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, + derivation_registry: DerivationRegistry, + term1: data.Term, + term2: data.Term, +): + def derivation_func(x: str) -> str: + return x + "_derived" + + derivation_registry.register("my_derivation", derivation_func) + + rule = DeriveTagRule( + rule_type="derive_tag", + source_term_key="term1", + derivation_function="my_derivation", + target_term_key="term2", + ) + transform_fn = build_transform_from_rule( + rule, + term_registry=term_registry, + derivation_registry=derivation_registry, + ) + transformed_annotation = transform_fn(annotation) + + assert len(transformed_annotation.tags) == 2 + assert transformed_annotation.tags[0].term == term1 + assert transformed_annotation.tags[0].value == "value1" + assert transformed_annotation.tags[1].term == term2 + assert transformed_annotation.tags[1].value == "value1_derived" + + +def test_derive_tag_rule_import_derivation( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, + term1: data.Term, + tmp_path: Path, +): + # Create a dummy derivation function in a temporary file + derivation_module_path = ( + tmp_path / "temp_derivation.py" + ) # Changed to /tmp since /home/santiago is not writable + derivation_module_path.write_text( + """ +def my_imported_derivation(x: str) -> str: + return x + "_imported" +""" + ) + # Ensure the temporary file is importable by adding its directory to sys.path + import sys + + sys.path.insert(0, str(tmp_path)) + + rule = DeriveTagRule( + rule_type="derive_tag", + source_term_key="term1", + derivation_function="temp_derivation.my_imported_derivation", + import_derivation=True, + ) + transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transformed_annotation = transform_fn(annotation) + + assert len(transformed_annotation.tags) == 2 + assert transformed_annotation.tags[0].term == term1 + assert transformed_annotation.tags[0].value == "value1" + assert transformed_annotation.tags[1].term == term1 + assert transformed_annotation.tags[1].value == "value1_imported" + + # Clean up the temporary file and sys.path + sys.path.remove(str(tmp_path)) + + +def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry): + rule = DeriveTagRule( + rule_type="derive_tag", + source_term_key="term1", + derivation_function="nonexistent_derivation", + ) + with pytest.raises(KeyError): + build_transform_from_rule(rule, term_registry=term_registry) + + +def test_build_transform_from_rule_invalid_rule_type(): + class InvalidRule: + rule_type = "invalid" + + rule = InvalidRule() # type: ignore + + with pytest.raises(ValueError): + build_transform_from_rule(rule) # type: ignore + + +def test_map_value_rule_target_term( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, + term2: data.Term, +): + rule = MapValueRule( + rule_type="map_value", + source_term_key="term1", + value_mapping={"value1": "value2"}, + target_term_key="term2", + ) + transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transformed_annotation = transform_fn(annotation) + assert transformed_annotation.tags[0].term == term2 + assert transformed_annotation.tags[0].value == "value2" + + +def test_map_value_rule_target_term_none( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, + term1: data.Term, +): + rule = MapValueRule( + rule_type="map_value", + source_term_key="term1", + value_mapping={"value1": "value2"}, + target_term_key=None, + ) + transform_fn = build_transform_from_rule(rule, term_registry=term_registry) + transformed_annotation = transform_fn(annotation) + assert transformed_annotation.tags[0].term == term1 + assert transformed_annotation.tags[0].value == "value2" + + +def test_derive_tag_rule_target_term_none( + annotation: data.SoundEventAnnotation, + term_registry: TermRegistry, + derivation_registry: DerivationRegistry, + term1: data.Term, +): + def derivation_func(x: str) -> str: + return x + "_derived" + + derivation_registry.register("my_derivation", derivation_func) + + rule = DeriveTagRule( + rule_type="derive_tag", + source_term_key="term1", + derivation_function="my_derivation", + target_term_key=None, + ) + transform_fn = build_transform_from_rule( + rule, + term_registry=term_registry, + derivation_registry=derivation_registry, + ) + transformed_annotation = transform_fn(annotation) + + assert len(transformed_annotation.tags) == 2 + assert transformed_annotation.tags[0].term == term1 + assert transformed_annotation.tags[0].value == "value1" + assert transformed_annotation.tags[1].term == term1 + assert transformed_annotation.tags[1].value == "value1_derived" + + +def test_build_transformation_from_config_empty( + annotation: data.SoundEventAnnotation, +): + config = TransformConfig(rules=[]) + transform = build_transformation_from_config(config) + transformed_annotation = transform(annotation) + assert transformed_annotation == annotation