mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Add conditions for clips and recordings
This commit is contained in:
parent
e80fe8675d
commit
c8dd4155bf
@ -1,312 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import Annotated, List, Literal, Sequence
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.core.registries import (
|
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
|
||||||
|
|
||||||
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(conditions)
|
|
||||||
class SoundEventConditionImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a sound event condition.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class HasTagConfig(BaseConfig):
|
|
||||||
name: Literal["has_tag"] = "has_tag"
|
|
||||||
tag: data.Tag
|
|
||||||
|
|
||||||
|
|
||||||
class HasTag:
|
|
||||||
def __init__(self, tag: data.Tag):
|
|
||||||
self.tag = tag
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return any(
|
|
||||||
self.tag.term.name == tag.term.name and self.tag.value == tag.value
|
|
||||||
for tag in sound_event_annotation.tags
|
|
||||||
)
|
|
||||||
|
|
||||||
@conditions.register(HasTagConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: HasTagConfig):
|
|
||||||
return HasTag(tag=config.tag)
|
|
||||||
|
|
||||||
|
|
||||||
class HasAllTagsConfig(BaseConfig):
|
|
||||||
name: Literal["has_all_tags"] = "has_all_tags"
|
|
||||||
tags: List[data.Tag]
|
|
||||||
|
|
||||||
|
|
||||||
class HasAllTags:
|
|
||||||
def __init__(self, tags: List[data.Tag]):
|
|
||||||
if not tags:
|
|
||||||
raise ValueError("Need to specify at least one tag")
|
|
||||||
|
|
||||||
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return self.tags.issubset(
|
|
||||||
{(tag.term.name, tag.value) for tag in sound_event_annotation.tags}
|
|
||||||
)
|
|
||||||
|
|
||||||
@conditions.register(HasAllTagsConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: HasAllTagsConfig):
|
|
||||||
return HasAllTags(tags=config.tags)
|
|
||||||
|
|
||||||
|
|
||||||
class HasAnyTagConfig(BaseConfig):
|
|
||||||
name: Literal["has_any_tag"] = "has_any_tag"
|
|
||||||
tags: List[data.Tag]
|
|
||||||
|
|
||||||
|
|
||||||
class HasAnyTag:
|
|
||||||
def __init__(self, tags: List[data.Tag]):
|
|
||||||
if not tags:
|
|
||||||
raise ValueError("Need to specify at least one tag")
|
|
||||||
|
|
||||||
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return bool(
|
|
||||||
self.tags.intersection(
|
|
||||||
{
|
|
||||||
(tag.term.name, tag.value)
|
|
||||||
for tag in sound_event_annotation.tags
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@conditions.register(HasAnyTagConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: HasAnyTagConfig):
|
|
||||||
return HasAnyTag(tags=config.tags)
|
|
||||||
|
|
||||||
|
|
||||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
|
||||||
|
|
||||||
|
|
||||||
class DurationConfig(BaseConfig):
|
|
||||||
name: Literal["duration"] = "duration"
|
|
||||||
operator: Operator
|
|
||||||
seconds: float
|
|
||||||
|
|
||||||
|
|
||||||
def _build_comparator(
|
|
||||||
operator: Operator, value: float
|
|
||||||
) -> Callable[[float], bool]:
|
|
||||||
if operator == "gt":
|
|
||||||
return lambda x: x > value
|
|
||||||
|
|
||||||
if operator == "gte":
|
|
||||||
return lambda x: x >= value
|
|
||||||
|
|
||||||
if operator == "lt":
|
|
||||||
return lambda x: x < value
|
|
||||||
|
|
||||||
if operator == "lte":
|
|
||||||
return lambda x: x <= value
|
|
||||||
|
|
||||||
if operator == "eq":
|
|
||||||
return lambda x: x == value
|
|
||||||
|
|
||||||
raise ValueError(f"Invalid operator {operator}")
|
|
||||||
|
|
||||||
|
|
||||||
class Duration:
|
|
||||||
def __init__(self, operator: Operator, seconds: float):
|
|
||||||
self.operator = operator
|
|
||||||
self.seconds = seconds
|
|
||||||
self._comparator = _build_comparator(self.operator, self.seconds)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> bool:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
|
||||||
duration = end_time - start_time
|
|
||||||
|
|
||||||
return self._comparator(duration)
|
|
||||||
|
|
||||||
@conditions.register(DurationConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: DurationConfig):
|
|
||||||
return Duration(operator=config.operator, seconds=config.seconds)
|
|
||||||
|
|
||||||
|
|
||||||
class FrequencyConfig(BaseConfig):
|
|
||||||
name: Literal["frequency"] = "frequency"
|
|
||||||
boundary: Literal["low", "high"]
|
|
||||||
operator: Operator
|
|
||||||
hertz: float
|
|
||||||
|
|
||||||
|
|
||||||
class Frequency:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
operator: Operator,
|
|
||||||
boundary: Literal["low", "high"],
|
|
||||||
hertz: float,
|
|
||||||
):
|
|
||||||
self.operator = operator
|
|
||||||
self.hertz = hertz
|
|
||||||
self.boundary = boundary
|
|
||||||
self._comparator = _build_comparator(self.operator, self.hertz)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> bool:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Automatically false if geometry does not have a frequency range
|
|
||||||
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
_, low_freq, _, high_freq = compute_bounds(geometry)
|
|
||||||
|
|
||||||
if self.boundary == "low":
|
|
||||||
return self._comparator(low_freq)
|
|
||||||
|
|
||||||
return self._comparator(high_freq)
|
|
||||||
|
|
||||||
@conditions.register(FrequencyConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: FrequencyConfig):
|
|
||||||
return Frequency(
|
|
||||||
operator=config.operator,
|
|
||||||
boundary=config.boundary,
|
|
||||||
hertz=config.hertz,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AllOfConfig(BaseConfig):
|
|
||||||
name: Literal["all_of"] = "all_of"
|
|
||||||
conditions: Sequence["SoundEventConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
class AllOf:
|
|
||||||
def __init__(self, conditions: List[SoundEventCondition]):
|
|
||||||
self.conditions = conditions
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return all(c(sound_event_annotation) for c in self.conditions)
|
|
||||||
|
|
||||||
@conditions.register(AllOfConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: AllOfConfig):
|
|
||||||
conditions = [
|
|
||||||
build_sound_event_condition(cond) for cond in config.conditions
|
|
||||||
]
|
|
||||||
return AllOf(conditions)
|
|
||||||
|
|
||||||
|
|
||||||
class AnyOfConfig(BaseConfig):
|
|
||||||
name: Literal["any_of"] = "any_of"
|
|
||||||
conditions: List["SoundEventConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
class AnyOf:
|
|
||||||
def __init__(self, conditions: List[SoundEventCondition]):
|
|
||||||
self.conditions = conditions
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return any(c(sound_event_annotation) for c in self.conditions)
|
|
||||||
|
|
||||||
@conditions.register(AnyOfConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: AnyOfConfig):
|
|
||||||
conditions = [
|
|
||||||
build_sound_event_condition(cond) for cond in config.conditions
|
|
||||||
]
|
|
||||||
return AnyOf(conditions)
|
|
||||||
|
|
||||||
|
|
||||||
class NotConfig(BaseConfig):
|
|
||||||
name: Literal["not"] = "not"
|
|
||||||
condition: "SoundEventConditionConfig"
|
|
||||||
|
|
||||||
|
|
||||||
class Not:
|
|
||||||
def __init__(self, condition: SoundEventCondition):
|
|
||||||
self.condition = condition
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return not self.condition(sound_event_annotation)
|
|
||||||
|
|
||||||
@conditions.register(NotConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: NotConfig):
|
|
||||||
condition = build_sound_event_condition(config.condition)
|
|
||||||
return Not(condition)
|
|
||||||
|
|
||||||
|
|
||||||
SoundEventConditionConfig = Annotated[
|
|
||||||
HasTagConfig
|
|
||||||
| HasAllTagsConfig
|
|
||||||
| HasAnyTagConfig
|
|
||||||
| DurationConfig
|
|
||||||
| FrequencyConfig
|
|
||||||
| AllOfConfig
|
|
||||||
| AnyOfConfig
|
|
||||||
| NotConfig,
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_condition(
|
|
||||||
config: SoundEventConditionConfig,
|
|
||||||
) -> SoundEventCondition:
|
|
||||||
return conditions.build(config)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_clip_annotation(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
condition: SoundEventCondition,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
return clip_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
sound_events=[
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if condition(sound_event)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
71
src/batdetect2/data/conditions/__init__.py
Normal file
71
src/batdetect2/data/conditions/__init__.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from batdetect2.data.conditions.clips import (
|
||||||
|
ClipAllOfConfig,
|
||||||
|
ClipAnnotationCondition,
|
||||||
|
ClipAnnotationConditionConfig,
|
||||||
|
ClipAnnotationConditionImportConfig,
|
||||||
|
ClipAnyOfConfig,
|
||||||
|
ClipNotConfig,
|
||||||
|
RecordingSatisfiesConfig,
|
||||||
|
build_clip_annotation_condition,
|
||||||
|
)
|
||||||
|
from batdetect2.data.conditions.common import (
|
||||||
|
HasAllTagsConfig,
|
||||||
|
HasAnyTagConfig,
|
||||||
|
HasTagConfig,
|
||||||
|
IdInListConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.data.conditions.recordings import (
|
||||||
|
RecordingAllOfConfig,
|
||||||
|
RecordingAnyOfConfig,
|
||||||
|
RecordingCondition,
|
||||||
|
RecordingConditionConfig,
|
||||||
|
RecordingConditionImportConfig,
|
||||||
|
RecordingNotConfig,
|
||||||
|
build_recording_condition,
|
||||||
|
)
|
||||||
|
from batdetect2.data.conditions.sound_events import (
|
||||||
|
AllOfConfig,
|
||||||
|
AnyOfConfig,
|
||||||
|
DurationConfig,
|
||||||
|
FrequencyConfig,
|
||||||
|
NotConfig,
|
||||||
|
Operator,
|
||||||
|
SoundEventCondition,
|
||||||
|
SoundEventConditionConfig,
|
||||||
|
SoundEventConditionImportConfig,
|
||||||
|
build_sound_event_condition,
|
||||||
|
filter_clip_annotation,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AllOfConfig",
|
||||||
|
"AnyOfConfig",
|
||||||
|
"ClipAllOfConfig",
|
||||||
|
"ClipAnnotationCondition",
|
||||||
|
"ClipAnnotationConditionConfig",
|
||||||
|
"ClipAnnotationConditionImportConfig",
|
||||||
|
"ClipAnyOfConfig",
|
||||||
|
"ClipNotConfig",
|
||||||
|
"DurationConfig",
|
||||||
|
"FrequencyConfig",
|
||||||
|
"HasAllTagsConfig",
|
||||||
|
"HasAnyTagConfig",
|
||||||
|
"HasTagConfig",
|
||||||
|
"IdInListConfig",
|
||||||
|
"NotConfig",
|
||||||
|
"Operator",
|
||||||
|
"RecordingCondition",
|
||||||
|
"RecordingConditionConfig",
|
||||||
|
"RecordingConditionImportConfig",
|
||||||
|
"RecordingAllOfConfig",
|
||||||
|
"RecordingAnyOfConfig",
|
||||||
|
"RecordingNotConfig",
|
||||||
|
"RecordingSatisfiesConfig",
|
||||||
|
"SoundEventCondition",
|
||||||
|
"SoundEventConditionConfig",
|
||||||
|
"SoundEventConditionImportConfig",
|
||||||
|
"build_clip_annotation_condition",
|
||||||
|
"build_recording_condition",
|
||||||
|
"build_sound_event_condition",
|
||||||
|
"filter_clip_annotation",
|
||||||
|
]
|
||||||
132
src/batdetect2/data/conditions/clips.py
Normal file
132
src/batdetect2/data/conditions/clips.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.core.registries import (
|
||||||
|
ImportConfig,
|
||||||
|
Registry,
|
||||||
|
add_import_config,
|
||||||
|
)
|
||||||
|
from batdetect2.data.conditions.common import (
|
||||||
|
HasAllTagsConfig,
|
||||||
|
HasAnyTagConfig,
|
||||||
|
HasTagConfig,
|
||||||
|
IdInListConfig,
|
||||||
|
MultiConditionConfigBase,
|
||||||
|
NotConditionConfigBase,
|
||||||
|
register_all_of_condition,
|
||||||
|
register_any_of_condition,
|
||||||
|
register_has_all_tags_condition,
|
||||||
|
register_has_any_tag_condition,
|
||||||
|
register_has_tag_condition,
|
||||||
|
register_id_in_list_condition,
|
||||||
|
register_not_condition,
|
||||||
|
)
|
||||||
|
from batdetect2.data.conditions.recordings import (
|
||||||
|
RecordingCondition,
|
||||||
|
RecordingConditionConfig,
|
||||||
|
build_recording_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClipAllOfConfig",
|
||||||
|
"ClipAnnotationCondition",
|
||||||
|
"ClipAnnotationConditionConfig",
|
||||||
|
"ClipAnnotationConditionImportConfig",
|
||||||
|
"ClipAnyOfConfig",
|
||||||
|
"ClipNotConfig",
|
||||||
|
"RecordingSatisfiesConfig",
|
||||||
|
"build_clip_annotation_condition",
|
||||||
|
]
|
||||||
|
|
||||||
|
ClipAnnotationCondition = Callable[[data.ClipAnnotation], bool]
|
||||||
|
|
||||||
|
clip_annotation_conditions: Registry[
|
||||||
|
ClipAnnotationCondition,
|
||||||
|
[data.PathLike | None],
|
||||||
|
] = Registry("clip_condition")
|
||||||
|
|
||||||
|
|
||||||
|
@add_import_config(clip_annotation_conditions, arg_names=["base_dir"])
|
||||||
|
class ClipAnnotationConditionImportConfig(ImportConfig):
|
||||||
|
"""Use any callable as a clip annotation condition.
|
||||||
|
|
||||||
|
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
||||||
|
to use it instead of a built-in option.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingSatisfiesConfig(BaseConfig):
|
||||||
|
name: Literal["recording_satisfies"] = "recording_satisfies"
|
||||||
|
condition: RecordingConditionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingSatisfies:
|
||||||
|
def __init__(self, condition: RecordingCondition):
|
||||||
|
self.condition = condition
|
||||||
|
|
||||||
|
def __call__(self, clip_annotation: data.ClipAnnotation) -> bool:
|
||||||
|
recording = clip_annotation.clip.recording
|
||||||
|
return self.condition(recording)
|
||||||
|
|
||||||
|
@clip_annotation_conditions.register(RecordingSatisfiesConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: RecordingSatisfiesConfig,
|
||||||
|
base_dir: data.PathLike | None = None,
|
||||||
|
) -> "RecordingSatisfies":
|
||||||
|
condition = build_recording_condition(
|
||||||
|
config.condition,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
return RecordingSatisfies(condition)
|
||||||
|
|
||||||
|
|
||||||
|
register_has_tag_condition(clip_annotation_conditions)(HasTagConfig)
|
||||||
|
register_has_all_tags_condition(clip_annotation_conditions)(HasAllTagsConfig)
|
||||||
|
register_has_any_tag_condition(clip_annotation_conditions)(HasAnyTagConfig)
|
||||||
|
register_id_in_list_condition(clip_annotation_conditions)(IdInListConfig)
|
||||||
|
|
||||||
|
|
||||||
|
@register_all_of_condition(clip_annotation_conditions)
|
||||||
|
class ClipAllOfConfig(MultiConditionConfigBase):
|
||||||
|
name: Literal["all_of"] = "all_of"
|
||||||
|
conditions: Sequence["ClipAnnotationConditionConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
@register_any_of_condition(clip_annotation_conditions)
|
||||||
|
class ClipAnyOfConfig(MultiConditionConfigBase):
|
||||||
|
name: Literal["any_of"] = "any_of"
|
||||||
|
conditions: Sequence["ClipAnnotationConditionConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
@register_not_condition(clip_annotation_conditions)
|
||||||
|
class ClipNotConfig(NotConditionConfigBase):
|
||||||
|
name: Literal["not"] = "not"
|
||||||
|
condition: "ClipAnnotationConditionConfig"
|
||||||
|
|
||||||
|
|
||||||
|
ClipAnnotationConditionConfig = Annotated[
|
||||||
|
RecordingSatisfiesConfig
|
||||||
|
| IdInListConfig
|
||||||
|
| HasTagConfig
|
||||||
|
| HasAllTagsConfig
|
||||||
|
| HasAnyTagConfig
|
||||||
|
| ClipAllOfConfig
|
||||||
|
| ClipAnyOfConfig
|
||||||
|
| ClipNotConfig
|
||||||
|
| ClipAnnotationConditionImportConfig,
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_clip_annotation_condition(
|
||||||
|
config: ClipAnnotationConditionConfig,
|
||||||
|
base_dir: data.PathLike | None = None,
|
||||||
|
) -> ClipAnnotationCondition:
|
||||||
|
return clip_annotation_conditions.build(config, base_dir)
|
||||||
314
src/batdetect2/data/conditions/common.py
Normal file
314
src/batdetect2/data/conditions/common.py
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generic, Literal, ParamSpec, Protocol, TypeVar
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AllOf",
|
||||||
|
"AnyOf",
|
||||||
|
"Condition",
|
||||||
|
"HasAllTags",
|
||||||
|
"HasAllTagsConfig",
|
||||||
|
"HasAnyTag",
|
||||||
|
"HasAnyTagConfig",
|
||||||
|
"HasTag",
|
||||||
|
"HasTagConfig",
|
||||||
|
"IdInList",
|
||||||
|
"IdInListConfig",
|
||||||
|
"MultiConditionConfigBase",
|
||||||
|
"Not",
|
||||||
|
"NotConditionConfigBase",
|
||||||
|
"ObjectWithTags",
|
||||||
|
"ObjectWithUUID",
|
||||||
|
"register_all_of_condition",
|
||||||
|
"register_any_of_condition",
|
||||||
|
"register_has_all_tags_condition",
|
||||||
|
"register_has_any_tag_condition",
|
||||||
|
"register_has_tag_condition",
|
||||||
|
"register_id_in_list_condition",
|
||||||
|
"register_not_condition",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectWithTags(Protocol):
|
||||||
|
tags: list[data.Tag]
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectWithUUID(Protocol):
|
||||||
|
uuid: UUID
|
||||||
|
|
||||||
|
|
||||||
|
ConditionObject = TypeVar("ConditionObject")
|
||||||
|
TaggedObject = TypeVar("TaggedObject", bound="ObjectWithTags")
|
||||||
|
UUIDObject = TypeVar("UUIDObject", bound="ObjectWithUUID")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
NotConfigType = TypeVar("NotConfigType", bound="NotConditionConfigBase")
|
||||||
|
MultiConfigType = TypeVar(
|
||||||
|
"MultiConfigType",
|
||||||
|
bound="MultiConditionConfigBase",
|
||||||
|
)
|
||||||
|
Condition = Callable[[ConditionObject], bool]
|
||||||
|
|
||||||
|
|
||||||
|
class NotConditionConfigBase(BaseConfig):
|
||||||
|
condition: BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class MultiConditionConfigBase(BaseConfig):
|
||||||
|
conditions: Sequence[BaseModel]
|
||||||
|
|
||||||
|
|
||||||
|
class Not(Generic[ConditionObject]):
|
||||||
|
def __init__(self, condition: Condition[ConditionObject]):
|
||||||
|
self.condition = condition
|
||||||
|
|
||||||
|
def __call__(self, obj: ConditionObject) -> bool:
|
||||||
|
return not self.condition(obj)
|
||||||
|
|
||||||
|
|
||||||
|
class AllOf(Generic[ConditionObject]):
|
||||||
|
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
|
||||||
|
self.conditions = list(conditions)
|
||||||
|
|
||||||
|
def __call__(self, obj: ConditionObject) -> bool:
|
||||||
|
return all(condition(obj) for condition in self.conditions)
|
||||||
|
|
||||||
|
|
||||||
|
class AnyOf(Generic[ConditionObject]):
|
||||||
|
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
|
||||||
|
self.conditions = list(conditions)
|
||||||
|
|
||||||
|
def __call__(self, obj: ConditionObject) -> bool:
|
||||||
|
return any(condition(obj) for condition in self.conditions)
|
||||||
|
|
||||||
|
|
||||||
|
class HasTag(Generic[TaggedObject]):
|
||||||
|
def __init__(self, tag: data.Tag):
|
||||||
|
self.tag_key = (tag.term.name, tag.value)
|
||||||
|
|
||||||
|
def __call__(self, obj: TaggedObject) -> bool:
|
||||||
|
return any(
|
||||||
|
(tag.term.name, tag.value) == self.tag_key for tag in obj.tags
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HasAllTags(Generic[TaggedObject]):
|
||||||
|
def __init__(self, tags: list[data.Tag]):
|
||||||
|
if not tags:
|
||||||
|
raise ValueError("Need to specify at least one tag")
|
||||||
|
|
||||||
|
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
|
||||||
|
|
||||||
|
def __call__(self, obj: TaggedObject) -> bool:
|
||||||
|
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
|
||||||
|
return self.required_keys.issubset(tag_keys)
|
||||||
|
|
||||||
|
|
||||||
|
class HasAnyTag(Generic[TaggedObject]):
|
||||||
|
def __init__(self, tags: list[data.Tag]):
|
||||||
|
if not tags:
|
||||||
|
raise ValueError("Need to specify at least one tag")
|
||||||
|
|
||||||
|
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
|
||||||
|
|
||||||
|
def __call__(self, obj: TaggedObject) -> bool:
|
||||||
|
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
|
||||||
|
return bool(self.required_keys.intersection(tag_keys))
|
||||||
|
|
||||||
|
|
||||||
|
class IdInList(Generic[UUIDObject]):
|
||||||
|
def __init__(self, ids: set[UUID]):
|
||||||
|
self.ids = ids
|
||||||
|
|
||||||
|
def __call__(self, obj: UUIDObject) -> bool:
|
||||||
|
return obj.uuid in self.ids
|
||||||
|
|
||||||
|
|
||||||
|
class HasTagConfig(BaseConfig):
|
||||||
|
name: Literal["has_tag"] = "has_tag"
|
||||||
|
tag: data.Tag
|
||||||
|
|
||||||
|
|
||||||
|
class HasAllTagsConfig(BaseConfig):
|
||||||
|
name: Literal["has_all_tags"] = "has_all_tags"
|
||||||
|
tags: list[data.Tag]
|
||||||
|
|
||||||
|
|
||||||
|
class HasAnyTagConfig(BaseConfig):
|
||||||
|
name: Literal["has_any_tag"] = "has_any_tag"
|
||||||
|
tags: list[data.Tag]
|
||||||
|
|
||||||
|
|
||||||
|
class IdInListConfig(BaseConfig):
|
||||||
|
name: Literal["id_in_list"] = "id_in_list"
|
||||||
|
path: Path
|
||||||
|
list_format: Literal["json", "txt"] = "json"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_ids(
|
||||||
|
path: Path,
|
||||||
|
list_format: Literal["json", "txt"],
|
||||||
|
) -> list[str]:
|
||||||
|
if list_format == "json":
|
||||||
|
content = json.loads(path.read_text())
|
||||||
|
|
||||||
|
if not isinstance(content, list):
|
||||||
|
raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
|
||||||
|
|
||||||
|
return [str(value) for value in content]
|
||||||
|
|
||||||
|
return [
|
||||||
|
line.strip() for line in path.read_text().splitlines() if line.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def register_id_in_list_condition(
|
||||||
|
registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
|
||||||
|
) -> Callable[[type[IdInListConfig]], type[IdInListConfig]]:
|
||||||
|
def decorator(config_cls: type[IdInListConfig]) -> type[IdInListConfig]:
|
||||||
|
@registry.register(config_cls)
|
||||||
|
def builder(
|
||||||
|
config: IdInListConfig,
|
||||||
|
base_dir: data.PathLike | None = None,
|
||||||
|
) -> Condition[UUIDObject]:
|
||||||
|
path = config.path
|
||||||
|
|
||||||
|
if base_dir is not None and not path.is_absolute():
|
||||||
|
path = Path(base_dir) / path
|
||||||
|
|
||||||
|
ids = set()
|
||||||
|
for index, value in enumerate(_load_ids(path, config.list_format)):
|
||||||
|
try:
|
||||||
|
ids.add(UUID(value))
|
||||||
|
except ValueError as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid ID at index {index} in '{path}': {value!r}."
|
||||||
|
) from err
|
||||||
|
|
||||||
|
return IdInList(ids)
|
||||||
|
|
||||||
|
return config_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_has_tag_condition(
|
||||||
|
registry: Registry[Condition[TaggedObject], P],
|
||||||
|
) -> Callable[[type[HasTagConfig]], type[HasTagConfig]]:
|
||||||
|
def decorator(config_cls: type[HasTagConfig]) -> type[HasTagConfig]:
|
||||||
|
@registry.register(config_cls)
|
||||||
|
def builder(
|
||||||
|
config: HasTagConfig,
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> Condition[TaggedObject]:
|
||||||
|
return HasTag(config.tag)
|
||||||
|
|
||||||
|
return config_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_has_all_tags_condition(
|
||||||
|
registry: Registry[Condition[TaggedObject], P],
|
||||||
|
) -> Callable[[type[HasAllTagsConfig]], type[HasAllTagsConfig]]:
|
||||||
|
def decorator(
|
||||||
|
config_cls: type[HasAllTagsConfig],
|
||||||
|
) -> type[HasAllTagsConfig]:
|
||||||
|
@registry.register(config_cls)
|
||||||
|
def builder(
|
||||||
|
config: HasAllTagsConfig,
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> Condition[TaggedObject]:
|
||||||
|
return HasAllTags(config.tags)
|
||||||
|
|
||||||
|
return config_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_has_any_tag_condition(
|
||||||
|
registry: Registry[Condition[TaggedObject], P],
|
||||||
|
) -> Callable[[type[HasAnyTagConfig]], type[HasAnyTagConfig]]:
|
||||||
|
def decorator(
|
||||||
|
config_cls: type[HasAnyTagConfig],
|
||||||
|
) -> type[HasAnyTagConfig]:
|
||||||
|
@registry.register(config_cls)
|
||||||
|
def builder(
|
||||||
|
config: HasAnyTagConfig,
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> Condition[TaggedObject]:
|
||||||
|
return HasAnyTag(config.tags)
|
||||||
|
|
||||||
|
return config_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_not_condition(
|
||||||
|
registry: Registry[Condition[ConditionObject], P],
|
||||||
|
) -> Callable[[type[NotConfigType]], type[NotConfigType]]:
|
||||||
|
def decorator(config_cls: type[NotConfigType]) -> type[NotConfigType]:
|
||||||
|
@registry.register(config_cls)
|
||||||
|
def builder(
|
||||||
|
config: NotConfigType,
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> Condition[ConditionObject]:
|
||||||
|
condition = registry.build(config.condition, *args, **kwargs)
|
||||||
|
return Not(condition)
|
||||||
|
|
||||||
|
return config_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_all_of_condition(
|
||||||
|
registry: Registry[Condition[ConditionObject], P],
|
||||||
|
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
|
||||||
|
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
|
||||||
|
@registry.register(config_cls)
|
||||||
|
def builder(
|
||||||
|
config: MultiConfigType,
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> Condition[ConditionObject]:
|
||||||
|
conditions = [
|
||||||
|
registry.build(condition, *args, **kwargs)
|
||||||
|
for condition in config.conditions
|
||||||
|
]
|
||||||
|
return AllOf(conditions)
|
||||||
|
|
||||||
|
return config_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_any_of_condition(
|
||||||
|
registry: Registry[Condition[ConditionObject], P],
|
||||||
|
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
|
||||||
|
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
|
||||||
|
@registry.register(config_cls)
|
||||||
|
def builder(
|
||||||
|
config: MultiConfigType,
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> Condition[ConditionObject]:
|
||||||
|
conditions = [
|
||||||
|
registry.build(condition, *args, **kwargs)
|
||||||
|
for condition in config.conditions
|
||||||
|
]
|
||||||
|
return AnyOf(conditions)
|
||||||
|
|
||||||
|
return config_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
98
src/batdetect2/data/conditions/recordings.py
Normal file
98
src/batdetect2/data/conditions/recordings.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core.registries import (
|
||||||
|
ImportConfig,
|
||||||
|
Registry,
|
||||||
|
add_import_config,
|
||||||
|
)
|
||||||
|
from batdetect2.data.conditions.common import (
|
||||||
|
HasAllTagsConfig,
|
||||||
|
HasAnyTagConfig,
|
||||||
|
HasTagConfig,
|
||||||
|
IdInListConfig,
|
||||||
|
MultiConditionConfigBase,
|
||||||
|
NotConditionConfigBase,
|
||||||
|
register_all_of_condition,
|
||||||
|
register_any_of_condition,
|
||||||
|
register_has_all_tags_condition,
|
||||||
|
register_has_any_tag_condition,
|
||||||
|
register_has_tag_condition,
|
||||||
|
register_id_in_list_condition,
|
||||||
|
register_not_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"IdInListConfig",
|
||||||
|
"RecordingAllOfConfig",
|
||||||
|
"RecordingAnyOfConfig",
|
||||||
|
"RecordingCondition",
|
||||||
|
"RecordingConditionConfig",
|
||||||
|
"RecordingConditionImportConfig",
|
||||||
|
"RecordingNotConfig",
|
||||||
|
"build_recording_condition",
|
||||||
|
]
|
||||||
|
|
||||||
|
RecordingCondition = Callable[[data.Recording], bool]
|
||||||
|
|
||||||
|
recording_conditions: Registry[RecordingCondition, [data.PathLike | None]] = (
|
||||||
|
Registry("recording_condition")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_import_config(recording_conditions, arg_names=["base_dir"])
|
||||||
|
class RecordingConditionImportConfig(ImportConfig):
|
||||||
|
"""Use any callable as a recording condition.
|
||||||
|
|
||||||
|
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
||||||
|
to use it instead of a built-in option.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
|
register_id_in_list_condition(recording_conditions)(IdInListConfig)
|
||||||
|
register_has_tag_condition(recording_conditions)(HasTagConfig)
|
||||||
|
register_has_all_tags_condition(recording_conditions)(HasAllTagsConfig)
|
||||||
|
register_has_any_tag_condition(recording_conditions)(HasAnyTagConfig)
|
||||||
|
|
||||||
|
|
||||||
|
@register_all_of_condition(recording_conditions)
|
||||||
|
class RecordingAllOfConfig(MultiConditionConfigBase):
|
||||||
|
name: Literal["all_of"] = "all_of"
|
||||||
|
conditions: Sequence["RecordingConditionConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
@register_any_of_condition(recording_conditions)
|
||||||
|
class RecordingAnyOfConfig(MultiConditionConfigBase):
|
||||||
|
name: Literal["any_of"] = "any_of"
|
||||||
|
conditions: Sequence["RecordingConditionConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
@register_not_condition(recording_conditions)
|
||||||
|
class RecordingNotConfig(NotConditionConfigBase):
|
||||||
|
name: Literal["not"] = "not"
|
||||||
|
condition: "RecordingConditionConfig"
|
||||||
|
|
||||||
|
|
||||||
|
RecordingConditionConfig = Annotated[
|
||||||
|
IdInListConfig
|
||||||
|
| HasTagConfig
|
||||||
|
| HasAllTagsConfig
|
||||||
|
| HasAnyTagConfig
|
||||||
|
| RecordingAllOfConfig
|
||||||
|
| RecordingAnyOfConfig
|
||||||
|
| RecordingNotConfig
|
||||||
|
| RecordingConditionImportConfig,
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_recording_condition(
|
||||||
|
config: RecordingConditionConfig,
|
||||||
|
base_dir: data.PathLike | None = None,
|
||||||
|
) -> RecordingCondition:
|
||||||
|
return recording_conditions.build(config, base_dir)
|
||||||
236
src/batdetect2/data/conditions/sound_events.py
Normal file
236
src/batdetect2/data/conditions/sound_events.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.core.registries import (
|
||||||
|
ImportConfig,
|
||||||
|
Registry,
|
||||||
|
add_import_config,
|
||||||
|
)
|
||||||
|
from batdetect2.data.conditions.common import (
|
||||||
|
HasAllTagsConfig,
|
||||||
|
HasAnyTagConfig,
|
||||||
|
HasTagConfig,
|
||||||
|
IdInListConfig,
|
||||||
|
MultiConditionConfigBase,
|
||||||
|
NotConditionConfigBase,
|
||||||
|
register_all_of_condition,
|
||||||
|
register_any_of_condition,
|
||||||
|
register_has_all_tags_condition,
|
||||||
|
register_has_any_tag_condition,
|
||||||
|
register_has_tag_condition,
|
||||||
|
register_id_in_list_condition,
|
||||||
|
register_not_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AllOfConfig",
|
||||||
|
"AnyOfConfig",
|
||||||
|
"DurationConfig",
|
||||||
|
"FrequencyConfig",
|
||||||
|
"HasAllTagsConfig",
|
||||||
|
"HasAnyTagConfig",
|
||||||
|
"HasTagConfig",
|
||||||
|
"NotConfig",
|
||||||
|
"Operator",
|
||||||
|
"SoundEventCondition",
|
||||||
|
"SoundEventConditionConfig",
|
||||||
|
"SoundEventConditionImportConfig",
|
||||||
|
"build_sound_event_condition",
|
||||||
|
"filter_clip_annotation",
|
||||||
|
]
|
||||||
|
|
||||||
|
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||||
|
|
||||||
|
sound_event_conditions: Registry[
|
||||||
|
SoundEventCondition,
|
||||||
|
[data.PathLike | None],
|
||||||
|
] = Registry("sound_event_condition")
|
||||||
|
|
||||||
|
|
||||||
|
@add_import_config(sound_event_conditions, arg_names=["base_dir"])
|
||||||
|
class SoundEventConditionImportConfig(ImportConfig):
|
||||||
|
"""Use any callable as a sound event condition.
|
||||||
|
|
||||||
|
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||||
|
callable to use it instead of a built-in option.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
|
register_has_tag_condition(sound_event_conditions)(HasTagConfig)
|
||||||
|
register_has_all_tags_condition(sound_event_conditions)(HasAllTagsConfig)
|
||||||
|
register_has_any_tag_condition(sound_event_conditions)(HasAnyTagConfig)
|
||||||
|
register_id_in_list_condition(sound_event_conditions)(IdInListConfig)
|
||||||
|
|
||||||
|
|
||||||
|
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||||
|
|
||||||
|
|
||||||
|
class DurationConfig(BaseConfig):
|
||||||
|
name: Literal["duration"] = "duration"
|
||||||
|
operator: Operator
|
||||||
|
seconds: float
|
||||||
|
|
||||||
|
|
||||||
|
def _build_comparator(
|
||||||
|
operator: Operator, value: float
|
||||||
|
) -> Callable[[float], bool]:
|
||||||
|
if operator == "gt":
|
||||||
|
return lambda x: x > value
|
||||||
|
|
||||||
|
if operator == "gte":
|
||||||
|
return lambda x: x >= value
|
||||||
|
|
||||||
|
if operator == "lt":
|
||||||
|
return lambda x: x < value
|
||||||
|
|
||||||
|
if operator == "lte":
|
||||||
|
return lambda x: x <= value
|
||||||
|
|
||||||
|
if operator == "eq":
|
||||||
|
return lambda x: x == value
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid operator {operator}")
|
||||||
|
|
||||||
|
|
||||||
|
class Duration:
|
||||||
|
def __init__(self, operator: Operator, seconds: float):
|
||||||
|
self.operator = operator
|
||||||
|
self.seconds = seconds
|
||||||
|
self._comparator = _build_comparator(self.operator, self.seconds)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> bool:
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||||
|
duration = end_time - start_time
|
||||||
|
|
||||||
|
return self._comparator(duration)
|
||||||
|
|
||||||
|
@sound_event_conditions.register(DurationConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: DurationConfig,
|
||||||
|
base_dir: data.PathLike | None = None,
|
||||||
|
):
|
||||||
|
_ = base_dir
|
||||||
|
return Duration(operator=config.operator, seconds=config.seconds)
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyConfig(BaseConfig):
|
||||||
|
name: Literal["frequency"] = "frequency"
|
||||||
|
boundary: Literal["low", "high"]
|
||||||
|
operator: Operator
|
||||||
|
hertz: float
|
||||||
|
|
||||||
|
|
||||||
|
class Frequency:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
operator: Operator,
|
||||||
|
boundary: Literal["low", "high"],
|
||||||
|
hertz: float,
|
||||||
|
):
|
||||||
|
self.operator = operator
|
||||||
|
self.hertz = hertz
|
||||||
|
self.boundary = boundary
|
||||||
|
self._comparator = _build_comparator(self.operator, self.hertz)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> bool:
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
_, low_freq, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
|
if self.boundary == "low":
|
||||||
|
return self._comparator(low_freq)
|
||||||
|
|
||||||
|
return self._comparator(high_freq)
|
||||||
|
|
||||||
|
@sound_event_conditions.register(FrequencyConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: FrequencyConfig,
|
||||||
|
base_dir: data.PathLike | None = None,
|
||||||
|
):
|
||||||
|
_ = base_dir
|
||||||
|
return Frequency(
|
||||||
|
operator=config.operator,
|
||||||
|
boundary=config.boundary,
|
||||||
|
hertz=config.hertz,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_all_of_condition(sound_event_conditions)
|
||||||
|
class AllOfConfig(MultiConditionConfigBase):
|
||||||
|
name: Literal["all_of"] = "all_of"
|
||||||
|
conditions: Sequence["SoundEventConditionConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
@register_any_of_condition(sound_event_conditions)
|
||||||
|
class AnyOfConfig(MultiConditionConfigBase):
|
||||||
|
name: Literal["any_of"] = "any_of"
|
||||||
|
conditions: list["SoundEventConditionConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
@register_not_condition(sound_event_conditions)
|
||||||
|
class NotConfig(NotConditionConfigBase):
|
||||||
|
name: Literal["not"] = "not"
|
||||||
|
condition: "SoundEventConditionConfig"
|
||||||
|
|
||||||
|
|
||||||
|
SoundEventConditionConfig = Annotated[
|
||||||
|
IdInListConfig
|
||||||
|
| HasTagConfig
|
||||||
|
| HasAllTagsConfig
|
||||||
|
| HasAnyTagConfig
|
||||||
|
| DurationConfig
|
||||||
|
| FrequencyConfig
|
||||||
|
| AllOfConfig
|
||||||
|
| AnyOfConfig
|
||||||
|
| NotConfig
|
||||||
|
| SoundEventConditionImportConfig,
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_sound_event_condition(
|
||||||
|
config: SoundEventConditionConfig,
|
||||||
|
base_dir: data.PathLike | None = None,
|
||||||
|
) -> SoundEventCondition:
|
||||||
|
return sound_event_conditions.build(config, base_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_clip_annotation(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
condition: SoundEventCondition,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
return clip_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_events=[
|
||||||
|
sound_event
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
if condition(sound_event)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
132
tests/test_data/test_conditions/test_clip.py
Normal file
132
tests/test_data/test_conditions/test_clip.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.conditions import (
|
||||||
|
ClipAllOfConfig,
|
||||||
|
ClipAnyOfConfig,
|
||||||
|
ClipNotConfig,
|
||||||
|
HasAllTagsConfig,
|
||||||
|
HasAnyTagConfig,
|
||||||
|
HasTagConfig,
|
||||||
|
IdInListConfig,
|
||||||
|
RecordingSatisfiesConfig,
|
||||||
|
build_clip_annotation_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recording_satisfies_condition(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
create_clip,
|
||||||
|
create_clip_annotation,
|
||||||
|
) -> None:
|
||||||
|
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||||
|
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||||
|
clip_a = create_clip(recording_a)
|
||||||
|
clip_b = create_clip(recording_b)
|
||||||
|
clip_annotation_a = create_clip_annotation(clip_a)
|
||||||
|
clip_annotation_b = create_clip_annotation(clip_b)
|
||||||
|
ids_path = tmp_path / "recording_ids.json"
|
||||||
|
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
|
||||||
|
|
||||||
|
condition = build_clip_annotation_condition(
|
||||||
|
RecordingSatisfiesConfig(
|
||||||
|
condition=IdInListConfig(path=ids_path),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(clip_annotation_a)
|
||||||
|
assert not condition(clip_annotation_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clip_id_in_list_condition(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
create_clip,
|
||||||
|
create_clip_annotation,
|
||||||
|
) -> None:
|
||||||
|
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||||
|
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||||
|
clip_annotation_a = create_clip_annotation(create_clip(recording_a))
|
||||||
|
clip_annotation_b = create_clip_annotation(create_clip(recording_b))
|
||||||
|
ids_path = tmp_path / "clip_annotation_ids.json"
|
||||||
|
ids_path.write_text(json.dumps([str(clip_annotation_a.uuid)]))
|
||||||
|
|
||||||
|
condition = build_clip_annotation_condition(IdInListConfig(path=ids_path))
|
||||||
|
|
||||||
|
assert condition(clip_annotation_a)
|
||||||
|
assert not condition(clip_annotation_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clip_has_tag_conditions(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
create_clip,
|
||||||
|
create_clip_annotation,
|
||||||
|
) -> None:
|
||||||
|
reviewed = data.Tag(key="status", value="reviewed")
|
||||||
|
train = data.Tag(key="split", value="train")
|
||||||
|
val = data.Tag(key="split", value="val")
|
||||||
|
|
||||||
|
recording = create_recording(path=tmp_path / "rec.wav")
|
||||||
|
clip = create_clip(recording)
|
||||||
|
clip_annotation = create_clip_annotation(
|
||||||
|
clip,
|
||||||
|
clip_tags=[reviewed, train],
|
||||||
|
)
|
||||||
|
|
||||||
|
has_tag = build_clip_annotation_condition(HasTagConfig(tag=reviewed))
|
||||||
|
has_all = build_clip_annotation_condition(
|
||||||
|
HasAllTagsConfig(tags=[reviewed, train])
|
||||||
|
)
|
||||||
|
has_any = build_clip_annotation_condition(
|
||||||
|
HasAnyTagConfig(tags=[val, train])
|
||||||
|
)
|
||||||
|
|
||||||
|
assert has_tag(clip_annotation)
|
||||||
|
assert has_all(clip_annotation)
|
||||||
|
assert has_any(clip_annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clip_logical_conditions(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
create_clip,
|
||||||
|
create_clip_annotation,
|
||||||
|
) -> None:
|
||||||
|
reviewed = data.Tag(key="status", value="reviewed")
|
||||||
|
train = data.Tag(key="split", value="train")
|
||||||
|
val = data.Tag(key="split", value="val")
|
||||||
|
|
||||||
|
recording = create_recording(path=tmp_path / "rec.wav")
|
||||||
|
clip = create_clip(recording)
|
||||||
|
clip_annotation = create_clip_annotation(
|
||||||
|
clip,
|
||||||
|
clip_tags=[reviewed, train],
|
||||||
|
)
|
||||||
|
|
||||||
|
all_condition = build_clip_annotation_condition(
|
||||||
|
ClipAllOfConfig(
|
||||||
|
conditions=[
|
||||||
|
HasTagConfig(tag=reviewed),
|
||||||
|
HasAnyTagConfig(tags=[train, val]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
any_condition = build_clip_annotation_condition(
|
||||||
|
ClipAnyOfConfig(
|
||||||
|
conditions=[
|
||||||
|
HasTagConfig(tag=val),
|
||||||
|
HasTagConfig(tag=reviewed),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
not_condition = build_clip_annotation_condition(
|
||||||
|
ClipNotConfig(condition=HasTagConfig(tag=val))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all_condition(clip_annotation)
|
||||||
|
assert any_condition(clip_annotation)
|
||||||
|
assert not_condition(clip_annotation)
|
||||||
139
tests/test_data/test_conditions/test_recording.py
Normal file
139
tests/test_data/test_conditions/test_recording.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.conditions import (
|
||||||
|
HasAllTagsConfig,
|
||||||
|
HasAnyTagConfig,
|
||||||
|
HasTagConfig,
|
||||||
|
IdInListConfig,
|
||||||
|
RecordingAllOfConfig,
|
||||||
|
RecordingAnyOfConfig,
|
||||||
|
RecordingNotConfig,
|
||||||
|
build_recording_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_id_in_list_condition(tmp_path: Path, create_recording) -> None:
|
||||||
|
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||||
|
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||||
|
ids_path = tmp_path / "recording_ids.json"
|
||||||
|
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
|
||||||
|
|
||||||
|
condition = build_recording_condition(IdInListConfig(path=ids_path))
|
||||||
|
|
||||||
|
assert condition(recording_a)
|
||||||
|
assert not condition(recording_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_id_in_list_condition_uses_base_dir(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
recording = create_recording(path=tmp_path / "a.wav")
|
||||||
|
split_dir = tmp_path / "splits"
|
||||||
|
split_dir.mkdir()
|
||||||
|
ids_path = split_dir / "train_ids.json"
|
||||||
|
ids_path.write_text(json.dumps([str(recording.uuid)]))
|
||||||
|
|
||||||
|
condition = build_recording_condition(
|
||||||
|
IdInListConfig(path=Path("splits/train_ids.json")),
|
||||||
|
base_dir=tmp_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording)
|
||||||
|
|
||||||
|
|
||||||
|
def test_id_in_list_condition_raises_for_non_list_json(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
ids_path = tmp_path / "recording_ids.json"
|
||||||
|
ids_path.write_text(json.dumps({"id": "foo"}))
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="Expected JSON list"):
|
||||||
|
build_recording_condition(IdInListConfig(path=ids_path))
|
||||||
|
|
||||||
|
|
||||||
|
def test_id_in_list_condition_raises_for_invalid_id(tmp_path: Path) -> None:
|
||||||
|
ids_path = tmp_path / "recording_ids.json"
|
||||||
|
ids_path.write_text(json.dumps(["not-a-uuid"]))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid ID"):
|
||||||
|
build_recording_condition(IdInListConfig(path=ids_path))
|
||||||
|
|
||||||
|
|
||||||
|
def test_id_in_list_condition_supports_txt_format(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||||
|
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||||
|
ids_path = tmp_path / "recording_ids.txt"
|
||||||
|
ids_path.write_text(f"{recording_a.uuid}\n")
|
||||||
|
|
||||||
|
condition = build_recording_condition(
|
||||||
|
IdInListConfig(path=ids_path, list_format="txt")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording_a)
|
||||||
|
assert not condition(recording_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recording_has_tag_conditions(
|
||||||
|
tmp_path: Path, create_recording
|
||||||
|
) -> None:
|
||||||
|
train = data.Tag(key="split", value="train")
|
||||||
|
uk = data.Tag(key="region", value="uk")
|
||||||
|
eu = data.Tag(key="region", value="eu")
|
||||||
|
|
||||||
|
recording = create_recording(
|
||||||
|
path=tmp_path / "rec.wav",
|
||||||
|
tags=[train, uk],
|
||||||
|
)
|
||||||
|
|
||||||
|
has_train = build_recording_condition(HasTagConfig(tag=train))
|
||||||
|
has_all = build_recording_condition(HasAllTagsConfig(tags=[train, uk]))
|
||||||
|
has_any = build_recording_condition(HasAnyTagConfig(tags=[eu, uk]))
|
||||||
|
|
||||||
|
assert has_train(recording)
|
||||||
|
assert has_all(recording)
|
||||||
|
assert has_any(recording)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recording_logical_conditions(
|
||||||
|
tmp_path: Path, create_recording
|
||||||
|
) -> None:
|
||||||
|
train = data.Tag(key="split", value="train")
|
||||||
|
uk = data.Tag(key="region", value="uk")
|
||||||
|
eu = data.Tag(key="region", value="eu")
|
||||||
|
|
||||||
|
recording = create_recording(
|
||||||
|
path=tmp_path / "rec.wav",
|
||||||
|
tags=[train, uk],
|
||||||
|
)
|
||||||
|
|
||||||
|
all_condition = build_recording_condition(
|
||||||
|
RecordingAllOfConfig(
|
||||||
|
conditions=[
|
||||||
|
HasTagConfig(tag=train),
|
||||||
|
HasAnyTagConfig(tags=[eu, uk]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
any_condition = build_recording_condition(
|
||||||
|
RecordingAnyOfConfig(
|
||||||
|
conditions=[
|
||||||
|
HasTagConfig(tag=eu),
|
||||||
|
HasTagConfig(tag=train),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
not_condition = build_recording_condition(
|
||||||
|
RecordingNotConfig(condition=HasTagConfig(tag=eu))
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all_condition(recording)
|
||||||
|
assert any_condition(recording)
|
||||||
|
assert not_condition(recording)
|
||||||
@ -1,4 +1,6 @@
|
|||||||
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
@ -6,16 +8,17 @@ from pydantic import TypeAdapter
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.conditions import (
|
from batdetect2.data.conditions import (
|
||||||
|
IdInListConfig,
|
||||||
SoundEventConditionConfig,
|
SoundEventConditionConfig,
|
||||||
build_sound_event_condition,
|
build_sound_event_condition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_condition_from_str(content):
|
def build_condition_from_str(content, base_dir: Path | None = None):
|
||||||
content = textwrap.dedent(content)
|
content = textwrap.dedent(content)
|
||||||
content = yaml.safe_load(content)
|
content = yaml.safe_load(content)
|
||||||
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
|
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
|
||||||
return build_sound_event_condition(config)
|
return build_sound_event_condition(config, base_dir=base_dir)
|
||||||
|
|
||||||
|
|
||||||
def test_has_tag(sound_event: data.SoundEvent):
|
def test_has_tag(sound_event: data.SoundEvent):
|
||||||
@ -160,6 +163,36 @@ def test_not(sound_event: data.SoundEvent):
|
|||||||
assert not condition(sound_event_annotation)
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def test_id_in_list(sound_event: data.SoundEvent, tmp_path: Path):
|
||||||
|
se1 = data.SoundEventAnnotation(sound_event=sound_event)
|
||||||
|
se2 = data.SoundEventAnnotation(sound_event=sound_event)
|
||||||
|
ids_path = tmp_path / "sound_event_ids.json"
|
||||||
|
ids_path.write_text(json.dumps([str(se1.uuid)]))
|
||||||
|
|
||||||
|
condition = build_sound_event_condition(IdInListConfig(path=ids_path))
|
||||||
|
|
||||||
|
assert condition(se1)
|
||||||
|
assert not condition(se2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_id_in_list_uses_base_dir(
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
se = data.SoundEventAnnotation(sound_event=sound_event)
|
||||||
|
split_dir = tmp_path / "splits"
|
||||||
|
split_dir.mkdir()
|
||||||
|
ids_path = split_dir / "sound_event_ids.json"
|
||||||
|
ids_path.write_text(json.dumps([str(se.uuid)]))
|
||||||
|
|
||||||
|
condition = build_sound_event_condition(
|
||||||
|
IdInListConfig(path=Path("splits/sound_event_ids.json")),
|
||||||
|
base_dir=tmp_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(se)
|
||||||
|
|
||||||
|
|
||||||
def test_duration(recording: data.Recording):
|
def test_duration(recording: data.Recording):
|
||||||
se1 = data.SoundEventAnnotation(
|
se1 = data.SoundEventAnnotation(
|
||||||
sound_event=data.SoundEvent(
|
sound_event=data.SoundEvent(
|
||||||
Loading…
Reference in New Issue
Block a user