From ec4fa6c26255ae5648e9e0f1087fc6701ff378c9 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 6 May 2026 19:27:00 +0100 Subject: [PATCH] feat: support configurable batdetect2 class labels --- src/batdetect2/outputs/formats/batdetect2.py | 39 ++++++++-- tests/test_api_v2/test_outputs_io.py | 79 ++++++++++++++++++++ 2 files changed, 111 insertions(+), 7 deletions(-) diff --git a/src/batdetect2/outputs/formats/batdetect2.py b/src/batdetect2/outputs/formats/batdetect2.py index 4db5f11..d913ea6 100644 --- a/src/batdetect2/outputs/formats/batdetect2.py +++ b/src/batdetect2/outputs/formats/batdetect2.py @@ -1,11 +1,10 @@ import json from pathlib import Path -from typing import List, Literal, Sequence, TypedDict +from typing import List, Literal, Sequence, TypedDict, cast import numpy as np import pandas as pd from soundevent import data -from soundevent import terms as soundevent_terms from soundevent.geometry import compute_bounds from batdetect2.core import BaseConfig @@ -54,6 +53,9 @@ class BatDetect2OutputConfig(BaseConfig): event_name: str = "Echolocation" annotation_note: str = "Automatically generated." + class_label_mode: Literal["class_name", "decoded_tag"] = "decoded_tag" + decoded_label_key: str = "dwc:scientificName" + fallback_to_class_name: bool = True write_detection_csv: bool = True write_cnn_features_csv: bool = False save_if_empty: bool = False @@ -67,6 +69,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): targets: TargetProtocol, event_name: str, annotation_note: str, + class_label_mode: Literal["class_name", "decoded_tag"] = "decoded_tag", + decoded_label_key: str = "dwc:scientificName", + fallback_to_class_name: bool = True, write_detection_csv: bool = True, write_cnn_features_csv: bool = False, save_if_empty: bool = False, @@ -76,6 +81,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): self.targets = targets self.event_name = event_name self.annotation_note = annotation_note + self.class_label_mode = class_label_mode + self.decoded_label_key = decoded_label_key + self.fallback_to_class_name = fallback_to_class_name self.write_detection_csv = write_detection_csv self.write_cnn_features_csv = write_cnn_features_csv self.save_if_empty = save_if_empty @@ -121,13 +129,14 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): if not self.include_file_path: data.pop("file_path", None) + annotations = cast(list[Annotation], data["annotation"]) data["annotation"] = [ { key: value for key, value in annotation.items() if key != "cnn_features" } - for annotation in data["annotation"] + for annotation in annotations ] pred_path.write_text(json.dumps(data, indent=2, sort_keys=True)) @@ -233,12 +242,25 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): def get_class_name(self, class_index: int) -> str: class_name = self.targets.class_names[class_index] + + if self.class_label_mode == "class_name": + return class_name + tags = self.targets.decode_class(class_name) - return data.find_tag_value( + default = class_name if self.fallback_to_class_name else None + decoded = data.find_tag_value( tags, - term=soundevent_terms.scientific_name, - default=class_name, - ) # type: ignore + key=self.decoded_label_key, + default=default, + ) + + if decoded is None: + raise ValueError( + "Could not decode class label using key " + f"{self.decoded_label_key!r} for class {class_name!r}." + ) + + return decoded def get_recording_class(self, detections: Sequence[Detection]) -> str: if not detections: @@ -318,6 +340,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): targets, event_name=config.event_name, annotation_note=config.annotation_note, + class_label_mode=config.class_label_mode, + decoded_label_key=config.decoded_label_key, + fallback_to_class_name=config.fallback_to_class_name, write_detection_csv=config.write_detection_csv, write_cnn_features_csv=config.write_cnn_features_csv, save_if_empty=config.save_if_empty, diff --git a/tests/test_api_v2/test_outputs_io.py b/tests/test_api_v2/test_outputs_io.py index 99c6ade..914b0ce 100644 --- a/tests/test_api_v2/test_outputs_io.py +++ b/tests/test_api_v2/test_outputs_io.py @@ -1,9 +1,11 @@ from pathlib import Path from typing import cast +from unittest.mock import Mock import numpy as np import pandas as pd import pytest +from soundevent import data as soundevent_data from batdetect2.api_v2 import BatDetect2API from batdetect2.outputs import build_output_formatter @@ -11,6 +13,7 @@ from batdetect2.outputs.formats import ( BatDetect2OutputConfig, SoundEventOutputConfig, ) +from batdetect2.outputs.formats.batdetect2 import BatDetect2Formatter from batdetect2.postprocess.types import ClipDetections @@ -79,6 +82,82 @@ def test_save_predictions_with_batdetect2_override( assert len(loaded[0]["annotation"]) == len(file_prediction.detections) +def test_batdetect2_formatter_can_use_raw_class_names( + api_v2: BatDetect2API, + file_prediction, + tmp_path: Path, +) -> None: + output_dir = tmp_path / "batdetect2_raw_class_names" + api_v2.save_predictions( + [file_prediction], + path=output_dir, + config=BatDetect2OutputConfig(class_label_mode="class_name"), + ) + + loaded = cast( + list[dict], api_v2.load_predictions(output_dir, format="batdetect2") + ) + first_annotation = loaded[0]["annotation"][0] + + assert first_annotation["class"] in api_v2.targets.class_names + + +def test_batdetect2_formatter_can_use_decoded_species_tag() -> None: + targets = Mock() + targets.class_names = ["myodau"] + targets.decode_class.return_value = [ + soundevent_data.Tag( + key="dwc:scientificName", + value="Myotis daubentonii", + ) + ] + + formatter = BatDetect2Formatter( + targets=targets, + event_name="Echolocation", + annotation_note="Automatically generated.", + ) + + assert formatter.get_class_name(0) == "Myotis daubentonii" + + +def test_batdetect2_formatter_can_fallback_to_class_name_when_key_missing() -> ( + None +): + targets = Mock() + targets.class_names = ["myodau"] + targets.decode_class.return_value = [] + + formatter = BatDetect2Formatter( + targets=targets, + event_name="Echolocation", + annotation_note="Automatically generated.", + decoded_label_key="dwc:scientificName", + fallback_to_class_name=True, + ) + + assert formatter.get_class_name(0) == "myodau" + + +def test_batdetect2_formatter_rejects_missing_decoded_key_without_fallback() -> ( + None +): + targets = Mock() + targets.class_names = ["myodau"] + targets.decode_class.return_value = [] + + formatter = BatDetect2Formatter( + targets=targets, + event_name="Echolocation", + annotation_note="Automatically generated.", + decoded_label_key="dwc:scientificName", + fallback_to_class_name=False, + ) + + with pytest.raises(ValueError, match="Could not decode class label"): + formatter.get_class_name(0) + + def test_load_predictions_with_format_override( api_v2: BatDetect2API, file_prediction,