From c8dd4155bfe0e8147ddb888d872f700bf28d0ae3 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Fri, 3 Apr 2026 16:40:11 +0100 Subject: [PATCH] Add conditions for clips and recordings --- src/batdetect2/data/conditions.py | 312 ----------------- src/batdetect2/data/conditions/__init__.py | 71 ++++ src/batdetect2/data/conditions/clips.py | 132 ++++++++ src/batdetect2/data/conditions/common.py | 314 ++++++++++++++++++ src/batdetect2/data/conditions/recordings.py | 98 ++++++ .../data/conditions/sound_events.py | 236 +++++++++++++ .../__init__.py | 0 tests/test_data/test_conditions/test_clip.py | 132 ++++++++ .../test_conditions/test_recording.py | 139 ++++++++ .../test_sound_events.py} | 37 ++- 10 files changed, 1157 insertions(+), 314 deletions(-) delete mode 100644 src/batdetect2/data/conditions.py create mode 100644 src/batdetect2/data/conditions/__init__.py create mode 100644 src/batdetect2/data/conditions/clips.py create mode 100644 src/batdetect2/data/conditions/common.py create mode 100644 src/batdetect2/data/conditions/recordings.py create mode 100644 src/batdetect2/data/conditions/sound_events.py rename tests/test_data/{test_transforms => test_conditions}/__init__.py (100%) create mode 100644 tests/test_data/test_conditions/test_clip.py create mode 100644 tests/test_data/test_conditions/test_recording.py rename tests/test_data/{test_transforms/test_conditions.py => test_conditions/test_sound_events.py} (91%) diff --git a/src/batdetect2/data/conditions.py b/src/batdetect2/data/conditions.py deleted file mode 100644 index ac52152..0000000 --- a/src/batdetect2/data/conditions.py +++ /dev/null @@ -1,312 +0,0 @@ -from collections.abc import Callable -from typing import Annotated, List, Literal, Sequence - -from pydantic import Field -from soundevent import data -from soundevent.geometry import compute_bounds - -from batdetect2.core.configs import BaseConfig -from batdetect2.core.registries import ( - ImportConfig, - Registry, - add_import_config, -) - -SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] - -conditions: Registry[SoundEventCondition, []] = Registry("condition") - - -@add_import_config(conditions) -class SoundEventConditionImportConfig(ImportConfig): - """Use any callable as a sound event condition. - - Set ``name="import"`` and provide a ``target`` pointing to any - callable to use it instead of a built-in option. - """ - - name: Literal["import"] = "import" - - -class HasTagConfig(BaseConfig): - name: Literal["has_tag"] = "has_tag" - tag: data.Tag - - -class HasTag: - def __init__(self, tag: data.Tag): - self.tag = tag - - def __call__( - self, sound_event_annotation: data.SoundEventAnnotation - ) -> bool: - return any( - self.tag.term.name == tag.term.name and self.tag.value == tag.value - for tag in sound_event_annotation.tags - ) - - @conditions.register(HasTagConfig) - @staticmethod - def from_config(config: HasTagConfig): - return HasTag(tag=config.tag) - - -class HasAllTagsConfig(BaseConfig): - name: Literal["has_all_tags"] = "has_all_tags" - tags: List[data.Tag] - - -class HasAllTags: - def __init__(self, tags: List[data.Tag]): - if not tags: - raise ValueError("Need to specify at least one tag") - - self.tags = {(tag.term.name, tag.value) for tag in tags} - - def __call__( - self, sound_event_annotation: data.SoundEventAnnotation - ) -> bool: - return self.tags.issubset( - {(tag.term.name, tag.value) for tag in sound_event_annotation.tags} - ) - - @conditions.register(HasAllTagsConfig) - @staticmethod - def from_config(config: HasAllTagsConfig): - return HasAllTags(tags=config.tags) - - -class HasAnyTagConfig(BaseConfig): - name: Literal["has_any_tag"] = "has_any_tag" - tags: List[data.Tag] - - -class HasAnyTag: - def __init__(self, tags: List[data.Tag]): - if not tags: - raise ValueError("Need to specify at least one tag") - - self.tags = {(tag.term.name, tag.value) for tag in tags} - - def __call__( - self, sound_event_annotation: data.SoundEventAnnotation - ) -> bool: - return bool( - self.tags.intersection( - { - (tag.term.name, tag.value) - for tag in sound_event_annotation.tags - } - ) - ) - - @conditions.register(HasAnyTagConfig) - @staticmethod - def from_config(config: HasAnyTagConfig): - return HasAnyTag(tags=config.tags) - - -Operator = Literal["gt", "gte", "lt", "lte", "eq"] - - -class DurationConfig(BaseConfig): - name: Literal["duration"] = "duration" - operator: Operator - seconds: float - - -def _build_comparator( - operator: Operator, value: float -) -> Callable[[float], bool]: - if operator == "gt": - return lambda x: x > value - - if operator == "gte": - return lambda x: x >= value - - if operator == "lt": - return lambda x: x < value - - if operator == "lte": - return lambda x: x <= value - - if operator == "eq": - return lambda x: x == value - - raise ValueError(f"Invalid operator {operator}") - - -class Duration: - def __init__(self, operator: Operator, seconds: float): - self.operator = operator - self.seconds = seconds - self._comparator = _build_comparator(self.operator, self.seconds) - - def __call__( - self, - sound_event_annotation: data.SoundEventAnnotation, - ) -> bool: - geometry = sound_event_annotation.sound_event.geometry - - if geometry is None: - return False - - start_time, _, end_time, _ = compute_bounds(geometry) - duration = end_time - start_time - - return self._comparator(duration) - - @conditions.register(DurationConfig) - @staticmethod - def from_config(config: DurationConfig): - return Duration(operator=config.operator, seconds=config.seconds) - - -class FrequencyConfig(BaseConfig): - name: Literal["frequency"] = "frequency" - boundary: Literal["low", "high"] - operator: Operator - hertz: float - - -class Frequency: - def __init__( - self, - operator: Operator, - boundary: Literal["low", "high"], - hertz: float, - ): - self.operator = operator - self.hertz = hertz - self.boundary = boundary - self._comparator = _build_comparator(self.operator, self.hertz) - - def __call__( - self, - sound_event_annotation: data.SoundEventAnnotation, - ) -> bool: - geometry = sound_event_annotation.sound_event.geometry - - if geometry is None: - return False - - # Automatically false if geometry does not have a frequency range - if isinstance(geometry, (data.TimeInterval, data.TimeStamp)): - return False - - _, low_freq, _, high_freq = compute_bounds(geometry) - - if self.boundary == "low": - return self._comparator(low_freq) - - return self._comparator(high_freq) - - @conditions.register(FrequencyConfig) - @staticmethod - def from_config(config: FrequencyConfig): - return Frequency( - operator=config.operator, - boundary=config.boundary, - hertz=config.hertz, - ) - - -class AllOfConfig(BaseConfig): - name: Literal["all_of"] = "all_of" - conditions: Sequence["SoundEventConditionConfig"] - - -class AllOf: - def __init__(self, conditions: List[SoundEventCondition]): - self.conditions = conditions - - def __call__( - self, sound_event_annotation: data.SoundEventAnnotation - ) -> bool: - return all(c(sound_event_annotation) for c in self.conditions) - - @conditions.register(AllOfConfig) - @staticmethod - def from_config(config: AllOfConfig): - conditions = [ - build_sound_event_condition(cond) for cond in config.conditions - ] - return AllOf(conditions) - - -class AnyOfConfig(BaseConfig): - name: Literal["any_of"] = "any_of" - conditions: List["SoundEventConditionConfig"] - - -class AnyOf: - def __init__(self, conditions: List[SoundEventCondition]): - self.conditions = conditions - - def __call__( - self, sound_event_annotation: data.SoundEventAnnotation - ) -> bool: - return any(c(sound_event_annotation) for c in self.conditions) - - @conditions.register(AnyOfConfig) - @staticmethod - def from_config(config: AnyOfConfig): - conditions = [ - build_sound_event_condition(cond) for cond in config.conditions - ] - return AnyOf(conditions) - - -class NotConfig(BaseConfig): - name: Literal["not"] = "not" - condition: "SoundEventConditionConfig" - - -class Not: - def __init__(self, condition: SoundEventCondition): - self.condition = condition - - def __call__( - self, sound_event_annotation: data.SoundEventAnnotation - ) -> bool: - return not self.condition(sound_event_annotation) - - @conditions.register(NotConfig) - @staticmethod - def from_config(config: NotConfig): - condition = build_sound_event_condition(config.condition) - return Not(condition) - - -SoundEventConditionConfig = Annotated[ - HasTagConfig - | HasAllTagsConfig - | HasAnyTagConfig - | DurationConfig - | FrequencyConfig - | AllOfConfig - | AnyOfConfig - | NotConfig, - Field(discriminator="name"), -] - - -def build_sound_event_condition( - config: SoundEventConditionConfig, -) -> SoundEventCondition: - return conditions.build(config) - - -def filter_clip_annotation( - clip_annotation: data.ClipAnnotation, - condition: SoundEventCondition, -) -> data.ClipAnnotation: - return clip_annotation.model_copy( - update=dict( - sound_events=[ - sound_event - for sound_event in clip_annotation.sound_events - if condition(sound_event) - ] - ) - ) diff --git a/src/batdetect2/data/conditions/__init__.py b/src/batdetect2/data/conditions/__init__.py new file mode 100644 index 0000000..1115c9c --- /dev/null +++ b/src/batdetect2/data/conditions/__init__.py @@ -0,0 +1,71 @@ +from batdetect2.data.conditions.clips import ( + ClipAllOfConfig, + ClipAnnotationCondition, + ClipAnnotationConditionConfig, + ClipAnnotationConditionImportConfig, + ClipAnyOfConfig, + ClipNotConfig, + RecordingSatisfiesConfig, + build_clip_annotation_condition, +) +from batdetect2.data.conditions.common import ( + HasAllTagsConfig, + HasAnyTagConfig, + HasTagConfig, + IdInListConfig, +) +from batdetect2.data.conditions.recordings import ( + RecordingAllOfConfig, + RecordingAnyOfConfig, + RecordingCondition, + RecordingConditionConfig, + RecordingConditionImportConfig, + RecordingNotConfig, + build_recording_condition, +) +from batdetect2.data.conditions.sound_events import ( + AllOfConfig, + AnyOfConfig, + DurationConfig, + FrequencyConfig, + NotConfig, + Operator, + SoundEventCondition, + SoundEventConditionConfig, + SoundEventConditionImportConfig, + build_sound_event_condition, + filter_clip_annotation, +) + +__all__ = [ + "AllOfConfig", + "AnyOfConfig", + "ClipAllOfConfig", + "ClipAnnotationCondition", + "ClipAnnotationConditionConfig", + "ClipAnnotationConditionImportConfig", + "ClipAnyOfConfig", + "ClipNotConfig", + "DurationConfig", + "FrequencyConfig", + "HasAllTagsConfig", + "HasAnyTagConfig", + "HasTagConfig", + "IdInListConfig", + "NotConfig", + "Operator", + "RecordingCondition", + "RecordingConditionConfig", + "RecordingConditionImportConfig", + "RecordingAllOfConfig", + "RecordingAnyOfConfig", + "RecordingNotConfig", + "RecordingSatisfiesConfig", + "SoundEventCondition", + "SoundEventConditionConfig", + "SoundEventConditionImportConfig", + "build_clip_annotation_condition", + "build_recording_condition", + "build_sound_event_condition", + "filter_clip_annotation", +] diff --git a/src/batdetect2/data/conditions/clips.py b/src/batdetect2/data/conditions/clips.py new file mode 100644 index 0000000..4ee187b --- /dev/null +++ b/src/batdetect2/data/conditions/clips.py @@ -0,0 +1,132 @@ +from collections.abc import Callable, Sequence +from typing import Annotated, Literal + +from pydantic import Field +from soundevent import data + +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) +from batdetect2.data.conditions.common import ( + HasAllTagsConfig, + HasAnyTagConfig, + HasTagConfig, + IdInListConfig, + MultiConditionConfigBase, + NotConditionConfigBase, + register_all_of_condition, + register_any_of_condition, + register_has_all_tags_condition, + register_has_any_tag_condition, + register_has_tag_condition, + register_id_in_list_condition, + register_not_condition, +) +from batdetect2.data.conditions.recordings import ( + RecordingCondition, + RecordingConditionConfig, + build_recording_condition, +) + +__all__ = [ + "ClipAllOfConfig", + "ClipAnnotationCondition", + "ClipAnnotationConditionConfig", + "ClipAnnotationConditionImportConfig", + "ClipAnyOfConfig", + "ClipNotConfig", + "RecordingSatisfiesConfig", + "build_clip_annotation_condition", +] + +ClipAnnotationCondition = Callable[[data.ClipAnnotation], bool] + +clip_annotation_conditions: Registry[ + ClipAnnotationCondition, + [data.PathLike | None], +] = Registry("clip_condition") + + +@add_import_config(clip_annotation_conditions, arg_names=["base_dir"]) +class ClipAnnotationConditionImportConfig(ImportConfig): + """Use any callable as a clip annotation condition. + + Set ``name="import"`` and provide a ``target`` pointing to any callable + to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + +class RecordingSatisfiesConfig(BaseConfig): + name: Literal["recording_satisfies"] = "recording_satisfies" + condition: RecordingConditionConfig + + +class RecordingSatisfies: + def __init__(self, condition: RecordingCondition): + self.condition = condition + + def __call__(self, clip_annotation: data.ClipAnnotation) -> bool: + recording = clip_annotation.clip.recording + return self.condition(recording) + + @clip_annotation_conditions.register(RecordingSatisfiesConfig) + @staticmethod + def from_config( + config: RecordingSatisfiesConfig, + base_dir: data.PathLike | None = None, + ) -> "RecordingSatisfies": + condition = build_recording_condition( + config.condition, + base_dir=base_dir, + ) + return RecordingSatisfies(condition) + + +register_has_tag_condition(clip_annotation_conditions)(HasTagConfig) +register_has_all_tags_condition(clip_annotation_conditions)(HasAllTagsConfig) +register_has_any_tag_condition(clip_annotation_conditions)(HasAnyTagConfig) +register_id_in_list_condition(clip_annotation_conditions)(IdInListConfig) + + +@register_all_of_condition(clip_annotation_conditions) +class ClipAllOfConfig(MultiConditionConfigBase): + name: Literal["all_of"] = "all_of" + conditions: Sequence["ClipAnnotationConditionConfig"] + + +@register_any_of_condition(clip_annotation_conditions) +class ClipAnyOfConfig(MultiConditionConfigBase): + name: Literal["any_of"] = "any_of" + conditions: Sequence["ClipAnnotationConditionConfig"] + + +@register_not_condition(clip_annotation_conditions) +class ClipNotConfig(NotConditionConfigBase): + name: Literal["not"] = "not" + condition: "ClipAnnotationConditionConfig" + + +ClipAnnotationConditionConfig = Annotated[ + RecordingSatisfiesConfig + | IdInListConfig + | HasTagConfig + | HasAllTagsConfig + | HasAnyTagConfig + | ClipAllOfConfig + | ClipAnyOfConfig + | ClipNotConfig + | ClipAnnotationConditionImportConfig, + Field(discriminator="name"), +] + + +def build_clip_annotation_condition( + config: ClipAnnotationConditionConfig, + base_dir: data.PathLike | None = None, +) -> ClipAnnotationCondition: + return clip_annotation_conditions.build(config, base_dir) diff --git a/src/batdetect2/data/conditions/common.py b/src/batdetect2/data/conditions/common.py new file mode 100644 index 0000000..2f18eb0 --- /dev/null +++ b/src/batdetect2/data/conditions/common.py @@ -0,0 +1,314 @@ +import json +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import Generic, Literal, ParamSpec, Protocol, TypeVar +from uuid import UUID + +from pydantic import BaseModel +from soundevent import data + +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import Registry + +__all__ = [ + "AllOf", + "AnyOf", + "Condition", + "HasAllTags", + "HasAllTagsConfig", + "HasAnyTag", + "HasAnyTagConfig", + "HasTag", + "HasTagConfig", + "IdInList", + "IdInListConfig", + "MultiConditionConfigBase", + "Not", + "NotConditionConfigBase", + "ObjectWithTags", + "ObjectWithUUID", + "register_all_of_condition", + "register_any_of_condition", + "register_has_all_tags_condition", + "register_has_any_tag_condition", + "register_has_tag_condition", + "register_id_in_list_condition", + "register_not_condition", +] + + +class ObjectWithTags(Protocol): + tags: list[data.Tag] + + +class ObjectWithUUID(Protocol): + uuid: UUID + + +ConditionObject = TypeVar("ConditionObject") +TaggedObject = TypeVar("TaggedObject", bound="ObjectWithTags") +UUIDObject = TypeVar("UUIDObject", bound="ObjectWithUUID") +P = ParamSpec("P") +NotConfigType = TypeVar("NotConfigType", bound="NotConditionConfigBase") +MultiConfigType = TypeVar( + "MultiConfigType", + bound="MultiConditionConfigBase", +) +Condition = Callable[[ConditionObject], bool] + + +class NotConditionConfigBase(BaseConfig): + condition: BaseModel + + +class MultiConditionConfigBase(BaseConfig): + conditions: Sequence[BaseModel] + + +class Not(Generic[ConditionObject]): + def __init__(self, condition: Condition[ConditionObject]): + self.condition = condition + + def __call__(self, obj: ConditionObject) -> bool: + return not self.condition(obj) + + +class AllOf(Generic[ConditionObject]): + def __init__(self, conditions: Sequence[Condition[ConditionObject]]): + self.conditions = list(conditions) + + def __call__(self, obj: ConditionObject) -> bool: + return all(condition(obj) for condition in self.conditions) + + +class AnyOf(Generic[ConditionObject]): + def __init__(self, conditions: Sequence[Condition[ConditionObject]]): + self.conditions = list(conditions) + + def __call__(self, obj: ConditionObject) -> bool: + return any(condition(obj) for condition in self.conditions) + + +class HasTag(Generic[TaggedObject]): + def __init__(self, tag: data.Tag): + self.tag_key = (tag.term.name, tag.value) + + def __call__(self, obj: TaggedObject) -> bool: + return any( + (tag.term.name, tag.value) == self.tag_key for tag in obj.tags + ) + + +class HasAllTags(Generic[TaggedObject]): + def __init__(self, tags: list[data.Tag]): + if not tags: + raise ValueError("Need to specify at least one tag") + + self.required_keys = {(tag.term.name, tag.value) for tag in tags} + + def __call__(self, obj: TaggedObject) -> bool: + tag_keys = {(tag.term.name, tag.value) for tag in obj.tags} + return self.required_keys.issubset(tag_keys) + + +class HasAnyTag(Generic[TaggedObject]): + def __init__(self, tags: list[data.Tag]): + if not tags: + raise ValueError("Need to specify at least one tag") + + self.required_keys = {(tag.term.name, tag.value) for tag in tags} + + def __call__(self, obj: TaggedObject) -> bool: + tag_keys = {(tag.term.name, tag.value) for tag in obj.tags} + return bool(self.required_keys.intersection(tag_keys)) + + +class IdInList(Generic[UUIDObject]): + def __init__(self, ids: set[UUID]): + self.ids = ids + + def __call__(self, obj: UUIDObject) -> bool: + return obj.uuid in self.ids + + +class HasTagConfig(BaseConfig): + name: Literal["has_tag"] = "has_tag" + tag: data.Tag + + +class HasAllTagsConfig(BaseConfig): + name: Literal["has_all_tags"] = "has_all_tags" + tags: list[data.Tag] + + +class HasAnyTagConfig(BaseConfig): + name: Literal["has_any_tag"] = "has_any_tag" + tags: list[data.Tag] + + +class IdInListConfig(BaseConfig): + name: Literal["id_in_list"] = "id_in_list" + path: Path + list_format: Literal["json", "txt"] = "json" + + +def _load_ids( + path: Path, + list_format: Literal["json", "txt"], +) -> list[str]: + if list_format == "json": + content = json.loads(path.read_text()) + + if not isinstance(content, list): + raise TypeError("Expected JSON list with IDs for 'id_in_list'.") + + return [str(value) for value in content] + + return [ + line.strip() for line in path.read_text().splitlines() if line.strip() + ] + + +def register_id_in_list_condition( + registry: Registry[Condition[UUIDObject], [data.PathLike | None]], +) -> Callable[[type[IdInListConfig]], type[IdInListConfig]]: + def decorator(config_cls: type[IdInListConfig]) -> type[IdInListConfig]: + @registry.register(config_cls) + def builder( + config: IdInListConfig, + base_dir: data.PathLike | None = None, + ) -> Condition[UUIDObject]: + path = config.path + + if base_dir is not None and not path.is_absolute(): + path = Path(base_dir) / path + + ids = set() + for index, value in enumerate(_load_ids(path, config.list_format)): + try: + ids.add(UUID(value)) + except ValueError as err: + raise ValueError( + f"Invalid ID at index {index} in '{path}': {value!r}." + ) from err + + return IdInList(ids) + + return config_cls + + return decorator + + +def register_has_tag_condition( + registry: Registry[Condition[TaggedObject], P], +) -> Callable[[type[HasTagConfig]], type[HasTagConfig]]: + def decorator(config_cls: type[HasTagConfig]) -> type[HasTagConfig]: + @registry.register(config_cls) + def builder( + config: HasTagConfig, + *args: P.args, + **kwargs: P.kwargs, + ) -> Condition[TaggedObject]: + return HasTag(config.tag) + + return config_cls + + return decorator + + +def register_has_all_tags_condition( + registry: Registry[Condition[TaggedObject], P], +) -> Callable[[type[HasAllTagsConfig]], type[HasAllTagsConfig]]: + def decorator( + config_cls: type[HasAllTagsConfig], + ) -> type[HasAllTagsConfig]: + @registry.register(config_cls) + def builder( + config: HasAllTagsConfig, + *args: P.args, + **kwargs: P.kwargs, + ) -> Condition[TaggedObject]: + return HasAllTags(config.tags) + + return config_cls + + return decorator + + +def register_has_any_tag_condition( + registry: Registry[Condition[TaggedObject], P], +) -> Callable[[type[HasAnyTagConfig]], type[HasAnyTagConfig]]: + def decorator( + config_cls: type[HasAnyTagConfig], + ) -> type[HasAnyTagConfig]: + @registry.register(config_cls) + def builder( + config: HasAnyTagConfig, + *args: P.args, + **kwargs: P.kwargs, + ) -> Condition[TaggedObject]: + return HasAnyTag(config.tags) + + return config_cls + + return decorator + + +def register_not_condition( + registry: Registry[Condition[ConditionObject], P], +) -> Callable[[type[NotConfigType]], type[NotConfigType]]: + def decorator(config_cls: type[NotConfigType]) -> type[NotConfigType]: + @registry.register(config_cls) + def builder( + config: NotConfigType, + *args: P.args, + **kwargs: P.kwargs, + ) -> Condition[ConditionObject]: + condition = registry.build(config.condition, *args, **kwargs) + return Not(condition) + + return config_cls + + return decorator + + +def register_all_of_condition( + registry: Registry[Condition[ConditionObject], P], +) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]: + def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]: + @registry.register(config_cls) + def builder( + config: MultiConfigType, + *args: P.args, + **kwargs: P.kwargs, + ) -> Condition[ConditionObject]: + conditions = [ + registry.build(condition, *args, **kwargs) + for condition in config.conditions + ] + return AllOf(conditions) + + return config_cls + + return decorator + + +def register_any_of_condition( + registry: Registry[Condition[ConditionObject], P], +) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]: + def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]: + @registry.register(config_cls) + def builder( + config: MultiConfigType, + *args: P.args, + **kwargs: P.kwargs, + ) -> Condition[ConditionObject]: + conditions = [ + registry.build(condition, *args, **kwargs) + for condition in config.conditions + ] + return AnyOf(conditions) + + return config_cls + + return decorator diff --git a/src/batdetect2/data/conditions/recordings.py b/src/batdetect2/data/conditions/recordings.py new file mode 100644 index 0000000..dcb4762 --- /dev/null +++ b/src/batdetect2/data/conditions/recordings.py @@ -0,0 +1,98 @@ +from collections.abc import Callable, Sequence +from typing import Annotated, Literal + +from pydantic import Field +from soundevent import data + +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) +from batdetect2.data.conditions.common import ( + HasAllTagsConfig, + HasAnyTagConfig, + HasTagConfig, + IdInListConfig, + MultiConditionConfigBase, + NotConditionConfigBase, + register_all_of_condition, + register_any_of_condition, + register_has_all_tags_condition, + register_has_any_tag_condition, + register_has_tag_condition, + register_id_in_list_condition, + register_not_condition, +) + +__all__ = [ + "IdInListConfig", + "RecordingAllOfConfig", + "RecordingAnyOfConfig", + "RecordingCondition", + "RecordingConditionConfig", + "RecordingConditionImportConfig", + "RecordingNotConfig", + "build_recording_condition", +] + +RecordingCondition = Callable[[data.Recording], bool] + +recording_conditions: Registry[RecordingCondition, [data.PathLike | None]] = ( + Registry("recording_condition") +) + + +@add_import_config(recording_conditions, arg_names=["base_dir"]) +class RecordingConditionImportConfig(ImportConfig): + """Use any callable as a recording condition. + + Set ``name="import"`` and provide a ``target`` pointing to any callable + to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + +register_id_in_list_condition(recording_conditions)(IdInListConfig) +register_has_tag_condition(recording_conditions)(HasTagConfig) +register_has_all_tags_condition(recording_conditions)(HasAllTagsConfig) +register_has_any_tag_condition(recording_conditions)(HasAnyTagConfig) + + +@register_all_of_condition(recording_conditions) +class RecordingAllOfConfig(MultiConditionConfigBase): + name: Literal["all_of"] = "all_of" + conditions: Sequence["RecordingConditionConfig"] + + +@register_any_of_condition(recording_conditions) +class RecordingAnyOfConfig(MultiConditionConfigBase): + name: Literal["any_of"] = "any_of" + conditions: Sequence["RecordingConditionConfig"] + + +@register_not_condition(recording_conditions) +class RecordingNotConfig(NotConditionConfigBase): + name: Literal["not"] = "not" + condition: "RecordingConditionConfig" + + +RecordingConditionConfig = Annotated[ + IdInListConfig + | HasTagConfig + | HasAllTagsConfig + | HasAnyTagConfig + | RecordingAllOfConfig + | RecordingAnyOfConfig + | RecordingNotConfig + | RecordingConditionImportConfig, + Field(discriminator="name"), +] + + +def build_recording_condition( + config: RecordingConditionConfig, + base_dir: data.PathLike | None = None, +) -> RecordingCondition: + return recording_conditions.build(config, base_dir) diff --git a/src/batdetect2/data/conditions/sound_events.py b/src/batdetect2/data/conditions/sound_events.py new file mode 100644 index 0000000..1fb38c0 --- /dev/null +++ b/src/batdetect2/data/conditions/sound_events.py @@ -0,0 +1,236 @@ +from collections.abc import Callable, Sequence +from typing import Annotated, Literal + +from pydantic import Field +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) +from batdetect2.data.conditions.common import ( + HasAllTagsConfig, + HasAnyTagConfig, + HasTagConfig, + IdInListConfig, + MultiConditionConfigBase, + NotConditionConfigBase, + register_all_of_condition, + register_any_of_condition, + register_has_all_tags_condition, + register_has_any_tag_condition, + register_has_tag_condition, + register_id_in_list_condition, + register_not_condition, +) + +__all__ = [ + "AllOfConfig", + "AnyOfConfig", + "DurationConfig", + "FrequencyConfig", + "HasAllTagsConfig", + "HasAnyTagConfig", + "HasTagConfig", + "NotConfig", + "Operator", + "SoundEventCondition", + "SoundEventConditionConfig", + "SoundEventConditionImportConfig", + "build_sound_event_condition", + "filter_clip_annotation", +] + +SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] + +sound_event_conditions: Registry[ + SoundEventCondition, + [data.PathLike | None], +] = Registry("sound_event_condition") + + +@add_import_config(sound_event_conditions, arg_names=["base_dir"]) +class SoundEventConditionImportConfig(ImportConfig): + """Use any callable as a sound event condition. + + Set ``name="import"`` and provide a ``target`` pointing to any + callable to use it instead of a built-in option. + """ + + name: Literal["import"] = "import" + + +register_has_tag_condition(sound_event_conditions)(HasTagConfig) +register_has_all_tags_condition(sound_event_conditions)(HasAllTagsConfig) +register_has_any_tag_condition(sound_event_conditions)(HasAnyTagConfig) +register_id_in_list_condition(sound_event_conditions)(IdInListConfig) + + +Operator = Literal["gt", "gte", "lt", "lte", "eq"] + + +class DurationConfig(BaseConfig): + name: Literal["duration"] = "duration" + operator: Operator + seconds: float + + +def _build_comparator( + operator: Operator, value: float +) -> Callable[[float], bool]: + if operator == "gt": + return lambda x: x > value + + if operator == "gte": + return lambda x: x >= value + + if operator == "lt": + return lambda x: x < value + + if operator == "lte": + return lambda x: x <= value + + if operator == "eq": + return lambda x: x == value + + raise ValueError(f"Invalid operator {operator}") + + +class Duration: + def __init__(self, operator: Operator, seconds: float): + self.operator = operator + self.seconds = seconds + self._comparator = _build_comparator(self.operator, self.seconds) + + def __call__( + self, + sound_event_annotation: data.SoundEventAnnotation, + ) -> bool: + geometry = sound_event_annotation.sound_event.geometry + + if geometry is None: + return False + + start_time, _, end_time, _ = compute_bounds(geometry) + duration = end_time - start_time + + return self._comparator(duration) + + @sound_event_conditions.register(DurationConfig) + @staticmethod + def from_config( + config: DurationConfig, + base_dir: data.PathLike | None = None, + ): + _ = base_dir + return Duration(operator=config.operator, seconds=config.seconds) + + +class FrequencyConfig(BaseConfig): + name: Literal["frequency"] = "frequency" + boundary: Literal["low", "high"] + operator: Operator + hertz: float + + +class Frequency: + def __init__( + self, + operator: Operator, + boundary: Literal["low", "high"], + hertz: float, + ): + self.operator = operator + self.hertz = hertz + self.boundary = boundary + self._comparator = _build_comparator(self.operator, self.hertz) + + def __call__( + self, + sound_event_annotation: data.SoundEventAnnotation, + ) -> bool: + geometry = sound_event_annotation.sound_event.geometry + + if geometry is None: + return False + + if isinstance(geometry, (data.TimeInterval, data.TimeStamp)): + return False + + _, low_freq, _, high_freq = compute_bounds(geometry) + + if self.boundary == "low": + return self._comparator(low_freq) + + return self._comparator(high_freq) + + @sound_event_conditions.register(FrequencyConfig) + @staticmethod + def from_config( + config: FrequencyConfig, + base_dir: data.PathLike | None = None, + ): + _ = base_dir + return Frequency( + operator=config.operator, + boundary=config.boundary, + hertz=config.hertz, + ) + + +@register_all_of_condition(sound_event_conditions) +class AllOfConfig(MultiConditionConfigBase): + name: Literal["all_of"] = "all_of" + conditions: Sequence["SoundEventConditionConfig"] + + +@register_any_of_condition(sound_event_conditions) +class AnyOfConfig(MultiConditionConfigBase): + name: Literal["any_of"] = "any_of" + conditions: list["SoundEventConditionConfig"] + + +@register_not_condition(sound_event_conditions) +class NotConfig(NotConditionConfigBase): + name: Literal["not"] = "not" + condition: "SoundEventConditionConfig" + + +SoundEventConditionConfig = Annotated[ + IdInListConfig + | HasTagConfig + | HasAllTagsConfig + | HasAnyTagConfig + | DurationConfig + | FrequencyConfig + | AllOfConfig + | AnyOfConfig + | NotConfig + | SoundEventConditionImportConfig, + Field(discriminator="name"), +] + + +def build_sound_event_condition( + config: SoundEventConditionConfig, + base_dir: data.PathLike | None = None, +) -> SoundEventCondition: + return sound_event_conditions.build(config, base_dir) + + +def filter_clip_annotation( + clip_annotation: data.ClipAnnotation, + condition: SoundEventCondition, +) -> data.ClipAnnotation: + return clip_annotation.model_copy( + update=dict( + sound_events=[ + sound_event + for sound_event in clip_annotation.sound_events + if condition(sound_event) + ] + ) + ) diff --git a/tests/test_data/test_transforms/__init__.py b/tests/test_data/test_conditions/__init__.py similarity index 100% rename from tests/test_data/test_transforms/__init__.py rename to tests/test_data/test_conditions/__init__.py diff --git a/tests/test_data/test_conditions/test_clip.py b/tests/test_data/test_conditions/test_clip.py new file mode 100644 index 0000000..e8c32f3 --- /dev/null +++ b/tests/test_data/test_conditions/test_clip.py @@ -0,0 +1,132 @@ +import json +from pathlib import Path + +from soundevent import data + +from batdetect2.data.conditions import ( + ClipAllOfConfig, + ClipAnyOfConfig, + ClipNotConfig, + HasAllTagsConfig, + HasAnyTagConfig, + HasTagConfig, + IdInListConfig, + RecordingSatisfiesConfig, + build_clip_annotation_condition, +) + + +def test_recording_satisfies_condition( + tmp_path: Path, + create_recording, + create_clip, + create_clip_annotation, +) -> None: + recording_a = create_recording(path=tmp_path / "a.wav") + recording_b = create_recording(path=tmp_path / "b.wav") + clip_a = create_clip(recording_a) + clip_b = create_clip(recording_b) + clip_annotation_a = create_clip_annotation(clip_a) + clip_annotation_b = create_clip_annotation(clip_b) + ids_path = tmp_path / "recording_ids.json" + ids_path.write_text(json.dumps([str(recording_a.uuid)])) + + condition = build_clip_annotation_condition( + RecordingSatisfiesConfig( + condition=IdInListConfig(path=ids_path), + ) + ) + + assert condition(clip_annotation_a) + assert not condition(clip_annotation_b) + + +def test_clip_id_in_list_condition( + tmp_path: Path, + create_recording, + create_clip, + create_clip_annotation, +) -> None: + recording_a = create_recording(path=tmp_path / "a.wav") + recording_b = create_recording(path=tmp_path / "b.wav") + clip_annotation_a = create_clip_annotation(create_clip(recording_a)) + clip_annotation_b = create_clip_annotation(create_clip(recording_b)) + ids_path = tmp_path / "clip_annotation_ids.json" + ids_path.write_text(json.dumps([str(clip_annotation_a.uuid)])) + + condition = build_clip_annotation_condition(IdInListConfig(path=ids_path)) + + assert condition(clip_annotation_a) + assert not condition(clip_annotation_b) + + +def test_clip_has_tag_conditions( + tmp_path: Path, + create_recording, + create_clip, + create_clip_annotation, +) -> None: + reviewed = data.Tag(key="status", value="reviewed") + train = data.Tag(key="split", value="train") + val = data.Tag(key="split", value="val") + + recording = create_recording(path=tmp_path / "rec.wav") + clip = create_clip(recording) + clip_annotation = create_clip_annotation( + clip, + clip_tags=[reviewed, train], + ) + + has_tag = build_clip_annotation_condition(HasTagConfig(tag=reviewed)) + has_all = build_clip_annotation_condition( + HasAllTagsConfig(tags=[reviewed, train]) + ) + has_any = build_clip_annotation_condition( + HasAnyTagConfig(tags=[val, train]) + ) + + assert has_tag(clip_annotation) + assert has_all(clip_annotation) + assert has_any(clip_annotation) + + +def test_clip_logical_conditions( + tmp_path: Path, + create_recording, + create_clip, + create_clip_annotation, +) -> None: + reviewed = data.Tag(key="status", value="reviewed") + train = data.Tag(key="split", value="train") + val = data.Tag(key="split", value="val") + + recording = create_recording(path=tmp_path / "rec.wav") + clip = create_clip(recording) + clip_annotation = create_clip_annotation( + clip, + clip_tags=[reviewed, train], + ) + + all_condition = build_clip_annotation_condition( + ClipAllOfConfig( + conditions=[ + HasTagConfig(tag=reviewed), + HasAnyTagConfig(tags=[train, val]), + ] + ) + ) + any_condition = build_clip_annotation_condition( + ClipAnyOfConfig( + conditions=[ + HasTagConfig(tag=val), + HasTagConfig(tag=reviewed), + ] + ) + ) + not_condition = build_clip_annotation_condition( + ClipNotConfig(condition=HasTagConfig(tag=val)) + ) + + assert all_condition(clip_annotation) + assert any_condition(clip_annotation) + assert not_condition(clip_annotation) diff --git a/tests/test_data/test_conditions/test_recording.py b/tests/test_data/test_conditions/test_recording.py new file mode 100644 index 0000000..526106e --- /dev/null +++ b/tests/test_data/test_conditions/test_recording.py @@ -0,0 +1,139 @@ +import json +from pathlib import Path + +import pytest +from soundevent import data + +from batdetect2.data.conditions import ( + HasAllTagsConfig, + HasAnyTagConfig, + HasTagConfig, + IdInListConfig, + RecordingAllOfConfig, + RecordingAnyOfConfig, + RecordingNotConfig, + build_recording_condition, +) + + +def test_id_in_list_condition(tmp_path: Path, create_recording) -> None: + recording_a = create_recording(path=tmp_path / "a.wav") + recording_b = create_recording(path=tmp_path / "b.wav") + ids_path = tmp_path / "recording_ids.json" + ids_path.write_text(json.dumps([str(recording_a.uuid)])) + + condition = build_recording_condition(IdInListConfig(path=ids_path)) + + assert condition(recording_a) + assert not condition(recording_b) + + +def test_id_in_list_condition_uses_base_dir( + tmp_path: Path, + create_recording, +) -> None: + recording = create_recording(path=tmp_path / "a.wav") + split_dir = tmp_path / "splits" + split_dir.mkdir() + ids_path = split_dir / "train_ids.json" + ids_path.write_text(json.dumps([str(recording.uuid)])) + + condition = build_recording_condition( + IdInListConfig(path=Path("splits/train_ids.json")), + base_dir=tmp_path, + ) + + assert condition(recording) + + +def test_id_in_list_condition_raises_for_non_list_json( + tmp_path: Path, +) -> None: + ids_path = tmp_path / "recording_ids.json" + ids_path.write_text(json.dumps({"id": "foo"})) + + with pytest.raises(TypeError, match="Expected JSON list"): + build_recording_condition(IdInListConfig(path=ids_path)) + + +def test_id_in_list_condition_raises_for_invalid_id(tmp_path: Path) -> None: + ids_path = tmp_path / "recording_ids.json" + ids_path.write_text(json.dumps(["not-a-uuid"])) + + with pytest.raises(ValueError, match="Invalid ID"): + build_recording_condition(IdInListConfig(path=ids_path)) + + +def test_id_in_list_condition_supports_txt_format( + tmp_path: Path, + create_recording, +) -> None: + recording_a = create_recording(path=tmp_path / "a.wav") + recording_b = create_recording(path=tmp_path / "b.wav") + ids_path = tmp_path / "recording_ids.txt" + ids_path.write_text(f"{recording_a.uuid}\n") + + condition = build_recording_condition( + IdInListConfig(path=ids_path, list_format="txt") + ) + + assert condition(recording_a) + assert not condition(recording_b) + + +def test_recording_has_tag_conditions( + tmp_path: Path, create_recording +) -> None: + train = data.Tag(key="split", value="train") + uk = data.Tag(key="region", value="uk") + eu = data.Tag(key="region", value="eu") + + recording = create_recording( + path=tmp_path / "rec.wav", + tags=[train, uk], + ) + + has_train = build_recording_condition(HasTagConfig(tag=train)) + has_all = build_recording_condition(HasAllTagsConfig(tags=[train, uk])) + has_any = build_recording_condition(HasAnyTagConfig(tags=[eu, uk])) + + assert has_train(recording) + assert has_all(recording) + assert has_any(recording) + + +def test_recording_logical_conditions( + tmp_path: Path, create_recording +) -> None: + train = data.Tag(key="split", value="train") + uk = data.Tag(key="region", value="uk") + eu = data.Tag(key="region", value="eu") + + recording = create_recording( + path=tmp_path / "rec.wav", + tags=[train, uk], + ) + + all_condition = build_recording_condition( + RecordingAllOfConfig( + conditions=[ + HasTagConfig(tag=train), + HasAnyTagConfig(tags=[eu, uk]), + ] + ) + ) + any_condition = build_recording_condition( + RecordingAnyOfConfig( + conditions=[ + HasTagConfig(tag=eu), + HasTagConfig(tag=train), + ] + ) + ) + not_condition = build_recording_condition( + RecordingNotConfig(condition=HasTagConfig(tag=eu)) + ) + + assert all_condition(recording) + assert any_condition(recording) + assert not_condition(recording) diff --git a/tests/test_data/test_transforms/test_conditions.py b/tests/test_data/test_conditions/test_sound_events.py similarity index 91% rename from tests/test_data/test_transforms/test_conditions.py rename to tests/test_data/test_conditions/test_sound_events.py index bdb4eb9..f7cf7f8 100644 --- a/tests/test_data/test_transforms/test_conditions.py +++ b/tests/test_data/test_conditions/test_sound_events.py @@ -1,4 +1,6 @@ +import json import textwrap +from pathlib import Path import pytest import yaml @@ -6,16 +8,17 @@ from pydantic import TypeAdapter from soundevent import data from batdetect2.data.conditions import ( + IdInListConfig, SoundEventConditionConfig, build_sound_event_condition, ) -def build_condition_from_str(content): +def build_condition_from_str(content, base_dir: Path | None = None): content = textwrap.dedent(content) content = yaml.safe_load(content) config = TypeAdapter(SoundEventConditionConfig).validate_python(content) - return build_sound_event_condition(config) + return build_sound_event_condition(config, base_dir=base_dir) def test_has_tag(sound_event: data.SoundEvent): @@ -160,6 +163,36 @@ def test_not(sound_event: data.SoundEvent): assert not condition(sound_event_annotation) +def test_id_in_list(sound_event: data.SoundEvent, tmp_path: Path): + se1 = data.SoundEventAnnotation(sound_event=sound_event) + se2 = data.SoundEventAnnotation(sound_event=sound_event) + ids_path = tmp_path / "sound_event_ids.json" + ids_path.write_text(json.dumps([str(se1.uuid)])) + + condition = build_sound_event_condition(IdInListConfig(path=ids_path)) + + assert condition(se1) + assert not condition(se2) + + +def test_id_in_list_uses_base_dir( + sound_event: data.SoundEvent, + tmp_path: Path, +) -> None: + se = data.SoundEventAnnotation(sound_event=sound_event) + split_dir = tmp_path / "splits" + split_dir.mkdir() + ids_path = split_dir / "sound_event_ids.json" + ids_path.write_text(json.dumps([str(se.uuid)])) + + condition = build_sound_event_condition( + IdInListConfig(path=Path("splits/sound_event_ids.json")), + base_dir=tmp_path, + ) + + assert condition(se) + + def test_duration(recording: data.Recording): se1 = data.SoundEventAnnotation( sound_event=data.SoundEvent(