diff --git a/batdetect2/targets/filtering.py b/batdetect2/targets/filtering.py new file mode 100644 index 0000000..ed8cf82 --- /dev/null +++ b/batdetect2/targets/filtering.py @@ -0,0 +1,297 @@ +import logging +from functools import partial +from typing import Callable, List, Literal, Optional, Set + +from pydantic import Field +from soundevent import data + +from batdetect2.configs import BaseConfig, load_config +from batdetect2.targets.terms import TagInfo, get_tag_from_info + +__all__ = [ + "build_filter_from_config", + "SoundEventFilter", +] + + +SoundEventFilter = Callable[[data.SoundEventAnnotation], bool] +"""Type alias for a filter function. + +A filter function accepts a soundevent.data.SoundEventAnnotation object +and returns True if the annotation should be kept based on the filter's +criteria, or False if it should be discarded. +""" + +logger = logging.getLogger(__name__) + + +class FilterRule(BaseConfig): + """Defines a single rule for filtering sound event annotations. + + Based on the `match_type`, this rule checks if the tags associated with a + sound event annotation meet certain criteria relative to the `tags` list + defined in this rule. + + Attributes + ---------- + match_type : Literal["any", "all", "exclude", "equal"] + Determines how the `tags` list is used: + - "any": Pass if the annotation has at least one tag from the list. + - "all": Pass if the annotation has all tags from the list (it can + have others too). + - "exclude": Pass if the annotation has none of the tags from the list. + - "equal": Pass if the annotation's tags are exactly the same set as + provided in the list. + tags : List[TagInfo] + A list of tags (defined using TagInfo for configuration) that this + rule operates on. + """ + + match_type: Literal["any", "all", "exclude", "equal"] + tags: List[TagInfo] + + +def has_any_tag( + sound_event_annotation: data.SoundEventAnnotation, + tags: Set[data.Tag], +) -> bool: + """Check if the annotation has at least one of the specified tags. + + Parameters + ---------- + sound_event_annotation : data.SoundEventAnnotation + The annotation to check. + tags : Set[data.Tag] + The set of tags to look for. + + Returns + ------- + bool + True if the annotation has one or more tags from the specified set, + False otherwise. + """ + sound_event_tags = set(sound_event_annotation.tags) + return bool(tags & sound_event_tags) + + +def contains_tags( + sound_event_annotation: data.SoundEventAnnotation, + tags: Set[data.Tag], +) -> bool: + """Check if the annotation contains all of the specified tags. + + The annotation may have additional tags beyond those specified. + + Parameters + ---------- + sound_event_annotation : data.SoundEventAnnotation + The annotation to check. + tags : Set[data.Tag] + The set of tags that must all be present in the annotation. + + Returns + ------- + bool + True if the annotation's tags are a superset of the specified tags, + False otherwise. + """ + sound_event_tags = set(sound_event_annotation.tags) + return tags < sound_event_tags + + +def does_not_have_tags( + sound_event_annotation: data.SoundEventAnnotation, + tags: Set[data.Tag], +): + """Check if the annotation has none of the specified tags. + + Parameters + ---------- + sound_event_annotation : data.SoundEventAnnotation + The annotation to check. + tags : Set[data.Tag] + The set of tags that must *not* be present in the annotation. + + Returns + ------- + bool + True if the annotation has zero tags in common with the specified set, + False otherwise. + """ + return not has_any_tag(sound_event_annotation, tags) + + +def equal_tags( + sound_event_annotation: data.SoundEventAnnotation, + tags: Set[data.Tag], +) -> bool: + """Check if the annotation's tags are exactly equal to the specified set. + + Parameters + ---------- + sound_event_annotation : data.SoundEventAnnotation + The annotation to check. + tags : Set[data.Tag] + The exact set of tags the annotation must have. + + Returns + ------- + bool + True if the annotation's tags set is identical to the specified set, + False otherwise. + """ + sound_event_tags = set(sound_event_annotation.tags) + return tags == sound_event_tags + + +def build_filter_from_rule(rule: FilterRule) -> SoundEventFilter: + """Creates a callable filter function from a single FilterRule. + + Parameters + ---------- + rule : FilterRule + The filter rule configuration object. + + Returns + ------- + SoundEventFilter + A function that takes a SoundEventAnnotation and returns True if it + passes the rule, False otherwise. + + Raises + ------ + ValueError + If the rule contains an invalid `match_type`. + """ + tag_set = {get_tag_from_info(tag_info) for tag_info in rule.tags} + + if rule.match_type == "any": + return partial(has_any_tag, tags=tag_set) + + if rule.match_type == "all": + return partial(contains_tags, tags=tag_set) + + if rule.match_type == "exclude": + return partial(does_not_have_tags, tags=tag_set) + + if rule.match_type == "equal": + return partial(equal_tags, tags=tag_set) + + raise ValueError( + f"Invalid match type {rule.match_type}. Valid types " + "are: 'any', 'all', 'exclude' and 'equal'" + ) + + +def merge_filters(*filters: SoundEventFilter) -> SoundEventFilter: + """Combines multiple filter functions into a single filter function. + + The resulting filter function applies AND logic: an annotation must pass + *all* the input filters to pass the merged filter. + + Parameters + ---------- + *filters_with_rules : Tuple[FilterRule, SoundEventFilter] + Variable number of tuples, each containing the original FilterRule + and its corresponding filter function (SoundEventFilter). + + Returns + ------- + SoundEventFilter + A single function that returns True only if the annotation passes + all the input filters. + """ + + def merged_filter( + sound_event_annotation: data.SoundEventAnnotation, + ) -> bool: + for filter_fn in filters: + if not filter_fn(sound_event_annotation): + logging.debug( + f"Sound event annotation {sound_event_annotation.uuid} " + f"excluded due to rule {filter_fn}", + ) + return False + + return True + + return merged_filter + + +class FilterConfig(BaseConfig): + """Configuration model for defining a list of filter rules. + + Attributes + ---------- + rules : List[FilterRule] + A list of FilterRule objects. An annotation must pass all rules in + this list to be considered valid by the filter built from this config. + """ + + rules: List[FilterRule] = Field(default_factory=list) + + +def build_filter_from_config(config: FilterConfig) -> SoundEventFilter: + """Builds a merged filter function from a FilterConfig object. + + Creates individual filter functions for each rule in the configuration + and merges them using AND logic. + + Parameters + ---------- + config : FilterConfig + The configuration object containing the list of filter rules. + + Returns + ------- + SoundEventFilter + A single callable filter function that applies all defined rules. + """ + filters = [build_filter_from_rule(rule) for rule in config.rules] + return merge_filters(*filters) + + +def load_filter_config( + path: data.PathLike, field: Optional[str] = None +) -> FilterConfig: + """Loads the filter configuration from a file. + + Parameters + ---------- + path : data.PathLike + Path to the configuration file (YAML). + field : Optional[str], optional + If the filter configuration is nested under a specific key in the + file, specify the key here. Defaults to None. + + Returns + ------- + FilterConfig + The loaded and validated filter configuration object. + """ + return load_config(path, schema=FilterConfig, field=field) + + +def load_filter_from_config( + path: data.PathLike, field: Optional[str] = None +) -> SoundEventFilter: + """Loads filter configuration from a file and builds the filter function. + + This is a convenience function that combines loading the configuration + and building the final callable filter function. + + Parameters + ---------- + path : data.PathLike + Path to the configuration file (YAML). + field : Optional[str], optional + If the filter configuration is nested under a specific key in the + file, specify the key here. Defaults to None. + + Returns + ------- + SoundEventFilter + The final merged filter function ready to be used. + """ + config = load_filter_config(path=path, field=field) + return build_filter_from_config(config) diff --git a/batdetect2/targets/targets.py b/batdetect2/targets/targets.py index e04005e..ad49114 100644 --- a/batdetect2/targets/targets.py +++ b/batdetect2/targets/targets.py @@ -14,7 +14,6 @@ __all__ = [ "load_target_config", "build_target_encoder", "build_decoder", - "filter_sound_event", ] @@ -52,23 +51,6 @@ class TargetConfig(BaseConfig): replace: Optional[List[ReplaceConfig]] = None -def build_sound_event_filter( - include: Optional[List[TagInfo]] = None, - exclude: Optional[List[TagInfo]] = None, -) -> Callable[[data.SoundEventAnnotation], bool]: - include_tags = ( - {get_tag_from_info(tag) for tag in include} if include else None - ) - exclude_tags = ( - {get_tag_from_info(tag) for tag in exclude} if exclude else None - ) - return partial( - filter_sound_event, - include=include_tags, - exclude=exclude_tags, - ) - - def get_tag_label(tag_info: TagInfo) -> str: return tag_info.label if tag_info.label else tag_info.value @@ -138,22 +120,6 @@ def build_decoder( return decoder -def filter_sound_event( - sound_event_annotation: data.SoundEventAnnotation, - include: Optional[Set[data.Tag]] = None, - exclude: Optional[Set[data.Tag]] = None, -) -> bool: - tags = set(sound_event_annotation.tags) - - if include is not None and not tags & include: - return False - - if exclude is not None and tags & exclude: - return False - - return True - - def load_target_config( path: Path, field: Optional[str] = None ) -> TargetConfig: