feat: support configurable batdetect2 class labels

This commit is contained in:
mbsantiago 2026-05-06 19:27:00 +01:00
parent d35866439b
commit ec4fa6c262
2 changed files with 111 additions and 7 deletions

View File

@ -1,11 +1,10 @@
import json
from pathlib import Path
from typing import List, Literal, Sequence, TypedDict
from typing import List, Literal, Sequence, TypedDict, cast
import numpy as np
import pandas as pd
from soundevent import data
from soundevent import terms as soundevent_terms
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
@ -54,6 +53,9 @@ class BatDetect2OutputConfig(BaseConfig):
event_name: str = "Echolocation"
annotation_note: str = "Automatically generated."
class_label_mode: Literal["class_name", "decoded_tag"] = "decoded_tag"
decoded_label_key: str = "dwc:scientificName"
fallback_to_class_name: bool = True
write_detection_csv: bool = True
write_cnn_features_csv: bool = False
save_if_empty: bool = False
@ -67,6 +69,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
targets: TargetProtocol,
event_name: str,
annotation_note: str,
class_label_mode: Literal["class_name", "decoded_tag"] = "decoded_tag",
decoded_label_key: str = "dwc:scientificName",
fallback_to_class_name: bool = True,
write_detection_csv: bool = True,
write_cnn_features_csv: bool = False,
save_if_empty: bool = False,
@ -76,6 +81,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
self.targets = targets
self.event_name = event_name
self.annotation_note = annotation_note
self.class_label_mode = class_label_mode
self.decoded_label_key = decoded_label_key
self.fallback_to_class_name = fallback_to_class_name
self.write_detection_csv = write_detection_csv
self.write_cnn_features_csv = write_cnn_features_csv
self.save_if_empty = save_if_empty
@ -121,13 +129,14 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
if not self.include_file_path:
data.pop("file_path", None)
annotations = cast(list[Annotation], data["annotation"])
data["annotation"] = [
{
key: value
for key, value in annotation.items()
if key != "cnn_features"
}
for annotation in data["annotation"]
for annotation in annotations
]
pred_path.write_text(json.dumps(data, indent=2, sort_keys=True))
@ -233,12 +242,25 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
def get_class_name(self, class_index: int) -> str:
class_name = self.targets.class_names[class_index]
if self.class_label_mode == "class_name":
return class_name
tags = self.targets.decode_class(class_name)
return data.find_tag_value(
default = class_name if self.fallback_to_class_name else None
decoded = data.find_tag_value(
tags,
term=soundevent_terms.scientific_name,
default=class_name,
) # type: ignore
key=self.decoded_label_key,
default=default,
)
if decoded is None:
raise ValueError(
"Could not decode class label using key "
f"{self.decoded_label_key!r} for class {class_name!r}."
)
return decoded
def get_recording_class(self, detections: Sequence[Detection]) -> str:
if not detections:
@ -318,6 +340,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
targets,
event_name=config.event_name,
annotation_note=config.annotation_note,
class_label_mode=config.class_label_mode,
decoded_label_key=config.decoded_label_key,
fallback_to_class_name=config.fallback_to_class_name,
write_detection_csv=config.write_detection_csv,
write_cnn_features_csv=config.write_cnn_features_csv,
save_if_empty=config.save_if_empty,

View File

@ -1,9 +1,11 @@
from pathlib import Path
from typing import cast
from unittest.mock import Mock
import numpy as np
import pandas as pd
import pytest
from soundevent import data as soundevent_data
from batdetect2.api_v2 import BatDetect2API
from batdetect2.outputs import build_output_formatter
@ -11,6 +13,7 @@ from batdetect2.outputs.formats import (
BatDetect2OutputConfig,
SoundEventOutputConfig,
)
from batdetect2.outputs.formats.batdetect2 import BatDetect2Formatter
from batdetect2.postprocess.types import ClipDetections
@ -79,6 +82,82 @@ def test_save_predictions_with_batdetect2_override(
assert len(loaded[0]["annotation"]) == len(file_prediction.detections)
def test_batdetect2_formatter_can_use_raw_class_names(
api_v2: BatDetect2API,
file_prediction,
tmp_path: Path,
) -> None:
output_dir = tmp_path / "batdetect2_raw_class_names"
api_v2.save_predictions(
[file_prediction],
path=output_dir,
config=BatDetect2OutputConfig(class_label_mode="class_name"),
)
loaded = cast(
list[dict], api_v2.load_predictions(output_dir, format="batdetect2")
)
first_annotation = loaded[0]["annotation"][0]
assert first_annotation["class"] in api_v2.targets.class_names
def test_batdetect2_formatter_can_use_decoded_species_tag() -> None:
targets = Mock()
targets.class_names = ["myodau"]
targets.decode_class.return_value = [
soundevent_data.Tag(
key="dwc:scientificName",
value="Myotis daubentonii",
)
]
formatter = BatDetect2Formatter(
targets=targets,
event_name="Echolocation",
annotation_note="Automatically generated.",
)
assert formatter.get_class_name(0) == "Myotis daubentonii"
def test_batdetect2_formatter_can_fallback_to_class_name_when_key_missing() -> (
None
):
targets = Mock()
targets.class_names = ["myodau"]
targets.decode_class.return_value = []
formatter = BatDetect2Formatter(
targets=targets,
event_name="Echolocation",
annotation_note="Automatically generated.",
decoded_label_key="dwc:scientificName",
fallback_to_class_name=True,
)
assert formatter.get_class_name(0) == "myodau"
def test_batdetect2_formatter_rejects_missing_decoded_key_without_fallback() -> (
None
):
targets = Mock()
targets.class_names = ["myodau"]
targets.decode_class.return_value = []
formatter = BatDetect2Formatter(
targets=targets,
event_name="Echolocation",
annotation_note="Automatically generated.",
decoded_label_key="dwc:scientificName",
fallback_to_class_name=False,
)
with pytest.raises(ValueError, match="Could not decode class label"):
formatter.get_class_name(0)
def test_load_predictions_with_format_override(
api_v2: BatDetect2API,
file_prediction,