mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Create dedicated filtering module
This commit is contained in:
parent
26a2c5c851
commit
2fb3039f17
297
batdetect2/targets/filtering.py
Normal file
297
batdetect2/targets/filtering.py
Normal file
@ -0,0 +1,297 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Callable, List, Literal, Optional, Set
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.terms import TagInfo, get_tag_from_info
|
||||
|
||||
__all__ = [
|
||||
"build_filter_from_config",
|
||||
"SoundEventFilter",
|
||||
]
|
||||
|
||||
|
||||
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
||||
"""Type alias for a filter function.
|
||||
|
||||
A filter function accepts a soundevent.data.SoundEventAnnotation object
|
||||
and returns True if the annotation should be kept based on the filter's
|
||||
criteria, or False if it should be discarded.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FilterRule(BaseConfig):
|
||||
"""Defines a single rule for filtering sound event annotations.
|
||||
|
||||
Based on the `match_type`, this rule checks if the tags associated with a
|
||||
sound event annotation meet certain criteria relative to the `tags` list
|
||||
defined in this rule.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
match_type : Literal["any", "all", "exclude", "equal"]
|
||||
Determines how the `tags` list is used:
|
||||
- "any": Pass if the annotation has at least one tag from the list.
|
||||
- "all": Pass if the annotation has all tags from the list (it can
|
||||
have others too).
|
||||
- "exclude": Pass if the annotation has none of the tags from the list.
|
||||
- "equal": Pass if the annotation's tags are exactly the same set as
|
||||
provided in the list.
|
||||
tags : List[TagInfo]
|
||||
A list of tags (defined using TagInfo for configuration) that this
|
||||
rule operates on.
|
||||
"""
|
||||
|
||||
match_type: Literal["any", "all", "exclude", "equal"]
|
||||
tags: List[TagInfo]
|
||||
|
||||
|
||||
def has_any_tag(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
) -> bool:
|
||||
"""Check if the annotation has at least one of the specified tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The set of tags to look for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation has one or more tags from the specified set,
|
||||
False otherwise.
|
||||
"""
|
||||
sound_event_tags = set(sound_event_annotation.tags)
|
||||
return bool(tags & sound_event_tags)
|
||||
|
||||
|
||||
def contains_tags(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
) -> bool:
|
||||
"""Check if the annotation contains all of the specified tags.
|
||||
|
||||
The annotation may have additional tags beyond those specified.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The set of tags that must all be present in the annotation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation's tags are a superset of the specified tags,
|
||||
False otherwise.
|
||||
"""
|
||||
sound_event_tags = set(sound_event_annotation.tags)
|
||||
return tags < sound_event_tags
|
||||
|
||||
|
||||
def does_not_have_tags(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
):
|
||||
"""Check if the annotation has none of the specified tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The set of tags that must *not* be present in the annotation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation has zero tags in common with the specified set,
|
||||
False otherwise.
|
||||
"""
|
||||
return not has_any_tag(sound_event_annotation, tags)
|
||||
|
||||
|
||||
def equal_tags(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
) -> bool:
|
||||
"""Check if the annotation's tags are exactly equal to the specified set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The exact set of tags the annotation must have.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation's tags set is identical to the specified set,
|
||||
False otherwise.
|
||||
"""
|
||||
sound_event_tags = set(sound_event_annotation.tags)
|
||||
return tags == sound_event_tags
|
||||
|
||||
|
||||
def build_filter_from_rule(rule: FilterRule) -> SoundEventFilter:
|
||||
"""Creates a callable filter function from a single FilterRule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rule : FilterRule
|
||||
The filter rule configuration object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventFilter
|
||||
A function that takes a SoundEventAnnotation and returns True if it
|
||||
passes the rule, False otherwise.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the rule contains an invalid `match_type`.
|
||||
"""
|
||||
tag_set = {get_tag_from_info(tag_info) for tag_info in rule.tags}
|
||||
|
||||
if rule.match_type == "any":
|
||||
return partial(has_any_tag, tags=tag_set)
|
||||
|
||||
if rule.match_type == "all":
|
||||
return partial(contains_tags, tags=tag_set)
|
||||
|
||||
if rule.match_type == "exclude":
|
||||
return partial(does_not_have_tags, tags=tag_set)
|
||||
|
||||
if rule.match_type == "equal":
|
||||
return partial(equal_tags, tags=tag_set)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid match type {rule.match_type}. Valid types "
|
||||
"are: 'any', 'all', 'exclude' and 'equal'"
|
||||
)
|
||||
|
||||
|
||||
def merge_filters(*filters: SoundEventFilter) -> SoundEventFilter:
|
||||
"""Combines multiple filter functions into a single filter function.
|
||||
|
||||
The resulting filter function applies AND logic: an annotation must pass
|
||||
*all* the input filters to pass the merged filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*filters_with_rules : Tuple[FilterRule, SoundEventFilter]
|
||||
Variable number of tuples, each containing the original FilterRule
|
||||
and its corresponding filter function (SoundEventFilter).
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventFilter
|
||||
A single function that returns True only if the annotation passes
|
||||
all the input filters.
|
||||
"""
|
||||
|
||||
def merged_filter(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
for filter_fn in filters:
|
||||
if not filter_fn(sound_event_annotation):
|
||||
logging.debug(
|
||||
f"Sound event annotation {sound_event_annotation.uuid} "
|
||||
f"excluded due to rule {filter_fn}",
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return merged_filter
|
||||
|
||||
|
||||
class FilterConfig(BaseConfig):
|
||||
"""Configuration model for defining a list of filter rules.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rules : List[FilterRule]
|
||||
A list of FilterRule objects. An annotation must pass all rules in
|
||||
this list to be considered valid by the filter built from this config.
|
||||
"""
|
||||
|
||||
rules: List[FilterRule] = Field(default_factory=list)
|
||||
|
||||
|
||||
def build_filter_from_config(config: FilterConfig) -> SoundEventFilter:
|
||||
"""Builds a merged filter function from a FilterConfig object.
|
||||
|
||||
Creates individual filter functions for each rule in the configuration
|
||||
and merges them using AND logic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : FilterConfig
|
||||
The configuration object containing the list of filter rules.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventFilter
|
||||
A single callable filter function that applies all defined rules.
|
||||
"""
|
||||
filters = [build_filter_from_rule(rule) for rule in config.rules]
|
||||
return merge_filters(*filters)
|
||||
|
||||
|
||||
def load_filter_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> FilterConfig:
|
||||
"""Loads the filter configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : Optional[str], optional
|
||||
If the filter configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FilterConfig
|
||||
The loaded and validated filter configuration object.
|
||||
"""
|
||||
return load_config(path, schema=FilterConfig, field=field)
|
||||
|
||||
|
||||
def load_filter_from_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> SoundEventFilter:
|
||||
"""Loads filter configuration from a file and builds the filter function.
|
||||
|
||||
This is a convenience function that combines loading the configuration
|
||||
and building the final callable filter function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : Optional[str], optional
|
||||
If the filter configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventFilter
|
||||
The final merged filter function ready to be used.
|
||||
"""
|
||||
config = load_filter_config(path=path, field=field)
|
||||
return build_filter_from_config(config)
|
@ -14,7 +14,6 @@ __all__ = [
|
||||
"load_target_config",
|
||||
"build_target_encoder",
|
||||
"build_decoder",
|
||||
"filter_sound_event",
|
||||
]
|
||||
|
||||
|
||||
@ -52,23 +51,6 @@ class TargetConfig(BaseConfig):
|
||||
replace: Optional[List[ReplaceConfig]] = None
|
||||
|
||||
|
||||
def build_sound_event_filter(
|
||||
include: Optional[List[TagInfo]] = None,
|
||||
exclude: Optional[List[TagInfo]] = None,
|
||||
) -> Callable[[data.SoundEventAnnotation], bool]:
|
||||
include_tags = (
|
||||
{get_tag_from_info(tag) for tag in include} if include else None
|
||||
)
|
||||
exclude_tags = (
|
||||
{get_tag_from_info(tag) for tag in exclude} if exclude else None
|
||||
)
|
||||
return partial(
|
||||
filter_sound_event,
|
||||
include=include_tags,
|
||||
exclude=exclude_tags,
|
||||
)
|
||||
|
||||
|
||||
def get_tag_label(tag_info: TagInfo) -> str:
|
||||
return tag_info.label if tag_info.label else tag_info.value
|
||||
|
||||
@ -138,22 +120,6 @@ def build_decoder(
|
||||
return decoder
|
||||
|
||||
|
||||
def filter_sound_event(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
include: Optional[Set[data.Tag]] = None,
|
||||
exclude: Optional[Set[data.Tag]] = None,
|
||||
) -> bool:
|
||||
tags = set(sound_event_annotation.tags)
|
||||
|
||||
if include is not None and not tags & include:
|
||||
return False
|
||||
|
||||
if exclude is not None and tags & exclude:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: Path, field: Optional[str] = None
|
||||
) -> TargetConfig:
|
||||
|
Loading…
Reference in New Issue
Block a user