batdetect2/tests/test_postprocessing/test_decoding.py
2025-04-30 22:51:49 +01:00

671 lines
20 KiB
Python

from pathlib import Path
from typing import List, Tuple
import numpy as np
import pytest
import xarray as xr
from soundevent import data
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction,
get_class_tags,
get_generic_tags,
get_prediction_features,
)
from batdetect2.postprocess.types import RawPrediction
@pytest.fixture
def dummy_geometry_builder():
"""A simple GeometryBuilder that creates a BBox around the point."""
def _builder(
position: Tuple[float, float],
dimensions: xr.DataArray,
) -> data.BoundingBox:
time, freq = position
width = dimensions.sel(dimension="width").item()
height = dimensions.sel(dimension="height").item()
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return _builder
@pytest.fixture
def dummy_sound_event_decoder():
"""A simple SoundEventDecoder mapping names to tags."""
tag_map = {
"bat": [
data.Tag(term=data.term_from_key(key="species"), value="Myotis")
],
"noise": [
data.Tag(term=data.term_from_key(key="category"), value="noise")
],
"unknown": [
data.Tag(term=data.term_from_key(key="status"), value="uncertain")
],
}
def _decoder(class_name: str) -> List[data.Tag]:
return tag_map.get(class_name.lower(), [])
return _decoder
@pytest.fixture
def generic_tags() -> List[data.Tag]:
"""Sample generic tags."""
return [
data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
]
@pytest.fixture
def sample_recording() -> data.Recording:
"""A sample soundevent Recording."""
return data.Recording(
path=Path("/path/to/recording.wav"),
duration=60.0,
channels=1,
samplerate=192000,
)
@pytest.fixture
def sample_clip(sample_recording) -> data.Clip:
"""A sample soundevent Clip."""
return data.Clip(
recording=sample_recording,
start_time=10.0,
end_time=20.0,
)
@pytest.fixture
def sample_detection_dataset() -> xr.Dataset:
"""Creates a sample detection dataset suitable for decoding."""
expected_times = np.array([20, 10])
expected_freqs = np.array([300, 200])
detection_coords = {
"time": ("detection", expected_times),
"frequency": ("detection", expected_freqs),
}
scores_data = np.array([0.9, 0.8], dtype=np.float64)
scores = xr.DataArray(
scores_data,
coords=detection_coords,
dims=["detection"],
name="score",
)
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
dimensions = xr.DataArray(
dimensions_data,
coords={**detection_coords, "dimension": ["width", "height"]},
dims=["detection", "dimension"],
name="dimensions",
)
classes_data = np.array(
[[0.43, 0.85], [0.24, 0.66]],
dtype=np.float32,
)
classes = xr.DataArray(
classes_data,
coords={**detection_coords, "category": ["bat", "noise"]},
dims=["detection", "category"],
name="classes",
)
features_data = np.array(
[[7.0, 16.0, 25.0, 34.0], [3.0, 12.0, 21.0, 30.0]], dtype=np.float32
)
features = xr.DataArray(
features_data,
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
dims=["detection", "feature"],
name="features",
)
ds = xr.Dataset(
{
"score": scores,
"dimensions": dimensions,
"classes": classes,
"features": features,
},
coords=detection_coords,
)
return ds
@pytest.fixture
def empty_detection_dataset() -> xr.Dataset:
"""Creates an empty detection dataset with correct structure."""
detection_coords = {
"time": ("detection", np.array([], dtype=np.float64)),
"freq": ("detection", np.array([], dtype=np.float64)),
}
scores = xr.DataArray(
np.array([], dtype=np.float64),
coords=detection_coords,
dims=["detection"],
name="scores",
)
dimensions = xr.DataArray(
np.empty((0, 2), dtype=np.float32),
coords={**detection_coords, "dimension": ["width", "height"]},
dims=["detection", "dimension"],
name="dimensions",
)
classes = xr.DataArray(
np.empty((0, 2), dtype=np.float32),
coords={**detection_coords, "category": ["bat", "noise"]},
dims=["detection", "category"],
name="classes",
)
features = xr.DataArray(
np.empty((0, 4), dtype=np.float32),
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
dims=["detection", "feature"],
name="features",
)
return xr.Dataset(
{
"score": scores,
"dimensions": dimensions,
"classes": classes,
"features": features,
},
coords=detection_coords,
)
@pytest.fixture
def sample_raw_predictions() -> List[RawPrediction]:
"""Manually crafted RawPrediction objects using the actual type."""
pred1_classes = xr.DataArray(
[0.43, 0.85], coords={"category": ["bat", "noise"]}, dims=["category"]
)
pred1_features = xr.DataArray(
[7.0, 16.0, 25.0, 34.0],
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred1 = RawPrediction(
detection_score=0.9,
geometry=data.BoundingBox(
coordinates=[
20 - 7 / 2,
300 - 16 / 2,
20 + 7 / 2,
300 + 16 / 2,
]
),
class_scores=pred1_classes,
features=pred1_features,
)
pred2_classes = xr.DataArray(
[0.24, 0.66], coords={"category": ["bat", "noise"]}, dims=["category"]
)
pred2_features = xr.DataArray(
[3.0, 12.0, 21.0, 30.0],
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred2 = RawPrediction(
detection_score=0.8,
geometry=data.BoundingBox(
coordinates=[
10 - 3 / 2,
200 - 12 / 2,
10 + 3 / 2,
200 + 12 / 2,
]
),
class_scores=pred2_classes,
features=pred2_features,
)
pred3_classes = xr.DataArray(
[0.05, 0.02], coords={"category": ["bat", "noise"]}, dims=["category"]
)
pred3_features = xr.DataArray(
[1.0, 2.0, 3.0, 4.0],
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred3 = RawPrediction(
detection_score=0.15,
geometry=data.BoundingBox(
coordinates=[
5.0,
50.0,
6.0,
60.0,
]
),
class_scores=pred3_classes,
features=pred3_features,
)
return [pred1, pred2, pred3]
def test_convert_xr_dataset_basic(
sample_detection_dataset, dummy_geometry_builder
):
"""Test basic conversion of a dataset to RawPrediction list."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
sample_detection_dataset, dummy_geometry_builder
)
assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 2
pred1 = raw_predictions[0]
assert isinstance(pred1, RawPrediction)
assert pred1.detection_score == 0.9
assert pred1.geometry.coordinates == [
20 - 7 / 2,
300 - 16 / 2,
20 + 7 / 2,
300 + 16 / 2,
]
xr.testing.assert_allclose(
pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0),
)
xr.testing.assert_allclose(
pred1.features, sample_detection_dataset["features"].sel(detection=0)
)
pred2 = raw_predictions[1]
assert isinstance(pred2, RawPrediction)
assert pred2.detection_score == 0.8
assert pred2.geometry.coordinates == [
10 - 3 / 2,
200 - 12 / 2,
10 + 3 / 2,
200 + 12 / 2,
]
xr.testing.assert_allclose(
pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1),
)
xr.testing.assert_allclose(
pred2.features, sample_detection_dataset["features"].sel(detection=1)
)
def test_convert_xr_dataset_empty(
empty_detection_dataset, dummy_geometry_builder
):
"""Test conversion of an empty dataset."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
empty_detection_dataset, dummy_geometry_builder
)
assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 0
def test_convert_raw_to_sound_event_basic(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test basic conversion, default threshold, multi-label."""
raw_pred = sample_raw_predictions[0]
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
)
assert isinstance(se_pred, data.SoundEventPrediction)
assert se_pred.score == raw_pred.detection_score
se = se_pred.sound_event
assert isinstance(se, data.SoundEvent)
assert se.recording == sample_recording
assert isinstance(se.geometry, data.BoundingBox)
assert se.geometry == raw_pred.geometry
assert len(se.features) == len(raw_pred.features)
feat_dict = {f.term.name: f.value for f in se.features}
assert "batdetect2:f0" in feat_dict and isinstance(
feat_dict["batdetect2:f0"], float
)
assert feat_dict["batdetect2:f0"] == 7.0
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85),
("soundevent:species", "Myotis", 0.43),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
}
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test effect of classification threshold."""
raw_pred = sample_raw_predictions[0]
high_threshold = 0.5
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=high_threshold,
top_class_only=False,
)
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
}
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test when classification_threshold is None."""
raw_pred = sample_raw_predictions[2]
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=None,
top_class_only=False,
)
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
("soundevent:species", "Myotis", 0.05),
("soundevent:category", "noise", 0.02),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
}
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test top_class_only=True behavior."""
raw_pred = sample_raw_predictions[0]
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True,
)
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
}
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test when all class scores are below the default threshold."""
raw_pred = sample_raw_predictions[2]
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False,
)
expected_tags = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
}
actual_tags = {
(pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags
}
assert actual_tags == expected_tags
def test_convert_raw_list_to_clip_basic(
sample_raw_predictions,
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test converting a list of RawPredictions to a ClipPrediction."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions,
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False,
)
assert isinstance(clip_pred, data.ClipPrediction)
assert clip_pred.clip == sample_clip
assert len(clip_pred.sound_events) == len(sample_raw_predictions)
assert clip_pred.sound_events[0].score == (
sample_raw_predictions[0].detection_score
)
assert clip_pred.sound_events[1].score == (
sample_raw_predictions[1].detection_score
)
assert clip_pred.sound_events[2].score == (
sample_raw_predictions[2].detection_score
)
se_pred3_tags = {
(pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[2].tags
}
expected_tags3 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.15),
}
assert se_pred3_tags == expected_tags3
def test_convert_raw_list_to_clip_empty(
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test converting an empty list of RawPredictions."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=[],
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
)
assert isinstance(clip_pred, data.ClipPrediction)
assert clip_pred.clip == sample_clip
assert len(clip_pred.sound_events) == 0
def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions,
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test that arguments like top_class_only are passed through."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions,
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True,
)
assert len(clip_pred.sound_events) == 3
se_pred1_tags = {
(pt.tag.term.name, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[0].tags
}
expected_tags1 = {
(generic_tags[0].term.name, generic_tags[0].value, 0.9),
("soundevent:category", "noise", 0.85),
}
assert se_pred1_tags == expected_tags1
def test_get_generic_tags_basic(generic_tags):
"""Test creation of generic tags with score."""
detection_score = 0.75
predicted_tags = get_generic_tags(
detection_score=detection_score, generic_class_tags=generic_tags
)
assert len(predicted_tags) == len(generic_tags)
for predicted_tag in predicted_tags:
assert isinstance(predicted_tag, data.PredictedTag)
assert predicted_tag.score == detection_score
assert predicted_tag.tag in generic_tags
def test_get_prediction_features_basic():
"""Test conversion of feature DataArray to list of Features."""
feature_data = xr.DataArray(
[1.1, 2.2, 3.3],
coords={"feature": ["feat1", "feat2", "feat3"]},
dims=["feature"],
)
features = get_prediction_features(feature_data)
assert len(features) == 3
for feature, feat_name, feat_value in zip(
features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3]
):
assert isinstance(feature, data.Feature)
assert feature.term.name == f"batdetect2:{feat_name}"
assert feature.value == feat_value
def test_get_class_tags_basic(dummy_sound_event_decoder):
"""Test creation of class tags based on scores and decoder."""
class_scores = xr.DataArray(
[0.6, 0.2, 0.9],
coords={"category": ["bat", "noise", "unknown"]},
dims=["category"],
)
predicted_tags = get_class_tags(
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
)
assert len(predicted_tags) == 3
tag_values = [pt.tag.value for pt in predicted_tags]
tag_scores = [pt.score for pt in predicted_tags]
assert "Myotis" in tag_values
assert "noise" in tag_values
assert "uncertain" in tag_values
assert 0.6 in tag_scores
assert 0.2 in tag_scores
assert 0.9 in tag_scores
def test_get_class_tags_thresholding(dummy_sound_event_decoder):
"""Test class tag creation with a threshold."""
class_scores = xr.DataArray(
[0.6, 0.2, 0.9],
coords={"category": ["bat", "noise", "unknown"]},
dims=["category"],
)
threshold = 0.5
predicted_tags = get_class_tags(
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
threshold=threshold,
)
assert len(predicted_tags) == 2
tag_values = [pt.tag.value for pt in predicted_tags]
assert "Myotis" in tag_values
assert "noise" not in tag_values
assert "uncertain" in tag_values
def test_get_class_tags_top_class_only(dummy_sound_event_decoder):
"""Test class tag creation with top_class_only."""
class_scores = xr.DataArray(
[0.6, 0.2, 0.9],
coords={"category": ["bat", "noise", "unknown"]},
dims=["category"],
)
predicted_tags = get_class_tags(
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
top_class_only=True,
)
assert len(predicted_tags) == 1
assert predicted_tags[0].tag.value == "uncertain"
assert predicted_tags[0].score == 0.9
def test_get_class_tags_empty(dummy_sound_event_decoder):
"""Test with empty class scores."""
class_scores = xr.DataArray([], coords={"category": []}, dims=["category"])
predicted_tags = get_class_tags(
class_scores=class_scores,
sound_event_decoder=dummy_sound_event_decoder,
)
assert len(predicted_tags) == 0