mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Update classes
This commit is contained in:
parent
a2ec190b73
commit
eda5f91c86
@ -10,12 +10,15 @@ from batdetect2.targets.classes import (
|
|||||||
DEFAULT_SPECIES_LIST,
|
DEFAULT_SPECIES_LIST,
|
||||||
ClassesConfig,
|
ClassesConfig,
|
||||||
TargetClass,
|
TargetClass,
|
||||||
build_encoder_from_config,
|
|
||||||
get_class_names_from_config,
|
|
||||||
_get_default_class_name,
|
_get_default_class_name,
|
||||||
_get_default_classes,
|
_get_default_classes,
|
||||||
_is_target_class,
|
_is_target_class,
|
||||||
|
build_decoder_from_config,
|
||||||
|
build_encoder_from_config,
|
||||||
|
build_generic_class_tags_from_config,
|
||||||
|
get_class_names_from_config,
|
||||||
load_classes_config,
|
load_classes_config,
|
||||||
|
load_decoder_from_config,
|
||||||
load_encoder_from_config,
|
load_encoder_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||||
@ -25,8 +28,9 @@ from batdetect2.targets.terms import TagInfo, TermRegistry
|
|||||||
def sample_term_registry() -> TermRegistry:
|
def sample_term_registry() -> TermRegistry:
|
||||||
"""Fixture for a sample TermRegistry."""
|
"""Fixture for a sample TermRegistry."""
|
||||||
registry = TermRegistry()
|
registry = TermRegistry()
|
||||||
|
registry.add_custom_term("order")
|
||||||
registry.add_custom_term("species")
|
registry.add_custom_term("species")
|
||||||
registry.add_custom_term("sound_type")
|
registry.add_custom_term("call_type")
|
||||||
registry.add_custom_term("quality")
|
registry.add_custom_term("quality")
|
||||||
return registry
|
return registry
|
||||||
|
|
||||||
@ -296,3 +300,113 @@ def test_get_default_classes():
|
|||||||
assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0])
|
assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0])
|
||||||
assert first_class.tags[0].key == "class"
|
assert first_class.tags[0].key == "class"
|
||||||
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|
assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_decoder_from_config(sample_term_registry: TermRegistry):
|
||||||
|
config = ClassesConfig(
|
||||||
|
classes=[
|
||||||
|
TargetClass(
|
||||||
|
name="pippip",
|
||||||
|
tags=[
|
||||||
|
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||||
|
],
|
||||||
|
output_tags=[TagInfo(key="call_type", value="Echolocation")],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||||
|
)
|
||||||
|
decoder = build_decoder_from_config(
|
||||||
|
config, term_registry=sample_term_registry
|
||||||
|
)
|
||||||
|
tags = decoder("pippip")
|
||||||
|
assert len(tags) == 1
|
||||||
|
assert tags[0].term == sample_term_registry["call_type"]
|
||||||
|
assert tags[0].value == "Echolocation"
|
||||||
|
|
||||||
|
# Test when output_tags is None, should fall back to tags
|
||||||
|
config = ClassesConfig(
|
||||||
|
classes=[
|
||||||
|
TargetClass(
|
||||||
|
name="pippip",
|
||||||
|
tags=[
|
||||||
|
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
generic_class=[TagInfo(key="order", value="Chiroptera")],
|
||||||
|
)
|
||||||
|
decoder = build_decoder_from_config(
|
||||||
|
config, term_registry=sample_term_registry
|
||||||
|
)
|
||||||
|
tags = decoder("pippip")
|
||||||
|
assert len(tags) == 1
|
||||||
|
assert tags[0].term == sample_term_registry["species"]
|
||||||
|
assert tags[0].value == "Pipistrellus pipistrellus"
|
||||||
|
|
||||||
|
# Test raise_on_unmapped=True
|
||||||
|
decoder = build_decoder_from_config(
|
||||||
|
config, term_registry=sample_term_registry, raise_on_unmapped=True
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
decoder("unknown_class")
|
||||||
|
|
||||||
|
# Test raise_on_unmapped=False
|
||||||
|
decoder = build_decoder_from_config(
|
||||||
|
config, term_registry=sample_term_registry, raise_on_unmapped=False
|
||||||
|
)
|
||||||
|
tags = decoder("unknown_class")
|
||||||
|
assert len(tags) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_decoder_from_config_valid(
|
||||||
|
create_temp_yaml: Callable[[str], Path],
|
||||||
|
sample_term_registry: TermRegistry,
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
output_tags:
|
||||||
|
- key: call_type
|
||||||
|
value: Echolocation
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
"""
|
||||||
|
temp_yaml_path = create_temp_yaml(yaml_content)
|
||||||
|
decoder = load_decoder_from_config(
|
||||||
|
temp_yaml_path, term_registry=sample_term_registry
|
||||||
|
)
|
||||||
|
tags = decoder("pippip")
|
||||||
|
assert len(tags) == 1
|
||||||
|
assert tags[0].term == sample_term_registry["call_type"]
|
||||||
|
assert tags[0].value == "Echolocation"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_generic_class_tags_from_config(
|
||||||
|
sample_term_registry: TermRegistry,
|
||||||
|
):
|
||||||
|
config = ClassesConfig(
|
||||||
|
classes=[
|
||||||
|
TargetClass(
|
||||||
|
name="pippip",
|
||||||
|
tags=[
|
||||||
|
TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
generic_class=[
|
||||||
|
TagInfo(key="order", value="Chiroptera"),
|
||||||
|
TagInfo(key="call_type", value="Echolocation"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
generic_tags = build_generic_class_tags_from_config(
|
||||||
|
config, term_registry=sample_term_registry
|
||||||
|
)
|
||||||
|
assert len(generic_tags) == 2
|
||||||
|
assert generic_tags[0].term == sample_term_registry["order"]
|
||||||
|
assert generic_tags[0].value == "Chiroptera"
|
||||||
|
assert generic_tags[1].term == sample_term_registry["call_type"]
|
||||||
|
assert generic_tags[1].value == "Echolocation"
|
||||||
|
Loading…
Reference in New Issue
Block a user