batdetect2/tests/test_postprocessing/test_decoding.py
2025-04-20 15:52:25 +01:00

605 lines
19 KiB
Python

from pathlib import Path
from typing import List, Tuple
import numpy as np
import pytest
import xarray as xr
# Removed dataclass import as MockRawPrediction is replaced
from soundevent import data
# Import functions to test
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction,
)
from batdetect2.postprocess.types import RawPrediction
# Dummy GeometryBuilder function fixture
@pytest.fixture
def dummy_geometry_builder():
"""A simple GeometryBuilder that creates a BBox around the point."""
def _builder(
position: Tuple[float, float],
dimensions: xr.DataArray,
) -> data.BoundingBox:
time, freq = position
width = dimensions.sel(dimension="width").item()
height = dimensions.sel(dimension="height").item()
# Assume position is the center
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return _builder
# Dummy SoundEventDecoder function fixture
@pytest.fixture
def dummy_sound_event_decoder():
"""A simple SoundEventDecoder mapping names to tags."""
tag_map = {
"bat": [
data.Tag(term=data.term_from_key(key="species"), value="Myotis")
],
"noise": [
data.Tag(term=data.term_from_key(key="category"), value="noise")
],
"unknown": [
data.Tag(term=data.term_from_key(key="status"), value="uncertain")
],
}
def _decoder(class_name: str) -> List[data.Tag]:
return tag_map.get(class_name.lower(), [])
return _decoder
@pytest.fixture
def generic_tags() -> List[data.Tag]:
"""Sample generic tags."""
return [
data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
]
@pytest.fixture
def sample_recording() -> data.Recording:
"""A sample soundevent Recording."""
return data.Recording(
path=Path("/path/to/recording.wav"),
duration=60.0,
channels=1,
samplerate=192000,
)
@pytest.fixture
def sample_clip(sample_recording) -> data.Clip:
"""A sample soundevent Clip."""
return data.Clip(
recording=sample_recording,
start_time=10.0,
end_time=20.0,
)
# Fixture for a detection dataset (adapted from test_extraction)
@pytest.fixture
def sample_detection_dataset() -> xr.Dataset:
"""Creates a sample detection dataset suitable for decoding."""
# Based on test_extraction's corrected expectations
# Detections: (t=20, f=300, s=0.9), (t=10, f=200, s=0.8)
expected_times = np.array([20, 10])
expected_freqs = np.array([300, 200])
detection_coords = {
"time": ("detection", expected_times),
"freq": ("detection", expected_freqs),
}
scores_data = np.array([0.9, 0.8], dtype=np.float64)
scores = xr.DataArray(
scores_data,
coords=detection_coords,
dims=["detection"],
name="scores",
)
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
dimensions = xr.DataArray(
dimensions_data,
coords={**detection_coords, "dimension": ["width", "height"]},
dims=["detection", "dimension"],
name="dimensions",
)
classes_data = np.array(
[[0.43, 0.85], [0.24, 0.66]],
dtype=np.float32, # Simplified values
)
classes = xr.DataArray(
classes_data,
coords={**detection_coords, "category": ["bat", "noise"]},
dims=["detection", "category"],
name="classes",
)
features_data = np.array(
[[7.0, 16.0, 25.0, 34.0], [3.0, 12.0, 21.0, 30.0]], dtype=np.float32
)
features = xr.DataArray(
features_data,
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
dims=["detection", "feature"],
name="features",
)
ds = xr.Dataset(
{
"score": scores,
"dimensions": dimensions,
"classes": classes,
"features": features,
},
coords=detection_coords,
)
return ds
@pytest.fixture
def empty_detection_dataset() -> xr.Dataset:
"""Creates an empty detection dataset with correct structure."""
detection_coords = {
"time": ("detection", np.array([], dtype=np.float64)),
"freq": ("detection", np.array([], dtype=np.float64)),
}
scores = xr.DataArray(
np.array([], dtype=np.float64),
coords=detection_coords,
dims=["detection"],
name="scores",
)
dimensions = xr.DataArray(
np.empty((0, 2), dtype=np.float32),
coords={**detection_coords, "dimension": ["width", "height"]},
dims=["detection", "dimension"],
name="dimensions",
)
classes = xr.DataArray(
np.empty((0, 2), dtype=np.float32),
coords={**detection_coords, "category": ["bat", "noise"]},
dims=["detection", "category"],
name="classes",
)
features = xr.DataArray(
np.empty((0, 4), dtype=np.float32),
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
dims=["detection", "feature"],
name="features",
)
return xr.Dataset(
{
"scores": scores,
"dimensions": dimensions,
"classes": classes,
"features": features,
},
coords=detection_coords,
)
# Fixture for sample RawPrediction objects (using the actual type)
@pytest.fixture
def sample_raw_predictions() -> List[RawPrediction]:
"""Manually crafted RawPrediction objects using the actual type."""
# Corresponds roughly to sample_detection_dataset after geometry building
# Det 1: t=20, f=300, s=0.9, w=7, h=16, classes=[0.43, 0.85], feats=[7, 16, 25, 34]
# Det 2: t=10, f=200, s=0.8, w=3, h=12, classes=[0.24, 0.66], feats=[ 3, 12, 21, 30]
pred1_classes = xr.DataArray(
[0.43, 0.85], coords={"category": ["bat", "noise"]}, dims=["category"]
)
pred1_features = xr.DataArray(
[7.0, 16.0, 25.0, 34.0],
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred1 = RawPrediction( # Use RawPrediction directly
detection_score=0.9,
start_time=20 - 7 / 2,
end_time=20 + 7 / 2, # 16.5, 23.5
low_freq=300 - 16 / 2,
high_freq=300 + 16 / 2, # 292, 308
class_scores=pred1_classes,
features=pred1_features,
)
pred2_classes = xr.DataArray(
[0.24, 0.66], coords={"category": ["bat", "noise"]}, dims=["category"]
)
pred2_features = xr.DataArray(
[3.0, 12.0, 21.0, 30.0],
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred2 = RawPrediction( # Use RawPrediction directly
detection_score=0.8,
start_time=10 - 3 / 2,
end_time=10 + 3 / 2, # 8.5, 11.5
low_freq=200 - 12 / 2,
high_freq=200 + 12 / 2, # 194, 206
class_scores=pred2_classes,
features=pred2_features,
)
pred3_classes = xr.DataArray(
[0.05, 0.02], coords={"category": ["bat", "noise"]}, dims=["category"]
) # Below default threshold
pred3_features = xr.DataArray(
[1.0, 2.0, 3.0, 4.0],
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred3 = RawPrediction( # Use RawPrediction directly
detection_score=0.15,
start_time=5.0,
end_time=6.0,
low_freq=50.0,
high_freq=60.0,
class_scores=pred3_classes,
features=pred3_features,
)
return [pred1, pred2, pred3]
# --- Tests for convert_xr_dataset_to_raw_prediction ---
def test_convert_xr_dataset_basic(
sample_detection_dataset, dummy_geometry_builder
):
"""Test basic conversion of a dataset to RawPrediction list."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
sample_detection_dataset, dummy_geometry_builder
)
assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 2
# Check first prediction (score=0.9)
pred1 = raw_predictions[0]
assert isinstance(pred1, RawPrediction) # Check against the actual type
assert pred1.detection_score == pytest.approx(0.9)
# Check bounds derived from dummy_geometry_builder (center pos assumed)
# t=20, f=300, w=7, h=16
assert pred1.start_time == pytest.approx(20 - 7 / 2)
assert pred1.end_time == pytest.approx(20 + 7 / 2)
assert pred1.low_freq == pytest.approx(300 - 16 / 2)
assert pred1.high_freq == pytest.approx(300 + 16 / 2)
xr.testing.assert_allclose(
pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0),
)
xr.testing.assert_allclose(
pred1.features, sample_detection_dataset["features"].sel(detection=0)
)
# Check second prediction (score=0.8)
pred2 = raw_predictions[1]
assert isinstance(pred2, RawPrediction) # Check against the actual type
assert pred2.detection_score == pytest.approx(0.8)
# t=10, f=200, w=3, h=12
assert pred2.start_time == pytest.approx(10 - 3 / 2)
assert pred2.end_time == pytest.approx(10 + 3 / 2)
assert pred2.low_freq == pytest.approx(200 - 12 / 2)
assert pred2.high_freq == pytest.approx(200 + 12 / 2)
xr.testing.assert_allclose(
pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1),
)
xr.testing.assert_allclose(
pred2.features, sample_detection_dataset["features"].sel(detection=1)
)
# ...(rest of the tests remain unchanged as they accessed attributes correctly)...
def test_convert_xr_dataset_empty(
empty_detection_dataset, dummy_geometry_builder
):
"""Test conversion of an empty dataset."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
empty_detection_dataset, dummy_geometry_builder
)
assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 0
# --- Tests for convert_raw_prediction_to_sound_event_prediction ---
def test_convert_raw_to_sound_event_basic(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test basic conversion, default threshold, multi-label."""
# score=0.9, classes=[0.43(bat), 0.85(noise)]
raw_pred = sample_raw_predictions[0]
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
# classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD (0.1),
# top_class_only=False,
)
assert isinstance(se_pred, data.SoundEventPrediction)
assert se_pred.score == pytest.approx(raw_pred.detection_score)
# Check SoundEvent
se = se_pred.sound_event
assert isinstance(se, data.SoundEvent)
assert se.recording == sample_recording
assert isinstance(se.geometry, data.BoundingBox)
np.testing.assert_allclose(
se.geometry.coordinates,
[
raw_pred.start_time,
raw_pred.low_freq,
raw_pred.end_time,
raw_pred.high_freq,
],
)
assert len(se.features) == len(raw_pred.features)
# Simple check for feature presence and value type
feat_dict = {f.term.name: f.value for f in se.features}
assert "batdetect2:f0" in feat_dict and isinstance(
feat_dict["batdetect2:f0"], float
)
assert feat_dict["batdetect2:f0"] == pytest.approx(7.0)
# Check Tags
# Expected: Generic(0.9), Noise(0.85), Bat(0.43)
# Note: Order might depend on sortby implementation detail, compare as sets
expected_tags = {
# Generic Tag
(generic_tags[0].key, generic_tags[0].value, 0.9),
# Noise Tag (score 0.85 > 0.1)
("category", "noise", 0.85),
# Bat Tag (score 0.43 > 0.1)
("species", "Myotis", 0.43),
}
print("expected", expected_tags)
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
print("actual", actual_tags)
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test effect of classification threshold."""
raw_pred = sample_raw_predictions[
0
] # score=0.9, classes=[0.43(bat), 0.85(noise)]
high_threshold = 0.5
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=high_threshold, # Only noise should pass
top_class_only=False,
)
# Expected: Generic(0.9), Noise(0.85) - Bat (0.43) is below threshold
expected_tags = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
("category", "noise", pytest.approx(0.85)),
}
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test when classification_threshold is None."""
raw_pred = sample_raw_predictions[
2
] # score=0.15, classes=[0.05(bat), 0.02(noise)]
# Both classes are below default threshold, but should be included if None
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=None, # No thresholding
top_class_only=False,
)
# Expected: Generic(0.15), Bat(0.05), Noise(0.02)
expected_tags = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
("species", "Myotis", pytest.approx(0.05)),
("category", "noise", pytest.approx(0.02)),
}
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test top_class_only=True behavior."""
raw_pred = sample_raw_predictions[
0
] # score=0.9, classes=[0.43(bat), 0.85(noise)]
# Highest score is noise (0.85)
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True, # Only include top class (noise)
)
# Expected: Generic(0.9), Noise(0.85)
expected_tags = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
("category", "noise", pytest.approx(0.85)),
}
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test when all class scores are below the default threshold."""
raw_pred = sample_raw_predictions[
2
] # score=0.15, classes=[0.05(bat), 0.02(noise)]
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, # 0.1
top_class_only=False,
)
# Expected: Only Generic(0.15) tag, as others are below threshold
expected_tags = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
}
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
assert actual_tags == expected_tags
# --- Tests for convert_raw_predictions_to_clip_prediction ---
def test_convert_raw_list_to_clip_basic(
sample_raw_predictions,
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test converting a list of RawPredictions to a ClipPrediction."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions,
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False,
)
assert isinstance(clip_pred, data.ClipPrediction)
assert clip_pred.clip == sample_clip
assert len(clip_pred.sound_events) == len(sample_raw_predictions)
# Check if the contained sound events seem correct (basic check)
assert clip_pred.sound_events[0].score == pytest.approx(
sample_raw_predictions[0].detection_score
)
assert clip_pred.sound_events[1].score == pytest.approx(
sample_raw_predictions[1].detection_score
)
assert clip_pred.sound_events[2].score == pytest.approx(
sample_raw_predictions[2].detection_score
)
# Check if tags were generated correctly for one event (e.g., the last one)
# Pred 3 has score 0.15, classes [0.05, 0.02]. Only generic tag expected.
se_pred3_tags = {
(pt.tag.key, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[2].tags
}
expected_tags3 = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
}
assert se_pred3_tags == expected_tags3
def test_convert_raw_list_to_clip_empty(
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test converting an empty list of RawPredictions."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=[],
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
)
assert isinstance(clip_pred, data.ClipPrediction)
assert clip_pred.clip == sample_clip
assert len(clip_pred.sound_events) == 0
def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions,
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test that arguments like top_class_only are passed through."""
# Use top_class_only = True
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions,
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True, # <<-- Argument being tested
)
assert len(clip_pred.sound_events) == 3
# Check tags for the first prediction (score=0.9, classes=[0.43(bat), 0.85(noise)])
# With top_class_only=True, expect Generic(0.9) and Noise(0.85) only
se_pred1_tags = {
(pt.tag.key, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[0].tags
}
expected_tags1 = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
("category", "noise", pytest.approx(0.85)),
}
assert se_pred1_tags == expected_tags1