Working on postprocess tests

This commit is contained in:
mbsantiago 2025-04-20 15:52:25 +01:00
parent bcf339c40d
commit 1f4454693e
7 changed files with 1280 additions and 49 deletions

View File

@ -29,6 +29,7 @@ The process involves:
from typing import List, Optional from typing import List, Optional
import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
@ -256,8 +257,17 @@ def convert_raw_prediction_to_sound_event_prediction(
] ]
), ),
features=[ features=[
data.Feature(term=data.term_from_key(feat_name), value=value) data.Feature(
for feat_name, value in raw_prediction.features term=data.Term(
name=f"batdetect2:{feat_name}",
label=feat_name,
definition="Automatically extracted features by BatDetect2",
),
value=value,
)
for feat_name, value in _iterate_over_array(
raw_prediction.features
)
], ],
) )
@ -274,9 +284,7 @@ def convert_raw_prediction_to_sound_event_prediction(
drop=True, drop=True,
) )
for class_name, score in class_scores.sortby( for class_name, score in _iterate_sorted(class_scores):
class_scores, ascending=False
):
class_tags = sound_event_decoder(class_name) class_tags = sound_event_decoder(class_name)
for tag in class_tags: for tag in class_tags:
@ -295,3 +303,18 @@ def convert_raw_prediction_to_sound_event_prediction(
score=raw_prediction.detection_score, score=raw_prediction.detection_score,
tags=tags, tags=tags,
) )
def _iterate_over_array(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name]
for value, coord in zip(array.values, coords.values):
yield coord, float(value)
def _iterate_sorted(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name]
indices = np.argsort(coords.values)
for index in indices:
yield str(coords[index]), coords.values[index]

View File

@ -103,7 +103,7 @@ def extract_detections_from_array(
top_values = top_values[mask] top_values = top_values[mask]
top_sorted_indices = top_sorted_indices[mask] top_sorted_indices = top_sorted_indices[mask]
time_indices, freq_indices = np.unravel_index( freq_indices, time_indices = np.unravel_index(
top_sorted_indices, top_sorted_indices,
detection_array.shape, detection_array.shape,
) )

View File

@ -1,43 +0,0 @@
from typing import List
import numpy as np
import torch
import xarray as xr
from soundevent import data
from batdetect2.modules import DetectorModel
from batdetect2.postprocess.arrays import to_xarray
from batdetect2.preprocess import preprocess_audio_clip
def test_this(clip: data.Clip, class_names: List[str]):
spec = xr.DataArray(
data=np.random.rand(100, 100),
dims=["time", "frequency"],
coords={
"time": np.linspace(0, 100, 100, endpoint=False),
"frequency": np.linspace(0, 100, 100, endpoint=False),
},
)
model = DetectorModel()
spec = preprocess_audio_clip(
clip,
config=model.config.preprocessing,
)
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
outputs = model(tensor)
arrays = to_xarray(
outputs,
start_time=clip.start_time,
end_time=clip.end_time,
class_names=class_names,
)
print(arrays)
assert False

View File

@ -0,0 +1,604 @@
from pathlib import Path
from typing import List, Tuple
import numpy as np
import pytest
import xarray as xr
# Removed dataclass import as MockRawPrediction is replaced
from soundevent import data
# Import functions to test
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction,
)
from batdetect2.postprocess.types import RawPrediction
# Dummy GeometryBuilder function fixture
@pytest.fixture
def dummy_geometry_builder():
"""A simple GeometryBuilder that creates a BBox around the point."""
def _builder(
position: Tuple[float, float],
dimensions: xr.DataArray,
) -> data.BoundingBox:
time, freq = position
width = dimensions.sel(dimension="width").item()
height = dimensions.sel(dimension="height").item()
# Assume position is the center
return data.BoundingBox(
coordinates=[
time - width / 2,
freq - height / 2,
time + width / 2,
freq + height / 2,
]
)
return _builder
# Dummy SoundEventDecoder function fixture
@pytest.fixture
def dummy_sound_event_decoder():
"""A simple SoundEventDecoder mapping names to tags."""
tag_map = {
"bat": [
data.Tag(term=data.term_from_key(key="species"), value="Myotis")
],
"noise": [
data.Tag(term=data.term_from_key(key="category"), value="noise")
],
"unknown": [
data.Tag(term=data.term_from_key(key="status"), value="uncertain")
],
}
def _decoder(class_name: str) -> List[data.Tag]:
return tag_map.get(class_name.lower(), [])
return _decoder
@pytest.fixture
def generic_tags() -> List[data.Tag]:
"""Sample generic tags."""
return [
data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
]
@pytest.fixture
def sample_recording() -> data.Recording:
"""A sample soundevent Recording."""
return data.Recording(
path=Path("/path/to/recording.wav"),
duration=60.0,
channels=1,
samplerate=192000,
)
@pytest.fixture
def sample_clip(sample_recording) -> data.Clip:
"""A sample soundevent Clip."""
return data.Clip(
recording=sample_recording,
start_time=10.0,
end_time=20.0,
)
# Fixture for a detection dataset (adapted from test_extraction)
@pytest.fixture
def sample_detection_dataset() -> xr.Dataset:
"""Creates a sample detection dataset suitable for decoding."""
# Based on test_extraction's corrected expectations
# Detections: (t=20, f=300, s=0.9), (t=10, f=200, s=0.8)
expected_times = np.array([20, 10])
expected_freqs = np.array([300, 200])
detection_coords = {
"time": ("detection", expected_times),
"freq": ("detection", expected_freqs),
}
scores_data = np.array([0.9, 0.8], dtype=np.float64)
scores = xr.DataArray(
scores_data,
coords=detection_coords,
dims=["detection"],
name="scores",
)
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
dimensions = xr.DataArray(
dimensions_data,
coords={**detection_coords, "dimension": ["width", "height"]},
dims=["detection", "dimension"],
name="dimensions",
)
classes_data = np.array(
[[0.43, 0.85], [0.24, 0.66]],
dtype=np.float32, # Simplified values
)
classes = xr.DataArray(
classes_data,
coords={**detection_coords, "category": ["bat", "noise"]},
dims=["detection", "category"],
name="classes",
)
features_data = np.array(
[[7.0, 16.0, 25.0, 34.0], [3.0, 12.0, 21.0, 30.0]], dtype=np.float32
)
features = xr.DataArray(
features_data,
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
dims=["detection", "feature"],
name="features",
)
ds = xr.Dataset(
{
"score": scores,
"dimensions": dimensions,
"classes": classes,
"features": features,
},
coords=detection_coords,
)
return ds
@pytest.fixture
def empty_detection_dataset() -> xr.Dataset:
"""Creates an empty detection dataset with correct structure."""
detection_coords = {
"time": ("detection", np.array([], dtype=np.float64)),
"freq": ("detection", np.array([], dtype=np.float64)),
}
scores = xr.DataArray(
np.array([], dtype=np.float64),
coords=detection_coords,
dims=["detection"],
name="scores",
)
dimensions = xr.DataArray(
np.empty((0, 2), dtype=np.float32),
coords={**detection_coords, "dimension": ["width", "height"]},
dims=["detection", "dimension"],
name="dimensions",
)
classes = xr.DataArray(
np.empty((0, 2), dtype=np.float32),
coords={**detection_coords, "category": ["bat", "noise"]},
dims=["detection", "category"],
name="classes",
)
features = xr.DataArray(
np.empty((0, 4), dtype=np.float32),
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
dims=["detection", "feature"],
name="features",
)
return xr.Dataset(
{
"scores": scores,
"dimensions": dimensions,
"classes": classes,
"features": features,
},
coords=detection_coords,
)
# Fixture for sample RawPrediction objects (using the actual type)
@pytest.fixture
def sample_raw_predictions() -> List[RawPrediction]:
"""Manually crafted RawPrediction objects using the actual type."""
# Corresponds roughly to sample_detection_dataset after geometry building
# Det 1: t=20, f=300, s=0.9, w=7, h=16, classes=[0.43, 0.85], feats=[7, 16, 25, 34]
# Det 2: t=10, f=200, s=0.8, w=3, h=12, classes=[0.24, 0.66], feats=[ 3, 12, 21, 30]
pred1_classes = xr.DataArray(
[0.43, 0.85], coords={"category": ["bat", "noise"]}, dims=["category"]
)
pred1_features = xr.DataArray(
[7.0, 16.0, 25.0, 34.0],
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred1 = RawPrediction( # Use RawPrediction directly
detection_score=0.9,
start_time=20 - 7 / 2,
end_time=20 + 7 / 2, # 16.5, 23.5
low_freq=300 - 16 / 2,
high_freq=300 + 16 / 2, # 292, 308
class_scores=pred1_classes,
features=pred1_features,
)
pred2_classes = xr.DataArray(
[0.24, 0.66], coords={"category": ["bat", "noise"]}, dims=["category"]
)
pred2_features = xr.DataArray(
[3.0, 12.0, 21.0, 30.0],
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred2 = RawPrediction( # Use RawPrediction directly
detection_score=0.8,
start_time=10 - 3 / 2,
end_time=10 + 3 / 2, # 8.5, 11.5
low_freq=200 - 12 / 2,
high_freq=200 + 12 / 2, # 194, 206
class_scores=pred2_classes,
features=pred2_features,
)
pred3_classes = xr.DataArray(
[0.05, 0.02], coords={"category": ["bat", "noise"]}, dims=["category"]
) # Below default threshold
pred3_features = xr.DataArray(
[1.0, 2.0, 3.0, 4.0],
coords={"feature": ["f0", "f1", "f2", "f3"]},
dims=["feature"],
)
pred3 = RawPrediction( # Use RawPrediction directly
detection_score=0.15,
start_time=5.0,
end_time=6.0,
low_freq=50.0,
high_freq=60.0,
class_scores=pred3_classes,
features=pred3_features,
)
return [pred1, pred2, pred3]
# --- Tests for convert_xr_dataset_to_raw_prediction ---
def test_convert_xr_dataset_basic(
sample_detection_dataset, dummy_geometry_builder
):
"""Test basic conversion of a dataset to RawPrediction list."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
sample_detection_dataset, dummy_geometry_builder
)
assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 2
# Check first prediction (score=0.9)
pred1 = raw_predictions[0]
assert isinstance(pred1, RawPrediction) # Check against the actual type
assert pred1.detection_score == pytest.approx(0.9)
# Check bounds derived from dummy_geometry_builder (center pos assumed)
# t=20, f=300, w=7, h=16
assert pred1.start_time == pytest.approx(20 - 7 / 2)
assert pred1.end_time == pytest.approx(20 + 7 / 2)
assert pred1.low_freq == pytest.approx(300 - 16 / 2)
assert pred1.high_freq == pytest.approx(300 + 16 / 2)
xr.testing.assert_allclose(
pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0),
)
xr.testing.assert_allclose(
pred1.features, sample_detection_dataset["features"].sel(detection=0)
)
# Check second prediction (score=0.8)
pred2 = raw_predictions[1]
assert isinstance(pred2, RawPrediction) # Check against the actual type
assert pred2.detection_score == pytest.approx(0.8)
# t=10, f=200, w=3, h=12
assert pred2.start_time == pytest.approx(10 - 3 / 2)
assert pred2.end_time == pytest.approx(10 + 3 / 2)
assert pred2.low_freq == pytest.approx(200 - 12 / 2)
assert pred2.high_freq == pytest.approx(200 + 12 / 2)
xr.testing.assert_allclose(
pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1),
)
xr.testing.assert_allclose(
pred2.features, sample_detection_dataset["features"].sel(detection=1)
)
# ...(rest of the tests remain unchanged as they accessed attributes correctly)...
def test_convert_xr_dataset_empty(
empty_detection_dataset, dummy_geometry_builder
):
"""Test conversion of an empty dataset."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
empty_detection_dataset, dummy_geometry_builder
)
assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 0
# --- Tests for convert_raw_prediction_to_sound_event_prediction ---
def test_convert_raw_to_sound_event_basic(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test basic conversion, default threshold, multi-label."""
# score=0.9, classes=[0.43(bat), 0.85(noise)]
raw_pred = sample_raw_predictions[0]
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
# classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD (0.1),
# top_class_only=False,
)
assert isinstance(se_pred, data.SoundEventPrediction)
assert se_pred.score == pytest.approx(raw_pred.detection_score)
# Check SoundEvent
se = se_pred.sound_event
assert isinstance(se, data.SoundEvent)
assert se.recording == sample_recording
assert isinstance(se.geometry, data.BoundingBox)
np.testing.assert_allclose(
se.geometry.coordinates,
[
raw_pred.start_time,
raw_pred.low_freq,
raw_pred.end_time,
raw_pred.high_freq,
],
)
assert len(se.features) == len(raw_pred.features)
# Simple check for feature presence and value type
feat_dict = {f.term.name: f.value for f in se.features}
assert "batdetect2:f0" in feat_dict and isinstance(
feat_dict["batdetect2:f0"], float
)
assert feat_dict["batdetect2:f0"] == pytest.approx(7.0)
# Check Tags
# Expected: Generic(0.9), Noise(0.85), Bat(0.43)
# Note: Order might depend on sortby implementation detail, compare as sets
expected_tags = {
# Generic Tag
(generic_tags[0].key, generic_tags[0].value, 0.9),
# Noise Tag (score 0.85 > 0.1)
("category", "noise", 0.85),
# Bat Tag (score 0.43 > 0.1)
("species", "Myotis", 0.43),
}
print("expected", expected_tags)
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
print("actual", actual_tags)
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test effect of classification threshold."""
raw_pred = sample_raw_predictions[
0
] # score=0.9, classes=[0.43(bat), 0.85(noise)]
high_threshold = 0.5
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=high_threshold, # Only noise should pass
top_class_only=False,
)
# Expected: Generic(0.9), Noise(0.85) - Bat (0.43) is below threshold
expected_tags = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
("category", "noise", pytest.approx(0.85)),
}
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test when classification_threshold is None."""
raw_pred = sample_raw_predictions[
2
] # score=0.15, classes=[0.05(bat), 0.02(noise)]
# Both classes are below default threshold, but should be included if None
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=None, # No thresholding
top_class_only=False,
)
# Expected: Generic(0.15), Bat(0.05), Noise(0.02)
expected_tags = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
("species", "Myotis", pytest.approx(0.05)),
("category", "noise", pytest.approx(0.02)),
}
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test top_class_only=True behavior."""
raw_pred = sample_raw_predictions[
0
] # score=0.9, classes=[0.43(bat), 0.85(noise)]
# Highest score is noise (0.85)
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True, # Only include top class (noise)
)
# Expected: Generic(0.9), Noise(0.85)
expected_tags = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
("category", "noise", pytest.approx(0.85)),
}
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
assert actual_tags == expected_tags
def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions,
sample_recording,
dummy_sound_event_decoder,
generic_tags,
):
"""Test when all class scores are below the default threshold."""
raw_pred = sample_raw_predictions[
2
] # score=0.15, classes=[0.05(bat), 0.02(noise)]
se_pred = convert_raw_prediction_to_sound_event_prediction(
raw_prediction=raw_pred,
recording=sample_recording,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, # 0.1
top_class_only=False,
)
# Expected: Only Generic(0.15) tag, as others are below threshold
expected_tags = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
}
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
assert actual_tags == expected_tags
# --- Tests for convert_raw_predictions_to_clip_prediction ---
def test_convert_raw_list_to_clip_basic(
sample_raw_predictions,
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test converting a list of RawPredictions to a ClipPrediction."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions,
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=False,
)
assert isinstance(clip_pred, data.ClipPrediction)
assert clip_pred.clip == sample_clip
assert len(clip_pred.sound_events) == len(sample_raw_predictions)
# Check if the contained sound events seem correct (basic check)
assert clip_pred.sound_events[0].score == pytest.approx(
sample_raw_predictions[0].detection_score
)
assert clip_pred.sound_events[1].score == pytest.approx(
sample_raw_predictions[1].detection_score
)
assert clip_pred.sound_events[2].score == pytest.approx(
sample_raw_predictions[2].detection_score
)
# Check if tags were generated correctly for one event (e.g., the last one)
# Pred 3 has score 0.15, classes [0.05, 0.02]. Only generic tag expected.
se_pred3_tags = {
(pt.tag.key, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[2].tags
}
expected_tags3 = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
}
assert se_pred3_tags == expected_tags3
def test_convert_raw_list_to_clip_empty(
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test converting an empty list of RawPredictions."""
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=[],
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
)
assert isinstance(clip_pred, data.ClipPrediction)
assert clip_pred.clip == sample_clip
assert len(clip_pred.sound_events) == 0
def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions,
sample_clip,
dummy_sound_event_decoder,
generic_tags,
):
"""Test that arguments like top_class_only are passed through."""
# Use top_class_only = True
clip_pred = convert_raw_predictions_to_clip_prediction(
raw_predictions=sample_raw_predictions,
clip=sample_clip,
sound_event_decoder=dummy_sound_event_decoder,
generic_class_tags=generic_tags,
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only=True, # <<-- Argument being tested
)
assert len(clip_pred.sound_events) == 3
# Check tags for the first prediction (score=0.9, classes=[0.43(bat), 0.85(noise)])
# With top_class_only=True, expect Generic(0.9) and Noise(0.85) only
se_pred1_tags = {
(pt.tag.key, pt.tag.value, pt.score)
for pt in clip_pred.sound_events[0].tags
}
expected_tags1 = {
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
("category", "noise", pytest.approx(0.85)),
}
assert se_pred1_tags == expected_tags1

View File

@ -0,0 +1,214 @@
import numpy as np
import pytest
import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.postprocess.detection import extract_detections_from_array
@pytest.fixture
def sample_data_array():
"""Provides a basic 3x3 DataArray.
Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30)
"""
array = xr.DataArray(
np.zeros([3, 3]),
coords={
Dimensions.frequency.value: [100, 200, 300],
Dimensions.time.value: [10, 20, 30],
},
dims=[
Dimensions.frequency.value,
Dimensions.time.value,
],
)
array.loc[dict(time=10, frequency=100)] = 0.005
array.loc[dict(time=10, frequency=200)] = 0.5
array.loc[dict(time=10, frequency=300)] = 0.03
array.loc[dict(time=20, frequency=100)] = 0.8
array.loc[dict(time=20, frequency=200)] = 0.02
array.loc[dict(time=20, frequency=300)] = 0.6
array.loc[dict(time=30, frequency=100)] = 0.04
array.loc[dict(time=30, frequency=200)] = 0.9
array.loc[dict(time=30, frequency=300)] = 0.7
return array
@pytest.fixture
def data_array_with_nans(sample_data_array: xr.DataArray):
"""Provides a 2D DataArray containing NaN values."""
array = sample_data_array.copy()
array.loc[dict(time=10, frequency=300)] = np.nan
array.loc[dict(time=30, frequency=100)] = np.nan
return array
def test_basic_extraction(sample_data_array: xr.DataArray):
threshold = 0.1
max_detections = 3
actual_result = extract_detections_from_array(
sample_data_array,
threshold=threshold,
max_detections=max_detections,
)
expected_values = np.array([0.9, 0.8, 0.7])
expected_times = np.array([30, 20, 30])
expected_freqs = np.array([200, 100, 300])
expected_coords = {
Dimensions.frequency.value: ("detection", expected_freqs),
Dimensions.time.value: ("detection", expected_times),
}
expected_result = xr.DataArray(
expected_values,
coords=expected_coords,
dims="detection",
name="score",
)
xr.testing.assert_equal(actual_result, expected_result)
def test_threshold_only(sample_data_array):
input_array = sample_data_array
threshold = 0.5
actual_result = extract_detections_from_array(
input_array, threshold=threshold
)
expected_values = np.array([0.9, 0.8, 0.7, 0.6])
expected_times = np.array([30, 20, 30, 20])
expected_freqs = np.array([200, 100, 300, 300])
expected_coords = {
Dimensions.time.value: ("detection", expected_times),
Dimensions.frequency.value: ("detection", expected_freqs),
}
expected_result = xr.DataArray(
expected_values,
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)
def test_max_detections_only(sample_data_array):
input_array = sample_data_array
max_detections = 4
actual_result = extract_detections_from_array(
input_array, max_detections=max_detections
)
expected_values = np.array([0.9, 0.8, 0.7, 0.6])
expected_times = np.array([30, 20, 30, 20])
expected_freqs = np.array([200, 100, 300, 300])
expected_coords = {
Dimensions.time.value: ("detection", expected_times),
Dimensions.frequency.value: ("detection", expected_freqs),
}
expected_result = xr.DataArray(
expected_values,
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)
def test_no_optional_args(sample_data_array):
input_array = sample_data_array
actual_result = extract_detections_from_array(input_array)
expected_values = np.array([0.9, 0.8, 0.7, 0.6, 0.5, 0.04, 0.03, 0.02])
expected_times = np.array([30, 20, 30, 20, 10, 30, 10, 20])
expected_freqs = np.array([200, 100, 300, 300, 200, 100, 300, 200])
expected_coords = {
Dimensions.time.value: ("detection", expected_times),
Dimensions.frequency.value: ("detection", expected_freqs),
}
expected_result = xr.DataArray(
expected_values,
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)
def test_no_values_above_threshold(sample_data_array):
input_array = sample_data_array
threshold = 1.0
actual_result = extract_detections_from_array(
input_array, threshold=threshold
)
expected_coords = {
Dimensions.time.value: ("detection", np.array([], dtype=np.int64)),
Dimensions.frequency.value: (
"detection",
np.array([], dtype=np.int64),
),
}
expected_result = xr.DataArray(
np.array([], dtype=np.float64),
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)
assert actual_result.sizes["detection"] == 0
def test_max_detections_zero(sample_data_array):
input_array = sample_data_array
max_detections = 0
with pytest.raises(ValueError):
extract_detections_from_array(
input_array,
max_detections=max_detections,
)
def test_empty_input_array():
empty_array = xr.DataArray(
np.empty((0, 0)),
coords={Dimensions.time.value: [], Dimensions.frequency.value: []},
dims=[Dimensions.time.value, Dimensions.frequency.value],
)
actual_result = extract_detections_from_array(empty_array)
expected_coords = {
Dimensions.time.value: ("detection", np.array([], dtype=np.int64)),
Dimensions.frequency.value: (
"detection",
np.array([], dtype=np.int64),
),
}
expected_result = xr.DataArray(
np.array([], dtype=np.float64),
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)
assert actual_result.sizes["detection"] == 0
def test_nan_handling(data_array_with_nans):
input_array = data_array_with_nans
threshold = 0.1
max_detections = 3
actual_result = extract_detections_from_array(
input_array, threshold=threshold, max_detections=max_detections
)
expected_values = np.array([0.9, 0.8, 0.7])
expected_times = np.array([30, 20, 30])
expected_freqs = np.array([200, 100, 300])
expected_coords = {
Dimensions.time.value: ("detection", expected_times),
Dimensions.frequency.value: ("detection", expected_freqs),
}
expected_result = xr.DataArray(
expected_values,
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)

View File

@ -0,0 +1,433 @@
import numpy as np
import pytest
import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.postprocess.detection import extract_detections_from_array
from batdetect2.postprocess.extraction import (
extract_detection_xr_dataset,
extract_values_at_positions,
)
@pytest.fixture
def sample_data_array():
"""Provides a basic 3x3 DataArray.
Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30)
"""
coords = {
Dimensions.frequency.value: [100, 200, 300],
Dimensions.time.value: [10, 20, 30],
}
array = xr.DataArray(
np.zeros([3, 3]),
coords=coords,
dims=[
Dimensions.frequency.value,
Dimensions.time.value,
],
)
array.loc[dict(time=10, frequency=100)] = 0.005
array.loc[dict(time=10, frequency=200)] = 0.5
array.loc[dict(time=10, frequency=300)] = 0.03
array.loc[dict(time=20, frequency=100)] = 0.8
array.loc[dict(time=20, frequency=200)] = 0.02
array.loc[dict(time=20, frequency=300)] = 0.6
array.loc[dict(time=30, frequency=100)] = 0.04
array.loc[dict(time=30, frequency=200)] = 0.9
array.loc[dict(time=30, frequency=300)] = 0.7
return array
@pytest.fixture
def sample_array_for_extraction():
"""Provides a simple array (1-9) for value extraction tests."""
data = np.arange(1, 10).reshape(3, 3)
coords = {
Dimensions.frequency.value: [100, 200, 300],
Dimensions.time.value: [10, 20, 30],
}
return xr.DataArray(
data,
coords=coords,
dims=[
Dimensions.frequency.value,
Dimensions.time.value,
],
name="test_values",
)
@pytest.fixture
def sample_positions_top3(sample_data_array):
"""Get top 3 detection positions from sample_data_array."""
# Expected: (f=300, t=20, s=0.9), (f=200, t=10, s=0.8), (f=300, t=30, s=0.7)
return extract_detections_from_array(
sample_data_array, max_detections=3, threshold=None
)
@pytest.fixture
def sample_positions_top2(sample_data_array):
"""Get top 2 detection positions from sample_data_array."""
# Expected: (f=300, t=20, s=0.9), (f=200, t=10, s=0.8)
return extract_detections_from_array(
sample_data_array, max_detections=2, threshold=None
)
@pytest.fixture
def empty_positions(sample_data_array):
"""Get an empty positions array (high threshold)."""
return extract_detections_from_array(
sample_data_array,
threshold=1.0, # No values > 1.0
)
@pytest.fixture
def sample_sizes_array(sample_data_array):
"""Provides a sample sizes array matching sample_data_array coords."""
coords = sample_data_array.coords
# Data: [[0, 1, 2], [3, 4, 5]] # Dim 0 (width)
# [[9,10,11], [12,13,14]] # Dim 1 (height)
# Reshaped: (2, 3, 3) -> (dim, freq, time)
data = np.array(
[
[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
], # width (freq increases down, time across)
[[9, 10, 11], [12, 13, 14], [15, 16, 17]], # height
],
dtype=np.float32,
)
return xr.DataArray(
data,
coords={
"dimension": ["width", "height"],
Dimensions.frequency.value: coords[Dimensions.frequency.value],
Dimensions.time.value: coords[Dimensions.time.value],
},
dims=["dimension", Dimensions.frequency.value, Dimensions.time.value],
name="sizes",
)
@pytest.fixture
def sample_classes_array(sample_data_array):
"""Provides a sample classes array matching sample_data_array coords."""
coords = sample_data_array.coords
# Example: (2 cats, 3 freqs, 3 times)
data = np.linspace(0.1, 0.9, 18, dtype=np.float32).reshape(2, 3, 3)
# data[0, 2, 1] -> cat=0, f=300, t=20 -> val for 0.9 detection
# data[0, 1, 0] -> cat=0, f=200, t=10 -> val for 0.8 detection
return xr.DataArray(
data,
coords={
"category": ["bat", "noise"],
Dimensions.frequency.value: coords[Dimensions.frequency.value],
Dimensions.time.value: coords[Dimensions.time.value],
},
dims=["category", Dimensions.frequency.value, Dimensions.time.value],
name="class_scores",
)
@pytest.fixture
def sample_features_array(sample_data_array):
"""Provides a sample features array matching sample_data_array coords."""
coords = sample_data_array.coords
# Example: (4 features, 3 freqs, 3 times)
data = np.arange(0, 36, dtype=np.float32).reshape(4, 3, 3)
# data[:, 2, 1] -> feats, f=300, t=20 -> vals for 0.9 detection
# data[:, 1, 0] -> feats, f=200, t=10 -> vals for 0.8 detection
return xr.DataArray(
data,
coords={
"feature": ["f0", "f1", "f2", "f3"],
Dimensions.frequency.value: coords[Dimensions.frequency.value],
Dimensions.time.value: coords[Dimensions.time.value],
},
dims=["feature", Dimensions.frequency.value, Dimensions.time.value],
name="features",
)
# --- Tests for extract_values_at_positions ---
def test_extract_values_at_positions_correct(
sample_array_for_extraction, sample_positions_top3
):
"""Verify correct values are extracted based on positions coords."""
# Positions: (f=300, t=20), (f=200, t=10), (f=300, t=30)
# Corresponding values in sample_array_for_extraction (1-9):
# f=300, t=20 -> index (2, 1) -> value 8
# f=200, t=10 -> index (1, 0) -> value 4
# f=300, t=30 -> index (2, 2) -> value 9
expected_values = np.array([8, 4, 9])
print(sample_positions_top3)
expected = xr.DataArray(
expected_values,
coords=sample_positions_top3.coords, # Should inherit coords
dims="detection",
name="test_values", # Should inherit name
)
extracted = extract_values_at_positions(
sample_array_for_extraction, sample_positions_top3
)
xr.testing.assert_allclose(extracted, expected)
def test_extract_values_at_positions_extra_dims(
sample_sizes_array, sample_positions_top2
):
"""Test extraction preserves other dimensions in the source array."""
# Positions: (f=300, t=20), (f=200, t=10)
# Extract from sample_sizes_array (dim, freq, time)
# Det 1 (f=300, t=20) -> index (:, 2, 1) -> values [7, 16]
# Det 2 (f=200, t=10) -> index (:, 1, 0) -> values [3, 12]
# Expected shape: (dimension, detection)
expected_values = np.array([[7.0, 3.0], [16.0, 12.0]], dtype=np.float32)
expected = xr.DataArray(
expected_values,
coords={
"dimension": ["width", "height"],
Dimensions.frequency.value: sample_positions_top2.coords[
Dimensions.frequency.value
],
Dimensions.time.value: sample_positions_top2.coords[
Dimensions.time.value
],
},
dims=["dimension", "detection"],
name="sizes", # Inherits name
)
extracted = extract_values_at_positions(
sample_sizes_array, sample_positions_top2
)
xr.testing.assert_allclose(extracted, expected)
def test_extract_values_at_positions_empty(
sample_array_for_extraction, empty_positions
):
"""Test extraction with empty positions returns empty array."""
extracted = extract_values_at_positions(
sample_array_for_extraction, empty_positions
)
assert extracted.sizes["detection"] == 0
# Check coordinates are also empty but defined
assert Dimensions.time.value in extracted.coords
assert Dimensions.frequency.value in extracted.coords
assert extracted.coords[Dimensions.time.value].size == 0
assert extracted.coords[Dimensions.frequency.value].size == 0
assert extracted.name == sample_array_for_extraction.name
def test_extract_values_at_positions_missing_coord_in_array(
sample_array_for_extraction, sample_positions_top2
):
"""Test error if source array misses required coordinates."""
array_no_time = sample_array_for_extraction.copy()
del array_no_time.coords[Dimensions.time.value]
with pytest.raises(IndexError):
extract_values_at_positions(array_no_time, sample_positions_top2)
array_no_freq = sample_array_for_extraction.copy()
del array_no_freq.coords[Dimensions.frequency.value]
with pytest.raises(IndexError):
extract_values_at_positions(array_no_freq, sample_positions_top2)
def test_extract_values_at_positions_missing_coord_in_positions(
sample_array_for_extraction, sample_positions_top2
):
"""Test error if positions array misses required coordinates."""
positions_no_time = sample_positions_top2.copy()
del positions_no_time.coords[Dimensions.time.value]
with pytest.raises(KeyError):
extract_values_at_positions(
sample_array_for_extraction, positions_no_time
)
positions_no_freq = sample_positions_top2.copy()
del positions_no_freq.coords[Dimensions.frequency.value]
with pytest.raises(KeyError):
extract_values_at_positions(
sample_array_for_extraction, positions_no_freq
)
def test_extract_values_at_positions_mismatched_coords(
sample_array_for_extraction, sample_positions_top2
):
"""Test error if positions requests coords not in source array."""
# Create positions requesting a time=40 not present in sample_array
bad_positions = sample_positions_top2.copy()
bad_positions.coords[Dimensions.time.value] = (
"detection",
np.array([40, 10]), # First time is invalid
)
with pytest.raises(
KeyError
): # xarray.sel raises KeyError for missing labels
extract_values_at_positions(sample_array_for_extraction, bad_positions)
# --- Tests for extract_detection_xr_dataset ---
def test_extract_detection_xr_dataset_correct(
sample_positions_top2,
sample_sizes_array,
sample_classes_array,
sample_features_array,
):
"""Tests extracting and bundling info for top 2 detections."""
actual_dataset = extract_detection_xr_dataset(
sample_positions_top2,
sample_sizes_array,
sample_classes_array,
sample_features_array,
)
# Expected positions (top 2):
# 1. Score 0.9, Time 20, Freq 300. Indices (freq=2, time=1)
# 2. Score 0.8, Time 10, Freq 200. Indices (freq=1, time=0)
expected_times = np.array([20, 10])
expected_freqs = np.array([300, 200])
detection_coords = {
Dimensions.time.value: ("detection", expected_times),
Dimensions.frequency.value: ("detection", expected_freqs),
}
# --- Manually Calculate Expected Data ---
# Scores (already correct in sample_positions_top2)
expected_score = sample_positions_top2.rename(
"scores"
) # Rename to match output
# Dimensions Data (width, height) -> Transposed to (detection, dimension)
# sample_sizes_array data: (dim, freq, time)
# Det 1 (f=300, t=20): index (:, 2, 1) -> values [ 7., 16.]
# Det 2 (f=200, t=10): index (:, 1, 0) -> values [ 3., 12.]
expected_dimensions_data = np.array(
[
[7.0, 16.0], # Detection 1 [width, height]
[3.0, 12.0],
], # Detection 2 [width, height]
dtype=np.float32,
)
expected_dimensions = xr.DataArray(
expected_dimensions_data,
coords={**detection_coords, "dimension": ["width", "height"]},
dims=["detection", "dimension"],
name="dimensions",
)
# Classes Data (bat, noise) -> Transposed to (detection, category)
# sample_classes_array data: np.linspace(0.1, 0.9, 18).reshape(2, 3, 3)
# linspace vals: [0.1, 0.147, 0.194, 0.241, 0.288, 0.335, 0.382, 0.429, 0.476, # cat 0
# 0.523, 0.570, 0.617, 0.664, 0.711, 0.758, 0.805, 0.852, 0.9] # cat 1
# Det 1 (cat, f=2, t=1): index (:, 2, 1) -> values [idx 7=0.429, idx 16=0.852]
# Det 2 (cat, f=1, t=0): index (:, 1, 0) -> values [idx 3=0.241, idx 12=0.664]
expected_classes_data = np.array(
[
[0.42941177, 0.85294118], # Detection 1 [bat_prob, noise_prob]
[0.24117647, 0.66470588],
], # Detection 2 [bat_prob, noise_prob]
dtype=np.float32,
)
expected_classes = xr.DataArray(
expected_classes_data,
coords={**detection_coords, "category": ["bat", "noise"]},
dims=["detection", "category"],
name="classes",
)
# Features Data (f0..f3) -> Transposed to (detection, feature)
# sample_features_array data: np.arange(36).reshape(4, 3, 3)
# Det 1 (feat, f=2, t=1): index (:, 2, 1) -> values [ 7, 16, 25, 34]
# Det 2 (feat, f=1, t=0): index (:, 1, 0) -> values [ 3, 12, 21, 30]
expected_features_data = np.array(
[
[7.0, 16.0, 25.0, 34.0], # Detection 1 [f0, f1, f2, f3]
[3.0, 12.0, 21.0, 30.0],
], # Detection 2 [f0, f1, f2, f3]
dtype=np.float32,
)
expected_features = xr.DataArray(
expected_features_data,
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
dims=["detection", "feature"],
name="features",
)
# Construct Expected Dataset
expected_dataset = xr.Dataset(
{
"scores": expected_score,
"dimensions": expected_dimensions,
"classes": expected_classes,
"features": expected_features,
}
)
# Add coords explicitly to ensure they match
expected_dataset = expected_dataset.assign_coords(detection_coords)
# --- Assert Equality ---
xr.testing.assert_allclose(actual_dataset, expected_dataset)
def test_extract_detection_xr_dataset_empty(
empty_positions,
sample_sizes_array,
sample_classes_array,
sample_features_array,
):
"""Test extraction with empty positions yields an empty dataset."""
actual_dataset = extract_detection_xr_dataset(
empty_positions,
sample_sizes_array,
sample_classes_array,
sample_features_array,
)
assert isinstance(actual_dataset, xr.Dataset)
assert "detection" in actual_dataset.dims
assert actual_dataset.dims["detection"] == 0
# Check variables exist and have 0 size along detection dim
assert "scores" in actual_dataset
assert actual_dataset["scores"].dims == ("detection",)
assert actual_dataset["scores"].size == 0
assert "dimensions" in actual_dataset
assert actual_dataset["dimensions"].dims == ("detection", "dimension")
assert actual_dataset["dimensions"].shape == (0, 2) # Check both dims
assert "classes" in actual_dataset
assert actual_dataset["classes"].dims == ("detection", "category")
assert actual_dataset["classes"].shape == (0, 2)
assert "features" in actual_dataset
assert actual_dataset["features"].dims == ("detection", "feature")
assert actual_dataset["features"].shape == (0, 4)
# Check coordinates exist and are empty
assert Dimensions.time.value in actual_dataset.coords
assert Dimensions.frequency.value in actual_dataset.coords
assert actual_dataset.coords[Dimensions.time.value].size == 0
assert actual_dataset.coords[Dimensions.frequency.value].size == 0