mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Add test for target filtering
This commit is contained in:
parent
d97614a10d
commit
02d4779207
@ -1,7 +1,6 @@
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Set
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
@ -128,6 +128,14 @@ def clip(recording: data.Recording) -> data.Clip:
|
||||
return data.Clip(recording=recording, start_time=0, end_time=0.5)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sound_event(recording: data.Recording) -> data.SoundEvent:
|
||||
return data.SoundEvent(
|
||||
geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]),
|
||||
recording=recording,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
|
195
tests/test_targets/test_filtering.py
Normal file
195
tests/test_targets/test_filtering.py
Normal file
@ -0,0 +1,195 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Set
|
||||
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets.filtering import (
|
||||
FilterConfig,
|
||||
FilterRule,
|
||||
build_filter_from_config,
|
||||
build_filter_from_rule,
|
||||
contains_tags,
|
||||
does_not_have_tags,
|
||||
equal_tags,
|
||||
has_any_tag,
|
||||
load_filter_config,
|
||||
load_filter_from_config,
|
||||
merge_filters,
|
||||
)
|
||||
from batdetect2.targets.terms import TagInfo, generic_class
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_annotation(
|
||||
sound_event: data.SoundEvent,
|
||||
) -> Callable[[List[str]], data.SoundEventAnnotation]:
|
||||
"""Helper function to create a SoundEventAnnotation with given tags."""
|
||||
|
||||
def factory(tags: List[str]) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=generic_class,
|
||||
value=tag,
|
||||
)
|
||||
for tag in tags
|
||||
],
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
def create_tag_set(tags: List[str]) -> Set[data.Tag]:
|
||||
"""Helper function to create a set of data.Tag objects from a list of strings."""
|
||||
return {
|
||||
data.Tag(
|
||||
term=generic_class,
|
||||
value=tag,
|
||||
)
|
||||
for tag in tags
|
||||
}
|
||||
|
||||
|
||||
def test_has_any_tag(create_annotation):
|
||||
annotation = create_annotation(["tag1", "tag2"])
|
||||
tags = create_tag_set(["tag1", "tag3"])
|
||||
assert has_any_tag(annotation, tags) is True
|
||||
|
||||
annotation = create_annotation(["tag2", "tag4"])
|
||||
tags = create_tag_set(["tag1", "tag3"])
|
||||
assert has_any_tag(annotation, tags) is False
|
||||
|
||||
|
||||
def test_contains_tags(create_annotation):
|
||||
annotation = create_annotation(["tag1", "tag2", "tag3"])
|
||||
tags = create_tag_set(["tag1", "tag2"])
|
||||
assert contains_tags(annotation, tags) is True
|
||||
|
||||
annotation = create_annotation(["tag1", "tag2"])
|
||||
tags = create_tag_set(["tag1", "tag2", "tag3"])
|
||||
assert contains_tags(annotation, tags) is False
|
||||
|
||||
|
||||
def test_does_not_have_tags(create_annotation):
|
||||
annotation = create_annotation(["tag1", "tag2"])
|
||||
tags = create_tag_set(["tag3", "tag4"])
|
||||
assert does_not_have_tags(annotation, tags) is True
|
||||
|
||||
annotation = create_annotation(["tag1", "tag2"])
|
||||
tags = create_tag_set(["tag1", "tag3"])
|
||||
assert does_not_have_tags(annotation, tags) is False
|
||||
|
||||
|
||||
def test_equal_tags(create_annotation):
|
||||
annotation = create_annotation(["tag1", "tag2"])
|
||||
tags = create_tag_set(["tag1", "tag2"])
|
||||
assert equal_tags(annotation, tags) is True
|
||||
|
||||
annotation = create_annotation(["tag1", "tag2", "tag3"])
|
||||
tags = create_tag_set(["tag1", "tag2"])
|
||||
assert equal_tags(annotation, tags) is False
|
||||
|
||||
|
||||
def test_build_filter_from_rule():
|
||||
rule_any = FilterRule(match_type="any", tags=[TagInfo(value="tag1")])
|
||||
build_filter_from_rule(rule_any)
|
||||
|
||||
rule_all = FilterRule(match_type="all", tags=[TagInfo(value="tag1")])
|
||||
build_filter_from_rule(rule_all)
|
||||
|
||||
rule_exclude = FilterRule(
|
||||
match_type="exclude", tags=[TagInfo(value="tag1")]
|
||||
)
|
||||
build_filter_from_rule(rule_exclude)
|
||||
|
||||
rule_equal = FilterRule(match_type="equal", tags=[TagInfo(value="tag1")])
|
||||
build_filter_from_rule(rule_equal)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
|
||||
build_filter_from_rule(
|
||||
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def test_merge_filters(create_annotation):
|
||||
def filter1(annotation):
|
||||
return "tag1" in [tag.value for tag in annotation.tags]
|
||||
|
||||
def filter2(annotation):
|
||||
return "tag2" in [tag.value for tag in annotation.tags]
|
||||
|
||||
merged_filter = merge_filters(filter1, filter2)
|
||||
|
||||
annotation_pass = create_annotation(["tag1", "tag2"])
|
||||
assert merged_filter(annotation_pass) is True
|
||||
|
||||
annotation_fail = create_annotation(["tag1"])
|
||||
assert merged_filter(annotation_fail) is False
|
||||
|
||||
|
||||
def test_build_filter_from_config(create_annotation):
|
||||
config = FilterConfig(
|
||||
rules=[
|
||||
FilterRule(match_type="any", tags=[TagInfo(value="tag1")]),
|
||||
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
|
||||
]
|
||||
)
|
||||
filter_from_config = build_filter_from_config(config)
|
||||
|
||||
annotation_pass = create_annotation(["tag1", "tag2"])
|
||||
assert filter_from_config(annotation_pass)
|
||||
|
||||
annotation_fail = create_annotation(["tag1"])
|
||||
assert not filter_from_config(annotation_fail)
|
||||
|
||||
|
||||
def test_load_filter_config(tmp_path: Path):
|
||||
test_config_path = tmp_path / "filtering.yaml"
|
||||
test_config_path.write_text(
|
||||
"""
|
||||
rules:
|
||||
- match_type: any
|
||||
tags:
|
||||
- value: tag1
|
||||
"""
|
||||
)
|
||||
config = load_filter_config(test_config_path)
|
||||
assert isinstance(config, FilterConfig)
|
||||
assert len(config.rules) == 1
|
||||
rule = config.rules[0]
|
||||
assert rule.match_type == "any"
|
||||
assert len(rule.tags) == 1
|
||||
assert rule.tags[0].value == "tag1"
|
||||
|
||||
|
||||
def test_load_filter_from_config(tmp_path: Path, create_annotation):
|
||||
test_config_path = tmp_path / "filtering.yaml"
|
||||
test_config_path.write_text(
|
||||
"""
|
||||
rules:
|
||||
- match_type: any
|
||||
tags:
|
||||
- value: tag1
|
||||
"""
|
||||
)
|
||||
|
||||
filter_result = load_filter_from_config(test_config_path)
|
||||
annotation = create_annotation(["tag1", "tag3"])
|
||||
assert filter_result(annotation)
|
||||
|
||||
test_config_path = tmp_path / "filtering.yaml"
|
||||
test_config_path.write_text(
|
||||
"""
|
||||
rules:
|
||||
- match_type: any
|
||||
tags:
|
||||
- value: tag2
|
||||
"""
|
||||
)
|
||||
|
||||
filter_result = load_filter_from_config(test_config_path)
|
||||
annotation = create_annotation(["tag1", "tag3"])
|
||||
assert filter_result(annotation) is False
|
Loading…
Reference in New Issue
Block a user