diff --git a/tests/test_targets/test_classes.py b/tests/test_targets/test_classes.py index e0ee8c8..69c30b5 100644 --- a/tests/test_targets/test_classes.py +++ b/tests/test_targets/test_classes.py @@ -10,12 +10,15 @@ from batdetect2.targets.classes import ( DEFAULT_SPECIES_LIST, ClassesConfig, TargetClass, - build_encoder_from_config, - get_class_names_from_config, _get_default_class_name, _get_default_classes, _is_target_class, + build_decoder_from_config, + build_encoder_from_config, + build_generic_class_tags_from_config, + get_class_names_from_config, load_classes_config, + load_decoder_from_config, load_encoder_from_config, ) from batdetect2.targets.terms import TagInfo, TermRegistry @@ -25,8 +28,9 @@ from batdetect2.targets.terms import TagInfo, TermRegistry def sample_term_registry() -> TermRegistry: """Fixture for a sample TermRegistry.""" registry = TermRegistry() + registry.add_custom_term("order") registry.add_custom_term("species") - registry.add_custom_term("sound_type") + registry.add_custom_term("call_type") registry.add_custom_term("quality") return registry @@ -296,3 +300,113 @@ def test_get_default_classes(): assert first_class.name == _get_default_class_name(DEFAULT_SPECIES_LIST[0]) assert first_class.tags[0].key == "class" assert first_class.tags[0].value == DEFAULT_SPECIES_LIST[0] + + +def test_build_decoder_from_config(sample_term_registry: TermRegistry): + config = ClassesConfig( + classes=[ + TargetClass( + name="pippip", + tags=[ + TagInfo(key="species", value="Pipistrellus pipistrellus") + ], + output_tags=[TagInfo(key="call_type", value="Echolocation")], + ) + ], + generic_class=[TagInfo(key="order", value="Chiroptera")], + ) + decoder = build_decoder_from_config( + config, term_registry=sample_term_registry + ) + tags = decoder("pippip") + assert len(tags) == 1 + assert tags[0].term == sample_term_registry["call_type"] + assert tags[0].value == "Echolocation" + + # Test when output_tags is None, should fall back to tags + config = ClassesConfig( + classes=[ + TargetClass( + name="pippip", + tags=[ + TagInfo(key="species", value="Pipistrellus pipistrellus") + ], + ) + ], + generic_class=[TagInfo(key="order", value="Chiroptera")], + ) + decoder = build_decoder_from_config( + config, term_registry=sample_term_registry + ) + tags = decoder("pippip") + assert len(tags) == 1 + assert tags[0].term == sample_term_registry["species"] + assert tags[0].value == "Pipistrellus pipistrellus" + + # Test raise_on_unmapped=True + decoder = build_decoder_from_config( + config, term_registry=sample_term_registry, raise_on_unmapped=True + ) + with pytest.raises(ValueError): + decoder("unknown_class") + + # Test raise_on_unmapped=False + decoder = build_decoder_from_config( + config, term_registry=sample_term_registry, raise_on_unmapped=False + ) + tags = decoder("unknown_class") + assert len(tags) == 0 + + +def test_load_decoder_from_config_valid( + create_temp_yaml: Callable[[str], Path], + sample_term_registry: TermRegistry, +): + yaml_content = """ + classes: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + output_tags: + - key: call_type + value: Echolocation + generic_class: + - key: order + value: Chiroptera + """ + temp_yaml_path = create_temp_yaml(yaml_content) + decoder = load_decoder_from_config( + temp_yaml_path, term_registry=sample_term_registry + ) + tags = decoder("pippip") + assert len(tags) == 1 + assert tags[0].term == sample_term_registry["call_type"] + assert tags[0].value == "Echolocation" + + +def test_build_generic_class_tags_from_config( + sample_term_registry: TermRegistry, +): + config = ClassesConfig( + classes=[ + TargetClass( + name="pippip", + tags=[ + TagInfo(key="species", value="Pipistrellus pipistrellus") + ], + ) + ], + generic_class=[ + TagInfo(key="order", value="Chiroptera"), + TagInfo(key="call_type", value="Echolocation"), + ], + ) + generic_tags = build_generic_class_tags_from_config( + config, term_registry=sample_term_registry + ) + assert len(generic_tags) == 2 + assert generic_tags[0].term == sample_term_registry["order"] + assert generic_tags[0].value == "Chiroptera" + assert generic_tags[1].term == sample_term_registry["call_type"] + assert generic_tags[1].value == "Echolocation"