mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
316 lines
8.7 KiB
Python
316 lines
8.7 KiB
Python
import logging
|
|
from functools import partial
|
|
from typing import Callable, List, Literal, Optional, Set
|
|
|
|
from pydantic import Field
|
|
from soundevent import data
|
|
|
|
from batdetect2.configs import BaseConfig, load_config
|
|
from batdetect2.targets.terms import (
|
|
TagInfo,
|
|
TermRegistry,
|
|
get_tag_from_info,
|
|
term_registry,
|
|
)
|
|
|
|
__all__ = [
|
|
"FilterConfig",
|
|
"FilterRule",
|
|
"SoundEventFilter",
|
|
"build_sound_event_filter",
|
|
"build_filter_from_rule",
|
|
"load_filter_config",
|
|
"load_filter_from_config",
|
|
]
|
|
|
|
|
|
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
|
"""Type alias for a filter function.
|
|
|
|
A filter function accepts a soundevent.data.SoundEventAnnotation object
|
|
and returns True if the annotation should be kept based on the filter's
|
|
criteria, or False if it should be discarded.
|
|
"""
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FilterRule(BaseConfig):
|
|
"""Defines a single rule for filtering sound event annotations.
|
|
|
|
Based on the `match_type`, this rule checks if the tags associated with a
|
|
sound event annotation meet certain criteria relative to the `tags` list
|
|
defined in this rule.
|
|
|
|
Attributes
|
|
----------
|
|
match_type : Literal["any", "all", "exclude", "equal"]
|
|
Determines how the `tags` list is used:
|
|
- "any": Pass if the annotation has at least one tag from the list.
|
|
- "all": Pass if the annotation has all tags from the list (it can
|
|
have others too).
|
|
- "exclude": Pass if the annotation has none of the tags from the list.
|
|
- "equal": Pass if the annotation's tags are exactly the same set as
|
|
provided in the list.
|
|
tags : List[TagInfo]
|
|
A list of tags (defined using TagInfo for configuration) that this
|
|
rule operates on.
|
|
"""
|
|
|
|
match_type: Literal["any", "all", "exclude", "equal"]
|
|
tags: List[TagInfo]
|
|
|
|
|
|
def has_any_tag(
|
|
sound_event_annotation: data.SoundEventAnnotation,
|
|
tags: Set[data.Tag],
|
|
) -> bool:
|
|
"""Check if the annotation has at least one of the specified tags.
|
|
|
|
Parameters
|
|
----------
|
|
sound_event_annotation : data.SoundEventAnnotation
|
|
The annotation to check.
|
|
tags : Set[data.Tag]
|
|
The set of tags to look for.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the annotation has one or more tags from the specified set,
|
|
False otherwise.
|
|
"""
|
|
sound_event_tags = set(sound_event_annotation.tags)
|
|
return bool(tags & sound_event_tags)
|
|
|
|
|
|
def contains_tags(
|
|
sound_event_annotation: data.SoundEventAnnotation,
|
|
tags: Set[data.Tag],
|
|
) -> bool:
|
|
"""Check if the annotation contains all of the specified tags.
|
|
|
|
The annotation may have additional tags beyond those specified.
|
|
|
|
Parameters
|
|
----------
|
|
sound_event_annotation : data.SoundEventAnnotation
|
|
The annotation to check.
|
|
tags : Set[data.Tag]
|
|
The set of tags that must all be present in the annotation.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the annotation's tags are a superset of the specified tags,
|
|
False otherwise.
|
|
"""
|
|
sound_event_tags = set(sound_event_annotation.tags)
|
|
return tags <= sound_event_tags
|
|
|
|
|
|
def does_not_have_tags(
|
|
sound_event_annotation: data.SoundEventAnnotation,
|
|
tags: Set[data.Tag],
|
|
):
|
|
"""Check if the annotation has none of the specified tags.
|
|
|
|
Parameters
|
|
----------
|
|
sound_event_annotation : data.SoundEventAnnotation
|
|
The annotation to check.
|
|
tags : Set[data.Tag]
|
|
The set of tags that must *not* be present in the annotation.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the annotation has zero tags in common with the specified set,
|
|
False otherwise.
|
|
"""
|
|
return not has_any_tag(sound_event_annotation, tags)
|
|
|
|
|
|
def equal_tags(
|
|
sound_event_annotation: data.SoundEventAnnotation,
|
|
tags: Set[data.Tag],
|
|
) -> bool:
|
|
"""Check if the annotation's tags are exactly equal to the specified set.
|
|
|
|
Parameters
|
|
----------
|
|
sound_event_annotation : data.SoundEventAnnotation
|
|
The annotation to check.
|
|
tags : Set[data.Tag]
|
|
The exact set of tags the annotation must have.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the annotation's tags set is identical to the specified set,
|
|
False otherwise.
|
|
"""
|
|
sound_event_tags = set(sound_event_annotation.tags)
|
|
return tags == sound_event_tags
|
|
|
|
|
|
def build_filter_from_rule(
|
|
rule: FilterRule,
|
|
term_registry: TermRegistry = term_registry,
|
|
) -> SoundEventFilter:
|
|
"""Creates a callable filter function from a single FilterRule.
|
|
|
|
Parameters
|
|
----------
|
|
rule : FilterRule
|
|
The filter rule configuration object.
|
|
|
|
Returns
|
|
-------
|
|
SoundEventFilter
|
|
A function that takes a SoundEventAnnotation and returns True if it
|
|
passes the rule, False otherwise.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If the rule contains an invalid `match_type`.
|
|
"""
|
|
tag_set = {
|
|
get_tag_from_info(tag_info, term_registry=term_registry)
|
|
for tag_info in rule.tags
|
|
}
|
|
|
|
if rule.match_type == "any":
|
|
return partial(has_any_tag, tags=tag_set)
|
|
|
|
if rule.match_type == "all":
|
|
return partial(contains_tags, tags=tag_set)
|
|
|
|
if rule.match_type == "exclude":
|
|
return partial(does_not_have_tags, tags=tag_set)
|
|
|
|
if rule.match_type == "equal":
|
|
return partial(equal_tags, tags=tag_set)
|
|
|
|
raise ValueError(
|
|
f"Invalid match type {rule.match_type}. Valid types "
|
|
"are: 'any', 'all', 'exclude' and 'equal'"
|
|
)
|
|
|
|
|
|
def _passes_all_filters(
|
|
sound_event_annotation: data.SoundEventAnnotation,
|
|
filters: List[SoundEventFilter],
|
|
) -> bool:
|
|
"""Check if the annotation passes all provided filters.
|
|
|
|
Parameters
|
|
----------
|
|
sound_event_annotation : data.SoundEventAnnotation
|
|
The annotation to check.
|
|
filters : List[SoundEventFilter]
|
|
A list of filter functions to apply.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the annotation passes all filters, False otherwise.
|
|
"""
|
|
for filter_fn in filters:
|
|
if not filter_fn(sound_event_annotation):
|
|
logging.debug(
|
|
f"Sound event annotation {sound_event_annotation.uuid} "
|
|
f"excluded due to rule {filter_fn}",
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class FilterConfig(BaseConfig):
|
|
"""Configuration model for defining a list of filter rules.
|
|
|
|
Attributes
|
|
----------
|
|
rules : List[FilterRule]
|
|
A list of FilterRule objects. An annotation must pass all rules in
|
|
this list to be considered valid by the filter built from this config.
|
|
"""
|
|
|
|
rules: List[FilterRule] = Field(default_factory=list)
|
|
|
|
|
|
def build_sound_event_filter(
|
|
config: FilterConfig,
|
|
term_registry: TermRegistry = term_registry,
|
|
) -> SoundEventFilter:
|
|
"""Builds a merged filter function from a FilterConfig object.
|
|
|
|
Creates individual filter functions for each rule in the configuration
|
|
and merges them using AND logic.
|
|
|
|
Parameters
|
|
----------
|
|
config : FilterConfig
|
|
The configuration object containing the list of filter rules.
|
|
|
|
Returns
|
|
-------
|
|
SoundEventFilter
|
|
A single callable filter function that applies all defined rules.
|
|
"""
|
|
filters = [
|
|
build_filter_from_rule(rule, term_registry=term_registry)
|
|
for rule in config.rules
|
|
]
|
|
return partial(_passes_all_filters, filters=filters)
|
|
|
|
|
|
def load_filter_config(
|
|
path: data.PathLike, field: Optional[str] = None
|
|
) -> FilterConfig:
|
|
"""Loads the filter configuration from a file.
|
|
|
|
Parameters
|
|
----------
|
|
path : data.PathLike
|
|
Path to the configuration file (YAML).
|
|
field : Optional[str], optional
|
|
If the filter configuration is nested under a specific key in the
|
|
file, specify the key here. Defaults to None.
|
|
|
|
Returns
|
|
-------
|
|
FilterConfig
|
|
The loaded and validated filter configuration object.
|
|
"""
|
|
return load_config(path, schema=FilterConfig, field=field)
|
|
|
|
|
|
def load_filter_from_config(
|
|
path: data.PathLike,
|
|
field: Optional[str] = None,
|
|
term_registry: TermRegistry = term_registry,
|
|
) -> SoundEventFilter:
|
|
"""Loads filter configuration from a file and builds the filter function.
|
|
|
|
This is a convenience function that combines loading the configuration
|
|
and building the final callable filter function.
|
|
|
|
Parameters
|
|
----------
|
|
path : data.PathLike
|
|
Path to the configuration file (YAML).
|
|
field : Optional[str], optional
|
|
If the filter configuration is nested under a specific key in the
|
|
file, specify the key here. Defaults to None.
|
|
|
|
Returns
|
|
-------
|
|
SoundEventFilter
|
|
The final merged filter function ready to be used.
|
|
"""
|
|
config = load_filter_config(path=path, field=field)
|
|
return build_sound_event_filter(config, term_registry=term_registry)
|