mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Update classes
This commit is contained in:
parent
a2ec190b73
commit
eda5f91c86
@ -10,12 +10,15 @@ from batdetect2.targets.classes import (
|
||||
DEFAULT_SPECIES_LIST,
|
||||
ClassesConfig,
|
||||
TargetClass,
|
||||
build_encoder_from_config,
|
||||
get_class_names_from_config,
|
||||
_get_default_class_name,
|
||||
_get_default_classes,
|
||||
_is_target_class,
|
||||
build_decoder_from_config,
|
||||
build_encoder_from_config,
|
||||
build_generic_class_tags_from_config,
|
||||
get_class_names_from_config,
|
||||
load_classes_config,
|
||||
load_decoder_from_config,
|
||||
load_encoder_from_config,
|
||||
)
|
||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||
@ -25,8 +28,9 @@ from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||
def sample_term_registry() -> TermRegistry:
|
||||
"""Fixture for a sample TermRegistry."""
|
||||
registry = TermRegistry()
|
||||
registry.add_custom_term("order")
|
||||
registry.add_custom_term("species")
|
||||
registry.add_custom_term("sound_type")
|
||||
registry.add_custom_term("call_type")
|
||||
registry.add_custom_term("quality")
|
||||
return registry
|
||||
|
||||
@ -296,3 +300,113 @@ def test_get_default_classes():
|
||||
assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0])
|
||||
assert first_class.tags[0].key == "class"
|
||||
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|
||||
|
||||
|
||||
def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
name="pippip",
|
||||
tags=[
|
||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
],
|
||||
output_tags=[TagInfo(key="call_type", value="Echolocation")],
|
||||
)
|
||||
],
|
||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||
)
|
||||
decoder = build_decoder_from_config(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == sample_term_registry["call_type"]
|
||||
assert tags[0].value == "Echolocation"
|
||||
|
||||
# Test when output_tags is None, should fall back to tags
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
name="pippip",
|
||||
tags=[
|
||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
],
|
||||
)
|
||||
],
|
||||
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||
)
|
||||
decoder = build_decoder_from_config(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == sample_term_registry["species"]
|
||||
assert tags[0].value == "Pipistrellus pipistrellus"
|
||||
|
||||
# Test raise_on_unmapped=True
|
||||
decoder = build_decoder_from_config(
|
||||
config, term_registry=sample_term_registry, raise_on_unmapped=True
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
decoder("unknown_class")
|
||||
|
||||
# Test raise_on_unmapped=False
|
||||
decoder = build_decoder_from_config(
|
||||
config, term_registry=sample_term_registry, raise_on_unmapped=False
|
||||
)
|
||||
tags = decoder("unknown_class")
|
||||
assert len(tags) == 0
|
||||
|
||||
|
||||
def test_load_decoder_from_config_valid(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
yaml_content = """
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
output_tags:
|
||||
- key: call_type
|
||||
value: Echolocation
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
"""
|
||||
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||
decoder = load_decoder_from_config(
|
||||
temp_yaml_path, term_registry=sample_term_registry
|
||||
)
|
||||
tags = decoder("pippip")
|
||||
assert len(tags) == 1
|
||||
assert tags[0].term == sample_term_registry["call_type"]
|
||||
assert tags[0].value == "Echolocation"
|
||||
|
||||
|
||||
def test_build_generic_class_tags_from_config(
|
||||
sample_term_registry: TermRegistry,
|
||||
):
|
||||
config = ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
name="pippip",
|
||||
tags=[
|
||||
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
],
|
||||
)
|
||||
],
|
||||
generic_class=[
|
||||
TagInfo(key="order", value="Chiroptera"),
|
||||
TagInfo(key="call_type", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
generic_tags = build_generic_class_tags_from_config(
|
||||
config, term_registry=sample_term_registry
|
||||
)
|
||||
assert len(generic_tags) == 2
|
||||
assert generic_tags[0].term == sample_term_registry["order"]
|
||||
assert generic_tags[0].value == "Chiroptera"
|
||||
assert generic_tags[1].term == sample_term_registry["call_type"]
|
||||
assert generic_tags[1].value == "Echolocation"
|
||||
|
Loading…
Reference in New Issue
Block a user