mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: support configurable batdetect2 class labels
This commit is contained in:
parent
d35866439b
commit
ec4fa6c262
@ -1,11 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Sequence, TypedDict
|
from typing import List, Literal, Sequence, TypedDict, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent import terms as soundevent_terms
|
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
@ -54,6 +53,9 @@ class BatDetect2OutputConfig(BaseConfig):
|
|||||||
|
|
||||||
event_name: str = "Echolocation"
|
event_name: str = "Echolocation"
|
||||||
annotation_note: str = "Automatically generated."
|
annotation_note: str = "Automatically generated."
|
||||||
|
class_label_mode: Literal["class_name", "decoded_tag"] = "decoded_tag"
|
||||||
|
decoded_label_key: str = "dwc:scientificName"
|
||||||
|
fallback_to_class_name: bool = True
|
||||||
write_detection_csv: bool = True
|
write_detection_csv: bool = True
|
||||||
write_cnn_features_csv: bool = False
|
write_cnn_features_csv: bool = False
|
||||||
save_if_empty: bool = False
|
save_if_empty: bool = False
|
||||||
@ -67,6 +69,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
event_name: str,
|
event_name: str,
|
||||||
annotation_note: str,
|
annotation_note: str,
|
||||||
|
class_label_mode: Literal["class_name", "decoded_tag"] = "decoded_tag",
|
||||||
|
decoded_label_key: str = "dwc:scientificName",
|
||||||
|
fallback_to_class_name: bool = True,
|
||||||
write_detection_csv: bool = True,
|
write_detection_csv: bool = True,
|
||||||
write_cnn_features_csv: bool = False,
|
write_cnn_features_csv: bool = False,
|
||||||
save_if_empty: bool = False,
|
save_if_empty: bool = False,
|
||||||
@ -76,6 +81,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.event_name = event_name
|
self.event_name = event_name
|
||||||
self.annotation_note = annotation_note
|
self.annotation_note = annotation_note
|
||||||
|
self.class_label_mode = class_label_mode
|
||||||
|
self.decoded_label_key = decoded_label_key
|
||||||
|
self.fallback_to_class_name = fallback_to_class_name
|
||||||
self.write_detection_csv = write_detection_csv
|
self.write_detection_csv = write_detection_csv
|
||||||
self.write_cnn_features_csv = write_cnn_features_csv
|
self.write_cnn_features_csv = write_cnn_features_csv
|
||||||
self.save_if_empty = save_if_empty
|
self.save_if_empty = save_if_empty
|
||||||
@ -121,13 +129,14 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
if not self.include_file_path:
|
if not self.include_file_path:
|
||||||
data.pop("file_path", None)
|
data.pop("file_path", None)
|
||||||
|
|
||||||
|
annotations = cast(list[Annotation], data["annotation"])
|
||||||
data["annotation"] = [
|
data["annotation"] = [
|
||||||
{
|
{
|
||||||
key: value
|
key: value
|
||||||
for key, value in annotation.items()
|
for key, value in annotation.items()
|
||||||
if key != "cnn_features"
|
if key != "cnn_features"
|
||||||
}
|
}
|
||||||
for annotation in data["annotation"]
|
for annotation in annotations
|
||||||
]
|
]
|
||||||
|
|
||||||
pred_path.write_text(json.dumps(data, indent=2, sort_keys=True))
|
pred_path.write_text(json.dumps(data, indent=2, sort_keys=True))
|
||||||
@ -233,12 +242,25 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
|
|
||||||
def get_class_name(self, class_index: int) -> str:
|
def get_class_name(self, class_index: int) -> str:
|
||||||
class_name = self.targets.class_names[class_index]
|
class_name = self.targets.class_names[class_index]
|
||||||
|
|
||||||
|
if self.class_label_mode == "class_name":
|
||||||
|
return class_name
|
||||||
|
|
||||||
tags = self.targets.decode_class(class_name)
|
tags = self.targets.decode_class(class_name)
|
||||||
return data.find_tag_value(
|
default = class_name if self.fallback_to_class_name else None
|
||||||
|
decoded = data.find_tag_value(
|
||||||
tags,
|
tags,
|
||||||
term=soundevent_terms.scientific_name,
|
key=self.decoded_label_key,
|
||||||
default=class_name,
|
default=default,
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
|
if decoded is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not decode class label using key "
|
||||||
|
f"{self.decoded_label_key!r} for class {class_name!r}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
def get_recording_class(self, detections: Sequence[Detection]) -> str:
|
def get_recording_class(self, detections: Sequence[Detection]) -> str:
|
||||||
if not detections:
|
if not detections:
|
||||||
@ -318,6 +340,9 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
targets,
|
targets,
|
||||||
event_name=config.event_name,
|
event_name=config.event_name,
|
||||||
annotation_note=config.annotation_note,
|
annotation_note=config.annotation_note,
|
||||||
|
class_label_mode=config.class_label_mode,
|
||||||
|
decoded_label_key=config.decoded_label_key,
|
||||||
|
fallback_to_class_name=config.fallback_to_class_name,
|
||||||
write_detection_csv=config.write_detection_csv,
|
write_detection_csv=config.write_detection_csv,
|
||||||
write_cnn_features_csv=config.write_cnn_features_csv,
|
write_cnn_features_csv=config.write_cnn_features_csv,
|
||||||
save_if_empty=config.save_if_empty,
|
save_if_empty=config.save_if_empty,
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
from soundevent import data as soundevent_data
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.outputs import build_output_formatter
|
from batdetect2.outputs import build_output_formatter
|
||||||
@ -11,6 +13,7 @@ from batdetect2.outputs.formats import (
|
|||||||
BatDetect2OutputConfig,
|
BatDetect2OutputConfig,
|
||||||
SoundEventOutputConfig,
|
SoundEventOutputConfig,
|
||||||
)
|
)
|
||||||
|
from batdetect2.outputs.formats.batdetect2 import BatDetect2Formatter
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
@ -79,6 +82,82 @@ def test_save_predictions_with_batdetect2_override(
|
|||||||
assert len(loaded[0]["annotation"]) == len(file_prediction.detections)
|
assert len(loaded[0]["annotation"]) == len(file_prediction.detections)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batdetect2_formatter_can_use_raw_class_names(
|
||||||
|
api_v2: BatDetect2API,
|
||||||
|
file_prediction,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
output_dir = tmp_path / "batdetect2_raw_class_names"
|
||||||
|
api_v2.save_predictions(
|
||||||
|
[file_prediction],
|
||||||
|
path=output_dir,
|
||||||
|
config=BatDetect2OutputConfig(class_label_mode="class_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded = cast(
|
||||||
|
list[dict], api_v2.load_predictions(output_dir, format="batdetect2")
|
||||||
|
)
|
||||||
|
first_annotation = loaded[0]["annotation"][0]
|
||||||
|
|
||||||
|
assert first_annotation["class"] in api_v2.targets.class_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_batdetect2_formatter_can_use_decoded_species_tag() -> None:
|
||||||
|
targets = Mock()
|
||||||
|
targets.class_names = ["myodau"]
|
||||||
|
targets.decode_class.return_value = [
|
||||||
|
soundevent_data.Tag(
|
||||||
|
key="dwc:scientificName",
|
||||||
|
value="Myotis daubentonii",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
formatter = BatDetect2Formatter(
|
||||||
|
targets=targets,
|
||||||
|
event_name="Echolocation",
|
||||||
|
annotation_note="Automatically generated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert formatter.get_class_name(0) == "Myotis daubentonii"
|
||||||
|
|
||||||
|
|
||||||
|
def test_batdetect2_formatter_can_fallback_to_class_name_when_key_missing() -> (
|
||||||
|
None
|
||||||
|
):
|
||||||
|
targets = Mock()
|
||||||
|
targets.class_names = ["myodau"]
|
||||||
|
targets.decode_class.return_value = []
|
||||||
|
|
||||||
|
formatter = BatDetect2Formatter(
|
||||||
|
targets=targets,
|
||||||
|
event_name="Echolocation",
|
||||||
|
annotation_note="Automatically generated.",
|
||||||
|
decoded_label_key="dwc:scientificName",
|
||||||
|
fallback_to_class_name=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert formatter.get_class_name(0) == "myodau"
|
||||||
|
|
||||||
|
|
||||||
|
def test_batdetect2_formatter_rejects_missing_decoded_key_without_fallback() -> (
|
||||||
|
None
|
||||||
|
):
|
||||||
|
targets = Mock()
|
||||||
|
targets.class_names = ["myodau"]
|
||||||
|
targets.decode_class.return_value = []
|
||||||
|
|
||||||
|
formatter = BatDetect2Formatter(
|
||||||
|
targets=targets,
|
||||||
|
event_name="Echolocation",
|
||||||
|
annotation_note="Automatically generated.",
|
||||||
|
decoded_label_key="dwc:scientificName",
|
||||||
|
fallback_to_class_name=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Could not decode class label"):
|
||||||
|
formatter.get_class_name(0)
|
||||||
|
|
||||||
|
|
||||||
def test_load_predictions_with_format_override(
|
def test_load_predictions_with_format_override(
|
||||||
api_v2: BatDetect2API,
|
api_v2: BatDetect2API,
|
||||||
file_prediction,
|
file_prediction,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user