mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Create dedicated filtering module
This commit is contained in:
parent
26a2c5c851
commit
2fb3039f17
297
batdetect2/targets/filtering.py
Normal file
297
batdetect2/targets/filtering.py
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, List, Literal, Optional, Set
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.targets.terms import TagInfo, get_tag_from_info
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_filter_from_config",
|
||||||
|
"SoundEventFilter",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
||||||
|
"""Type alias for a filter function.
|
||||||
|
|
||||||
|
A filter function accepts a soundevent.data.SoundEventAnnotation object
|
||||||
|
and returns True if the annotation should be kept based on the filter's
|
||||||
|
criteria, or False if it should be discarded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FilterRule(BaseConfig):
|
||||||
|
"""Defines a single rule for filtering sound event annotations.
|
||||||
|
|
||||||
|
Based on the `match_type`, this rule checks if the tags associated with a
|
||||||
|
sound event annotation meet certain criteria relative to the `tags` list
|
||||||
|
defined in this rule.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
match_type : Literal["any", "all", "exclude", "equal"]
|
||||||
|
Determines how the `tags` list is used:
|
||||||
|
- "any": Pass if the annotation has at least one tag from the list.
|
||||||
|
- "all": Pass if the annotation has all tags from the list (it can
|
||||||
|
have others too).
|
||||||
|
- "exclude": Pass if the annotation has none of the tags from the list.
|
||||||
|
- "equal": Pass if the annotation's tags are exactly the same set as
|
||||||
|
provided in the list.
|
||||||
|
tags : List[TagInfo]
|
||||||
|
A list of tags (defined using TagInfo for configuration) that this
|
||||||
|
rule operates on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
match_type: Literal["any", "all", "exclude", "equal"]
|
||||||
|
tags: List[TagInfo]
|
||||||
|
|
||||||
|
|
||||||
|
def has_any_tag(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
tags: Set[data.Tag],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the annotation has at least one of the specified tags.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to check.
|
||||||
|
tags : Set[data.Tag]
|
||||||
|
The set of tags to look for.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation has one or more tags from the specified set,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
sound_event_tags = set(sound_event_annotation.tags)
|
||||||
|
return bool(tags & sound_event_tags)
|
||||||
|
|
||||||
|
|
||||||
|
def contains_tags(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
tags: Set[data.Tag],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the annotation contains all of the specified tags.
|
||||||
|
|
||||||
|
The annotation may have additional tags beyond those specified.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to check.
|
||||||
|
tags : Set[data.Tag]
|
||||||
|
The set of tags that must all be present in the annotation.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation's tags are a superset of the specified tags,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
sound_event_tags = set(sound_event_annotation.tags)
|
||||||
|
return tags < sound_event_tags
|
||||||
|
|
||||||
|
|
||||||
|
def does_not_have_tags(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
tags: Set[data.Tag],
|
||||||
|
):
|
||||||
|
"""Check if the annotation has none of the specified tags.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to check.
|
||||||
|
tags : Set[data.Tag]
|
||||||
|
The set of tags that must *not* be present in the annotation.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation has zero tags in common with the specified set,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
return not has_any_tag(sound_event_annotation, tags)
|
||||||
|
|
||||||
|
|
||||||
|
def equal_tags(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
tags: Set[data.Tag],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the annotation's tags are exactly equal to the specified set.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event_annotation : data.SoundEventAnnotation
|
||||||
|
The annotation to check.
|
||||||
|
tags : Set[data.Tag]
|
||||||
|
The exact set of tags the annotation must have.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation's tags set is identical to the specified set,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
sound_event_tags = set(sound_event_annotation.tags)
|
||||||
|
return tags == sound_event_tags
|
||||||
|
|
||||||
|
|
||||||
|
def build_filter_from_rule(rule: FilterRule) -> SoundEventFilter:
|
||||||
|
"""Creates a callable filter function from a single FilterRule.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rule : FilterRule
|
||||||
|
The filter rule configuration object.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventFilter
|
||||||
|
A function that takes a SoundEventAnnotation and returns True if it
|
||||||
|
passes the rule, False otherwise.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the rule contains an invalid `match_type`.
|
||||||
|
"""
|
||||||
|
tag_set = {get_tag_from_info(tag_info) for tag_info in rule.tags}
|
||||||
|
|
||||||
|
if rule.match_type == "any":
|
||||||
|
return partial(has_any_tag, tags=tag_set)
|
||||||
|
|
||||||
|
if rule.match_type == "all":
|
||||||
|
return partial(contains_tags, tags=tag_set)
|
||||||
|
|
||||||
|
if rule.match_type == "exclude":
|
||||||
|
return partial(does_not_have_tags, tags=tag_set)
|
||||||
|
|
||||||
|
if rule.match_type == "equal":
|
||||||
|
return partial(equal_tags, tags=tag_set)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid match type {rule.match_type}. Valid types "
|
||||||
|
"are: 'any', 'all', 'exclude' and 'equal'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_filters(*filters: SoundEventFilter) -> SoundEventFilter:
|
||||||
|
"""Combines multiple filter functions into a single filter function.
|
||||||
|
|
||||||
|
The resulting filter function applies AND logic: an annotation must pass
|
||||||
|
*all* the input filters to pass the merged filter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
*filters_with_rules : Tuple[FilterRule, SoundEventFilter]
|
||||||
|
Variable number of tuples, each containing the original FilterRule
|
||||||
|
and its corresponding filter function (SoundEventFilter).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventFilter
|
||||||
|
A single function that returns True only if the annotation passes
|
||||||
|
all the input filters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def merged_filter(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> bool:
|
||||||
|
for filter_fn in filters:
|
||||||
|
if not filter_fn(sound_event_annotation):
|
||||||
|
logging.debug(
|
||||||
|
f"Sound event annotation {sound_event_annotation.uuid} "
|
||||||
|
f"excluded due to rule {filter_fn}",
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
return merged_filter
|
||||||
|
|
||||||
|
|
||||||
|
class FilterConfig(BaseConfig):
|
||||||
|
"""Configuration model for defining a list of filter rules.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
rules : List[FilterRule]
|
||||||
|
A list of FilterRule objects. An annotation must pass all rules in
|
||||||
|
this list to be considered valid by the filter built from this config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rules: List[FilterRule] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def build_filter_from_config(config: FilterConfig) -> SoundEventFilter:
|
||||||
|
"""Builds a merged filter function from a FilterConfig object.
|
||||||
|
|
||||||
|
Creates individual filter functions for each rule in the configuration
|
||||||
|
and merges them using AND logic.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : FilterConfig
|
||||||
|
The configuration object containing the list of filter rules.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventFilter
|
||||||
|
A single callable filter function that applies all defined rules.
|
||||||
|
"""
|
||||||
|
filters = [build_filter_from_rule(rule) for rule in config.rules]
|
||||||
|
return merge_filters(*filters)
|
||||||
|
|
||||||
|
|
||||||
|
def load_filter_config(
|
||||||
|
path: data.PathLike, field: Optional[str] = None
|
||||||
|
) -> FilterConfig:
|
||||||
|
"""Loads the filter configuration from a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (YAML).
|
||||||
|
field : Optional[str], optional
|
||||||
|
If the filter configuration is nested under a specific key in the
|
||||||
|
file, specify the key here. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
FilterConfig
|
||||||
|
The loaded and validated filter configuration object.
|
||||||
|
"""
|
||||||
|
return load_config(path, schema=FilterConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def load_filter_from_config(
|
||||||
|
path: data.PathLike, field: Optional[str] = None
|
||||||
|
) -> SoundEventFilter:
|
||||||
|
"""Loads filter configuration from a file and builds the filter function.
|
||||||
|
|
||||||
|
This is a convenience function that combines loading the configuration
|
||||||
|
and building the final callable filter function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file (YAML).
|
||||||
|
field : Optional[str], optional
|
||||||
|
If the filter configuration is nested under a specific key in the
|
||||||
|
file, specify the key here. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
SoundEventFilter
|
||||||
|
The final merged filter function ready to be used.
|
||||||
|
"""
|
||||||
|
config = load_filter_config(path=path, field=field)
|
||||||
|
return build_filter_from_config(config)
|
@ -14,7 +14,6 @@ __all__ = [
|
|||||||
"load_target_config",
|
"load_target_config",
|
||||||
"build_target_encoder",
|
"build_target_encoder",
|
||||||
"build_decoder",
|
"build_decoder",
|
||||||
"filter_sound_event",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -52,23 +51,6 @@ class TargetConfig(BaseConfig):
|
|||||||
replace: Optional[List[ReplaceConfig]] = None
|
replace: Optional[List[ReplaceConfig]] = None
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_filter(
|
|
||||||
include: Optional[List[TagInfo]] = None,
|
|
||||||
exclude: Optional[List[TagInfo]] = None,
|
|
||||||
) -> Callable[[data.SoundEventAnnotation], bool]:
|
|
||||||
include_tags = (
|
|
||||||
{get_tag_from_info(tag) for tag in include} if include else None
|
|
||||||
)
|
|
||||||
exclude_tags = (
|
|
||||||
{get_tag_from_info(tag) for tag in exclude} if exclude else None
|
|
||||||
)
|
|
||||||
return partial(
|
|
||||||
filter_sound_event,
|
|
||||||
include=include_tags,
|
|
||||||
exclude=exclude_tags,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tag_label(tag_info: TagInfo) -> str:
|
def get_tag_label(tag_info: TagInfo) -> str:
|
||||||
return tag_info.label if tag_info.label else tag_info.value
|
return tag_info.label if tag_info.label else tag_info.value
|
||||||
|
|
||||||
@ -138,22 +120,6 @@ def build_decoder(
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def filter_sound_event(
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
include: Optional[Set[data.Tag]] = None,
|
|
||||||
exclude: Optional[Set[data.Tag]] = None,
|
|
||||||
) -> bool:
|
|
||||||
tags = set(sound_event_annotation.tags)
|
|
||||||
|
|
||||||
if include is not None and not tags & include:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if exclude is not None and tags & exclude:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def load_target_config(
|
def load_target_config(
|
||||||
path: Path, field: Optional[str] = None
|
path: Path, field: Optional[str] = None
|
||||||
) -> TargetConfig:
|
) -> TargetConfig:
|
||||||
|
Loading…
Reference in New Issue
Block a user