Add conditions for clips and recordings

This commit is contained in:
mbsantiago 2026-04-03 16:40:11 +01:00
parent e80fe8675d
commit c8dd4155bf
10 changed files with 1157 additions and 314 deletions

View File

@ -1,312 +0,0 @@
from collections.abc import Callable
from typing import Annotated, List, Literal, Sequence
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
conditions: Registry[SoundEventCondition, []] = Registry("condition")
@add_import_config(conditions)
class SoundEventConditionImportConfig(ImportConfig):
"""Use any callable as a sound event condition.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag"
tag: data.Tag
class HasTag:
def __init__(self, tag: data.Tag):
self.tag = tag
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return any(
self.tag.term.name == tag.term.name and self.tag.value == tag.value
for tag in sound_event_annotation.tags
)
@conditions.register(HasTagConfig)
@staticmethod
def from_config(config: HasTagConfig):
return HasTag(tag=config.tag)
class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags"
tags: List[data.Tag]
class HasAllTags:
def __init__(self, tags: List[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = {(tag.term.name, tag.value) for tag in tags}
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return self.tags.issubset(
{(tag.term.name, tag.value) for tag in sound_event_annotation.tags}
)
@conditions.register(HasAllTagsConfig)
@staticmethod
def from_config(config: HasAllTagsConfig):
return HasAllTags(tags=config.tags)
class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag"
tags: List[data.Tag]
class HasAnyTag:
def __init__(self, tags: List[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = {(tag.term.name, tag.value) for tag in tags}
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return bool(
self.tags.intersection(
{
(tag.term.name, tag.value)
for tag in sound_event_annotation.tags
}
)
)
@conditions.register(HasAnyTagConfig)
@staticmethod
def from_config(config: HasAnyTagConfig):
return HasAnyTag(tags=config.tags)
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
class DurationConfig(BaseConfig):
name: Literal["duration"] = "duration"
operator: Operator
seconds: float
def _build_comparator(
operator: Operator, value: float
) -> Callable[[float], bool]:
if operator == "gt":
return lambda x: x > value
if operator == "gte":
return lambda x: x >= value
if operator == "lt":
return lambda x: x < value
if operator == "lte":
return lambda x: x <= value
if operator == "eq":
return lambda x: x == value
raise ValueError(f"Invalid operator {operator}")
class Duration:
def __init__(self, operator: Operator, seconds: float):
self.operator = operator
self.seconds = seconds
self._comparator = _build_comparator(self.operator, self.seconds)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
start_time, _, end_time, _ = compute_bounds(geometry)
duration = end_time - start_time
return self._comparator(duration)
@conditions.register(DurationConfig)
@staticmethod
def from_config(config: DurationConfig):
return Duration(operator=config.operator, seconds=config.seconds)
class FrequencyConfig(BaseConfig):
name: Literal["frequency"] = "frequency"
boundary: Literal["low", "high"]
operator: Operator
hertz: float
class Frequency:
def __init__(
self,
operator: Operator,
boundary: Literal["low", "high"],
hertz: float,
):
self.operator = operator
self.hertz = hertz
self.boundary = boundary
self._comparator = _build_comparator(self.operator, self.hertz)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
# Automatically false if geometry does not have a frequency range
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
return False
_, low_freq, _, high_freq = compute_bounds(geometry)
if self.boundary == "low":
return self._comparator(low_freq)
return self._comparator(high_freq)
@conditions.register(FrequencyConfig)
@staticmethod
def from_config(config: FrequencyConfig):
return Frequency(
operator=config.operator,
boundary=config.boundary,
hertz=config.hertz,
)
class AllOfConfig(BaseConfig):
name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"]
class AllOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return all(c(sound_event_annotation) for c in self.conditions)
@conditions.register(AllOfConfig)
@staticmethod
def from_config(config: AllOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return AllOf(conditions)
class AnyOfConfig(BaseConfig):
name: Literal["any_of"] = "any_of"
conditions: List["SoundEventConditionConfig"]
class AnyOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return any(c(sound_event_annotation) for c in self.conditions)
@conditions.register(AnyOfConfig)
@staticmethod
def from_config(config: AnyOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return AnyOf(conditions)
class NotConfig(BaseConfig):
name: Literal["not"] = "not"
condition: "SoundEventConditionConfig"
class Not:
def __init__(self, condition: SoundEventCondition):
self.condition = condition
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return not self.condition(sound_event_annotation)
@conditions.register(NotConfig)
@staticmethod
def from_config(config: NotConfig):
condition = build_sound_event_condition(config.condition)
return Not(condition)
SoundEventConditionConfig = Annotated[
HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| DurationConfig
| FrequencyConfig
| AllOfConfig
| AnyOfConfig
| NotConfig,
Field(discriminator="name"),
]
def build_sound_event_condition(
config: SoundEventConditionConfig,
) -> SoundEventCondition:
return conditions.build(config)
def filter_clip_annotation(
clip_annotation: data.ClipAnnotation,
condition: SoundEventCondition,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
sound_event
for sound_event in clip_annotation.sound_events
if condition(sound_event)
]
)
)

View File

@ -0,0 +1,71 @@
from batdetect2.data.conditions.clips import (
ClipAllOfConfig,
ClipAnnotationCondition,
ClipAnnotationConditionConfig,
ClipAnnotationConditionImportConfig,
ClipAnyOfConfig,
ClipNotConfig,
RecordingSatisfiesConfig,
build_clip_annotation_condition,
)
from batdetect2.data.conditions.common import (
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
)
from batdetect2.data.conditions.recordings import (
RecordingAllOfConfig,
RecordingAnyOfConfig,
RecordingCondition,
RecordingConditionConfig,
RecordingConditionImportConfig,
RecordingNotConfig,
build_recording_condition,
)
from batdetect2.data.conditions.sound_events import (
AllOfConfig,
AnyOfConfig,
DurationConfig,
FrequencyConfig,
NotConfig,
Operator,
SoundEventCondition,
SoundEventConditionConfig,
SoundEventConditionImportConfig,
build_sound_event_condition,
filter_clip_annotation,
)
__all__ = [
"AllOfConfig",
"AnyOfConfig",
"ClipAllOfConfig",
"ClipAnnotationCondition",
"ClipAnnotationConditionConfig",
"ClipAnnotationConditionImportConfig",
"ClipAnyOfConfig",
"ClipNotConfig",
"DurationConfig",
"FrequencyConfig",
"HasAllTagsConfig",
"HasAnyTagConfig",
"HasTagConfig",
"IdInListConfig",
"NotConfig",
"Operator",
"RecordingCondition",
"RecordingConditionConfig",
"RecordingConditionImportConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingNotConfig",
"RecordingSatisfiesConfig",
"SoundEventCondition",
"SoundEventConditionConfig",
"SoundEventConditionImportConfig",
"build_clip_annotation_condition",
"build_recording_condition",
"build_sound_event_condition",
"filter_clip_annotation",
]

View File

@ -0,0 +1,132 @@
from collections.abc import Callable, Sequence
from typing import Annotated, Literal
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.data.conditions.common import (
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
MultiConditionConfigBase,
NotConditionConfigBase,
register_all_of_condition,
register_any_of_condition,
register_has_all_tags_condition,
register_has_any_tag_condition,
register_has_tag_condition,
register_id_in_list_condition,
register_not_condition,
)
from batdetect2.data.conditions.recordings import (
RecordingCondition,
RecordingConditionConfig,
build_recording_condition,
)
__all__ = [
"ClipAllOfConfig",
"ClipAnnotationCondition",
"ClipAnnotationConditionConfig",
"ClipAnnotationConditionImportConfig",
"ClipAnyOfConfig",
"ClipNotConfig",
"RecordingSatisfiesConfig",
"build_clip_annotation_condition",
]
ClipAnnotationCondition = Callable[[data.ClipAnnotation], bool]
clip_annotation_conditions: Registry[
ClipAnnotationCondition,
[data.PathLike | None],
] = Registry("clip_condition")
@add_import_config(clip_annotation_conditions, arg_names=["base_dir"])
class ClipAnnotationConditionImportConfig(ImportConfig):
"""Use any callable as a clip annotation condition.
Set ``name="import"`` and provide a ``target`` pointing to any callable
to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class RecordingSatisfiesConfig(BaseConfig):
name: Literal["recording_satisfies"] = "recording_satisfies"
condition: RecordingConditionConfig
class RecordingSatisfies:
def __init__(self, condition: RecordingCondition):
self.condition = condition
def __call__(self, clip_annotation: data.ClipAnnotation) -> bool:
recording = clip_annotation.clip.recording
return self.condition(recording)
@clip_annotation_conditions.register(RecordingSatisfiesConfig)
@staticmethod
def from_config(
config: RecordingSatisfiesConfig,
base_dir: data.PathLike | None = None,
) -> "RecordingSatisfies":
condition = build_recording_condition(
config.condition,
base_dir=base_dir,
)
return RecordingSatisfies(condition)
register_has_tag_condition(clip_annotation_conditions)(HasTagConfig)
register_has_all_tags_condition(clip_annotation_conditions)(HasAllTagsConfig)
register_has_any_tag_condition(clip_annotation_conditions)(HasAnyTagConfig)
register_id_in_list_condition(clip_annotation_conditions)(IdInListConfig)
@register_all_of_condition(clip_annotation_conditions)
class ClipAllOfConfig(MultiConditionConfigBase):
name: Literal["all_of"] = "all_of"
conditions: Sequence["ClipAnnotationConditionConfig"]
@register_any_of_condition(clip_annotation_conditions)
class ClipAnyOfConfig(MultiConditionConfigBase):
name: Literal["any_of"] = "any_of"
conditions: Sequence["ClipAnnotationConditionConfig"]
@register_not_condition(clip_annotation_conditions)
class ClipNotConfig(NotConditionConfigBase):
name: Literal["not"] = "not"
condition: "ClipAnnotationConditionConfig"
ClipAnnotationConditionConfig = Annotated[
RecordingSatisfiesConfig
| IdInListConfig
| HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| ClipAllOfConfig
| ClipAnyOfConfig
| ClipNotConfig
| ClipAnnotationConditionImportConfig,
Field(discriminator="name"),
]
def build_clip_annotation_condition(
config: ClipAnnotationConditionConfig,
base_dir: data.PathLike | None = None,
) -> ClipAnnotationCondition:
return clip_annotation_conditions.build(config, base_dir)

View File

@ -0,0 +1,314 @@
import json
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Generic, Literal, ParamSpec, Protocol, TypeVar
from uuid import UUID
from pydantic import BaseModel
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
__all__ = [
"AllOf",
"AnyOf",
"Condition",
"HasAllTags",
"HasAllTagsConfig",
"HasAnyTag",
"HasAnyTagConfig",
"HasTag",
"HasTagConfig",
"IdInList",
"IdInListConfig",
"MultiConditionConfigBase",
"Not",
"NotConditionConfigBase",
"ObjectWithTags",
"ObjectWithUUID",
"register_all_of_condition",
"register_any_of_condition",
"register_has_all_tags_condition",
"register_has_any_tag_condition",
"register_has_tag_condition",
"register_id_in_list_condition",
"register_not_condition",
]
class ObjectWithTags(Protocol):
tags: list[data.Tag]
class ObjectWithUUID(Protocol):
uuid: UUID
ConditionObject = TypeVar("ConditionObject")
TaggedObject = TypeVar("TaggedObject", bound="ObjectWithTags")
UUIDObject = TypeVar("UUIDObject", bound="ObjectWithUUID")
P = ParamSpec("P")
NotConfigType = TypeVar("NotConfigType", bound="NotConditionConfigBase")
MultiConfigType = TypeVar(
"MultiConfigType",
bound="MultiConditionConfigBase",
)
Condition = Callable[[ConditionObject], bool]
class NotConditionConfigBase(BaseConfig):
condition: BaseModel
class MultiConditionConfigBase(BaseConfig):
conditions: Sequence[BaseModel]
class Not(Generic[ConditionObject]):
def __init__(self, condition: Condition[ConditionObject]):
self.condition = condition
def __call__(self, obj: ConditionObject) -> bool:
return not self.condition(obj)
class AllOf(Generic[ConditionObject]):
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
self.conditions = list(conditions)
def __call__(self, obj: ConditionObject) -> bool:
return all(condition(obj) for condition in self.conditions)
class AnyOf(Generic[ConditionObject]):
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
self.conditions = list(conditions)
def __call__(self, obj: ConditionObject) -> bool:
return any(condition(obj) for condition in self.conditions)
class HasTag(Generic[TaggedObject]):
def __init__(self, tag: data.Tag):
self.tag_key = (tag.term.name, tag.value)
def __call__(self, obj: TaggedObject) -> bool:
return any(
(tag.term.name, tag.value) == self.tag_key for tag in obj.tags
)
class HasAllTags(Generic[TaggedObject]):
def __init__(self, tags: list[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
def __call__(self, obj: TaggedObject) -> bool:
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
return self.required_keys.issubset(tag_keys)
class HasAnyTag(Generic[TaggedObject]):
def __init__(self, tags: list[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
def __call__(self, obj: TaggedObject) -> bool:
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
return bool(self.required_keys.intersection(tag_keys))
class IdInList(Generic[UUIDObject]):
def __init__(self, ids: set[UUID]):
self.ids = ids
def __call__(self, obj: UUIDObject) -> bool:
return obj.uuid in self.ids
class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag"
tag: data.Tag
class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags"
tags: list[data.Tag]
class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag"
tags: list[data.Tag]
class IdInListConfig(BaseConfig):
name: Literal["id_in_list"] = "id_in_list"
path: Path
list_format: Literal["json", "txt"] = "json"
def _load_ids(
path: Path,
list_format: Literal["json", "txt"],
) -> list[str]:
if list_format == "json":
content = json.loads(path.read_text())
if not isinstance(content, list):
raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
return [str(value) for value in content]
return [
line.strip() for line in path.read_text().splitlines() if line.strip()
]
def register_id_in_list_condition(
registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
) -> Callable[[type[IdInListConfig]], type[IdInListConfig]]:
def decorator(config_cls: type[IdInListConfig]) -> type[IdInListConfig]:
@registry.register(config_cls)
def builder(
config: IdInListConfig,
base_dir: data.PathLike | None = None,
) -> Condition[UUIDObject]:
path = config.path
if base_dir is not None and not path.is_absolute():
path = Path(base_dir) / path
ids = set()
for index, value in enumerate(_load_ids(path, config.list_format)):
try:
ids.add(UUID(value))
except ValueError as err:
raise ValueError(
f"Invalid ID at index {index} in '{path}': {value!r}."
) from err
return IdInList(ids)
return config_cls
return decorator
def register_has_tag_condition(
registry: Registry[Condition[TaggedObject], P],
) -> Callable[[type[HasTagConfig]], type[HasTagConfig]]:
def decorator(config_cls: type[HasTagConfig]) -> type[HasTagConfig]:
@registry.register(config_cls)
def builder(
config: HasTagConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasTag(config.tag)
return config_cls
return decorator
def register_has_all_tags_condition(
registry: Registry[Condition[TaggedObject], P],
) -> Callable[[type[HasAllTagsConfig]], type[HasAllTagsConfig]]:
def decorator(
config_cls: type[HasAllTagsConfig],
) -> type[HasAllTagsConfig]:
@registry.register(config_cls)
def builder(
config: HasAllTagsConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasAllTags(config.tags)
return config_cls
return decorator
def register_has_any_tag_condition(
registry: Registry[Condition[TaggedObject], P],
) -> Callable[[type[HasAnyTagConfig]], type[HasAnyTagConfig]]:
def decorator(
config_cls: type[HasAnyTagConfig],
) -> type[HasAnyTagConfig]:
@registry.register(config_cls)
def builder(
config: HasAnyTagConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasAnyTag(config.tags)
return config_cls
return decorator
def register_not_condition(
registry: Registry[Condition[ConditionObject], P],
) -> Callable[[type[NotConfigType]], type[NotConfigType]]:
def decorator(config_cls: type[NotConfigType]) -> type[NotConfigType]:
@registry.register(config_cls)
def builder(
config: NotConfigType,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[ConditionObject]:
condition = registry.build(config.condition, *args, **kwargs)
return Not(condition)
return config_cls
return decorator
def register_all_of_condition(
registry: Registry[Condition[ConditionObject], P],
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
@registry.register(config_cls)
def builder(
config: MultiConfigType,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[ConditionObject]:
conditions = [
registry.build(condition, *args, **kwargs)
for condition in config.conditions
]
return AllOf(conditions)
return config_cls
return decorator
def register_any_of_condition(
registry: Registry[Condition[ConditionObject], P],
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
@registry.register(config_cls)
def builder(
config: MultiConfigType,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[ConditionObject]:
conditions = [
registry.build(condition, *args, **kwargs)
for condition in config.conditions
]
return AnyOf(conditions)
return config_cls
return decorator

View File

@ -0,0 +1,98 @@
from collections.abc import Callable, Sequence
from typing import Annotated, Literal
from pydantic import Field
from soundevent import data
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.data.conditions.common import (
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
MultiConditionConfigBase,
NotConditionConfigBase,
register_all_of_condition,
register_any_of_condition,
register_has_all_tags_condition,
register_has_any_tag_condition,
register_has_tag_condition,
register_id_in_list_condition,
register_not_condition,
)
__all__ = [
"IdInListConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingCondition",
"RecordingConditionConfig",
"RecordingConditionImportConfig",
"RecordingNotConfig",
"build_recording_condition",
]
RecordingCondition = Callable[[data.Recording], bool]
recording_conditions: Registry[RecordingCondition, [data.PathLike | None]] = (
Registry("recording_condition")
)
@add_import_config(recording_conditions, arg_names=["base_dir"])
class RecordingConditionImportConfig(ImportConfig):
"""Use any callable as a recording condition.
Set ``name="import"`` and provide a ``target`` pointing to any callable
to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
register_id_in_list_condition(recording_conditions)(IdInListConfig)
register_has_tag_condition(recording_conditions)(HasTagConfig)
register_has_all_tags_condition(recording_conditions)(HasAllTagsConfig)
register_has_any_tag_condition(recording_conditions)(HasAnyTagConfig)
@register_all_of_condition(recording_conditions)
class RecordingAllOfConfig(MultiConditionConfigBase):
name: Literal["all_of"] = "all_of"
conditions: Sequence["RecordingConditionConfig"]
@register_any_of_condition(recording_conditions)
class RecordingAnyOfConfig(MultiConditionConfigBase):
name: Literal["any_of"] = "any_of"
conditions: Sequence["RecordingConditionConfig"]
@register_not_condition(recording_conditions)
class RecordingNotConfig(NotConditionConfigBase):
name: Literal["not"] = "not"
condition: "RecordingConditionConfig"
RecordingConditionConfig = Annotated[
IdInListConfig
| HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| RecordingAllOfConfig
| RecordingAnyOfConfig
| RecordingNotConfig
| RecordingConditionImportConfig,
Field(discriminator="name"),
]
def build_recording_condition(
config: RecordingConditionConfig,
base_dir: data.PathLike | None = None,
) -> RecordingCondition:
return recording_conditions.build(config, base_dir)

View File

@ -0,0 +1,236 @@
from collections.abc import Callable, Sequence
from typing import Annotated, Literal
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.data.conditions.common import (
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
MultiConditionConfigBase,
NotConditionConfigBase,
register_all_of_condition,
register_any_of_condition,
register_has_all_tags_condition,
register_has_any_tag_condition,
register_has_tag_condition,
register_id_in_list_condition,
register_not_condition,
)
__all__ = [
"AllOfConfig",
"AnyOfConfig",
"DurationConfig",
"FrequencyConfig",
"HasAllTagsConfig",
"HasAnyTagConfig",
"HasTagConfig",
"NotConfig",
"Operator",
"SoundEventCondition",
"SoundEventConditionConfig",
"SoundEventConditionImportConfig",
"build_sound_event_condition",
"filter_clip_annotation",
]
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
sound_event_conditions: Registry[
SoundEventCondition,
[data.PathLike | None],
] = Registry("sound_event_condition")
@add_import_config(sound_event_conditions, arg_names=["base_dir"])
class SoundEventConditionImportConfig(ImportConfig):
"""Use any callable as a sound event condition.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
register_has_tag_condition(sound_event_conditions)(HasTagConfig)
register_has_all_tags_condition(sound_event_conditions)(HasAllTagsConfig)
register_has_any_tag_condition(sound_event_conditions)(HasAnyTagConfig)
register_id_in_list_condition(sound_event_conditions)(IdInListConfig)
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
class DurationConfig(BaseConfig):
name: Literal["duration"] = "duration"
operator: Operator
seconds: float
def _build_comparator(
operator: Operator, value: float
) -> Callable[[float], bool]:
if operator == "gt":
return lambda x: x > value
if operator == "gte":
return lambda x: x >= value
if operator == "lt":
return lambda x: x < value
if operator == "lte":
return lambda x: x <= value
if operator == "eq":
return lambda x: x == value
raise ValueError(f"Invalid operator {operator}")
class Duration:
def __init__(self, operator: Operator, seconds: float):
self.operator = operator
self.seconds = seconds
self._comparator = _build_comparator(self.operator, self.seconds)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
start_time, _, end_time, _ = compute_bounds(geometry)
duration = end_time - start_time
return self._comparator(duration)
@sound_event_conditions.register(DurationConfig)
@staticmethod
def from_config(
config: DurationConfig,
base_dir: data.PathLike | None = None,
):
_ = base_dir
return Duration(operator=config.operator, seconds=config.seconds)
class FrequencyConfig(BaseConfig):
name: Literal["frequency"] = "frequency"
boundary: Literal["low", "high"]
operator: Operator
hertz: float
class Frequency:
def __init__(
self,
operator: Operator,
boundary: Literal["low", "high"],
hertz: float,
):
self.operator = operator
self.hertz = hertz
self.boundary = boundary
self._comparator = _build_comparator(self.operator, self.hertz)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
return False
_, low_freq, _, high_freq = compute_bounds(geometry)
if self.boundary == "low":
return self._comparator(low_freq)
return self._comparator(high_freq)
@sound_event_conditions.register(FrequencyConfig)
@staticmethod
def from_config(
config: FrequencyConfig,
base_dir: data.PathLike | None = None,
):
_ = base_dir
return Frequency(
operator=config.operator,
boundary=config.boundary,
hertz=config.hertz,
)
@register_all_of_condition(sound_event_conditions)
class AllOfConfig(MultiConditionConfigBase):
name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"]
@register_any_of_condition(sound_event_conditions)
class AnyOfConfig(MultiConditionConfigBase):
name: Literal["any_of"] = "any_of"
conditions: list["SoundEventConditionConfig"]
@register_not_condition(sound_event_conditions)
class NotConfig(NotConditionConfigBase):
name: Literal["not"] = "not"
condition: "SoundEventConditionConfig"
SoundEventConditionConfig = Annotated[
IdInListConfig
| HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| DurationConfig
| FrequencyConfig
| AllOfConfig
| AnyOfConfig
| NotConfig
| SoundEventConditionImportConfig,
Field(discriminator="name"),
]
def build_sound_event_condition(
config: SoundEventConditionConfig,
base_dir: data.PathLike | None = None,
) -> SoundEventCondition:
return sound_event_conditions.build(config, base_dir)
def filter_clip_annotation(
clip_annotation: data.ClipAnnotation,
condition: SoundEventCondition,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
sound_event
for sound_event in clip_annotation.sound_events
if condition(sound_event)
]
)
)

View File

@ -0,0 +1,132 @@
import json
from pathlib import Path
from soundevent import data
from batdetect2.data.conditions import (
ClipAllOfConfig,
ClipAnyOfConfig,
ClipNotConfig,
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
RecordingSatisfiesConfig,
build_clip_annotation_condition,
)
def test_recording_satisfies_condition(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
clip_a = create_clip(recording_a)
clip_b = create_clip(recording_b)
clip_annotation_a = create_clip_annotation(clip_a)
clip_annotation_b = create_clip_annotation(clip_b)
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
condition = build_clip_annotation_condition(
RecordingSatisfiesConfig(
condition=IdInListConfig(path=ids_path),
)
)
assert condition(clip_annotation_a)
assert not condition(clip_annotation_b)
def test_clip_id_in_list_condition(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
clip_annotation_a = create_clip_annotation(create_clip(recording_a))
clip_annotation_b = create_clip_annotation(create_clip(recording_b))
ids_path = tmp_path / "clip_annotation_ids.json"
ids_path.write_text(json.dumps([str(clip_annotation_a.uuid)]))
condition = build_clip_annotation_condition(IdInListConfig(path=ids_path))
assert condition(clip_annotation_a)
assert not condition(clip_annotation_b)
def test_clip_has_tag_conditions(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
reviewed = data.Tag(key="status", value="reviewed")
train = data.Tag(key="split", value="train")
val = data.Tag(key="split", value="val")
recording = create_recording(path=tmp_path / "rec.wav")
clip = create_clip(recording)
clip_annotation = create_clip_annotation(
clip,
clip_tags=[reviewed, train],
)
has_tag = build_clip_annotation_condition(HasTagConfig(tag=reviewed))
has_all = build_clip_annotation_condition(
HasAllTagsConfig(tags=[reviewed, train])
)
has_any = build_clip_annotation_condition(
HasAnyTagConfig(tags=[val, train])
)
assert has_tag(clip_annotation)
assert has_all(clip_annotation)
assert has_any(clip_annotation)
def test_clip_logical_conditions(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
reviewed = data.Tag(key="status", value="reviewed")
train = data.Tag(key="split", value="train")
val = data.Tag(key="split", value="val")
recording = create_recording(path=tmp_path / "rec.wav")
clip = create_clip(recording)
clip_annotation = create_clip_annotation(
clip,
clip_tags=[reviewed, train],
)
all_condition = build_clip_annotation_condition(
ClipAllOfConfig(
conditions=[
HasTagConfig(tag=reviewed),
HasAnyTagConfig(tags=[train, val]),
]
)
)
any_condition = build_clip_annotation_condition(
ClipAnyOfConfig(
conditions=[
HasTagConfig(tag=val),
HasTagConfig(tag=reviewed),
]
)
)
not_condition = build_clip_annotation_condition(
ClipNotConfig(condition=HasTagConfig(tag=val))
)
assert all_condition(clip_annotation)
assert any_condition(clip_annotation)
assert not_condition(clip_annotation)

View File

@ -0,0 +1,139 @@
import json
from pathlib import Path
import pytest
from soundevent import data
from batdetect2.data.conditions import (
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
RecordingAllOfConfig,
RecordingAnyOfConfig,
RecordingNotConfig,
build_recording_condition,
)
def test_id_in_list_condition(tmp_path: Path, create_recording) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
condition = build_recording_condition(IdInListConfig(path=ids_path))
assert condition(recording_a)
assert not condition(recording_b)
def test_id_in_list_condition_uses_base_dir(
tmp_path: Path,
create_recording,
) -> None:
recording = create_recording(path=tmp_path / "a.wav")
split_dir = tmp_path / "splits"
split_dir.mkdir()
ids_path = split_dir / "train_ids.json"
ids_path.write_text(json.dumps([str(recording.uuid)]))
condition = build_recording_condition(
IdInListConfig(path=Path("splits/train_ids.json")),
base_dir=tmp_path,
)
assert condition(recording)
def test_id_in_list_condition_raises_for_non_list_json(
tmp_path: Path,
) -> None:
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(json.dumps({"id": "foo"}))
with pytest.raises(TypeError, match="Expected JSON list"):
build_recording_condition(IdInListConfig(path=ids_path))
def test_id_in_list_condition_raises_for_invalid_id(tmp_path: Path) -> None:
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(json.dumps(["not-a-uuid"]))
with pytest.raises(ValueError, match="Invalid ID"):
build_recording_condition(IdInListConfig(path=ids_path))
def test_id_in_list_condition_supports_txt_format(
tmp_path: Path,
create_recording,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
ids_path = tmp_path / "recording_ids.txt"
ids_path.write_text(f"{recording_a.uuid}\n")
condition = build_recording_condition(
IdInListConfig(path=ids_path, list_format="txt")
)
assert condition(recording_a)
assert not condition(recording_b)
def test_recording_has_tag_conditions(
tmp_path: Path, create_recording
) -> None:
train = data.Tag(key="split", value="train")
uk = data.Tag(key="region", value="uk")
eu = data.Tag(key="region", value="eu")
recording = create_recording(
path=tmp_path / "rec.wav",
tags=[train, uk],
)
has_train = build_recording_condition(HasTagConfig(tag=train))
has_all = build_recording_condition(HasAllTagsConfig(tags=[train, uk]))
has_any = build_recording_condition(HasAnyTagConfig(tags=[eu, uk]))
assert has_train(recording)
assert has_all(recording)
assert has_any(recording)
def test_recording_logical_conditions(
tmp_path: Path, create_recording
) -> None:
train = data.Tag(key="split", value="train")
uk = data.Tag(key="region", value="uk")
eu = data.Tag(key="region", value="eu")
recording = create_recording(
path=tmp_path / "rec.wav",
tags=[train, uk],
)
all_condition = build_recording_condition(
RecordingAllOfConfig(
conditions=[
HasTagConfig(tag=train),
HasAnyTagConfig(tags=[eu, uk]),
]
)
)
any_condition = build_recording_condition(
RecordingAnyOfConfig(
conditions=[
HasTagConfig(tag=eu),
HasTagConfig(tag=train),
]
)
)
not_condition = build_recording_condition(
RecordingNotConfig(condition=HasTagConfig(tag=eu))
)
assert all_condition(recording)
assert any_condition(recording)
assert not_condition(recording)

View File

@ -1,4 +1,6 @@
import json
import textwrap import textwrap
from pathlib import Path
import pytest import pytest
import yaml import yaml
@ -6,16 +8,17 @@ from pydantic import TypeAdapter
from soundevent import data from soundevent import data
from batdetect2.data.conditions import ( from batdetect2.data.conditions import (
IdInListConfig,
SoundEventConditionConfig, SoundEventConditionConfig,
build_sound_event_condition, build_sound_event_condition,
) )
def build_condition_from_str(content): def build_condition_from_str(content, base_dir: Path | None = None):
content = textwrap.dedent(content) content = textwrap.dedent(content)
content = yaml.safe_load(content) content = yaml.safe_load(content)
config = TypeAdapter(SoundEventConditionConfig).validate_python(content) config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
return build_sound_event_condition(config) return build_sound_event_condition(config, base_dir=base_dir)
def test_has_tag(sound_event: data.SoundEvent): def test_has_tag(sound_event: data.SoundEvent):
@ -160,6 +163,36 @@ def test_not(sound_event: data.SoundEvent):
assert not condition(sound_event_annotation) assert not condition(sound_event_annotation)
def test_id_in_list(sound_event: data.SoundEvent, tmp_path: Path):
se1 = data.SoundEventAnnotation(sound_event=sound_event)
se2 = data.SoundEventAnnotation(sound_event=sound_event)
ids_path = tmp_path / "sound_event_ids.json"
ids_path.write_text(json.dumps([str(se1.uuid)]))
condition = build_sound_event_condition(IdInListConfig(path=ids_path))
assert condition(se1)
assert not condition(se2)
def test_id_in_list_uses_base_dir(
sound_event: data.SoundEvent,
tmp_path: Path,
) -> None:
se = data.SoundEventAnnotation(sound_event=sound_event)
split_dir = tmp_path / "splits"
split_dir.mkdir()
ids_path = split_dir / "sound_event_ids.json"
ids_path.write_text(json.dumps([str(se.uuid)]))
condition = build_sound_event_condition(
IdInListConfig(path=Path("splits/sound_event_ids.json")),
base_dir=tmp_path,
)
assert condition(se)
def test_duration(recording: data.Recording): def test_duration(recording: data.Recording):
se1 = data.SoundEventAnnotation( se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent( sound_event=data.SoundEvent(