From a2ec190b73a7c1173b19d0f10d32edfee52bd8bd Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 15 Apr 2025 23:56:24 +0100 Subject: [PATCH] Add decode functions to classes module --- batdetect2/targets/__init__.py | 12 - batdetect2/targets/classes.py | 476 ++++++++++++++++++++++++----- batdetect2/targets/targets.py | 147 --------- docs/source/targets/classes.md | 185 ++++++----- tests/conftest.py | 29 +- tests/test_targets/test_classes.py | 31 +- 6 files changed, 512 insertions(+), 368 deletions(-) delete mode 100644 batdetect2/targets/targets.py diff --git a/batdetect2/targets/__init__.py b/batdetect2/targets/__init__.py index 1a99d80..fd4f52e 100644 --- a/batdetect2/targets/__init__.py +++ b/batdetect2/targets/__init__.py @@ -11,13 +11,6 @@ from batdetect2.targets.labels import ( generate_heatmaps, load_label_config, ) -from batdetect2.targets.targets import ( - TargetConfig, - build_decoder, - build_target_encoder, - get_class_names, - load_target_config, -) from batdetect2.targets.terms import ( TagInfo, TermInfo, @@ -51,23 +44,18 @@ __all__ = [ "ReplaceRule", "SoundEventTransformation", "TagInfo", - "TargetConfig", "TermInfo", "TransformConfig", - "build_decoder", - "build_target_encoder", "build_transform_from_rule", "build_transformation_from_config", "call_type", "derivation_registry", "generate_heatmaps", - "get_class_names", "get_derivation", "get_tag_from_info", "get_term_from_key", "individual", "load_label_config", - "load_target_config", "load_transformation_config", "load_transformation_from_config", "register_term", diff --git a/batdetect2/targets/classes.py b/batdetect2/targets/classes.py index 07df2f7..79b5563 100644 --- a/batdetect2/targets/classes.py +++ b/batdetect2/targets/classes.py @@ -1,12 +1,13 @@ from collections import Counter from functools import partial -from typing import Callable, List, Literal, Optional, Set +from typing import Callable, Dict, List, Literal, Optional, Set, Tuple from pydantic import Field, field_validator from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.targets.terms import ( + GENERIC_CLASS_KEY, TagInfo, TermRegistry, get_tag_from_info, @@ -15,12 +16,17 @@ from batdetect2.targets.terms import ( __all__ = [ "SoundEventEncoder", + "SoundEventDecoder", "TargetClass", "ClassesConfig", "load_classes_config", - "build_encoder_from_config", "load_encoder_from_config", + "load_decoder_from_config", + "build_encoder_from_config", + "build_decoder_from_config", + "build_generic_class_tags_from_config", "get_class_names_from_config", + "DEFAULT_SPECIES_LIST", ] SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]] @@ -33,63 +39,184 @@ rules, the function returns None. """ +SoundEventDecoder = Callable[[str], List[data.Tag]] +"""Type alias for a sound event class decoder function. + +A decoder function takes a class name string (as predicted by the model or +assigned during encoding) and returns a list of `soundevent.data.Tag` objects +that represent that class according to the configuration. This is used to +translate model outputs back into meaningful annotations. +""" + +DEFAULT_SPECIES_LIST = [ + "Barbastella barbastellus", + "Eptesicus serotinus", + "Myotis alcathoe", + "Myotis bechsteinii", + "Myotis brandtii", + "Myotis daubentonii", + "Myotis mystacinus", + "Myotis nattereri", + "Nyctalus leisleri", + "Nyctalus noctula", + "Pipistrellus nathusii", + "Pipistrellus pipistrellus", + "Pipistrellus pygmaeus", + "Plecotus auritus", + "Plecotus austriacus", + "Rhinolophus ferrumequinum", + "Rhinolophus hipposideros", +] +"""A default list of common bat species names found in the UK.""" + + class TargetClass(BaseConfig): - """Defines the criteria for assigning an annotation to a specific class. + """Defines criteria for encoding annotations and decoding predictions. Each instance represents one potential output class for the classification - model. It specifies the class name and the tag conditions an annotation - must meet to be assigned this class label. + model. It specifies: + 1. A unique `name` for the class. + 2. The tag conditions (`tags` and `match_type`) an annotation must meet to + be assigned this class name during training data preparation (encoding). + 3. An optional, alternative set of tags (`output_tags`) to be used when + converting a model's prediction of this class name back into annotation + tags (decoding). Attributes ---------- name : str The unique name assigned to this target class (e.g., 'pippip', - 'myodau', 'noise'). This name will be used as the label during model - training and output. Should be unique across all TargetClass - definitions in a configuration. - tag : List[TagInfo] - A list of one or more tags (defined using `TagInfo`) that an annotation - must possess to potentially match this class. + 'myodau', 'noise'). This name is used as the label during model + training and is the expected output from the model's prediction. + Should be unique across all TargetClass definitions in a configuration. + tags : List[TagInfo] + A list of one or more tags (defined using `TagInfo`) used to identify + if an existing annotation belongs to this class during encoding (data + preparation for training). The `match_type` attribute determines how + these tags are evaluated. match_type : Literal["all", "any"], default="all" - Determines how the `tag` list is evaluated: - - "all": The annotation must have *all* the tags listed in the `tag` - field to match this class definition. + Determines how the `tags` list is evaluated during encoding: + - "all": The annotation must have *all* the tags listed to match. - "any": The annotation must have *at least one* of the tags listed - in the `tag` field to match this class definition. + to match. + output_tags: Optional[List[TagInfo]], default=None + An optional list of tags (defined using `TagInfo`) to be assigned to a + new annotation when the model predicts this class `name`. If `None` + (default), the tags listed in the `tags` field will be used for + decoding. If provided, this list overrides the `tags` field for the + purpose of decoding predictions back into meaningful annotation tags. + This allows, for example, training on broader categories but decoding + to more specific representative tags. """ name: str - tags: List[TagInfo] = Field(default_factory=list, min_length=1) + tags: List[TagInfo] = Field(min_length=1) match_type: Literal["all", "any"] = Field(default="all") + output_tags: Optional[List[TagInfo]] = None + + +def _get_default_classes() -> List[TargetClass]: + """Generate a list of default target classes. + + Returns + ------- + List[TargetClass] + A list of TargetClass objects, one for each species in + DEFAULT_SPECIES_LIST. The class names are simplified versions of the + species names. + """ + return [ + TargetClass( + name=_get_default_class_name(value), + tags=[TagInfo(key=GENERIC_CLASS_KEY, value=value)], + ) + for value in DEFAULT_SPECIES_LIST + ] + + +def _get_default_class_name(species: str) -> str: + """Generate a default class name from a species name. + + Parameters + ---------- + species : str + The species name (e.g., "Myotis daubentonii"). + + Returns + ------- + str + A simplified class name (e.g., "myodau"). + The genus and species names are converted to lowercase, + the first three letters of each are taken, and concatenated. + """ + genus, species = species.strip().split(" ") + return f"{genus.lower()[:3]}{species.lower()[:3]}" + + +def _get_default_generic_class() -> List[TagInfo]: + """Generate the default list of TagInfo objects for the generic class. + + Provides a default set of tags used to represent the generic "Bat" category + when decoding predictions that didn't match a specific class. + + Returns + ------- + List[TagInfo] + A list containing default TagInfo objects, typically representing + `call_type: Echolocation` and `order: Chiroptera`. + """ + return [ + TagInfo(key="call_type", value="Echolocation"), + TagInfo(key="order", value="Chiroptera"), + ] class ClassesConfig(BaseConfig): - """Configuration model holding the list of target class definitions. + """Configuration defining target classes and the generic fallback category. + + Holds the ordered list of specific target class definitions (`TargetClass`) + and defines the tags representing the generic category for sounds that pass + filtering but do not match any specific class. The order of `TargetClass` objects in the `classes` list defines the - priority for classification. When encoding an annotation, the system checks - against the class definitions in this sequence and assigns the name of the - *first* matching class. + priority for classification during encoding. The system checks annotations + against these definitions sequentially and assigns the name of the *first* + matching class. Attributes ---------- classes : List[TargetClass] - An ordered list of target class definitions. The order determines - matching priority (first match wins). + An ordered list of specific target class definitions. The order + determines matching priority (first match wins). Defaults to a + standard set of classes via `get_default_classes`. + generic_class : List[TagInfo] + A list of tags defining the "generic" or "unclassified but relevant" + category (e.g., representing a generic 'Bat' call that wasn't + assigned to a specific species). These tags are typically assigned + during decoding when a sound event was detected and passed filtering + but did not match any specific class rule defined in the `classes` list. + Defaults to a standard set of tags via `get_default_generic_class`. Raises ------ ValueError - If validation fails (e.g., non-unique class names). + If validation fails (e.g., non-unique class names in the `classes` + list). Notes ----- - It is crucial that the `name` attribute of each `TargetClass` in the - `classes` list is unique. This configuration includes a validator to - enforce this uniqueness. + - It is crucial that the `name` attribute of each `TargetClass` in the + `classes` list is unique. This configuration includes a validator to + enforce this uniqueness. + - The `generic_class` tags provide a baseline identity for relevant sounds + that don't fit into more specific defined categories. """ - classes: List[TargetClass] = Field(default_factory=list) + classes: List[TargetClass] = Field(default_factory=_get_default_classes) + + generic_class: List[TagInfo] = Field( + default_factory=_get_default_generic_class + ) @field_validator("classes") def check_unique_class_names(cls, v: List[TargetClass]): @@ -108,37 +235,7 @@ class ClassesConfig(BaseConfig): return v -def load_classes_config( - path: data.PathLike, - field: Optional[str] = None, -) -> ClassesConfig: - """Load the target classes configuration from a file. - - Parameters - ---------- - path : data.PathLike - Path to the configuration file (YAML). - field : str, optional - If the classes configuration is nested under a specific key in the - file, specify the key here. Defaults to None. - - Returns - ------- - ClassesConfig - The loaded and validated classes configuration object. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - pydantic.ValidationError - If the config file structure does not match the ClassesConfig schema - or if class names are not unique. - """ - return load_config(path, schema=ClassesConfig, field=field) - - -def is_target_class( +def _is_target_class( sound_event_annotation: data.SoundEventAnnotation, tags: Set[data.Tag], match_all: bool = True, @@ -185,6 +282,38 @@ def get_class_names_from_config(config: ClassesConfig) -> List[str]: return [class_info.name for class_info in config.classes] +def _encode_with_multiple_classifiers( + sound_event_annotation: data.SoundEventAnnotation, + classifiers: List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]], +) -> Optional[str]: + """Encode an annotation by checking against a list of classifiers. + + Internal helper function used by the `SoundEventEncoder`. It iterates + through the provided list of (class_name, classifier_function) pairs. + Returns the name associated with the first classifier function that + returns True for the given annotation. + + Parameters + ---------- + sound_event_annotation : data.SoundEventAnnotation + The annotation to encode. + classifiers : List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]] + An ordered list where each tuple contains a class name and a function + that returns True if the annotation matches that class. The order + determines priority. + + Returns + ------- + str or None + The name of the first matching class, or None if no classifier matches. + """ + for class_name, classifier in classifiers: + if classifier(sound_event_annotation): + return class_name + + return None + + def build_encoder_from_config( config: ClassesConfig, term_registry: TermRegistry = term_registry, @@ -221,7 +350,7 @@ def build_encoder_from_config( ( class_info.name, partial( - is_target_class, + _is_target_class, tags={ get_tag_from_info(tag_info, term_registry=term_registry) for tag_info in class_info.tags @@ -232,31 +361,170 @@ def build_encoder_from_config( for class_info in config.classes ] - def encoder( - sound_event_annotation: data.SoundEventAnnotation, - ) -> Optional[str]: - """Assign a class name to an annotation based on configured rules. + return partial( + _encode_with_multiple_classifiers, + classifiers=binary_classifiers, + ) - Iterates through pre-compiled classifiers in priority order. Returns - the name of the first matching class, or None if no match is found. - Parameters - ---------- - sound_event_annotation : data.SoundEventAnnotation - The annotation to encode. +def _decode_class( + name: str, + mapping: Dict[str, List[data.Tag]], + raise_on_error: bool = True, +) -> List[data.Tag]: + """Decode a class name into a list of representative tags using a mapping. - Returns - ------- - str or None - The name of the matched class, or None. - """ - for class_name, classifier in binary_classifiers: - if classifier(sound_event_annotation): - return class_name + Internal helper function used by the `SoundEventDecoder`. Looks up the + provided class `name` in the `mapping` dictionary. - return None + Parameters + ---------- + name : str + The class name to decode. + mapping : Dict[str, List[data.Tag]] + A dictionary mapping class names to lists of `soundevent.data.Tag` + objects. + raise_on_error : bool, default=True + If True, raises a ValueError if the `name` is not found in the + `mapping`. If False, returns an empty list if the `name` is not found. - return encoder + Returns + ------- + List[data.Tag] + The list of tags associated with the class name, or an empty list if + not found and `raise_on_error` is False. + + Raises + ------ + ValueError + If `name` is not found in `mapping` and `raise_on_error` is True. + """ + if name not in mapping and raise_on_error: + raise ValueError(f"Class {name} not found in mapping.") + + if name not in mapping: + return [] + + return mapping[name] + + +def build_decoder_from_config( + config: ClassesConfig, + term_registry: TermRegistry = term_registry, + raise_on_unmapped: bool = False, +) -> SoundEventDecoder: + """Build a sound event decoder function from the classes configuration. + + Creates a callable `SoundEventDecoder` that maps a class name string + back to a list of representative `soundevent.data.Tag` objects based on + the `ClassesConfig`. It uses the `output_tags` field if provided in a + `TargetClass`, otherwise falls back to the `tags` field. + + Parameters + ---------- + config : ClassesConfig + The loaded and validated classes configuration object. + term_registry : TermRegistry, optional + The TermRegistry instance used to look up term keys. Defaults to the + global `batdetect2.targets.terms.registry`. + raise_on_unmapped : bool, default=False + If True, the returned decoder function will raise a ValueError if asked + to decode a class name that is not in the configuration. If False, it + will return an empty list for unmapped names. + + Returns + ------- + SoundEventDecoder + A callable function that takes a class name string and returns a list + of `soundevent.data.Tag` objects. + + Raises + ------ + KeyError + If a term key specified in the configuration (`output_tags`, `tags`, or + `generic_class`) is not found in the provided `term_registry`. + """ + mapping = {} + for class_info in config.classes: + tags_to_use = ( + class_info.output_tags + if class_info.output_tags is not None + else class_info.tags + ) + mapping[class_info.name] = [ + get_tag_from_info(tag_info, term_registry=term_registry) + for tag_info in tags_to_use + ] + + return partial( + _decode_class, + mapping=mapping, + raise_on_error=raise_on_unmapped, + ) + + +def build_generic_class_tags_from_config( + config: ClassesConfig, + term_registry: TermRegistry = term_registry, +) -> List[data.Tag]: + """Extract and build the list of tags for the generic class from config. + + Converts the list of `TagInfo` objects defined in `config.generic_class` + into a list of `soundevent.data.Tag` objects using the term registry. + + Parameters + ---------- + config : ClassesConfig + The loaded classes configuration object. + term_registry : TermRegistry, optional + The TermRegistry instance for term lookups. Defaults to the global + `batdetect2.targets.terms.registry`. + + Returns + ------- + List[data.Tag] + The list of fully constructed tags representing the generic class. + + Raises + ------ + KeyError + If a term key specified in `config.generic_class` is not found in the + provided `term_registry`. + """ + return [ + get_tag_from_info(tag_info, term_registry=term_registry) + for tag_info in config.generic_class + ] + + +def load_classes_config( + path: data.PathLike, + field: Optional[str] = None, +) -> ClassesConfig: + """Load the target classes configuration from a file. + + Parameters + ---------- + path : data.PathLike + Path to the configuration file (YAML). + field : str, optional + If the classes configuration is nested under a specific key in the + file, specify the key here. Defaults to None. + + Returns + ------- + ClassesConfig + The loaded and validated classes configuration object. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + pydantic.ValidationError + If the config file structure does not match the ClassesConfig schema + or if class names are not unique. + """ + return load_config(path, schema=ClassesConfig, field=field) def load_encoder_from_config( @@ -298,3 +566,53 @@ def load_encoder_from_config( """ config = load_classes_config(path, field=field) return build_encoder_from_config(config, term_registry=term_registry) + + +def load_decoder_from_config( + path: data.PathLike, + field: Optional[str] = None, + term_registry: TermRegistry = term_registry, + raise_on_unmapped: bool = False, +) -> SoundEventDecoder: + """Load a class decoder function directly from a configuration file. + + This is a convenience function that combines loading the `ClassesConfig` + from a file and building the final `SoundEventDecoder` function. + + Parameters + ---------- + path : data.PathLike + Path to the configuration file (e.g., YAML). + field : str, optional + If the classes configuration is nested under a specific key in the + file, specify the key here. Defaults to None. + term_registry : TermRegistry, optional + The TermRegistry instance used for term lookups. Defaults to the + global `batdetect2.targets.terms.registry`. + raise_on_unmapped : bool, default=False + If True, the returned decoder function will raise a ValueError if asked + to decode a class name that is not in the configuration. If False, it + will return an empty list for unmapped names. + + Returns + ------- + SoundEventDecoder + The final decoder function ready to convert class names back into tags. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + pydantic.ValidationError + If the config file structure does not match the ClassesConfig schema + or if class names are not unique. + KeyError + If a term key specified in the configuration is not found in the + provided `term_registry` during the build process. + """ + config = load_classes_config(path, field=field) + return build_decoder_from_config( + config, + term_registry=term_registry, + raise_on_unmapped=raise_on_unmapped, + ) diff --git a/batdetect2/targets/targets.py b/batdetect2/targets/targets.py deleted file mode 100644 index c105c42..0000000 --- a/batdetect2/targets/targets.py +++ /dev/null @@ -1,147 +0,0 @@ -from collections.abc import Iterable -from pathlib import Path -from typing import Callable, List, Optional - -from pydantic import Field -from soundevent import data - -from batdetect2.configs import BaseConfig, load_config -from batdetect2.targets.terms import TagInfo, get_tag_from_info - -__all__ = [ - "TargetConfig", - "load_target_config", - "build_target_encoder", - "build_decoder", -] - - -class ReplaceConfig(BaseConfig): - """Configuration for replacing tags.""" - - original: TagInfo - replacement: TagInfo - - -class TargetConfig(BaseConfig): - """Configuration for target generation.""" - - classes: List[TagInfo] = Field( - default_factory=lambda: [ - TagInfo(key="class", value=value) for value in DEFAULT_SPECIES_LIST - ] - ) - generic_class: Optional[TagInfo] = Field( - default_factory=lambda: TagInfo(key="class", value="Bat") - ) - - include: Optional[List[TagInfo]] = Field( - default_factory=lambda: [TagInfo(key="event", value="Echolocation")] - ) - - exclude: Optional[List[TagInfo]] = Field( - default_factory=lambda: [ - TagInfo(key="class", value=""), - TagInfo(key="class", value=" "), - TagInfo(key="class", value="Unknown"), - ] - ) - - replace: Optional[List[ReplaceConfig]] = None - - -def get_tag_label(tag_info: TagInfo) -> str: - # TODO: Review this - return tag_info.value - - -def get_class_names(classes: List[TagInfo]) -> List[str]: - return sorted({get_tag_label(tag) for tag in classes}) - - -def build_replacer( - rules: List[ReplaceConfig], -) -> Callable[[data.Tag], data.Tag]: - mapping = { - get_tag_from_info(rule.original): get_tag_from_info(rule.replacement) - for rule in rules - } - - def replacer(tag: data.Tag) -> data.Tag: - return mapping.get(tag, tag) - - return replacer - - -def build_target_encoder( - classes: List[TagInfo], - replacement_rules: Optional[List[ReplaceConfig]] = None, -) -> Callable[[Iterable[data.Tag]], Optional[str]]: - target_tags = set([get_tag_from_info(tag) for tag in classes]) - - tag_mapping = { - tag: get_tag_label(tag_info) - for tag, tag_info in zip(target_tags, classes) - } - - replacer = ( - build_replacer(replacement_rules) if replacement_rules else lambda x: x - ) - - def encoder( - tags: Iterable[data.Tag], - ) -> Optional[str]: - sanitized_tags = {replacer(tag) for tag in tags} - - intersection = sanitized_tags & target_tags - - if not intersection: - return None - - first = intersection.pop() - return tag_mapping[first] - - return encoder - - -def build_decoder( - classes: List[TagInfo], -) -> Callable[[str], List[data.Tag]]: - target_tags = set([get_tag_from_info(tag) for tag in classes]) - tag_mapping = { - get_tag_label(tag_info): tag - for tag, tag_info in zip(target_tags, classes) - } - - def decoder(label: str) -> List[data.Tag]: - tag = tag_mapping.get(label) - return [tag] if tag else [] - - return decoder - - -def load_target_config( - path: Path, field: Optional[str] = None -) -> TargetConfig: - return load_config(path, schema=TargetConfig, field=field) - - -DEFAULT_SPECIES_LIST = [ - "Barbastellus barbastellus", - "Eptesicus serotinus", - "Myotis alcathoe", - "Myotis bechsteinii", - "Myotis brandtii", - "Myotis daubentonii", - "Myotis mystacinus", - "Myotis nattereri", - "Nyctalus leisleri", - "Nyctalus noctula", - "Pipistrellus nathusii", - "Pipistrellus pipistrellus", - "Pipistrellus pygmaeus", - "Plecotus auritus", - "Plecotus austriacus", - "Rhinolophus ferrumequinum", - "Rhinolophus hipposideros", -] diff --git a/docs/source/targets/classes.md b/docs/source/targets/classes.md index 5b0aa37..1ce09a3 100644 --- a/docs/source/targets/classes.md +++ b/docs/source/targets/classes.md @@ -1,148 +1,141 @@ -# Step 4: Defining Target Classes for Training +# Step 4: Defining Target Classes and Decoding Rules ## Purpose and Context You've prepared your data by defining your annotation vocabulary (Step 1: Terms), removing irrelevant sounds (Step 2: Filtering), and potentially cleaning up or modifying tags (Step 3: Transforming Tags). -Now, it's time to tell `batdetect2` **exactly what categories (classes) your model should learn to identify**. +Now, it's time for a crucial step with two related goals: -This step involves defining rules that map the final tags on your sound event annotations to specific **class names** (like `pippip`, `myodau`, or `noise`). -These class names are the labels the machine learning model will be trained to predict. -Getting this definition right is essential for successful model training. +1. Telling `batdetect2` **exactly what categories (classes) your model should learn to identify** by defining rules that map annotation tags to class names (like `pippip`, `myodau`, or `noise`). + This process is often called **encoding**. +2. Defining how the model's predictions (those same class names) should be translated back into meaningful, structured **annotation tags** when you use the trained model. + This is often called **decoding**. + +These definitions are essential for both training the model correctly and interpreting its output later. ## How it Works: Defining Classes with Rules -You define your target classes in your main configuration file (e.g., your `.yaml` training config), typically under a section named `classes`. -This section contains a **list** of class definitions. -Each item in the list defines one specific class your model should learn. +You define your target classes and their corresponding decoding rules in your main configuration file (e.g., your `.yaml` training config), typically under a section named `classes`. +This section contains: + +1. A **list** of specific class definitions. +2. A definition for the **generic class** tags. + +Each item in the `classes` list defines one specific class your model should learn. ## Defining a Single Class -Each class definition rule requires a few key pieces of information: +Each specific class definition rule requires the following information: -1. `name`: **(Required)** This is the unique, simple name you want to give this class (e.g., `pipistrellus_pipistrellus`, `myotis_daubentonii`, `echolocation_noise`). - This is the label the model will actually use. - Choose names that are clear and distinct. +1. `name`: **(Required)** This is the unique, simple name for this class (e.g., `pipistrellus_pipistrellus`, `myotis_daubentonii`, `noise`). + This label is used during training and is what the model predicts. + Choose clear, distinct names. **Each class name must be unique.** -2. `tags`: **(Required)** This is a list containing one or more specific tags that identify annotations belonging to this class. - Remember, each tag is specified using its term `key` (like `species` or `sound_type`, defaulting to `class` if omitted) and its specific `value` (like `Pipistrellus pipistrellus` or `Echolocation`). -3. `match_type`: **(Optional, defaults to `"all"`)** This tells the system how to use the list of tags you provided in the `tag` field: - - `"all"`: An annotation must have **ALL** of the tags listed in the `tags` section to be considered part of this class. - (This is the default if you don't specify `match_type`). - - `"any"`: An annotation only needs to have **AT LEAST ONE** of the tags listed in the `tags` section to be considered part of this class. +2. `tags`: **(Required)** This list contains one or more specific tags (using `key` and `value`) used to identify if an _existing_ annotation belongs to this class during the _encoding_ phase (preparing training data). +3. `match_type`: **(Optional, defaults to `"all"`)** Determines how the `tags` list is evaluated during _encoding_: + - `"all"`: The annotation must have **ALL** listed tags to match. + (Default). + - `"any"`: The annotation needs **AT LEAST ONE** listed tag to match. +4. `output_tags`: **(Optional)** This list specifies the tags that should be assigned to an annotation when the model _predicts_ this class `name`. + This is used during the _decoding_ phase (interpreting model output). + - **If you omit `output_tags` (or set it to `null`/~), the system will default to using the same tags listed in the `tags` field for decoding.** This is often what you want. + - Providing `output_tags` allows you to specify a different, potentially more canonical or detailed, set of tags to represent the class upon prediction. + For example, you could match based on simplified tags but output standardized tags. -**Example: Defining two specific bat species classes** +**Example: Defining Species Classes (Encoding & Default Decoding)** + +Here, the `tags` used for matching during encoding will also be used for decoding, as `output_tags` is omitted. ```yaml # In your main configuration file classes: # Definition for the first class - name: pippip # Simple name for Pipistrellus pipistrellus - tags: - - key: species # Term key (could also default to 'class') - value: Pipistrellus pipistrellus # Specific tag value - # match_type defaults to "all" (which is fine for a single tag) + tags: # Used for BOTH encoding match and decoding output + - key: species + value: Pipistrellus pipistrellus + # match_type defaults to "all" + # output_tags is omitted, defaults to using 'tags' above # Definition for the second class - name: myodau # Simple name for Myotis daubentonii - tags: + tags: # Used for BOTH encoding match and decoding output - key: species value: Myotis daubentonii ``` -**Example: Defining a class requiring multiple conditions (`match_type: "all"`)** +**Example: Defining a Class with Separate Encoding and Decoding Tags** + +Here, we match based on _either_ of two tags (`match_type: any`), but when the model predicts `'pipistrelle'`, we decode it _only_ to the specific `Pipistrellus pipistrellus` tag plus a genus tag. ```yaml classes: - - name: high_quality_pippip # Name for high-quality P. pip calls - match_type: all # Annotation must match BOTH tags below - tags: - - key: species - value: Pipistrellus pipistrellus - - key: quality # Assumes 'quality' term key exists - value: Good -``` - -**Example: Defining a class matching multiple alternative tags (`match_type: "any"`)** - -```yaml -classes: - - name: pipistrelle # Name for any Pipistrellus species in this list - match_type: any # Annotation must match AT LEAST ONE tag below + - name: pipistrelle # Name for a Pipistrellus group + match_type: any # Match if EITHER tag below is present during encoding tags: - key: species value: Pipistrellus pipistrellus - key: species - value: Pipistrellus pygmaeus + value: Pipistrellus pygmaeus # Match pygmaeus too + output_tags: # BUT, when decoding 'pipistrelle', assign THESE tags: - key: species - value: Pipistrellus nathusii + value: Pipistrellus pipistrellus # Canonical species + - key: genus # Assumes 'genus' key exists + value: Pipistrellus # Add genus tag ``` -## Handling Overlap: Priority Order Matters +## Handling Overlap During Encoding: Priority Order Matters -Sometimes, an annotation might have tags that match the rules for _more than one_ class definition. -For example, an annotation tagged `species: Pipistrellus pipistrellus` would match both a specific `'pippip'` class rule and a broader `'pipistrelle'` genus rule (like the examples above) if both were defined. +As before, when preparing training data (encoding), if an annotation matches the `tags` and `match_type` rules for multiple class definitions, the **order of the class definitions in the configuration list determines the priority**. -How does `batdetect2` decide which class name to assign? It uses the **order of the class definitions in your configuration list**. +- The system checks rules from the **top** of the `classes` list down. +- The annotation gets assigned the `name` of the **first class rule it matches**. +- **Place more specific rules before more general rules.** -- The system checks an annotation against your class rules one by one, starting from the **top** of the `classes` list and moving down. -- As soon as it finds a rule that the annotation matches, it assigns that rule's `name` to the annotation and **stops checking** further rules for that annotation. -- **The first match wins!** +_(The YAML example for prioritizing Species over Noise remains the same as the previous version)_ -Therefore, you should generally place your **most specific rules before more general rules** if you want the specific category to take precedence. +## Handling Non-Matches & Decoding the Generic Class -**Example: Prioritizing Species over Noise** +What happens if an annotation passes filtering/transformation but doesn't match any specific class rule during encoding? + +- **Encoding:** As explained previously, these annotations are **not ignored**. + They are typically assigned to a generic "relevant sound" category, often called the **"Bat"** class in BatDetect2, intended for all relevant bat calls not specifically classified. +- **Decoding:** When the model predicts this generic "Bat" category (or when processing sounds that weren't assigned a specific class during encoding), we need a way to represent this generic status with tags. + This is defined by the `generic_class` list directly within the main `classes` configuration section. + +**Defining the Generic Class Tags:** + +You specify the tags for the generic class like this: ```yaml -classes: - # --- Specific Species Rules (Checked First) --- - - name: pippip - tags: - - key: species - value: Pipistrellus pipistrellus +# In your main configuration file +classes: # Main configuration section for classes + # --- List of specific class definitions --- + classes: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + # ... other specific classes ... - - name: myodau - tags: - - key: species - value: Myotis daubentonii - - # --- General Noise Rule (Checked Last) --- - - name: noise # Catch-all for anything tagged as Noise - match_type: any # Match if any noise tag is present - tags: - - key: sound_type # Assume 'sound_type' term key exists - value: Noise - - key: quality # Assume 'quality' term key exists - value: Low # Maybe low quality is also considered noise for training + # --- Definition of the generic class tags --- + generic_class: # Define tags for the generic 'Bat' category + - key: call_type + value: Echolocation + - key: order + value: Chiroptera + # These tags will be assigned when decoding the generic category ``` -In this example, an annotation tagged with `species: Myotis daubentonii` _and_ `quality: Low` would be assigned the class name `myodau` because that rule comes first in the list. -It would not be assigned `noise`, even though it also matches the second condition of the noise rule. +This `generic_class` list provides the standard tags assigned when a sound is identified as relevant (passed filtering) but doesn't belong to one of the specific target classes you defined. +Like the specific classes, sensible defaults are often provided if you don't explicitly define `generic_class`. -Okay, that's a very important clarification about how BatDetect2 handles sounds that don't match specific class definitions. -Let's refine that section to accurately reflect this behavior. +**Crucially:** Remember, if sounds should be **completely excluded** from training (not even considered "generic"), use **Filtering rules (Step 2)**. -## What if No Class Matches? +### Outcome -It's important to understand what happens if a sound event annotation passes through the filtering (Step 2) and transformation (Step 3) steps, but its final set of tags doesn't match _any_ of the specific class definitions you've listed in this section. +By defining this list of prioritized class rules (including their `name`, matching `tags`, `match_type`, and optional `output_tags`) and the `generic_class` tags, you provide `batdetect2` with: -These annotations are **not ignored** during training. -Instead, they are typically assigned to a **generic "relevant sound" class**. -Think of this as a category for sounds that you considered important enough to keep after filtering, but which don't fit into one of your specific target classes for detailed classification (like a particular species). -This generic class is distinct from background noise. +1. A clear procedure to assign a target label (`name`) to each relevant annotation for training. +2. A clear mapping to convert predicted class names (including the generic case) back into meaningful annotation tags. -In BatDetect2, this default generic class is often referred to as the **"Bat"** class. -The goal is generally that all relevant bat echolocation calls that pass the initial filtering should fall into _either_ one of your specific defined classes (like `pippip` or `myodau`) _or_ this generic "Bat" class. - -**In summary:** - -- Sounds passing **filtering** are considered relevant. -- If a relevant sound matches one of your **specific class rules** (in priority order), it gets that specific class label. -- If a relevant sound does **not** match any specific class rule, it gets the **generic "Bat" class** label. - -**Crucially:** If you want certain types of sounds (even if they are bat calls) to be **completely excluded** from the training process altogether (not even included in the generic "Bat" class), you **must remove them using rules in the Filtering step (Step 2)**. -Any sound annotation that makes it past filtering _will_ be used in training, either under one of your specific classes or the generic one. - -## Outcome - -By defining this list of prioritized class rules, you provide `batdetect2` with a clear procedure to assign a specific target label (your class `name`) to each relevant sound event annotation based on its tags. -This labelled data is exactly what the model needs for training (Step 5). +This complete definition prepares your data for the final heatmap generation (Step 5) and enables interpretation of the model's results. diff --git a/tests/conftest.py b/tests/conftest.py index e39ccff..99c9877 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,13 @@ import uuid from pathlib import Path -from typing import Callable, Iterable, List, Optional +from typing import Callable, List, Optional import numpy as np import pytest import soundfile as sf from soundevent import data, terms -from batdetect2.targets import ( - TargetConfig, - build_target_encoder, - call_type, - get_class_names, -) +from batdetect2.targets import call_type @pytest.fixture @@ -201,23 +196,3 @@ def clip_annotation( non_relevant_sound_event, ], ) - - -@pytest.fixture -def target_config() -> TargetConfig: - return TargetConfig() - - -@pytest.fixture -def class_names(target_config: TargetConfig) -> List[str]: - return get_class_names(target_config.classes) - - -@pytest.fixture -def encoder( - target_config: TargetConfig, -) -> Callable[[Iterable[data.Tag]], Optional[str]]: - return build_target_encoder( - classes=target_config.classes, - replacement_rules=target_config.replace, - ) diff --git a/tests/test_targets/test_classes.py b/tests/test_targets/test_classes.py index 6bf8cae..e0ee8c8 100644 --- a/tests/test_targets/test_classes.py +++ b/tests/test_targets/test_classes.py @@ -7,11 +7,14 @@ from pydantic import ValidationError from soundevent import data from batdetect2.targets.classes import ( + DEFAULT_SPECIES_LIST, ClassesConfig, TargetClass, build_encoder_from_config, get_class_names_from_config, - is_target_class, + _get_default_class_name, + _get_default_classes, + _is_target_class, load_classes_config, load_encoder_from_config, ) @@ -149,7 +152,7 @@ def test_is_target_class_match_all( ), data.Tag(term=sample_term_registry["quality"], value="Good"), } - assert is_target_class(sample_annotation, tags, match_all=True) is True + assert _is_target_class(sample_annotation, tags, match_all=True) is True tags = { data.Tag( @@ -157,14 +160,14 @@ def test_is_target_class_match_all( value="Pipistrellus pipistrellus", ) } - assert is_target_class(sample_annotation, tags, match_all=True) is True + assert _is_target_class(sample_annotation, tags, match_all=True) is True tags = { data.Tag( term=sample_term_registry["species"], value="Myotis daubentonii" ) } - assert is_target_class(sample_annotation, tags, match_all=True) is False + assert _is_target_class(sample_annotation, tags, match_all=True) is False def test_is_target_class_match_any( @@ -178,7 +181,7 @@ def test_is_target_class_match_any( ), data.Tag(term=sample_term_registry["quality"], value="Good"), } - assert is_target_class(sample_annotation, tags, match_all=False) is True + assert _is_target_class(sample_annotation, tags, match_all=False) is True tags = { data.Tag( @@ -186,14 +189,14 @@ def test_is_target_class_match_any( value="Pipistrellus pipistrellus", ) } - assert is_target_class(sample_annotation, tags, match_all=False) is True + assert _is_target_class(sample_annotation, tags, match_all=False) is True tags = { data.Tag( term=sample_term_registry["species"], value="Myotis daubentonii" ) } - assert is_target_class(sample_annotation, tags, match_all=False) is False + assert _is_target_class(sample_annotation, tags, match_all=False) is False def test_get_class_names_from_config(): @@ -279,3 +282,17 @@ def test_load_encoder_from_config_invalid( temp_yaml_path, term_registry=sample_term_registry, ) + + +def test_get_default_class_name(): + assert _get_default_class_name("Myotis daubentonii") == "myodau" + + +def test_get_default_classes(): + default_classes = _get_default_classes() + assert len(default_classes) == len(DEFAULT_SPECIES_LIST) + first_class = default_classes[0] + assert isinstance(first_class, TargetClass) + assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0]) + assert first_class.tags[0].key == "class" + assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]