mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Add tests for target.classes
This commit is contained in:
parent
af48c33307
commit
62471664fa
281
tests/test_targets/test_classes.py
Normal file
281
tests/test_targets/test_classes.py
Normal file
@ -0,0 +1,281 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets.classes import (
|
||||
ClassesConfig,
|
||||
TargetClass,
|
||||
build_encoder_from_config,
|
||||
get_class_names_from_config,
|
||||
is_target_class,
|
||||
load_classes_config,
|
||||
load_encoder_from_config,
|
||||
)
|
||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_term_registry() -> TermRegistry:
|
||||
"""Fixture for a sample TermRegistry."""
|
||||
registry = TermRegistry()
|
||||
registry.add_custom_term("species")
|
||||
registry.add_custom_term("sound_type")
|
||||
registry.add_custom_term("quality")
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_annotation(
|
||||
sound_event: data.SoundEvent,
|
||||
sample_term_registry: TermRegistry,
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Fixture for a sample SoundEventAnnotation."""
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=sample_term_registry.get_term("species"),
|
||||
value="Pipistrellus pipistrellus",
|
||||
),
|
||||
data.Tag(
|
||||
term=sample_term_registry.get_term("quality"),
|
||||
value="Good",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
||||
"""Create a temporary YAML file with the given content."""
|
||||
|
||||
def factory(content: str) -> Path:
|
||||
temp_file = tmp_path / f"{uuid4()}.yaml"
|
||||
temp_file.write_text(content)
|
||||
return temp_file
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
def test_target_class_creation():
|
||||
target_class = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
assert target_class.name == "pippip"
|
||||
assert target_class.tags[0].key == "species"
|
||||
assert target_class.tags[0].value == "Pipistrellus pipistrellus"
|
||||
assert target_class.match_type == "all"
|
||||
|
||||
|
||||
def test_classes_config_creation():
|
||||
target_class = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
config = ClassesConfig(classes=[target_class])
|
||||
assert len(config.classes) == 1
|
||||
assert config.classes[0].name == "pippip"
|
||||
|
||||
|
||||
def test_classes_config_unique_names():
|
||||
target_class1 = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
target_class2 = TargetClass(
|
||||
name="myodau",
|
||||
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
||||
)
|
||||
ClassesConfig(classes=[target_class1, target_class2]) # No error
|
||||
|
||||
|
||||
def test_classes_config_non_unique_names():
|
||||
target_class1 = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
target_class2 = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
ClassesConfig(classes=[target_class1, target_class2])
|
||||
|
||||
|
||||
def test_load_classes_config_valid(create_temp_yaml: Callable[[str], Path]):
|
||||
yaml_content = """
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
config = load_classes_config(temp_yaml_path)
|
||||
assert len(config.classes) == 1
|
||||
assert config.classes[0].name == "pippip"
|
||||
|
||||
|
||||
def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
|
||||
yaml_content = """
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis daubentonii
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
with pytest.raises(ValidationError):
|
||||
load_classes_config(temp_yaml_path)
|
||||
|
||||
|
||||
def test_is_target_class_match_all(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
),
|
||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
)
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
||||
)
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is False
|
||||
|
||||
|
||||
def test_is_target_class_match_any(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
),
|
||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"],
|
||||
value="Pipistrellus pipistrellus",
|
||||
)
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
||||
)
|
||||
}
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is False
|
||||
|
||||
|
||||
def test_get_class_names_from_config():
|
||||
target_class1 = TargetClass(
|
||||
name="pippip",
|
||||
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
target_class2 = TargetClass(
|
||||
name="myodau",
|
||||
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
||||
)
|
||||
config = ClassesConfig(classes=[target_class1, target_class2])
|
||||
names = get_class_names_from_config(config)
|
||||
assert names == ["pippip", "myodau"]
|
||||
|
||||
|
||||
def test_build_encoder_from_config(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
name="pippip",
|
||||
tags=[
|
||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
encoder = build_encoder_from_config(
|
||||
config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
result = encoder(sample_annotation)
|
||||
assert result == "pippip"
|
||||
|
||||
config = ClassesConfig(classes=[])
|
||||
encoder = build_encoder_from_config(
|
||||
config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
result = encoder(sample_annotation)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_load_encoder_from_config_valid(
|
||||
sample_annotation: data.SoundEventAnnotation,
|
||||
sample_term_registry: TermRegistry,
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
yaml_content = """
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
encoder = load_encoder_from_config(
|
||||
temp_yaml_path,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
# We cannot directly compare the function, so we test it.
|
||||
result = encoder(sample_annotation) # type: ignore
|
||||
assert result == "pippip"
|
||||
|
||||
|
||||
def test_load_encoder_from_config_invalid(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
yaml_content = """
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: invalid_key
|
||||
value: Pipistrellus pipistrellus
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
with pytest.raises(KeyError):
|
||||
load_encoder_from_config(
|
||||
temp_yaml_path,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
Loading…
Reference in New Issue
Block a user