From 62471664fa95e2c2ce31985620f108f936998bd5 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 15 Apr 2025 18:22:19 +0100 Subject: [PATCH] Add tests for target.classes --- tests/test_targets/test_classes.py | 281 +++++++++++++++++++++++++++++ 1 file changed, 281 insertions(+) create mode 100644 tests/test_targets/test_classes.py diff --git a/tests/test_targets/test_classes.py b/tests/test_targets/test_classes.py new file mode 100644 index 0000000..6bf8cae --- /dev/null +++ b/tests/test_targets/test_classes.py @@ -0,0 +1,281 @@ +from pathlib import Path +from typing import Callable +from uuid import uuid4 + +import pytest +from pydantic import ValidationError +from soundevent import data + +from batdetect2.targets.classes import ( + ClassesConfig, + TargetClass, + build_encoder_from_config, + get_class_names_from_config, + is_target_class, + load_classes_config, + load_encoder_from_config, +) +from batdetect2.targets.terms import TagInfo, TermRegistry + + +@pytest.fixture +def sample_term_registry() -> TermRegistry: + """Fixture for a sample TermRegistry.""" + registry = TermRegistry() + registry.add_custom_term("species") + registry.add_custom_term("sound_type") + registry.add_custom_term("quality") + return registry + + +@pytest.fixture +def sample_annotation( + sound_event: data.SoundEvent, + sample_term_registry: TermRegistry, +) -> data.SoundEventAnnotation: + """Fixture for a sample SoundEventAnnotation.""" + return data.SoundEventAnnotation( + sound_event=sound_event, + tags=[ + data.Tag( + term=sample_term_registry.get_term("species"), + value="Pipistrellus pipistrellus", + ), + data.Tag( + term=sample_term_registry.get_term("quality"), + value="Good", + ), + ], + ) + + +@pytest.fixture +def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]: + """Create a temporary YAML file with the given content.""" + + def factory(content: str) -> Path: + temp_file = tmp_path / f"{uuid4()}.yaml" + temp_file.write_text(content) + return temp_file + + return factory + + +def test_target_class_creation(): + target_class = TargetClass( + name="pippip", + tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")], + ) + assert target_class.name == "pippip" + assert target_class.tags[0].key == "species" + assert target_class.tags[0].value == "Pipistrellus pipistrellus" + assert target_class.match_type == "all" + + +def test_classes_config_creation(): + target_class = TargetClass( + name="pippip", + tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")], + ) + config = ClassesConfig(classes=[target_class]) + assert len(config.classes) == 1 + assert config.classes[0].name == "pippip" + + +def test_classes_config_unique_names(): + target_class1 = TargetClass( + name="pippip", + tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")], + ) + target_class2 = TargetClass( + name="myodau", + tags=[TagInfo(key="species", value="Myotis daubentonii")], + ) + ClassesConfig(classes=[target_class1, target_class2]) # No error + + +def test_classes_config_non_unique_names(): + target_class1 = TargetClass( + name="pippip", + tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")], + ) + target_class2 = TargetClass( + name="pippip", + tags=[TagInfo(key="species", value="Myotis daubentonii")], + ) + with pytest.raises(ValidationError): + ClassesConfig(classes=[target_class1, target_class2]) + + +def test_load_classes_config_valid(create_temp_yaml: Callable[[str], Path]): + yaml_content = """ + classes: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + """ + temp_yaml_path = create_temp_yaml(yaml_content) + config = load_classes_config(temp_yaml_path) + assert len(config.classes) == 1 + assert config.classes[0].name == "pippip" + + +def test_load_classes_config_invalid(create_temp_yaml: Callable[[str], Path]): + yaml_content = """ + classes: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + - name: pippip + tags: + - key: species + value: Myotis daubentonii + """ + temp_yaml_path = create_temp_yaml(yaml_content) + with pytest.raises(ValidationError): + load_classes_config(temp_yaml_path) + + +def test_is_target_class_match_all( + sample_annotation: data.SoundEventAnnotation, + sample_term_registry: TermRegistry, +): + tags = { + data.Tag( + term=sample_term_registry["species"], + value="Pipistrellus pipistrellus", + ), + data.Tag(term=sample_term_registry["quality"], value="Good"), + } + assert is_target_class(sample_annotation, tags, match_all=True) is True + + tags = { + data.Tag( + term=sample_term_registry["species"], + value="Pipistrellus pipistrellus", + ) + } + assert is_target_class(sample_annotation, tags, match_all=True) is True + + tags = { + data.Tag( + term=sample_term_registry["species"], value="Myotis daubentonii" + ) + } + assert is_target_class(sample_annotation, tags, match_all=True) is False + + +def test_is_target_class_match_any( + sample_annotation: data.SoundEventAnnotation, + sample_term_registry: TermRegistry, +): + tags = { + data.Tag( + term=sample_term_registry["species"], + value="Pipistrellus pipistrellus", + ), + data.Tag(term=sample_term_registry["quality"], value="Good"), + } + assert is_target_class(sample_annotation, tags, match_all=False) is True + + tags = { + data.Tag( + term=sample_term_registry["species"], + value="Pipistrellus pipistrellus", + ) + } + assert is_target_class(sample_annotation, tags, match_all=False) is True + + tags = { + data.Tag( + term=sample_term_registry["species"], value="Myotis daubentonii" + ) + } + assert is_target_class(sample_annotation, tags, match_all=False) is False + + +def test_get_class_names_from_config(): + target_class1 = TargetClass( + name="pippip", + tags=[TagInfo(key="species", value="Pipistrellus pipistrellus")], + ) + target_class2 = TargetClass( + name="myodau", + tags=[TagInfo(key="species", value="Myotis daubentonii")], + ) + config = ClassesConfig(classes=[target_class1, target_class2]) + names = get_class_names_from_config(config) + assert names == ["pippip", "myodau"] + + +def test_build_encoder_from_config( + sample_annotation: data.SoundEventAnnotation, + sample_term_registry: TermRegistry, +): + config = ClassesConfig( + classes=[ + TargetClass( + name="pippip", + tags=[ + TagInfo(key="species", value="Pipistrellus pipistrellus") + ], + ) + ] + ) + encoder = build_encoder_from_config( + config, + term_registry=sample_term_registry, + ) + result = encoder(sample_annotation) + assert result == "pippip" + + config = ClassesConfig(classes=[]) + encoder = build_encoder_from_config( + config, + term_registry=sample_term_registry, + ) + result = encoder(sample_annotation) + assert result is None + + +def test_load_encoder_from_config_valid( + sample_annotation: data.SoundEventAnnotation, + sample_term_registry: TermRegistry, + create_temp_yaml: Callable[[str], Path], +): + yaml_content = """ + classes: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + """ + temp_yaml_path = create_temp_yaml(yaml_content) + encoder = load_encoder_from_config( + temp_yaml_path, + term_registry=sample_term_registry, + ) + # We cannot directly compare the function, so we test it. + result = encoder(sample_annotation) # type: ignore + assert result == "pippip" + + +def test_load_encoder_from_config_invalid( + create_temp_yaml: Callable[[str], Path], + sample_term_registry: TermRegistry, +): + yaml_content = """ + classes: + - name: pippip + tags: + - key: invalid_key + value: Pipistrellus pipistrellus + """ + temp_yaml_path = create_temp_yaml(yaml_content) + with pytest.raises(KeyError): + load_encoder_from_config( + temp_yaml_path, + term_registry=sample_term_registry, + )