Add tests for target.classes

This commit is contained in:
mbsantiago 2025-04-15 18:22:19 +01:00
parent af48c33307
commit 62471664fa

View File

@ -0,0 +1,281 @@
from pathlib import Path
from typing import Callable
from uuid import uuid4
import pytest
from pydantic import ValidationError
from soundevent import data
from batdetect2.targets.classes import (
ClassesConfig,
TargetClass,
build_encoder_from_config,
get_class_names_from_config,
is_target_class,
load_classes_config,
load_encoder_from_config,
)
from batdetect2.targets.terms import TagInfo, TermRegistry
@pytest.fixture
def sample_term_registry() -> TermRegistry:
"""Fixture for a sample TermRegistry."""
registry = TermRegistry()
registry.add_custom_term("species")
registry.add_custom_term("sound_type")
registry.add_custom_term("quality")
return registry
@pytest.fixture
def sample_annotation(
sound_event: data.SoundEvent,
sample_term_registry: TermRegistry,
) -> data.SoundEventAnnotation:
"""Fixture for a sample SoundEventAnnotation."""
return data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(
term=sample_term_registry.get_term("species"),
value="Pipistrellus pipistrellus",
),
data.Tag(
term=sample_term_registry.get_term("quality"),
value="Good",
),
],
)
@pytest.fixture
def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
"""Create a temporary YAML file with the given content."""
def factory(content: str) -> Path:
temp_file = tmp_path / f"{uuid4()}.yaml"
temp_file.write_text(content)
return temp_file
return factory
def test_target_class_creation():
target_class = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
assert target_class.name == "pippip"
assert target_class.tags[0].key == "species"
assert target_class.tags[0].value == "Pipistrellus pipistrellus"
assert target_class.match_type == "all"
def test_classes_config_creation():
target_class = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
config = ClassesConfig(classes=[target_class])
assert len(config.classes) == 1
assert config.classes[0].name == "pippip"
def test_classes_config_unique_names():
target_class1 = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
target_class2 = TargetClass(
name="myodau",
tags=[TagInfo(key="species", value="Myotis daubentonii")],
)
ClassesConfig(classes=[target_class1, target_class2]) # No error
def test_classes_config_non_unique_names():
target_class1 = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
target_class2 = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Myotis daubentonii")],
)
with pytest.raises(ValidationError):
ClassesConfig(classes=[target_class1, target_class2])
def test_load_classes_config_valid(create_temp_yaml: Callable[[str], Path]):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
config = load_classes_config(temp_yaml_path)
assert len(config.classes) == 1
assert config.classes[0].name == "pippip"
def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: pippip
tags:
- key: species
value: Myotis daubentonii
"""
temp_yaml_path = create_temp_yaml(yaml_content)
with pytest.raises(ValidationError):
load_classes_config(temp_yaml_path)
def test_is_target_class_match_all(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
):
tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
),
data.Tag(term=sample_term_registry["quality"], value="Good"),
}
assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
)
}
assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = {
data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii"
)
}
assert is_target_class(sample_annotation, tags, match_all=True) is False
def test_is_target_class_match_any(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
):
tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
),
data.Tag(term=sample_term_registry["quality"], value="Good"),
}
assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {
data.Tag(
term=sample_term_registry["species"],
value="Pipistrellus pipistrellus",
)
}
assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = {
data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii"
)
}
assert is_target_class(sample_annotation, tags, match_all=False) is False
def test_get_class_names_from_config():
target_class1 = TargetClass(
name="pippip",
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
)
target_class2 = TargetClass(
name="myodau",
tags=[TagInfo(key="species", value="Myotis daubentonii")],
)
config = ClassesConfig(classes=[target_class1, target_class2])
names = get_class_names_from_config(config)
assert names == ["pippip", "myodau"]
def test_build_encoder_from_config(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
):
config = ClassesConfig(
classes=[
TargetClass(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
)
]
)
encoder = build_encoder_from_config(
config,
term_registry=sample_term_registry,
)
result = encoder(sample_annotation)
assert result == "pippip"
config = ClassesConfig(classes=[])
encoder = build_encoder_from_config(
config,
term_registry=sample_term_registry,
)
result = encoder(sample_annotation)
assert result is None
def test_load_encoder_from_config_valid(
sample_annotation: data.SoundEventAnnotation,
sample_term_registry: TermRegistry,
create_temp_yaml: Callable[[str], Path],
):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
encoder = load_encoder_from_config(
temp_yaml_path,
term_registry=sample_term_registry,
)
# We cannot directly compare the function, so we test it.
result = encoder(sample_annotation) # type: ignore
assert result == "pippip"
def test_load_encoder_from_config_invalid(
create_temp_yaml: Callable[[str], Path],
sample_term_registry: TermRegistry,
):
yaml_content = """
classes:
- name: pippip
tags:
- key: invalid_key
value: Pipistrellus pipistrellus
"""
temp_yaml_path = create_temp_yaml(yaml_content)
with pytest.raises(KeyError):
load_encoder_from_config(
temp_yaml_path,
term_registry=sample_term_registry,
)