Update classes

This commit is contained in:
mbsantiago 2025-04-16 00:01:37 +01:00
parent a2ec190b73
commit eda5f91c86

View File

@ -10,12 +10,15 @@ from batdetect2.targets.classes import (
DEFAULT_SPECIES_LIST,
ClassesConfig,
TargetClass,
build_encoder_from_config,
get_class_names_from_config,
_get_default_class_name,
_get_default_classes,
_is_target_class,
build_decoder_from_config,
build_encoder_from_config,
build_generic_class_tags_from_config,
get_class_names_from_config,
load_classes_config,
load_decoder_from_config,
load_encoder_from_config,
)
from batdetect2.targets.terms import TagInfo, TermRegistry
@ -25,8 +28,9 @@ from batdetect2.targets.terms import TagInfo, TermRegistry
def sample_term_registry() -> TermRegistry:
"""Fixture for a sample TermRegistry."""
registry = TermRegistry()
registry.add_custom_term("order")
registry.add_custom_term("species")
registry.add_custom_term("sound_type")
registry.add_custom_term("call_type")
registry.add_custom_term("quality")
return registry
@ -296,3 +300,113 @@ def test_get_default_classes():
assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0])
assert first_class.tags[0].key == "class"
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
def test_build_decoder_from_config(sample_term_registry: TermRegistry):
config = ClassesConfig(
classes=[
TargetClass(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
output_tags=[TagInfo(key="call_type", value="Echolocation")],
)
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_decoder_from_config(
config, term_registry=sample_term_registry
)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == sample_term_registry["call_type"]
assert tags[0].value == "Echolocation"
# Test when output_tags is None, should fall back to tags
config = ClassesConfig(
classes=[
TargetClass(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
)
],
generic_class=[TagInfo(key="order", value="Chiroptera")],
)
decoder = build_decoder_from_config(
config, term_registry=sample_term_registry
)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == sample_term_registry["species"]
assert tags[0].value == "Pipistrellus pipistrellus"
# Test raise_on_unmapped=True
decoder = build_decoder_from_config(
config, term_registry=sample_term_registry, raise_on_unmapped=True
)
with pytest.raises(ValueError):
decoder("unknown_class")
# Test raise_on_unmapped=False
decoder = build_decoder_from_config(
config, term_registry=sample_term_registry, raise_on_unmapped=False
)
tags = decoder("unknown_class")
assert len(tags) == 0
def test_load_decoder_from_config_valid(
create_temp_yaml: Callable[[str], Path],
sample_term_registry: TermRegistry,
):
yaml_content = """
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
output_tags:
- key: call_type
value: Echolocation
generic_class:
- key: order
value: Chiroptera
"""
temp_yaml_path = create_temp_yaml(yaml_content)
decoder = load_decoder_from_config(
temp_yaml_path, term_registry=sample_term_registry
)
tags = decoder("pippip")
assert len(tags) == 1
assert tags[0].term == sample_term_registry["call_type"]
assert tags[0].value == "Echolocation"
def test_build_generic_class_tags_from_config(
sample_term_registry: TermRegistry,
):
config = ClassesConfig(
classes=[
TargetClass(
name="pippip",
tags=[
TagInfo(key="species", value="Pipistrellus pipistrellus")
],
)
],
generic_class=[
TagInfo(key="order", value="Chiroptera"),
TagInfo(key="call_type", value="Echolocation"),
],
)
generic_tags = build_generic_class_tags_from_config(
config, term_registry=sample_term_registry
)
assert len(generic_tags) == 2
assert generic_tags[0].term == sample_term_registry["order"]
assert generic_tags[0].value == "Chiroptera"
assert generic_tags[1].term == sample_term_registry["call_type"]
assert generic_tags[1].value == "Echolocation"