mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
299 lines
8.4 KiB
Python
299 lines
8.4 KiB
Python
from pathlib import Path
|
|
from typing import Callable
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
from soundevent import data
|
|
|
|
from batdetect2.targets.classes import (
|
|
DEFAULT_SPECIES_LIST,
|
|
ClassesConfig,
|
|
TargetClass,
|
|
build_encoder_from_config,
|
|
get_class_names_from_config,
|
|
_get_default_class_name,
|
|
_get_default_classes,
|
|
_is_target_class,
|
|
load_classes_config,
|
|
load_encoder_from_config,
|
|
)
|
|
from batdetect2.targets.terms import TagInfo, TermRegistry
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_term_registry() -> TermRegistry:
|
|
"""Fixture for a sample TermRegistry."""
|
|
registry = TermRegistry()
|
|
registry.add_custom_term("species")
|
|
registry.add_custom_term("sound_type")
|
|
registry.add_custom_term("quality")
|
|
return registry
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_annotation(
|
|
sound_event: data.SoundEvent,
|
|
sample_term_registry: TermRegistry,
|
|
) -> data.SoundEventAnnotation:
|
|
"""Fixture for a sample SoundEventAnnotation."""
|
|
return data.SoundEventAnnotation(
|
|
sound_event=sound_event,
|
|
tags=[
|
|
data.Tag(
|
|
term=sample_term_registry.get_term("species"),
|
|
value="Pipistrellus pipistrellus",
|
|
),
|
|
data.Tag(
|
|
term=sample_term_registry.get_term("quality"),
|
|
value="Good",
|
|
),
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
|
"""Create a temporary YAML file with the given content."""
|
|
|
|
def factory(content: str) -> Path:
|
|
temp_file = tmp_path / f"{uuid4()}.yaml"
|
|
temp_file.write_text(content)
|
|
return temp_file
|
|
|
|
return factory
|
|
|
|
|
|
def test_target_class_creation():
|
|
target_class = TargetClass(
|
|
name="pippip",
|
|
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
|
)
|
|
assert target_class.name == "pippip"
|
|
assert target_class.tags[0].key == "species"
|
|
assert target_class.tags[0].value == "Pipistrellus pipistrellus"
|
|
assert target_class.match_type == "all"
|
|
|
|
|
|
def test_classes_config_creation():
|
|
target_class = TargetClass(
|
|
name="pippip",
|
|
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
|
)
|
|
config = ClassesConfig(classes=[target_class])
|
|
assert len(config.classes) == 1
|
|
assert config.classes[0].name == "pippip"
|
|
|
|
|
|
def test_classes_config_unique_names():
|
|
target_class1 = TargetClass(
|
|
name="pippip",
|
|
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
|
)
|
|
target_class2 = TargetClass(
|
|
name="myodau",
|
|
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
|
)
|
|
ClassesConfig(classes=[target_class1, target_class2]) # No error
|
|
|
|
|
|
def test_classes_config_non_unique_names():
|
|
target_class1 = TargetClass(
|
|
name="pippip",
|
|
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
|
)
|
|
target_class2 = TargetClass(
|
|
name="pippip",
|
|
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
|
)
|
|
with pytest.raises(ValidationError):
|
|
ClassesConfig(classes=[target_class1, target_class2])
|
|
|
|
|
|
def test_load_classes_config_valid(create_temp_yaml: Callable[[str], Path]):
|
|
yaml_content = """
|
|
classes:
|
|
- name: pippip
|
|
tags:
|
|
- key: species
|
|
value: Pipistrellus pipistrellus
|
|
"""
|
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
|
config = load_classes_config(temp_yaml_path)
|
|
assert len(config.classes) == 1
|
|
assert config.classes[0].name == "pippip"
|
|
|
|
|
|
def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]):
|
|
yaml_content = """
|
|
classes:
|
|
- name: pippip
|
|
tags:
|
|
- key: species
|
|
value: Pipistrellus pipistrellus
|
|
- name: pippip
|
|
tags:
|
|
- key: species
|
|
value: Myotis daubentonii
|
|
"""
|
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
|
with pytest.raises(ValidationError):
|
|
load_classes_config(temp_yaml_path)
|
|
|
|
|
|
def test_is_target_class_match_all(
|
|
sample_annotation: data.SoundEventAnnotation,
|
|
sample_term_registry: TermRegistry,
|
|
):
|
|
tags = {
|
|
data.Tag(
|
|
term=sample_term_registry["species"],
|
|
value="Pipistrellus pipistrellus",
|
|
),
|
|
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
|
}
|
|
assert _is_target_class(sample_annotation, tags, match_all=True) is True
|
|
|
|
tags = {
|
|
data.Tag(
|
|
term=sample_term_registry["species"],
|
|
value="Pipistrellus pipistrellus",
|
|
)
|
|
}
|
|
assert _is_target_class(sample_annotation, tags, match_all=True) is True
|
|
|
|
tags = {
|
|
data.Tag(
|
|
term=sample_term_registry["species"], value="Myotis daubentonii"
|
|
)
|
|
}
|
|
assert _is_target_class(sample_annotation, tags, match_all=True) is False
|
|
|
|
|
|
def test_is_target_class_match_any(
|
|
sample_annotation: data.SoundEventAnnotation,
|
|
sample_term_registry: TermRegistry,
|
|
):
|
|
tags = {
|
|
data.Tag(
|
|
term=sample_term_registry["species"],
|
|
value="Pipistrellus pipistrellus",
|
|
),
|
|
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
|
}
|
|
assert _is_target_class(sample_annotation, tags, match_all=False) is True
|
|
|
|
tags = {
|
|
data.Tag(
|
|
term=sample_term_registry["species"],
|
|
value="Pipistrellus pipistrellus",
|
|
)
|
|
}
|
|
assert _is_target_class(sample_annotation, tags, match_all=False) is True
|
|
|
|
tags = {
|
|
data.Tag(
|
|
term=sample_term_registry["species"], value="Myotis daubentonii"
|
|
)
|
|
}
|
|
assert _is_target_class(sample_annotation, tags, match_all=False) is False
|
|
|
|
|
|
def test_get_class_names_from_config():
|
|
target_class1 = TargetClass(
|
|
name="pippip",
|
|
tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")],
|
|
)
|
|
target_class2 = TargetClass(
|
|
name="myodau",
|
|
tags=[TagInfo(key="species", value="Myotis daubentonii")],
|
|
)
|
|
config = ClassesConfig(classes=[target_class1, target_class2])
|
|
names = get_class_names_from_config(config)
|
|
assert names == ["pippip", "myodau"]
|
|
|
|
|
|
def test_build_encoder_from_config(
|
|
sample_annotation: data.SoundEventAnnotation,
|
|
sample_term_registry: TermRegistry,
|
|
):
|
|
config = ClassesConfig(
|
|
classes=[
|
|
TargetClass(
|
|
name="pippip",
|
|
tags=[
|
|
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
|
],
|
|
)
|
|
]
|
|
)
|
|
encoder = build_encoder_from_config(
|
|
config,
|
|
term_registry=sample_term_registry,
|
|
)
|
|
result = encoder(sample_annotation)
|
|
assert result == "pippip"
|
|
|
|
config = ClassesConfig(classes=[])
|
|
encoder = build_encoder_from_config(
|
|
config,
|
|
term_registry=sample_term_registry,
|
|
)
|
|
result = encoder(sample_annotation)
|
|
assert result is None
|
|
|
|
|
|
def test_load_encoder_from_config_valid(
|
|
sample_annotation: data.SoundEventAnnotation,
|
|
sample_term_registry: TermRegistry,
|
|
create_temp_yaml: Callable[[str], Path],
|
|
):
|
|
yaml_content = """
|
|
classes:
|
|
- name: pippip
|
|
tags:
|
|
- key: species
|
|
value: Pipistrellus pipistrellus
|
|
"""
|
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
|
encoder = load_encoder_from_config(
|
|
temp_yaml_path,
|
|
term_registry=sample_term_registry,
|
|
)
|
|
# We cannot directly compare the function, so we test it.
|
|
result = encoder(sample_annotation) # type: ignore
|
|
assert result == "pippip"
|
|
|
|
|
|
def test_load_encoder_from_config_invalid(
|
|
create_temp_yaml: Callable[[str], Path],
|
|
sample_term_registry: TermRegistry,
|
|
):
|
|
yaml_content = """
|
|
classes:
|
|
- name: pippip
|
|
tags:
|
|
- key: invalid_key
|
|
value: Pipistrellus pipistrellus
|
|
"""
|
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
|
with pytest.raises(KeyError):
|
|
load_encoder_from_config(
|
|
temp_yaml_path,
|
|
term_registry=sample_term_registry,
|
|
)
|
|
|
|
|
|
def test_get_default_class_name():
|
|
assert _get_default_class_name("Myotis daubentonii") == "myodau"
|
|
|
|
|
|
def test_get_default_classes():
|
|
default_classes = _get_default_classes()
|
|
assert len(default_classes) == len(DEFAULT_SPECIES_LIST)
|
|
first_class = default_classes[0]
|
|
assert isinstance(first_class, TargetClass)
|
|
assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0])
|
|
assert first_class.tags[0].key == "class"
|
|
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|