mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Add conditions for clips and recordings
This commit is contained in:
parent
e80fe8675d
commit
c8dd4155bf
@ -1,312 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
||||
|
||||
|
||||
@add_import_config(conditions)
|
||||
class SoundEventConditionImportConfig(ImportConfig):
|
||||
"""Use any callable as a sound event condition.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
name: Literal["has_tag"] = "has_tag"
|
||||
tag: data.Tag
|
||||
|
||||
|
||||
class HasTag:
|
||||
def __init__(self, tag: data.Tag):
|
||||
self.tag = tag
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return any(
|
||||
self.tag.term.name == tag.term.name and self.tag.value == tag.value
|
||||
for tag in sound_event_annotation.tags
|
||||
)
|
||||
|
||||
@conditions.register(HasTagConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasTagConfig):
|
||||
return HasTag(tag=config.tag)
|
||||
|
||||
|
||||
class HasAllTagsConfig(BaseConfig):
|
||||
name: Literal["has_all_tags"] = "has_all_tags"
|
||||
tags: List[data.Tag]
|
||||
|
||||
|
||||
class HasAllTags:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return self.tags.issubset(
|
||||
{(tag.term.name, tag.value) for tag in sound_event_annotation.tags}
|
||||
)
|
||||
|
||||
@conditions.register(HasAllTagsConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasAllTagsConfig):
|
||||
return HasAllTags(tags=config.tags)
|
||||
|
||||
|
||||
class HasAnyTagConfig(BaseConfig):
|
||||
name: Literal["has_any_tag"] = "has_any_tag"
|
||||
tags: List[data.Tag]
|
||||
|
||||
|
||||
class HasAnyTag:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return bool(
|
||||
self.tags.intersection(
|
||||
{
|
||||
(tag.term.name, tag.value)
|
||||
for tag in sound_event_annotation.tags
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@conditions.register(HasAnyTagConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasAnyTagConfig):
|
||||
return HasAnyTag(tags=config.tags)
|
||||
|
||||
|
||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||
|
||||
|
||||
class DurationConfig(BaseConfig):
|
||||
name: Literal["duration"] = "duration"
|
||||
operator: Operator
|
||||
seconds: float
|
||||
|
||||
|
||||
def _build_comparator(
|
||||
operator: Operator, value: float
|
||||
) -> Callable[[float], bool]:
|
||||
if operator == "gt":
|
||||
return lambda x: x > value
|
||||
|
||||
if operator == "gte":
|
||||
return lambda x: x >= value
|
||||
|
||||
if operator == "lt":
|
||||
return lambda x: x < value
|
||||
|
||||
if operator == "lte":
|
||||
return lambda x: x <= value
|
||||
|
||||
if operator == "eq":
|
||||
return lambda x: x == value
|
||||
|
||||
raise ValueError(f"Invalid operator {operator}")
|
||||
|
||||
|
||||
class Duration:
|
||||
def __init__(self, operator: Operator, seconds: float):
|
||||
self.operator = operator
|
||||
self.seconds = seconds
|
||||
self._comparator = _build_comparator(self.operator, self.seconds)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||
duration = end_time - start_time
|
||||
|
||||
return self._comparator(duration)
|
||||
|
||||
@conditions.register(DurationConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DurationConfig):
|
||||
return Duration(operator=config.operator, seconds=config.seconds)
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
name: Literal["frequency"] = "frequency"
|
||||
boundary: Literal["low", "high"]
|
||||
operator: Operator
|
||||
hertz: float
|
||||
|
||||
|
||||
class Frequency:
|
||||
def __init__(
|
||||
self,
|
||||
operator: Operator,
|
||||
boundary: Literal["low", "high"],
|
||||
hertz: float,
|
||||
):
|
||||
self.operator = operator
|
||||
self.hertz = hertz
|
||||
self.boundary = boundary
|
||||
self._comparator = _build_comparator(self.operator, self.hertz)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
# Automatically false if geometry does not have a frequency range
|
||||
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
||||
return False
|
||||
|
||||
_, low_freq, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
if self.boundary == "low":
|
||||
return self._comparator(low_freq)
|
||||
|
||||
return self._comparator(high_freq)
|
||||
|
||||
@conditions.register(FrequencyConfig)
|
||||
@staticmethod
|
||||
def from_config(config: FrequencyConfig):
|
||||
return Frequency(
|
||||
operator=config.operator,
|
||||
boundary=config.boundary,
|
||||
hertz=config.hertz,
|
||||
)
|
||||
|
||||
|
||||
class AllOfConfig(BaseConfig):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
class AllOf:
|
||||
def __init__(self, conditions: List[SoundEventCondition]):
|
||||
self.conditions = conditions
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return all(c(sound_event_annotation) for c in self.conditions)
|
||||
|
||||
@conditions.register(AllOfConfig)
|
||||
@staticmethod
|
||||
def from_config(config: AllOfConfig):
|
||||
conditions = [
|
||||
build_sound_event_condition(cond) for cond in config.conditions
|
||||
]
|
||||
return AllOf(conditions)
|
||||
|
||||
|
||||
class AnyOfConfig(BaseConfig):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: List["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
class AnyOf:
|
||||
def __init__(self, conditions: List[SoundEventCondition]):
|
||||
self.conditions = conditions
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return any(c(sound_event_annotation) for c in self.conditions)
|
||||
|
||||
@conditions.register(AnyOfConfig)
|
||||
@staticmethod
|
||||
def from_config(config: AnyOfConfig):
|
||||
conditions = [
|
||||
build_sound_event_condition(cond) for cond in config.conditions
|
||||
]
|
||||
return AnyOf(conditions)
|
||||
|
||||
|
||||
class NotConfig(BaseConfig):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "SoundEventConditionConfig"
|
||||
|
||||
|
||||
class Not:
|
||||
def __init__(self, condition: SoundEventCondition):
|
||||
self.condition = condition
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return not self.condition(sound_event_annotation)
|
||||
|
||||
@conditions.register(NotConfig)
|
||||
@staticmethod
|
||||
def from_config(config: NotConfig):
|
||||
condition = build_sound_event_condition(config.condition)
|
||||
return Not(condition)
|
||||
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
| DurationConfig
|
||||
| FrequencyConfig
|
||||
| AllOfConfig
|
||||
| AnyOfConfig
|
||||
| NotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_sound_event_condition(
|
||||
config: SoundEventConditionConfig,
|
||||
) -> SoundEventCondition:
|
||||
return conditions.build(config)
|
||||
|
||||
|
||||
def filter_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
condition: SoundEventCondition,
|
||||
) -> data.ClipAnnotation:
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
sound_events=[
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if condition(sound_event)
|
||||
]
|
||||
)
|
||||
)
|
||||
71
src/batdetect2/data/conditions/__init__.py
Normal file
71
src/batdetect2/data/conditions/__init__.py
Normal file
@ -0,0 +1,71 @@
|
||||
from batdetect2.data.conditions.clips import (
|
||||
ClipAllOfConfig,
|
||||
ClipAnnotationCondition,
|
||||
ClipAnnotationConditionConfig,
|
||||
ClipAnnotationConditionImportConfig,
|
||||
ClipAnyOfConfig,
|
||||
ClipNotConfig,
|
||||
RecordingSatisfiesConfig,
|
||||
build_clip_annotation_condition,
|
||||
)
|
||||
from batdetect2.data.conditions.common import (
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
)
|
||||
from batdetect2.data.conditions.recordings import (
|
||||
RecordingAllOfConfig,
|
||||
RecordingAnyOfConfig,
|
||||
RecordingCondition,
|
||||
RecordingConditionConfig,
|
||||
RecordingConditionImportConfig,
|
||||
RecordingNotConfig,
|
||||
build_recording_condition,
|
||||
)
|
||||
from batdetect2.data.conditions.sound_events import (
|
||||
AllOfConfig,
|
||||
AnyOfConfig,
|
||||
DurationConfig,
|
||||
FrequencyConfig,
|
||||
NotConfig,
|
||||
Operator,
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
SoundEventConditionImportConfig,
|
||||
build_sound_event_condition,
|
||||
filter_clip_annotation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AllOfConfig",
|
||||
"AnyOfConfig",
|
||||
"ClipAllOfConfig",
|
||||
"ClipAnnotationCondition",
|
||||
"ClipAnnotationConditionConfig",
|
||||
"ClipAnnotationConditionImportConfig",
|
||||
"ClipAnyOfConfig",
|
||||
"ClipNotConfig",
|
||||
"DurationConfig",
|
||||
"FrequencyConfig",
|
||||
"HasAllTagsConfig",
|
||||
"HasAnyTagConfig",
|
||||
"HasTagConfig",
|
||||
"IdInListConfig",
|
||||
"NotConfig",
|
||||
"Operator",
|
||||
"RecordingCondition",
|
||||
"RecordingConditionConfig",
|
||||
"RecordingConditionImportConfig",
|
||||
"RecordingAllOfConfig",
|
||||
"RecordingAnyOfConfig",
|
||||
"RecordingNotConfig",
|
||||
"RecordingSatisfiesConfig",
|
||||
"SoundEventCondition",
|
||||
"SoundEventConditionConfig",
|
||||
"SoundEventConditionImportConfig",
|
||||
"build_clip_annotation_condition",
|
||||
"build_recording_condition",
|
||||
"build_sound_event_condition",
|
||||
"filter_clip_annotation",
|
||||
]
|
||||
132
src/batdetect2/data/conditions/clips.py
Normal file
132
src/batdetect2/data/conditions/clips.py
Normal file
@ -0,0 +1,132 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.data.conditions.common import (
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
MultiConditionConfigBase,
|
||||
NotConditionConfigBase,
|
||||
register_all_of_condition,
|
||||
register_any_of_condition,
|
||||
register_has_all_tags_condition,
|
||||
register_has_any_tag_condition,
|
||||
register_has_tag_condition,
|
||||
register_id_in_list_condition,
|
||||
register_not_condition,
|
||||
)
|
||||
from batdetect2.data.conditions.recordings import (
|
||||
RecordingCondition,
|
||||
RecordingConditionConfig,
|
||||
build_recording_condition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ClipAllOfConfig",
|
||||
"ClipAnnotationCondition",
|
||||
"ClipAnnotationConditionConfig",
|
||||
"ClipAnnotationConditionImportConfig",
|
||||
"ClipAnyOfConfig",
|
||||
"ClipNotConfig",
|
||||
"RecordingSatisfiesConfig",
|
||||
"build_clip_annotation_condition",
|
||||
]
|
||||
|
||||
ClipAnnotationCondition = Callable[[data.ClipAnnotation], bool]
|
||||
|
||||
clip_annotation_conditions: Registry[
|
||||
ClipAnnotationCondition,
|
||||
[data.PathLike | None],
|
||||
] = Registry("clip_condition")
|
||||
|
||||
|
||||
@add_import_config(clip_annotation_conditions, arg_names=["base_dir"])
|
||||
class ClipAnnotationConditionImportConfig(ImportConfig):
|
||||
"""Use any callable as a clip annotation condition.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
||||
to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class RecordingSatisfiesConfig(BaseConfig):
|
||||
name: Literal["recording_satisfies"] = "recording_satisfies"
|
||||
condition: RecordingConditionConfig
|
||||
|
||||
|
||||
class RecordingSatisfies:
|
||||
def __init__(self, condition: RecordingCondition):
|
||||
self.condition = condition
|
||||
|
||||
def __call__(self, clip_annotation: data.ClipAnnotation) -> bool:
|
||||
recording = clip_annotation.clip.recording
|
||||
return self.condition(recording)
|
||||
|
||||
@clip_annotation_conditions.register(RecordingSatisfiesConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: RecordingSatisfiesConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> "RecordingSatisfies":
|
||||
condition = build_recording_condition(
|
||||
config.condition,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
return RecordingSatisfies(condition)
|
||||
|
||||
|
||||
register_has_tag_condition(clip_annotation_conditions)(HasTagConfig)
|
||||
register_has_all_tags_condition(clip_annotation_conditions)(HasAllTagsConfig)
|
||||
register_has_any_tag_condition(clip_annotation_conditions)(HasAnyTagConfig)
|
||||
register_id_in_list_condition(clip_annotation_conditions)(IdInListConfig)
|
||||
|
||||
|
||||
@register_all_of_condition(clip_annotation_conditions)
|
||||
class ClipAllOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["ClipAnnotationConditionConfig"]
|
||||
|
||||
|
||||
@register_any_of_condition(clip_annotation_conditions)
|
||||
class ClipAnyOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: Sequence["ClipAnnotationConditionConfig"]
|
||||
|
||||
|
||||
@register_not_condition(clip_annotation_conditions)
|
||||
class ClipNotConfig(NotConditionConfigBase):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "ClipAnnotationConditionConfig"
|
||||
|
||||
|
||||
ClipAnnotationConditionConfig = Annotated[
|
||||
RecordingSatisfiesConfig
|
||||
| IdInListConfig
|
||||
| HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
| ClipAllOfConfig
|
||||
| ClipAnyOfConfig
|
||||
| ClipNotConfig
|
||||
| ClipAnnotationConditionImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clip_annotation_condition(
|
||||
config: ClipAnnotationConditionConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> ClipAnnotationCondition:
|
||||
return clip_annotation_conditions.build(config, base_dir)
|
||||
314
src/batdetect2/data/conditions/common.py
Normal file
314
src/batdetect2/data/conditions/common.py
Normal file
@ -0,0 +1,314 @@
|
||||
import json
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Generic, Literal, ParamSpec, Protocol, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
|
||||
__all__ = [
|
||||
"AllOf",
|
||||
"AnyOf",
|
||||
"Condition",
|
||||
"HasAllTags",
|
||||
"HasAllTagsConfig",
|
||||
"HasAnyTag",
|
||||
"HasAnyTagConfig",
|
||||
"HasTag",
|
||||
"HasTagConfig",
|
||||
"IdInList",
|
||||
"IdInListConfig",
|
||||
"MultiConditionConfigBase",
|
||||
"Not",
|
||||
"NotConditionConfigBase",
|
||||
"ObjectWithTags",
|
||||
"ObjectWithUUID",
|
||||
"register_all_of_condition",
|
||||
"register_any_of_condition",
|
||||
"register_has_all_tags_condition",
|
||||
"register_has_any_tag_condition",
|
||||
"register_has_tag_condition",
|
||||
"register_id_in_list_condition",
|
||||
"register_not_condition",
|
||||
]
|
||||
|
||||
|
||||
class ObjectWithTags(Protocol):
|
||||
tags: list[data.Tag]
|
||||
|
||||
|
||||
class ObjectWithUUID(Protocol):
|
||||
uuid: UUID
|
||||
|
||||
|
||||
ConditionObject = TypeVar("ConditionObject")
|
||||
TaggedObject = TypeVar("TaggedObject", bound="ObjectWithTags")
|
||||
UUIDObject = TypeVar("UUIDObject", bound="ObjectWithUUID")
|
||||
P = ParamSpec("P")
|
||||
NotConfigType = TypeVar("NotConfigType", bound="NotConditionConfigBase")
|
||||
MultiConfigType = TypeVar(
|
||||
"MultiConfigType",
|
||||
bound="MultiConditionConfigBase",
|
||||
)
|
||||
Condition = Callable[[ConditionObject], bool]
|
||||
|
||||
|
||||
class NotConditionConfigBase(BaseConfig):
|
||||
condition: BaseModel
|
||||
|
||||
|
||||
class MultiConditionConfigBase(BaseConfig):
|
||||
conditions: Sequence[BaseModel]
|
||||
|
||||
|
||||
class Not(Generic[ConditionObject]):
|
||||
def __init__(self, condition: Condition[ConditionObject]):
|
||||
self.condition = condition
|
||||
|
||||
def __call__(self, obj: ConditionObject) -> bool:
|
||||
return not self.condition(obj)
|
||||
|
||||
|
||||
class AllOf(Generic[ConditionObject]):
|
||||
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
|
||||
self.conditions = list(conditions)
|
||||
|
||||
def __call__(self, obj: ConditionObject) -> bool:
|
||||
return all(condition(obj) for condition in self.conditions)
|
||||
|
||||
|
||||
class AnyOf(Generic[ConditionObject]):
|
||||
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
|
||||
self.conditions = list(conditions)
|
||||
|
||||
def __call__(self, obj: ConditionObject) -> bool:
|
||||
return any(condition(obj) for condition in self.conditions)
|
||||
|
||||
|
||||
class HasTag(Generic[TaggedObject]):
|
||||
def __init__(self, tag: data.Tag):
|
||||
self.tag_key = (tag.term.name, tag.value)
|
||||
|
||||
def __call__(self, obj: TaggedObject) -> bool:
|
||||
return any(
|
||||
(tag.term.name, tag.value) == self.tag_key for tag in obj.tags
|
||||
)
|
||||
|
||||
|
||||
class HasAllTags(Generic[TaggedObject]):
|
||||
def __init__(self, tags: list[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
|
||||
|
||||
def __call__(self, obj: TaggedObject) -> bool:
|
||||
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
|
||||
return self.required_keys.issubset(tag_keys)
|
||||
|
||||
|
||||
class HasAnyTag(Generic[TaggedObject]):
|
||||
def __init__(self, tags: list[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
|
||||
|
||||
def __call__(self, obj: TaggedObject) -> bool:
|
||||
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
|
||||
return bool(self.required_keys.intersection(tag_keys))
|
||||
|
||||
|
||||
class IdInList(Generic[UUIDObject]):
|
||||
def __init__(self, ids: set[UUID]):
|
||||
self.ids = ids
|
||||
|
||||
def __call__(self, obj: UUIDObject) -> bool:
|
||||
return obj.uuid in self.ids
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
name: Literal["has_tag"] = "has_tag"
|
||||
tag: data.Tag
|
||||
|
||||
|
||||
class HasAllTagsConfig(BaseConfig):
|
||||
name: Literal["has_all_tags"] = "has_all_tags"
|
||||
tags: list[data.Tag]
|
||||
|
||||
|
||||
class HasAnyTagConfig(BaseConfig):
|
||||
name: Literal["has_any_tag"] = "has_any_tag"
|
||||
tags: list[data.Tag]
|
||||
|
||||
|
||||
class IdInListConfig(BaseConfig):
|
||||
name: Literal["id_in_list"] = "id_in_list"
|
||||
path: Path
|
||||
list_format: Literal["json", "txt"] = "json"
|
||||
|
||||
|
||||
def _load_ids(
|
||||
path: Path,
|
||||
list_format: Literal["json", "txt"],
|
||||
) -> list[str]:
|
||||
if list_format == "json":
|
||||
content = json.loads(path.read_text())
|
||||
|
||||
if not isinstance(content, list):
|
||||
raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
|
||||
|
||||
return [str(value) for value in content]
|
||||
|
||||
return [
|
||||
line.strip() for line in path.read_text().splitlines() if line.strip()
|
||||
]
|
||||
|
||||
|
||||
def register_id_in_list_condition(
|
||||
registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
|
||||
) -> Callable[[type[IdInListConfig]], type[IdInListConfig]]:
|
||||
def decorator(config_cls: type[IdInListConfig]) -> type[IdInListConfig]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: IdInListConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Condition[UUIDObject]:
|
||||
path = config.path
|
||||
|
||||
if base_dir is not None and not path.is_absolute():
|
||||
path = Path(base_dir) / path
|
||||
|
||||
ids = set()
|
||||
for index, value in enumerate(_load_ids(path, config.list_format)):
|
||||
try:
|
||||
ids.add(UUID(value))
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
f"Invalid ID at index {index} in '{path}': {value!r}."
|
||||
) from err
|
||||
|
||||
return IdInList(ids)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_has_tag_condition(
|
||||
registry: Registry[Condition[TaggedObject], P],
|
||||
) -> Callable[[type[HasTagConfig]], type[HasTagConfig]]:
|
||||
def decorator(config_cls: type[HasTagConfig]) -> type[HasTagConfig]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: HasTagConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasTag(config.tag)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_has_all_tags_condition(
|
||||
registry: Registry[Condition[TaggedObject], P],
|
||||
) -> Callable[[type[HasAllTagsConfig]], type[HasAllTagsConfig]]:
|
||||
def decorator(
|
||||
config_cls: type[HasAllTagsConfig],
|
||||
) -> type[HasAllTagsConfig]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: HasAllTagsConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasAllTags(config.tags)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_has_any_tag_condition(
|
||||
registry: Registry[Condition[TaggedObject], P],
|
||||
) -> Callable[[type[HasAnyTagConfig]], type[HasAnyTagConfig]]:
|
||||
def decorator(
|
||||
config_cls: type[HasAnyTagConfig],
|
||||
) -> type[HasAnyTagConfig]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: HasAnyTagConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasAnyTag(config.tags)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_not_condition(
|
||||
registry: Registry[Condition[ConditionObject], P],
|
||||
) -> Callable[[type[NotConfigType]], type[NotConfigType]]:
|
||||
def decorator(config_cls: type[NotConfigType]) -> type[NotConfigType]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: NotConfigType,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[ConditionObject]:
|
||||
condition = registry.build(config.condition, *args, **kwargs)
|
||||
return Not(condition)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_all_of_condition(
|
||||
registry: Registry[Condition[ConditionObject], P],
|
||||
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
|
||||
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: MultiConfigType,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[ConditionObject]:
|
||||
conditions = [
|
||||
registry.build(condition, *args, **kwargs)
|
||||
for condition in config.conditions
|
||||
]
|
||||
return AllOf(conditions)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_any_of_condition(
|
||||
registry: Registry[Condition[ConditionObject], P],
|
||||
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
|
||||
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: MultiConfigType,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[ConditionObject]:
|
||||
conditions = [
|
||||
registry.build(condition, *args, **kwargs)
|
||||
for condition in config.conditions
|
||||
]
|
||||
return AnyOf(conditions)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
98
src/batdetect2/data/conditions/recordings.py
Normal file
98
src/batdetect2/data/conditions/recordings.py
Normal file
@ -0,0 +1,98 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.data.conditions.common import (
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
MultiConditionConfigBase,
|
||||
NotConditionConfigBase,
|
||||
register_all_of_condition,
|
||||
register_any_of_condition,
|
||||
register_has_all_tags_condition,
|
||||
register_has_any_tag_condition,
|
||||
register_has_tag_condition,
|
||||
register_id_in_list_condition,
|
||||
register_not_condition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"IdInListConfig",
|
||||
"RecordingAllOfConfig",
|
||||
"RecordingAnyOfConfig",
|
||||
"RecordingCondition",
|
||||
"RecordingConditionConfig",
|
||||
"RecordingConditionImportConfig",
|
||||
"RecordingNotConfig",
|
||||
"build_recording_condition",
|
||||
]
|
||||
|
||||
RecordingCondition = Callable[[data.Recording], bool]
|
||||
|
||||
recording_conditions: Registry[RecordingCondition, [data.PathLike | None]] = (
|
||||
Registry("recording_condition")
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(recording_conditions, arg_names=["base_dir"])
|
||||
class RecordingConditionImportConfig(ImportConfig):
|
||||
"""Use any callable as a recording condition.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
||||
to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
register_id_in_list_condition(recording_conditions)(IdInListConfig)
|
||||
register_has_tag_condition(recording_conditions)(HasTagConfig)
|
||||
register_has_all_tags_condition(recording_conditions)(HasAllTagsConfig)
|
||||
register_has_any_tag_condition(recording_conditions)(HasAnyTagConfig)
|
||||
|
||||
|
||||
@register_all_of_condition(recording_conditions)
|
||||
class RecordingAllOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["RecordingConditionConfig"]
|
||||
|
||||
|
||||
@register_any_of_condition(recording_conditions)
|
||||
class RecordingAnyOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: Sequence["RecordingConditionConfig"]
|
||||
|
||||
|
||||
@register_not_condition(recording_conditions)
|
||||
class RecordingNotConfig(NotConditionConfigBase):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "RecordingConditionConfig"
|
||||
|
||||
|
||||
RecordingConditionConfig = Annotated[
|
||||
IdInListConfig
|
||||
| HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
| RecordingAllOfConfig
|
||||
| RecordingAnyOfConfig
|
||||
| RecordingNotConfig
|
||||
| RecordingConditionImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_recording_condition(
|
||||
config: RecordingConditionConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> RecordingCondition:
|
||||
return recording_conditions.build(config, base_dir)
|
||||
236
src/batdetect2/data/conditions/sound_events.py
Normal file
236
src/batdetect2/data/conditions/sound_events.py
Normal file
@ -0,0 +1,236 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.data.conditions.common import (
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
MultiConditionConfigBase,
|
||||
NotConditionConfigBase,
|
||||
register_all_of_condition,
|
||||
register_any_of_condition,
|
||||
register_has_all_tags_condition,
|
||||
register_has_any_tag_condition,
|
||||
register_has_tag_condition,
|
||||
register_id_in_list_condition,
|
||||
register_not_condition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AllOfConfig",
|
||||
"AnyOfConfig",
|
||||
"DurationConfig",
|
||||
"FrequencyConfig",
|
||||
"HasAllTagsConfig",
|
||||
"HasAnyTagConfig",
|
||||
"HasTagConfig",
|
||||
"NotConfig",
|
||||
"Operator",
|
||||
"SoundEventCondition",
|
||||
"SoundEventConditionConfig",
|
||||
"SoundEventConditionImportConfig",
|
||||
"build_sound_event_condition",
|
||||
"filter_clip_annotation",
|
||||
]
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
sound_event_conditions: Registry[
|
||||
SoundEventCondition,
|
||||
[data.PathLike | None],
|
||||
] = Registry("sound_event_condition")
|
||||
|
||||
|
||||
@add_import_config(sound_event_conditions, arg_names=["base_dir"])
|
||||
class SoundEventConditionImportConfig(ImportConfig):
|
||||
"""Use any callable as a sound event condition.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
register_has_tag_condition(sound_event_conditions)(HasTagConfig)
|
||||
register_has_all_tags_condition(sound_event_conditions)(HasAllTagsConfig)
|
||||
register_has_any_tag_condition(sound_event_conditions)(HasAnyTagConfig)
|
||||
register_id_in_list_condition(sound_event_conditions)(IdInListConfig)
|
||||
|
||||
|
||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||
|
||||
|
||||
class DurationConfig(BaseConfig):
|
||||
name: Literal["duration"] = "duration"
|
||||
operator: Operator
|
||||
seconds: float
|
||||
|
||||
|
||||
def _build_comparator(
|
||||
operator: Operator, value: float
|
||||
) -> Callable[[float], bool]:
|
||||
if operator == "gt":
|
||||
return lambda x: x > value
|
||||
|
||||
if operator == "gte":
|
||||
return lambda x: x >= value
|
||||
|
||||
if operator == "lt":
|
||||
return lambda x: x < value
|
||||
|
||||
if operator == "lte":
|
||||
return lambda x: x <= value
|
||||
|
||||
if operator == "eq":
|
||||
return lambda x: x == value
|
||||
|
||||
raise ValueError(f"Invalid operator {operator}")
|
||||
|
||||
|
||||
class Duration:
|
||||
def __init__(self, operator: Operator, seconds: float):
|
||||
self.operator = operator
|
||||
self.seconds = seconds
|
||||
self._comparator = _build_comparator(self.operator, self.seconds)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||
duration = end_time - start_time
|
||||
|
||||
return self._comparator(duration)
|
||||
|
||||
@sound_event_conditions.register(DurationConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: DurationConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
):
|
||||
_ = base_dir
|
||||
return Duration(operator=config.operator, seconds=config.seconds)
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
name: Literal["frequency"] = "frequency"
|
||||
boundary: Literal["low", "high"]
|
||||
operator: Operator
|
||||
hertz: float
|
||||
|
||||
|
||||
class Frequency:
|
||||
def __init__(
|
||||
self,
|
||||
operator: Operator,
|
||||
boundary: Literal["low", "high"],
|
||||
hertz: float,
|
||||
):
|
||||
self.operator = operator
|
||||
self.hertz = hertz
|
||||
self.boundary = boundary
|
||||
self._comparator = _build_comparator(self.operator, self.hertz)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
||||
return False
|
||||
|
||||
_, low_freq, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
if self.boundary == "low":
|
||||
return self._comparator(low_freq)
|
||||
|
||||
return self._comparator(high_freq)
|
||||
|
||||
@sound_event_conditions.register(FrequencyConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: FrequencyConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
):
|
||||
_ = base_dir
|
||||
return Frequency(
|
||||
operator=config.operator,
|
||||
boundary=config.boundary,
|
||||
hertz=config.hertz,
|
||||
)
|
||||
|
||||
|
||||
@register_all_of_condition(sound_event_conditions)
|
||||
class AllOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
@register_any_of_condition(sound_event_conditions)
|
||||
class AnyOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: list["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
@register_not_condition(sound_event_conditions)
|
||||
class NotConfig(NotConditionConfigBase):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "SoundEventConditionConfig"
|
||||
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
IdInListConfig
|
||||
| HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
| DurationConfig
|
||||
| FrequencyConfig
|
||||
| AllOfConfig
|
||||
| AnyOfConfig
|
||||
| NotConfig
|
||||
| SoundEventConditionImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_sound_event_condition(
|
||||
config: SoundEventConditionConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> SoundEventCondition:
|
||||
return sound_event_conditions.build(config, base_dir)
|
||||
|
||||
|
||||
def filter_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
condition: SoundEventCondition,
|
||||
) -> data.ClipAnnotation:
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
sound_events=[
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if condition(sound_event)
|
||||
]
|
||||
)
|
||||
)
|
||||
132
tests/test_data/test_conditions/test_clip.py
Normal file
132
tests/test_data/test_conditions/test_clip.py
Normal file
@ -0,0 +1,132 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.conditions import (
|
||||
ClipAllOfConfig,
|
||||
ClipAnyOfConfig,
|
||||
ClipNotConfig,
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
RecordingSatisfiesConfig,
|
||||
build_clip_annotation_condition,
|
||||
)
|
||||
|
||||
|
||||
def test_recording_satisfies_condition(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
clip_a = create_clip(recording_a)
|
||||
clip_b = create_clip(recording_b)
|
||||
clip_annotation_a = create_clip_annotation(clip_a)
|
||||
clip_annotation_b = create_clip_annotation(clip_b)
|
||||
ids_path = tmp_path / "recording_ids.json"
|
||||
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
|
||||
|
||||
condition = build_clip_annotation_condition(
|
||||
RecordingSatisfiesConfig(
|
||||
condition=IdInListConfig(path=ids_path),
|
||||
)
|
||||
)
|
||||
|
||||
assert condition(clip_annotation_a)
|
||||
assert not condition(clip_annotation_b)
|
||||
|
||||
|
||||
def test_clip_id_in_list_condition(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
clip_annotation_a = create_clip_annotation(create_clip(recording_a))
|
||||
clip_annotation_b = create_clip_annotation(create_clip(recording_b))
|
||||
ids_path = tmp_path / "clip_annotation_ids.json"
|
||||
ids_path.write_text(json.dumps([str(clip_annotation_a.uuid)]))
|
||||
|
||||
condition = build_clip_annotation_condition(IdInListConfig(path=ids_path))
|
||||
|
||||
assert condition(clip_annotation_a)
|
||||
assert not condition(clip_annotation_b)
|
||||
|
||||
|
||||
def test_clip_has_tag_conditions(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
reviewed = data.Tag(key="status", value="reviewed")
|
||||
train = data.Tag(key="split", value="train")
|
||||
val = data.Tag(key="split", value="val")
|
||||
|
||||
recording = create_recording(path=tmp_path / "rec.wav")
|
||||
clip = create_clip(recording)
|
||||
clip_annotation = create_clip_annotation(
|
||||
clip,
|
||||
clip_tags=[reviewed, train],
|
||||
)
|
||||
|
||||
has_tag = build_clip_annotation_condition(HasTagConfig(tag=reviewed))
|
||||
has_all = build_clip_annotation_condition(
|
||||
HasAllTagsConfig(tags=[reviewed, train])
|
||||
)
|
||||
has_any = build_clip_annotation_condition(
|
||||
HasAnyTagConfig(tags=[val, train])
|
||||
)
|
||||
|
||||
assert has_tag(clip_annotation)
|
||||
assert has_all(clip_annotation)
|
||||
assert has_any(clip_annotation)
|
||||
|
||||
|
||||
def test_clip_logical_conditions(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
reviewed = data.Tag(key="status", value="reviewed")
|
||||
train = data.Tag(key="split", value="train")
|
||||
val = data.Tag(key="split", value="val")
|
||||
|
||||
recording = create_recording(path=tmp_path / "rec.wav")
|
||||
clip = create_clip(recording)
|
||||
clip_annotation = create_clip_annotation(
|
||||
clip,
|
||||
clip_tags=[reviewed, train],
|
||||
)
|
||||
|
||||
all_condition = build_clip_annotation_condition(
|
||||
ClipAllOfConfig(
|
||||
conditions=[
|
||||
HasTagConfig(tag=reviewed),
|
||||
HasAnyTagConfig(tags=[train, val]),
|
||||
]
|
||||
)
|
||||
)
|
||||
any_condition = build_clip_annotation_condition(
|
||||
ClipAnyOfConfig(
|
||||
conditions=[
|
||||
HasTagConfig(tag=val),
|
||||
HasTagConfig(tag=reviewed),
|
||||
]
|
||||
)
|
||||
)
|
||||
not_condition = build_clip_annotation_condition(
|
||||
ClipNotConfig(condition=HasTagConfig(tag=val))
|
||||
)
|
||||
|
||||
assert all_condition(clip_annotation)
|
||||
assert any_condition(clip_annotation)
|
||||
assert not_condition(clip_annotation)
|
||||
139
tests/test_data/test_conditions/test_recording.py
Normal file
139
tests/test_data/test_conditions/test_recording.py
Normal file
@ -0,0 +1,139 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.conditions import (
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
RecordingAllOfConfig,
|
||||
RecordingAnyOfConfig,
|
||||
RecordingNotConfig,
|
||||
build_recording_condition,
|
||||
)
|
||||
|
||||
|
||||
def test_id_in_list_condition(tmp_path: Path, create_recording) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
ids_path = tmp_path / "recording_ids.json"
|
||||
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
|
||||
|
||||
condition = build_recording_condition(IdInListConfig(path=ids_path))
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_id_in_list_condition_uses_base_dir(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
recording = create_recording(path=tmp_path / "a.wav")
|
||||
split_dir = tmp_path / "splits"
|
||||
split_dir.mkdir()
|
||||
ids_path = split_dir / "train_ids.json"
|
||||
ids_path.write_text(json.dumps([str(recording.uuid)]))
|
||||
|
||||
condition = build_recording_condition(
|
||||
IdInListConfig(path=Path("splits/train_ids.json")),
|
||||
base_dir=tmp_path,
|
||||
)
|
||||
|
||||
assert condition(recording)
|
||||
|
||||
|
||||
def test_id_in_list_condition_raises_for_non_list_json(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
ids_path = tmp_path / "recording_ids.json"
|
||||
ids_path.write_text(json.dumps({"id": "foo"}))
|
||||
|
||||
with pytest.raises(TypeError, match="Expected JSON list"):
|
||||
build_recording_condition(IdInListConfig(path=ids_path))
|
||||
|
||||
|
||||
def test_id_in_list_condition_raises_for_invalid_id(tmp_path: Path) -> None:
|
||||
ids_path = tmp_path / "recording_ids.json"
|
||||
ids_path.write_text(json.dumps(["not-a-uuid"]))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid ID"):
|
||||
build_recording_condition(IdInListConfig(path=ids_path))
|
||||
|
||||
|
||||
def test_id_in_list_condition_supports_txt_format(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
ids_path = tmp_path / "recording_ids.txt"
|
||||
ids_path.write_text(f"{recording_a.uuid}\n")
|
||||
|
||||
condition = build_recording_condition(
|
||||
IdInListConfig(path=ids_path, list_format="txt")
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_recording_has_tag_conditions(
|
||||
tmp_path: Path, create_recording
|
||||
) -> None:
|
||||
train = data.Tag(key="split", value="train")
|
||||
uk = data.Tag(key="region", value="uk")
|
||||
eu = data.Tag(key="region", value="eu")
|
||||
|
||||
recording = create_recording(
|
||||
path=tmp_path / "rec.wav",
|
||||
tags=[train, uk],
|
||||
)
|
||||
|
||||
has_train = build_recording_condition(HasTagConfig(tag=train))
|
||||
has_all = build_recording_condition(HasAllTagsConfig(tags=[train, uk]))
|
||||
has_any = build_recording_condition(HasAnyTagConfig(tags=[eu, uk]))
|
||||
|
||||
assert has_train(recording)
|
||||
assert has_all(recording)
|
||||
assert has_any(recording)
|
||||
|
||||
|
||||
def test_recording_logical_conditions(
|
||||
tmp_path: Path, create_recording
|
||||
) -> None:
|
||||
train = data.Tag(key="split", value="train")
|
||||
uk = data.Tag(key="region", value="uk")
|
||||
eu = data.Tag(key="region", value="eu")
|
||||
|
||||
recording = create_recording(
|
||||
path=tmp_path / "rec.wav",
|
||||
tags=[train, uk],
|
||||
)
|
||||
|
||||
all_condition = build_recording_condition(
|
||||
RecordingAllOfConfig(
|
||||
conditions=[
|
||||
HasTagConfig(tag=train),
|
||||
HasAnyTagConfig(tags=[eu, uk]),
|
||||
]
|
||||
)
|
||||
)
|
||||
any_condition = build_recording_condition(
|
||||
RecordingAnyOfConfig(
|
||||
conditions=[
|
||||
HasTagConfig(tag=eu),
|
||||
HasTagConfig(tag=train),
|
||||
]
|
||||
)
|
||||
)
|
||||
not_condition = build_recording_condition(
|
||||
RecordingNotConfig(condition=HasTagConfig(tag=eu))
|
||||
)
|
||||
|
||||
assert all_condition(recording)
|
||||
assert any_condition(recording)
|
||||
assert not_condition(recording)
|
||||
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
@ -6,16 +8,17 @@ from pydantic import TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.conditions import (
|
||||
IdInListConfig,
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
|
||||
|
||||
def build_condition_from_str(content):
|
||||
def build_condition_from_str(content, base_dir: Path | None = None):
|
||||
content = textwrap.dedent(content)
|
||||
content = yaml.safe_load(content)
|
||||
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
|
||||
return build_sound_event_condition(config)
|
||||
return build_sound_event_condition(config, base_dir=base_dir)
|
||||
|
||||
|
||||
def test_has_tag(sound_event: data.SoundEvent):
|
||||
@ -160,6 +163,36 @@ def test_not(sound_event: data.SoundEvent):
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
|
||||
def test_id_in_list(sound_event: data.SoundEvent, tmp_path: Path):
|
||||
se1 = data.SoundEventAnnotation(sound_event=sound_event)
|
||||
se2 = data.SoundEventAnnotation(sound_event=sound_event)
|
||||
ids_path = tmp_path / "sound_event_ids.json"
|
||||
ids_path.write_text(json.dumps([str(se1.uuid)]))
|
||||
|
||||
condition = build_sound_event_condition(IdInListConfig(path=ids_path))
|
||||
|
||||
assert condition(se1)
|
||||
assert not condition(se2)
|
||||
|
||||
|
||||
def test_id_in_list_uses_base_dir(
|
||||
sound_event: data.SoundEvent,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
se = data.SoundEventAnnotation(sound_event=sound_event)
|
||||
split_dir = tmp_path / "splits"
|
||||
split_dir.mkdir()
|
||||
ids_path = split_dir / "sound_event_ids.json"
|
||||
ids_path.write_text(json.dumps([str(se.uuid)]))
|
||||
|
||||
condition = build_sound_event_condition(
|
||||
IdInListConfig(path=Path("splits/sound_event_ids.json")),
|
||||
base_dir=tmp_path,
|
||||
)
|
||||
|
||||
assert condition(se)
|
||||
|
||||
|
||||
def test_duration(recording: data.Recording):
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
Loading…
Reference in New Issue
Block a user