mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Add target classes module
This commit is contained in:
parent
02d4779207
commit
af48c33307
@ -24,22 +24,53 @@ from batdetect2.targets.terms import (
|
||||
TermInfo,
|
||||
call_type,
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
individual,
|
||||
register_term,
|
||||
term_registry,
|
||||
)
|
||||
from batdetect2.targets.transform import (
|
||||
DerivationRegistry,
|
||||
DeriveTagRule,
|
||||
MapValueRule,
|
||||
ReplaceRule,
|
||||
SoundEventTransformation,
|
||||
TransformConfig,
|
||||
build_transform_from_rule,
|
||||
build_transformation_from_config,
|
||||
derivation_registry,
|
||||
get_derivation,
|
||||
load_transformation_config,
|
||||
load_transformation_from_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DerivationRegistry",
|
||||
"DeriveTagRule",
|
||||
"HeatmapsConfig",
|
||||
"LabelConfig",
|
||||
"MapValueRule",
|
||||
"ReplaceRule",
|
||||
"SoundEventTransformation",
|
||||
"TagInfo",
|
||||
"TargetConfig",
|
||||
"TermInfo",
|
||||
"TransformConfig",
|
||||
"build_decoder",
|
||||
"build_target_encoder",
|
||||
"build_transform_from_rule",
|
||||
"build_transformation_from_config",
|
||||
"call_type",
|
||||
"derivation_registry",
|
||||
"generate_heatmaps",
|
||||
"get_class_names",
|
||||
"get_derivation",
|
||||
"get_tag_from_info",
|
||||
"individual",
|
||||
"load_label_config",
|
||||
"load_target_config",
|
||||
"load_transformation_config",
|
||||
"load_transformation_from_config",
|
||||
"register_term",
|
||||
"term_registry",
|
||||
]
|
||||
|
300
batdetect2/targets/classes.py
Normal file
300
batdetect2/targets/classes.py
Normal file
@ -0,0 +1,300 @@
|
||||
from collections import Counter
|
||||
from functools import partial
|
||||
from typing import Callable, List, Literal, Optional, Set
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SoundEventEncoder",
|
||||
"TargetClass",
|
||||
"ClassesConfig",
|
||||
"load_classes_config",
|
||||
"build_encoder_from_config",
|
||||
"load_encoder_from_config",
|
||||
"get_class_names_from_config",
|
||||
]
|
||||
|
||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
"""Type alias for a sound event class encoder function.
|
||||
|
||||
An encoder function takes a sound event annotation and returns the string name
|
||||
of the target class it belongs to, based on a predefined set of rules.
|
||||
If the annotation does not match any defined target class according to the
|
||||
rules, the function returns None.
|
||||
"""
|
||||
|
||||
|
||||
class TargetClass(BaseConfig):
|
||||
"""Defines the criteria for assigning an annotation to a specific class.
|
||||
|
||||
Each instance represents one potential output class for the classification
|
||||
model. It specifies the class name and the tag conditions an annotation
|
||||
must meet to be assigned this class label.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
The unique name assigned to this target class (e.g., 'pippip',
|
||||
'myodau', 'noise'). This name will be used as the label during model
|
||||
training and output. Should be unique across all TargetClass
|
||||
definitions in a configuration.
|
||||
tag : List[TagInfo]
|
||||
A list of one or more tags (defined using `TagInfo`) that an annotation
|
||||
must possess to potentially match this class.
|
||||
match_type : Literal["all", "any"], default="all"
|
||||
Determines how the `tag` list is evaluated:
|
||||
- "all": The annotation must have *all* the tags listed in the `tag`
|
||||
field to match this class definition.
|
||||
- "any": The annotation must have *at least one* of the tags listed
|
||||
in the `tag` field to match this class definition.
|
||||
"""
|
||||
|
||||
name: str
|
||||
tags: List[TagInfo] = Field(default_factory=list, min_length=1)
|
||||
match_type: Literal["all", "any"] = Field(default="all")
|
||||
|
||||
|
||||
class ClassesConfig(BaseConfig):
|
||||
"""Configuration model holding the list of target class definitions.
|
||||
|
||||
The order of `TargetClass` objects in the `classes` list defines the
|
||||
priority for classification. When encoding an annotation, the system checks
|
||||
against the class definitions in this sequence and assigns the name of the
|
||||
*first* matching class.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
classes : List[TargetClass]
|
||||
An ordered list of target class definitions. The order determines
|
||||
matching priority (first match wins).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If validation fails (e.g., non-unique class names).
|
||||
|
||||
Notes
|
||||
-----
|
||||
It is crucial that the `name` attribute of each `TargetClass` in the
|
||||
`classes` list is unique. This configuration includes a validator to
|
||||
enforce this uniqueness.
|
||||
"""
|
||||
|
||||
classes: List[TargetClass] = Field(default_factory=list)
|
||||
|
||||
@field_validator("classes")
|
||||
def check_unique_class_names(cls, v: List[TargetClass]):
|
||||
"""Ensure all defined class names are unique."""
|
||||
names = [c.name for c in v]
|
||||
|
||||
if len(names) != len(set(names)):
|
||||
name_counts = Counter(names)
|
||||
duplicates = [
|
||||
name for name, count in name_counts.items() if count > 1
|
||||
]
|
||||
raise ValueError(
|
||||
"Class names must be unique. Found duplicates: "
|
||||
f"{', '.join(duplicates)}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def load_classes_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> ClassesConfig:
|
||||
"""Load the target classes configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : str, optional
|
||||
If the classes configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ClassesConfig
|
||||
The loaded and validated classes configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the ClassesConfig schema
|
||||
or if class names are not unique.
|
||||
"""
|
||||
return load_config(path, schema=ClassesConfig, field=field)
|
||||
|
||||
|
||||
def is_target_class(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
match_all: bool = True,
|
||||
) -> bool:
|
||||
"""Check if a sound event annotation matches a set of required tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
required_tags : Set[data.Tag]
|
||||
A set of `soundevent.data.Tag` objects that define the class criteria.
|
||||
match_all : bool, default=True
|
||||
If True, checks if *all* `required_tags` are present in the
|
||||
annotation's tags (subset check). If False, checks if *at least one*
|
||||
of the `required_tags` is present (intersection check).
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation meets the tag criteria, False otherwise.
|
||||
"""
|
||||
annotation_tags = set(sound_event_annotation.tags)
|
||||
|
||||
if match_all:
|
||||
return tags <= annotation_tags
|
||||
|
||||
return bool(tags & annotation_tags)
|
||||
|
||||
|
||||
def get_class_names_from_config(config: ClassesConfig) -> List[str]:
|
||||
"""Extract the list of class names from a ClassesConfig object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ClassesConfig
|
||||
The loaded classes configuration object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
An ordered list of unique class names defined in the configuration.
|
||||
"""
|
||||
return [class_info.name for class_info in config.classes]
|
||||
|
||||
|
||||
def build_encoder_from_config(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
"""Build a sound event encoder function from the classes configuration.
|
||||
|
||||
The returned encoder function iterates through the class definitions in the
|
||||
order specified in the config. It assigns an annotation the name of the
|
||||
first class definition it matches.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ClassesConfig
|
||||
The loaded and validated classes configuration object.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance used to look up term keys specified in the
|
||||
`TagInfo` objects within the configuration. Defaults to the global
|
||||
`batdetect2.targets.terms.registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventEncoder
|
||||
A callable function that takes a `SoundEventAnnotation` and returns
|
||||
an optional string representing the matched class name, or None if no
|
||||
class matches.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term key specified in the configuration is not found in the
|
||||
provided `term_registry`.
|
||||
"""
|
||||
binary_classifiers = [
|
||||
(
|
||||
class_info.name,
|
||||
partial(
|
||||
is_target_class,
|
||||
tags={
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in class_info.tags
|
||||
},
|
||||
match_all=class_info.match_type == "all",
|
||||
),
|
||||
)
|
||||
for class_info in config.classes
|
||||
]
|
||||
|
||||
def encoder(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> Optional[str]:
|
||||
"""Assign a class name to an annotation based on configured rules.
|
||||
|
||||
Iterates through pre-compiled classifiers in priority order. Returns
|
||||
the name of the first matching class, or None if no match is found.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to encode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The name of the matched class, or None.
|
||||
"""
|
||||
for class_name, classifier in binary_classifiers:
|
||||
if classifier(sound_event_annotation):
|
||||
return class_name
|
||||
|
||||
return None
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def load_encoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
"""Load a class encoder function directly from a configuration file.
|
||||
|
||||
This is a convenience function that combines loading the `ClassesConfig`
|
||||
from a file and building the final `SoundEventEncoder` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
If the classes configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance used for term lookups. Defaults to the
|
||||
global `batdetect2.targets.terms.registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventEncoder
|
||||
The final encoder function ready to classify annotations.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the ClassesConfig schema
|
||||
or if class names are not unique.
|
||||
KeyError
|
||||
If a term key specified in the configuration is not found in the
|
||||
provided `term_registry` during the build process.
|
||||
"""
|
||||
config = load_classes_config(path, field=field)
|
||||
return build_encoder_from_config(config, term_registry=term_registry)
|
@ -143,7 +143,7 @@ class TermRegistry(Mapping[str, data.Term]):
|
||||
raise KeyError(
|
||||
"No term found for key "
|
||||
f"'{key}'. Ensure it is registered or loaded. "
|
||||
"Use `get_term_keys()` to list available terms."
|
||||
f"Available keys: {', '.join(self.get_keys())}"
|
||||
) from err
|
||||
|
||||
def add_custom_term(
|
||||
@ -215,8 +215,11 @@ class TermRegistry(Mapping[str, data.Term]):
|
||||
"""
|
||||
return list(self._terms.values())
|
||||
|
||||
def remove_key(self, key: str) -> None:
|
||||
del self._terms[key]
|
||||
|
||||
registry = TermRegistry(
|
||||
|
||||
term_registry = TermRegistry(
|
||||
terms=dict(
|
||||
[
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
@ -237,7 +240,7 @@ is explicitly passed.
|
||||
|
||||
def get_term_from_key(
|
||||
key: str,
|
||||
term_registry: TermRegistry = registry,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> data.Term:
|
||||
"""Convenience function to retrieve a term by key from a registry.
|
||||
|
||||
@ -265,7 +268,7 @@ def get_term_from_key(
|
||||
return term_registry.get_term(key)
|
||||
|
||||
|
||||
def get_term_keys(term_registry: TermRegistry = registry) -> List[str]:
|
||||
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
|
||||
"""Convenience function to get all registered keys from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
@ -284,7 +287,7 @@ def get_term_keys(term_registry: TermRegistry = registry) -> List[str]:
|
||||
return term_registry.get_keys()
|
||||
|
||||
|
||||
def get_terms(term_registry: TermRegistry = registry) -> List[data.Term]:
|
||||
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
|
||||
"""Convenience function to get all registered terms from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
@ -327,7 +330,7 @@ class TagInfo(BaseModel):
|
||||
|
||||
def get_tag_from_info(
|
||||
tag_info: TagInfo,
|
||||
term_registry: TermRegistry = registry,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> data.Tag:
|
||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||
|
||||
@ -424,7 +427,7 @@ class TermConfig(BaseModel):
|
||||
def load_terms_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = registry,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> Dict[str, data.Term]:
|
||||
"""Loads term definitions from a configuration file and registers them.
|
||||
|
||||
@ -472,3 +475,9 @@ def load_terms_from_config(
|
||||
)
|
||||
for info in data.terms
|
||||
}
|
||||
|
||||
|
||||
def register_term(
|
||||
key: str, term: data.Term, registry: TermRegistry = term_registry
|
||||
) -> None:
|
||||
registry.add_term(key, term)
|
||||
|
@ -1,6 +1,15 @@
|
||||
import importlib
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Literal, Mapping, Optional, Union
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -8,9 +17,13 @@ from soundevent import data
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
term_registry as default_term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SoundEventTransformation",
|
||||
@ -282,9 +295,13 @@ class TransformConfig(BaseConfig):
|
||||
discriminates between the different rule models.
|
||||
"""
|
||||
|
||||
rules: List[Union[ReplaceRule, MapValueRule, DeriveTagRule]] = Field(
|
||||
rules: List[
|
||||
Annotated[
|
||||
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||
Field(discriminator="rule_type"),
|
||||
]
|
||||
] = Field(
|
||||
default_factory=list,
|
||||
discriminator="rule_type",
|
||||
)
|
||||
|
||||
|
||||
@ -442,7 +459,8 @@ def get_derivation(
|
||||
|
||||
def build_transform_from_rule(
|
||||
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||
registry: DerivationRegistry = derivation_registry,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a specific SoundEventTransformation function from a rule config.
|
||||
|
||||
@ -473,21 +491,33 @@ def build_transform_from_rule(
|
||||
If dynamic import of a derivation function fails.
|
||||
"""
|
||||
if rule.rule_type == "replace":
|
||||
source = get_tag_from_info(rule.original)
|
||||
target = get_tag_from_info(rule.replacement)
|
||||
source = get_tag_from_info(
|
||||
rule.original,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
target = get_tag_from_info(
|
||||
rule.replacement,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
return partial(replace_tag_transform, source=source, target=target)
|
||||
|
||||
if rule.rule_type == "derive_tag":
|
||||
source_term = get_term_from_key(rule.source_term_key)
|
||||
source_term = get_term_from_key(
|
||||
rule.source_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
target_term = (
|
||||
get_term_from_key(rule.target_term_key)
|
||||
get_term_from_key(
|
||||
rule.target_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
if rule.target_term_key
|
||||
else source_term
|
||||
)
|
||||
derivation = get_derivation(
|
||||
key=rule.derivation_function,
|
||||
import_derivation=rule.import_derivation,
|
||||
registry=registry,
|
||||
registry=derivation_registry,
|
||||
)
|
||||
return partial(
|
||||
derivation_tag_transform,
|
||||
@ -498,9 +528,15 @@ def build_transform_from_rule(
|
||||
)
|
||||
|
||||
if rule.rule_type == "map_value":
|
||||
source_term = get_term_from_key(rule.source_term_key)
|
||||
source_term = get_term_from_key(
|
||||
rule.source_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
target_term = (
|
||||
get_term_from_key(rule.target_term_key)
|
||||
get_term_from_key(
|
||||
rule.target_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
if rule.target_term_key
|
||||
else source_term
|
||||
)
|
||||
@ -522,6 +558,8 @@ def build_transform_from_rule(
|
||||
|
||||
def build_transformation_from_config(
|
||||
config: TransformConfig,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a composite transformation function from a TransformConfig.
|
||||
|
||||
@ -542,7 +580,14 @@ def build_transformation_from_config(
|
||||
SoundEventTransformation
|
||||
A single function that applies all configured transformations in order.
|
||||
"""
|
||||
transforms = [build_transform_from_rule(rule) for rule in config.rules]
|
||||
transforms = [
|
||||
build_transform_from_rule(
|
||||
rule,
|
||||
derivation_registry=derivation_registry,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
for rule in config.rules
|
||||
]
|
||||
|
||||
def transformation(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
@ -583,7 +628,10 @@ def load_transformation_config(
|
||||
|
||||
|
||||
def load_transformation_from_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventTransformation:
|
||||
"""Load transformation config from a file and build the final function.
|
||||
|
||||
@ -618,4 +666,8 @@ def load_transformation_from_config(
|
||||
fails.
|
||||
"""
|
||||
config = load_transformation_config(path=path, field=field)
|
||||
return build_transformation_from_config(config)
|
||||
return build_transformation_from_config(
|
||||
config,
|
||||
derivation_registry=derivation_registry,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
|
148
docs/targets/classes.md
Normal file
148
docs/targets/classes.md
Normal file
@ -0,0 +1,148 @@
|
||||
## Step 4: Defining Target Classes for Training
|
||||
|
||||
### Purpose and Context
|
||||
|
||||
You've prepared your data by defining your annotation vocabulary (Step 1: Terms), removing irrelevant sounds (Step 2: Filtering), and potentially cleaning up or modifying tags (Step 3: Transforming Tags).
|
||||
Now, it's time to tell `batdetect2` **exactly what categories (classes) your model should learn to identify**.
|
||||
|
||||
This step involves defining rules that map the final tags on your sound event annotations to specific **class names** (like `pippip`, `myodau`, or `noise`).
|
||||
These class names are the labels the machine learning model will be trained to predict.
|
||||
Getting this definition right is essential for successful model training.
|
||||
|
||||
### How it Works: Defining Classes with Rules
|
||||
|
||||
You define your target classes in your main configuration file (e.g., your `.yaml` training config), typically under a section named `classes`.
|
||||
This section contains a **list** of class definitions.
|
||||
Each item in the list defines one specific class your model should learn.
|
||||
|
||||
### Defining a Single Class
|
||||
|
||||
Each class definition rule requires a few key pieces of information:
|
||||
|
||||
1. `name`: **(Required)** This is the unique, simple name you want to give this class (e.g., `pipistrellus_pipistrellus`, `myotis_daubentonii`, `echolocation_noise`).
|
||||
This is the label the model will actually use.
|
||||
Choose names that are clear and distinct.
|
||||
**Each class name must be unique.**
|
||||
2. `tags`: **(Required)** This is a list containing one or more specific tags that identify annotations belonging to this class.
|
||||
Remember, each tag is specified using its term `key` (like `species` or `sound_type`, defaulting to `class` if omitted) and its specific `value` (like `Pipistrellus pipistrellus` or `Echolocation`).
|
||||
3. `match_type`: **(Optional, defaults to `"all"`)** This tells the system how to use the list of tags you provided in the `tag` field:
|
||||
- `"all"`: An annotation must have **ALL** of the tags listed in the `tags` section to be considered part of this class.
|
||||
(This is the default if you don't specify `match_type`).
|
||||
- `"any"`: An annotation only needs to have **AT LEAST ONE** of the tags listed in the `tags` section to be considered part of this class.
|
||||
|
||||
**Example: Defining two specific bat species classes**
|
||||
|
||||
```yaml
|
||||
# In your main configuration file
|
||||
classes:
|
||||
# Definition for the first class
|
||||
- name: pippip # Simple name for Pipistrellus pipistrellus
|
||||
tags:
|
||||
- key: species # Term key (could also default to 'class')
|
||||
value: Pipistrellus pipistrellus # Specific tag value
|
||||
# match_type defaults to "all" (which is fine for a single tag)
|
||||
|
||||
# Definition for the second class
|
||||
- name: myodau # Simple name for Myotis daubentonii
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis daubentonii
|
||||
```
|
||||
|
||||
**Example: Defining a class requiring multiple conditions (`match_type: "all"`)**
|
||||
|
||||
```yaml
|
||||
classes:
|
||||
- name: high_quality_pippip # Name for high-quality P. pip calls
|
||||
match_type: all # Annotation must match BOTH tags below
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- key: quality # Assumes 'quality' term key exists
|
||||
value: Good
|
||||
```
|
||||
|
||||
**Example: Defining a class matching multiple alternative tags (`match_type: "any"`)**
|
||||
|
||||
```yaml
|
||||
classes:
|
||||
- name: pipistrelle # Name for any Pipistrellus species in this list
|
||||
match_type: any # Annotation must match AT LEAST ONE tag below
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- key: species
|
||||
value: Pipistrellus pygmaeus
|
||||
- key: species
|
||||
value: Pipistrellus nathusii
|
||||
```
|
||||
|
||||
### Handling Overlap: Priority Order Matters!
|
||||
|
||||
Sometimes, an annotation might have tags that match the rules for _more than one_ class definition.
|
||||
For example, an annotation tagged `species: Pipistrellus pipistrellus` would match both a specific `'pippip'` class rule and a broader `'pipistrelle'` genus rule (like the examples above) if both were defined.
|
||||
|
||||
How does `batdetect2` decide which class name to assign? It uses the **order of the class definitions in your configuration list**.
|
||||
|
||||
- The system checks an annotation against your class rules one by one, starting from the **top** of the `classes` list and moving down.
|
||||
- As soon as it finds a rule that the annotation matches, it assigns that rule's `name` to the annotation and **stops checking** further rules for that annotation.
|
||||
- **The first match wins!**
|
||||
|
||||
Therefore, you should generally place your **most specific rules before more general rules** if you want the specific category to take precedence.
|
||||
|
||||
**Example: Prioritizing Species over Noise**
|
||||
|
||||
```yaml
|
||||
classes:
|
||||
# --- Specific Species Rules (Checked First) ---
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
|
||||
- name: myodau
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis daubentonii
|
||||
|
||||
# --- General Noise Rule (Checked Last) ---
|
||||
- name: noise # Catch-all for anything tagged as Noise
|
||||
match_type: any # Match if any noise tag is present
|
||||
tags:
|
||||
- key: sound_type # Assume 'sound_type' term key exists
|
||||
value: Noise
|
||||
- key: quality # Assume 'quality' term key exists
|
||||
value: Low # Maybe low quality is also considered noise for training
|
||||
```
|
||||
|
||||
In this example, an annotation tagged with `species: Myotis daubentonii` _and_ `quality: Low` would be assigned the class name `myodau` because that rule comes first in the list.
|
||||
It would not be assigned `noise`, even though it also matches the second condition of the noise rule.
|
||||
|
||||
Okay, that's a very important clarification about how BatDetect2 handles sounds that don't match specific class definitions.
|
||||
Let's refine that section to accurately reflect this behavior.
|
||||
|
||||
### What if No Class Matches? (The Generic "Bat" Class)
|
||||
|
||||
It's important to understand what happens if a sound event annotation passes through the filtering (Step 2) and transformation (Step 3) steps, but its final set of tags doesn't match _any_ of the specific class definitions you've listed in this section.
|
||||
|
||||
These annotations are **not ignored** during training.
|
||||
Instead, they are typically assigned to a **generic "relevant sound" class**.
|
||||
Think of this as a category for sounds that you considered important enough to keep after filtering, but which don't fit into one of your specific target classes for detailed classification (like a particular species).
|
||||
This generic class is distinct from background noise.
|
||||
|
||||
In BatDetect2, this default generic class is often referred to as the **"Bat"** class.
|
||||
The goal is generally that all relevant bat echolocation calls that pass the initial filtering should fall into _either_ one of your specific defined classes (like `pippip` or `myodau`) _or_ this generic "Bat" class.
|
||||
|
||||
**In summary:**
|
||||
|
||||
- Sounds passing **filtering** are considered relevant.
|
||||
- If a relevant sound matches one of your **specific class rules** (in priority order), it gets that specific class label.
|
||||
- If a relevant sound does **not** match any specific class rule, it gets the **generic "Bat" class** label.
|
||||
|
||||
**Crucially:** If you want certain types of sounds (even if they are bat calls) to be **completely excluded** from the training process altogether (not even included in the generic "Bat" class), you **must remove them using rules in the Filtering step (Step 2)**.
|
||||
Any sound annotation that makes it past filtering _will_ be used in training, either under one of your specific classes or the generic one.
|
||||
|
||||
### Outcome
|
||||
|
||||
By defining this list of prioritized class rules, you provide `batdetect2` with a clear procedure to assign a specific target label (your class `name`) to each relevant sound event annotation based on its tags.
|
||||
This labelled data is exactly what the model needs for training (Step 5).
|
361
tests/test_targets/test_transform.py
Normal file
361
tests/test_targets/test_transform.py
Normal file
@ -0,0 +1,361 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import (
|
||||
DeriveTagRule,
|
||||
MapValueRule,
|
||||
ReplaceRule,
|
||||
TagInfo,
|
||||
TransformConfig,
|
||||
build_transform_from_rule,
|
||||
build_transformation_from_config,
|
||||
)
|
||||
from batdetect2.targets.terms import TermRegistry
|
||||
from batdetect2.targets.transform import DerivationRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def term_registry():
|
||||
return TermRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def derivation_registry():
|
||||
return DerivationRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def term1(term_registry: TermRegistry) -> data.Term:
|
||||
return term_registry.add_custom_term(key="term1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def term2(term_registry: TermRegistry) -> data.Term:
|
||||
return term_registry.add_custom_term(key="term2")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def annotation(
|
||||
sound_event: data.SoundEvent,
|
||||
term1: data.Term,
|
||||
) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=sound_event, tags=[data.Tag(term=term1, value="value1")]
|
||||
)
|
||||
|
||||
|
||||
def test_map_value_rule(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
rule = MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"value1": "value2"},
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_map_value_rule_no_match(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
rule = MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"other_value": "value2"},
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
|
||||
|
||||
def test_replace_rule(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term2: data.Term,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
rule = ReplaceRule(
|
||||
rule_type="replace",
|
||||
original=TagInfo(key="term1", value="value1"),
|
||||
replacement=TagInfo(key="term2", value="value2"),
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].term == term2
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_replace_rule_no_match(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term2: data.Term,
|
||||
):
|
||||
rule = ReplaceRule(
|
||||
rule_type="replace",
|
||||
original=TagInfo(key="term1", value="wrong_value"),
|
||||
replacement=TagInfo(key="term2", value="value2"),
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].key == "term1"
|
||||
assert transformed_annotation.tags[0].term != term2
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
|
||||
|
||||
def test_build_transformation_from_config(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
):
|
||||
config = TransformConfig(
|
||||
rules=[
|
||||
MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"value1": "value2"},
|
||||
),
|
||||
ReplaceRule(
|
||||
rule_type="replace",
|
||||
original=TagInfo(key="term2", value="value2"),
|
||||
replacement=TagInfo(key="term3", value="value3"),
|
||||
),
|
||||
]
|
||||
)
|
||||
term_registry.add_custom_term("term2")
|
||||
term_registry.add_custom_term("term3")
|
||||
transform = build_transformation_from_config(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
transformed_annotation = transform(annotation)
|
||||
assert transformed_annotation.tags[0].key == "term1"
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_derive_tag_rule(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
def derivation_func(x: str) -> str:
|
||||
return x + "_derived"
|
||||
|
||||
derivation_registry.register("my_derivation", derivation_func)
|
||||
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="my_derivation",
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 2
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
assert transformed_annotation.tags[1].term == term1
|
||||
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||
|
||||
|
||||
def test_derive_tag_rule_keep_source_false(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
def derivation_func(x: str) -> str:
|
||||
return x + "_derived"
|
||||
|
||||
derivation_registry.register("my_derivation", derivation_func)
|
||||
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="my_derivation",
|
||||
keep_source=False,
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 1
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value1_derived"
|
||||
|
||||
|
||||
def test_derive_tag_rule_target_term(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
term2: data.Term,
|
||||
):
|
||||
def derivation_func(x: str) -> str:
|
||||
return x + "_derived"
|
||||
|
||||
derivation_registry.register("my_derivation", derivation_func)
|
||||
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="my_derivation",
|
||||
target_term_key="term2",
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 2
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
assert transformed_annotation.tags[1].term == term2
|
||||
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||
|
||||
|
||||
def test_derive_tag_rule_import_derivation(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term1: data.Term,
|
||||
tmp_path: Path,
|
||||
):
|
||||
# Create a dummy derivation function in a temporary file
|
||||
derivation_module_path = (
|
||||
tmp_path / "temp_derivation.py"
|
||||
) # Changed to /tmp since /home/santiago is not writable
|
||||
derivation_module_path.write_text(
|
||||
"""
|
||||
def my_imported_derivation(x: str) -> str:
|
||||
return x + "_imported"
|
||||
"""
|
||||
)
|
||||
# Ensure the temporary file is importable by adding its directory to sys.path
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, str(tmp_path))
|
||||
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="temp_derivation.my_imported_derivation",
|
||||
import_derivation=True,
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 2
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
assert transformed_annotation.tags[1].term == term1
|
||||
assert transformed_annotation.tags[1].value == "value1_imported"
|
||||
|
||||
# Clean up the temporary file and sys.path
|
||||
sys.path.remove(str(tmp_path))
|
||||
|
||||
|
||||
def test_derive_tag_rule_invalid_derivation(term_registry: TermRegistry):
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="nonexistent_derivation",
|
||||
)
|
||||
with pytest.raises(KeyError):
|
||||
build_transform_from_rule(rule, term_registry=term_registry)
|
||||
|
||||
|
||||
def test_build_transform_from_rule_invalid_rule_type():
|
||||
class InvalidRule:
|
||||
rule_type = "invalid"
|
||||
|
||||
rule = InvalidRule() # type: ignore
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
build_transform_from_rule(rule) # type: ignore
|
||||
|
||||
|
||||
def test_map_value_rule_target_term(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term2: data.Term,
|
||||
):
|
||||
rule = MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"value1": "value2"},
|
||||
target_term_key="term2",
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].term == term2
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_map_value_rule_target_term_none(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
rule = MapValueRule(
|
||||
rule_type="map_value",
|
||||
source_term_key="term1",
|
||||
value_mapping={"value1": "value2"},
|
||||
target_term_key=None,
|
||||
)
|
||||
transform_fn = build_transform_from_rule(rule, term_registry=term_registry)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value2"
|
||||
|
||||
|
||||
def test_derive_tag_rule_target_term_none(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
term_registry: TermRegistry,
|
||||
derivation_registry: DerivationRegistry,
|
||||
term1: data.Term,
|
||||
):
|
||||
def derivation_func(x: str) -> str:
|
||||
return x + "_derived"
|
||||
|
||||
derivation_registry.register("my_derivation", derivation_func)
|
||||
|
||||
rule = DeriveTagRule(
|
||||
rule_type="derive_tag",
|
||||
source_term_key="term1",
|
||||
derivation_function="my_derivation",
|
||||
target_term_key=None,
|
||||
)
|
||||
transform_fn = build_transform_from_rule(
|
||||
rule,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
transformed_annotation = transform_fn(annotation)
|
||||
|
||||
assert len(transformed_annotation.tags) == 2
|
||||
assert transformed_annotation.tags[0].term == term1
|
||||
assert transformed_annotation.tags[0].value == "value1"
|
||||
assert transformed_annotation.tags[1].term == term1
|
||||
assert transformed_annotation.tags[1].value == "value1_derived"
|
||||
|
||||
|
||||
def test_build_transformation_from_config_empty(
|
||||
annotation: data.SoundEventAnnotation,
|
||||
):
|
||||
config = TransformConfig(rules=[])
|
||||
transform = build_transformation_from_config(config)
|
||||
transformed_annotation = transform(annotation)
|
||||
assert transformed_annotation == annotation
|
Loading…
Reference in New Issue
Block a user