mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Add test for target filtering
This commit is contained in:
parent
d97614a10d
commit
02d4779207
@ -1,7 +1,6 @@
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional, Set
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
@ -128,6 +128,14 @@ def clip(recording: data.Recording) -> data.Clip:
|
|||||||
return data.Clip(recording=recording, start_time=0, end_time=0.5)
|
return data.Clip(recording=recording, start_time=0, end_time=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sound_event(recording: data.Recording) -> data.SoundEvent:
|
||||||
|
return data.SoundEvent(
|
||||||
|
geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]),
|
||||||
|
recording=recording,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation:
|
def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation:
|
||||||
return data.SoundEventAnnotation(
|
return data.SoundEventAnnotation(
|
||||||
|
195
tests/test_targets/test_filtering.py
Normal file
195
tests/test_targets/test_filtering.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, Set
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets.filtering import (
|
||||||
|
FilterConfig,
|
||||||
|
FilterRule,
|
||||||
|
build_filter_from_config,
|
||||||
|
build_filter_from_rule,
|
||||||
|
contains_tags,
|
||||||
|
does_not_have_tags,
|
||||||
|
equal_tags,
|
||||||
|
has_any_tag,
|
||||||
|
load_filter_config,
|
||||||
|
load_filter_from_config,
|
||||||
|
merge_filters,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.terms import TagInfo, generic_class
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_annotation(
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
) -> Callable[[List[str]], data.SoundEventAnnotation]:
|
||||||
|
"""Helper function to create a SoundEventAnnotation with given tags."""
|
||||||
|
|
||||||
|
def factory(tags: List[str]) -> data.SoundEventAnnotation:
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=generic_class,
|
||||||
|
value=tag,
|
||||||
|
)
|
||||||
|
for tag in tags
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
|
||||||
|
def create_tag_set(tags: List[str]) -> Set[data.Tag]:
|
||||||
|
"""Helper function to create a set of data.Tag objects from a list of strings."""
|
||||||
|
return {
|
||||||
|
data.Tag(
|
||||||
|
term=generic_class,
|
||||||
|
value=tag,
|
||||||
|
)
|
||||||
|
for tag in tags
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_any_tag(create_annotation):
|
||||||
|
annotation = create_annotation(["tag1", "tag2"])
|
||||||
|
tags = create_tag_set(["tag1", "tag3"])
|
||||||
|
assert has_any_tag(annotation, tags) is True
|
||||||
|
|
||||||
|
annotation = create_annotation(["tag2", "tag4"])
|
||||||
|
tags = create_tag_set(["tag1", "tag3"])
|
||||||
|
assert has_any_tag(annotation, tags) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_contains_tags(create_annotation):
|
||||||
|
annotation = create_annotation(["tag1", "tag2", "tag3"])
|
||||||
|
tags = create_tag_set(["tag1", "tag2"])
|
||||||
|
assert contains_tags(annotation, tags) is True
|
||||||
|
|
||||||
|
annotation = create_annotation(["tag1", "tag2"])
|
||||||
|
tags = create_tag_set(["tag1", "tag2", "tag3"])
|
||||||
|
assert contains_tags(annotation, tags) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_does_not_have_tags(create_annotation):
|
||||||
|
annotation = create_annotation(["tag1", "tag2"])
|
||||||
|
tags = create_tag_set(["tag3", "tag4"])
|
||||||
|
assert does_not_have_tags(annotation, tags) is True
|
||||||
|
|
||||||
|
annotation = create_annotation(["tag1", "tag2"])
|
||||||
|
tags = create_tag_set(["tag1", "tag3"])
|
||||||
|
assert does_not_have_tags(annotation, tags) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_equal_tags(create_annotation):
|
||||||
|
annotation = create_annotation(["tag1", "tag2"])
|
||||||
|
tags = create_tag_set(["tag1", "tag2"])
|
||||||
|
assert equal_tags(annotation, tags) is True
|
||||||
|
|
||||||
|
annotation = create_annotation(["tag1", "tag2", "tag3"])
|
||||||
|
tags = create_tag_set(["tag1", "tag2"])
|
||||||
|
assert equal_tags(annotation, tags) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_filter_from_rule():
|
||||||
|
rule_any = FilterRule(match_type="any", tags=[TagInfo(value="tag1")])
|
||||||
|
build_filter_from_rule(rule_any)
|
||||||
|
|
||||||
|
rule_all = FilterRule(match_type="all", tags=[TagInfo(value="tag1")])
|
||||||
|
build_filter_from_rule(rule_all)
|
||||||
|
|
||||||
|
rule_exclude = FilterRule(
|
||||||
|
match_type="exclude", tags=[TagInfo(value="tag1")]
|
||||||
|
)
|
||||||
|
build_filter_from_rule(rule_exclude)
|
||||||
|
|
||||||
|
rule_equal = FilterRule(match_type="equal", tags=[TagInfo(value="tag1")])
|
||||||
|
build_filter_from_rule(rule_equal)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
|
||||||
|
build_filter_from_rule(
|
||||||
|
FilterRule(match_type="invalid", tags=[TagInfo(value="tag1")]) # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_filters(create_annotation):
|
||||||
|
def filter1(annotation):
|
||||||
|
return "tag1" in [tag.value for tag in annotation.tags]
|
||||||
|
|
||||||
|
def filter2(annotation):
|
||||||
|
return "tag2" in [tag.value for tag in annotation.tags]
|
||||||
|
|
||||||
|
merged_filter = merge_filters(filter1, filter2)
|
||||||
|
|
||||||
|
annotation_pass = create_annotation(["tag1", "tag2"])
|
||||||
|
assert merged_filter(annotation_pass) is True
|
||||||
|
|
||||||
|
annotation_fail = create_annotation(["tag1"])
|
||||||
|
assert merged_filter(annotation_fail) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_filter_from_config(create_annotation):
|
||||||
|
config = FilterConfig(
|
||||||
|
rules=[
|
||||||
|
FilterRule(match_type="any", tags=[TagInfo(value="tag1")]),
|
||||||
|
FilterRule(match_type="any", tags=[TagInfo(value="tag2")]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
filter_from_config = build_filter_from_config(config)
|
||||||
|
|
||||||
|
annotation_pass = create_annotation(["tag1", "tag2"])
|
||||||
|
assert filter_from_config(annotation_pass)
|
||||||
|
|
||||||
|
annotation_fail = create_annotation(["tag1"])
|
||||||
|
assert not filter_from_config(annotation_fail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_filter_config(tmp_path: Path):
|
||||||
|
test_config_path = tmp_path / "filtering.yaml"
|
||||||
|
test_config_path.write_text(
|
||||||
|
"""
|
||||||
|
rules:
|
||||||
|
- match_type: any
|
||||||
|
tags:
|
||||||
|
- value: tag1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
config = load_filter_config(test_config_path)
|
||||||
|
assert isinstance(config, FilterConfig)
|
||||||
|
assert len(config.rules) == 1
|
||||||
|
rule = config.rules[0]
|
||||||
|
assert rule.match_type == "any"
|
||||||
|
assert len(rule.tags) == 1
|
||||||
|
assert rule.tags[0].value == "tag1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_filter_from_config(tmp_path: Path, create_annotation):
|
||||||
|
test_config_path = tmp_path / "filtering.yaml"
|
||||||
|
test_config_path.write_text(
|
||||||
|
"""
|
||||||
|
rules:
|
||||||
|
- match_type: any
|
||||||
|
tags:
|
||||||
|
- value: tag1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
filter_result = load_filter_from_config(test_config_path)
|
||||||
|
annotation = create_annotation(["tag1", "tag3"])
|
||||||
|
assert filter_result(annotation)
|
||||||
|
|
||||||
|
test_config_path = tmp_path / "filtering.yaml"
|
||||||
|
test_config_path.write_text(
|
||||||
|
"""
|
||||||
|
rules:
|
||||||
|
- match_type: any
|
||||||
|
tags:
|
||||||
|
- value: tag2
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
filter_result = load_filter_from_config(test_config_path)
|
||||||
|
annotation = create_annotation(["tag1", "tag3"])
|
||||||
|
assert filter_result(annotation) is False
|
Loading…
Reference in New Issue
Block a user