Add test for target filtering

This commit is contained in:
mbsantiago 2025-04-14 13:31:30 +01:00
parent d97614a10d
commit 02d4779207
3 changed files with 204 additions and 2 deletions

View File

@ -1,7 +1,6 @@
from collections.abc import Iterable
from functools import partial
from pathlib import Path
from typing import Callable, List, Optional, Set
from typing import Callable, List, Optional
from pydantic import Field
from soundevent import data

View File

@ -128,6 +128,14 @@ def clip(recording: data.Recording) -> data.Clip:
return data.Clip(recording=recording, start_time=0, end_time=0.5)
@pytest.fixture
def sound_event(recording: data.Recording) -> data.SoundEvent:
return data.SoundEvent(
geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]),
recording=recording,
)
@pytest.fixture
def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(

View File

@ -0,0 +1,195 @@
from pathlib import Path
from typing import Callable, List, Set
import pytest
from soundevent import data
from batdetect2.targets.filtering import (
FilterConfig,
FilterRule,
build_filter_from_config,
build_filter_from_rule,
contains_tags,
does_not_have_tags,
equal_tags,
has_any_tag,
load_filter_config,
load_filter_from_config,
merge_filters,
)
from batdetect2.targets.terms import TagInfo, generic_class
@pytest.fixture
def create_annotation(
sound_event: data.SoundEvent,
) -> Callable[[List[str]], data.SoundEventAnnotation]:
"""Helper function to create a SoundEventAnnotation with given tags."""
def factory(tags: List[str]) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(
term=generic_class,
value=tag,
)
for tag in tags
],
)
return factory
def create_tag_set(tags: List[str]) -> Set[data.Tag]:
"""Helper function to create a set of data.Tag objects from a list of strings."""
return {
data.Tag(
term=generic_class,
value=tag,
)
for tag in tags
}
def test_has_any_tag(create_annotation):
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag3"])
assert has_any_tag(annotation, tags) is True
annotation = create_annotation(["tag2", "tag4"])
tags = create_tag_set(["tag1", "tag3"])
assert has_any_tag(annotation, tags) is False
def test_contains_tags(create_annotation):
annotation = create_annotation(["tag1", "tag2", "tag3"])
tags = create_tag_set(["tag1", "tag2"])
assert contains_tags(annotation, tags) is True
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag2", "tag3"])
assert contains_tags(annotation, tags) is False
def test_does_not_have_tags(create_annotation):
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag3", "tag4"])
assert does_not_have_tags(annotation, tags) is True
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag3"])
assert does_not_have_tags(annotation, tags) is False
def test_equal_tags(create_annotation):
annotation = create_annotation(["tag1", "tag2"])
tags = create_tag_set(["tag1", "tag2"])
assert equal_tags(annotation, tags) is True
annotation = create_annotation(["tag1", "tag2", "tag3"])
tags = create_tag_set(["tag1", "tag2"])
assert equal_tags(annotation, tags) is False
def test_build_filter_from_rule():
rule_any = FilterRule(match_type="any", tags=[TagInfo(value="tag1")])
build_filter_from_rule(rule_any)
rule_all = FilterRule(match_type="all", tags=[TagInfo(value="tag1")])
build_filter_from_rule(rule_all)
rule_exclude = FilterRule(
match_type="exclude", tags=[TagInfo(value="tag1")]
)
build_filter_from_rule(rule_exclude)
rule_equal = FilterRule(match_type="equal", tags=[TagInfo(value="tag1")])
build_filter_from_rule(rule_equal)
with pytest.raises(ValueError):
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
build_filter_from_rule(
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
)
def test_merge_filters(create_annotation):
def filter1(annotation):
return "tag1" in [tag.value for tag in annotation.tags]
def filter2(annotation):
return "tag2" in [tag.value for tag in annotation.tags]
merged_filter = merge_filters(filter1, filter2)
annotation_pass = create_annotation(["tag1", "tag2"])
assert merged_filter(annotation_pass) is True
annotation_fail = create_annotation(["tag1"])
assert merged_filter(annotation_fail) is False
def test_build_filter_from_config(create_annotation):
config = FilterConfig(
rules=[
FilterRule(match_type="any", tags=[TagInfo(value="tag1")]),
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
]
)
filter_from_config = build_filter_from_config(config)
annotation_pass = create_annotation(["tag1", "tag2"])
assert filter_from_config(annotation_pass)
annotation_fail = create_annotation(["tag1"])
assert not filter_from_config(annotation_fail)
def test_load_filter_config(tmp_path: Path):
test_config_path = tmp_path / "filtering.yaml"
test_config_path.write_text(
"""
rules:
- match_type: any
tags:
- value: tag1
"""
)
config = load_filter_config(test_config_path)
assert isinstance(config, FilterConfig)
assert len(config.rules) == 1
rule = config.rules[0]
assert rule.match_type == "any"
assert len(rule.tags) == 1
assert rule.tags[0].value == "tag1"
def test_load_filter_from_config(tmp_path: Path, create_annotation):
test_config_path = tmp_path / "filtering.yaml"
test_config_path.write_text(
"""
rules:
- match_type: any
tags:
- value: tag1
"""
)
filter_result = load_filter_from_config(test_config_path)
annotation = create_annotation(["tag1", "tag3"])
assert filter_result(annotation)
test_config_path = tmp_path / "filtering.yaml"
test_config_path.write_text(
"""
rules:
- match_type: any
tags:
- value: tag2
"""
)
filter_result = load_filter_from_config(test_config_path)
annotation = create_annotation(["tag1", "tag3"])
assert filter_result(annotation) is False