mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Working on postprocess tests
This commit is contained in:
parent
bcf339c40d
commit
1f4454693e
@ -29,6 +29,7 @@ The process involves:
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
@ -256,8 +257,17 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
]
|
||||
),
|
||||
features=[
|
||||
data.Feature(term=data.term_from_key(feat_name), value=value)
|
||||
for feat_name, value in raw_prediction.features
|
||||
data.Feature(
|
||||
term=data.Term(
|
||||
name=f"batdetect2:{feat_name}",
|
||||
label=feat_name,
|
||||
definition="Automatically extracted features by BatDetect2",
|
||||
),
|
||||
value=value,
|
||||
)
|
||||
for feat_name, value in _iterate_over_array(
|
||||
raw_prediction.features
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@ -274,9 +284,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
drop=True,
|
||||
)
|
||||
|
||||
for class_name, score in class_scores.sortby(
|
||||
class_scores, ascending=False
|
||||
):
|
||||
for class_name, score in _iterate_sorted(class_scores):
|
||||
class_tags = sound_event_decoder(class_name)
|
||||
|
||||
for tag in class_tags:
|
||||
@ -295,3 +303,18 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
score=raw_prediction.detection_score,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
def _iterate_over_array(array: xr.DataArray):
|
||||
dim_name = array.dims[0]
|
||||
coords = array.coords[dim_name]
|
||||
for value, coord in zip(array.values, coords.values):
|
||||
yield coord, float(value)
|
||||
|
||||
|
||||
def _iterate_sorted(array: xr.DataArray):
|
||||
dim_name = array.dims[0]
|
||||
coords = array.coords[dim_name]
|
||||
indices = np.argsort(coords.values)
|
||||
for index in indices:
|
||||
yield str(coords[index]), coords.values[index]
|
||||
|
@ -103,7 +103,7 @@ def extract_detections_from_array(
|
||||
top_values = top_values[mask]
|
||||
top_sorted_indices = top_sorted_indices[mask]
|
||||
|
||||
time_indices, freq_indices = np.unravel_index(
|
||||
freq_indices, time_indices = np.unravel_index(
|
||||
top_sorted_indices,
|
||||
detection_array.shape,
|
||||
)
|
||||
|
@ -1,43 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.modules import DetectorModel
|
||||
from batdetect2.postprocess.arrays import to_xarray
|
||||
from batdetect2.preprocess import preprocess_audio_clip
|
||||
|
||||
|
||||
def test_this(clip: data.Clip, class_names: List[str]):
|
||||
spec = xr.DataArray(
|
||||
data=np.random.rand(100, 100),
|
||||
dims=["time", "frequency"],
|
||||
coords={
|
||||
"time": np.linspace(0, 100, 100, endpoint=False),
|
||||
"frequency": np.linspace(0, 100, 100, endpoint=False),
|
||||
},
|
||||
)
|
||||
|
||||
model = DetectorModel()
|
||||
|
||||
spec = preprocess_audio_clip(
|
||||
clip,
|
||||
config=model.config.preprocessing,
|
||||
)
|
||||
|
||||
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
outputs = model(tensor)
|
||||
|
||||
arrays = to_xarray(
|
||||
outputs,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
print(arrays)
|
||||
|
||||
assert False
|
604
tests/test_postprocessing/test_decoding.py
Normal file
604
tests/test_postprocessing/test_decoding.py
Normal file
@ -0,0 +1,604 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
|
||||
# Removed dataclass import as MockRawPrediction is replaced
|
||||
from soundevent import data
|
||||
|
||||
# Import functions to test
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
convert_xr_dataset_to_raw_prediction,
|
||||
)
|
||||
from batdetect2.postprocess.types import RawPrediction
|
||||
|
||||
|
||||
# Dummy GeometryBuilder function fixture
|
||||
@pytest.fixture
|
||||
def dummy_geometry_builder():
|
||||
"""A simple GeometryBuilder that creates a BBox around the point."""
|
||||
|
||||
def _builder(
|
||||
position: Tuple[float, float],
|
||||
dimensions: xr.DataArray,
|
||||
) -> data.BoundingBox:
|
||||
time, freq = position
|
||||
width = dimensions.sel(dimension="width").item()
|
||||
height = dimensions.sel(dimension="height").item()
|
||||
# Assume position is the center
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
time - width / 2,
|
||||
freq - height / 2,
|
||||
time + width / 2,
|
||||
freq + height / 2,
|
||||
]
|
||||
)
|
||||
|
||||
return _builder
|
||||
|
||||
|
||||
# Dummy SoundEventDecoder function fixture
|
||||
@pytest.fixture
|
||||
def dummy_sound_event_decoder():
|
||||
"""A simple SoundEventDecoder mapping names to tags."""
|
||||
tag_map = {
|
||||
"bat": [
|
||||
data.Tag(term=data.term_from_key(key="species"), value="Myotis")
|
||||
],
|
||||
"noise": [
|
||||
data.Tag(term=data.term_from_key(key="category"), value="noise")
|
||||
],
|
||||
"unknown": [
|
||||
data.Tag(term=data.term_from_key(key="status"), value="uncertain")
|
||||
],
|
||||
}
|
||||
|
||||
def _decoder(class_name: str) -> List[data.Tag]:
|
||||
return tag_map.get(class_name.lower(), [])
|
||||
|
||||
return _decoder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generic_tags() -> List[data.Tag]:
|
||||
"""Sample generic tags."""
|
||||
return [
|
||||
data.Tag(term=data.term_from_key(key="detector"), value="batdetect2")
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_recording() -> data.Recording:
|
||||
"""A sample soundevent Recording."""
|
||||
return data.Recording(
|
||||
path=Path("/path/to/recording.wav"),
|
||||
duration=60.0,
|
||||
channels=1,
|
||||
samplerate=192000,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_clip(sample_recording) -> data.Clip:
|
||||
"""A sample soundevent Clip."""
|
||||
return data.Clip(
|
||||
recording=sample_recording,
|
||||
start_time=10.0,
|
||||
end_time=20.0,
|
||||
)
|
||||
|
||||
|
||||
# Fixture for a detection dataset (adapted from test_extraction)
|
||||
@pytest.fixture
|
||||
def sample_detection_dataset() -> xr.Dataset:
|
||||
"""Creates a sample detection dataset suitable for decoding."""
|
||||
# Based on test_extraction's corrected expectations
|
||||
# Detections: (t=20, f=300, s=0.9), (t=10, f=200, s=0.8)
|
||||
expected_times = np.array([20, 10])
|
||||
expected_freqs = np.array([300, 200])
|
||||
detection_coords = {
|
||||
"time": ("detection", expected_times),
|
||||
"freq": ("detection", expected_freqs),
|
||||
}
|
||||
|
||||
scores_data = np.array([0.9, 0.8], dtype=np.float64)
|
||||
scores = xr.DataArray(
|
||||
scores_data,
|
||||
coords=detection_coords,
|
||||
dims=["detection"],
|
||||
name="scores",
|
||||
)
|
||||
|
||||
dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32)
|
||||
dimensions = xr.DataArray(
|
||||
dimensions_data,
|
||||
coords={**detection_coords, "dimension": ["width", "height"]},
|
||||
dims=["detection", "dimension"],
|
||||
name="dimensions",
|
||||
)
|
||||
|
||||
classes_data = np.array(
|
||||
[[0.43, 0.85], [0.24, 0.66]],
|
||||
dtype=np.float32, # Simplified values
|
||||
)
|
||||
classes = xr.DataArray(
|
||||
classes_data,
|
||||
coords={**detection_coords, "category": ["bat", "noise"]},
|
||||
dims=["detection", "category"],
|
||||
name="classes",
|
||||
)
|
||||
|
||||
features_data = np.array(
|
||||
[[7.0, 16.0, 25.0, 34.0], [3.0, 12.0, 21.0, 30.0]], dtype=np.float32
|
||||
)
|
||||
features = xr.DataArray(
|
||||
features_data,
|
||||
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["detection", "feature"],
|
||||
name="features",
|
||||
)
|
||||
|
||||
ds = xr.Dataset(
|
||||
{
|
||||
"score": scores,
|
||||
"dimensions": dimensions,
|
||||
"classes": classes,
|
||||
"features": features,
|
||||
},
|
||||
coords=detection_coords,
|
||||
)
|
||||
return ds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_detection_dataset() -> xr.Dataset:
|
||||
"""Creates an empty detection dataset with correct structure."""
|
||||
detection_coords = {
|
||||
"time": ("detection", np.array([], dtype=np.float64)),
|
||||
"freq": ("detection", np.array([], dtype=np.float64)),
|
||||
}
|
||||
scores = xr.DataArray(
|
||||
np.array([], dtype=np.float64),
|
||||
coords=detection_coords,
|
||||
dims=["detection"],
|
||||
name="scores",
|
||||
)
|
||||
dimensions = xr.DataArray(
|
||||
np.empty((0, 2), dtype=np.float32),
|
||||
coords={**detection_coords, "dimension": ["width", "height"]},
|
||||
dims=["detection", "dimension"],
|
||||
name="dimensions",
|
||||
)
|
||||
classes = xr.DataArray(
|
||||
np.empty((0, 2), dtype=np.float32),
|
||||
coords={**detection_coords, "category": ["bat", "noise"]},
|
||||
dims=["detection", "category"],
|
||||
name="classes",
|
||||
)
|
||||
features = xr.DataArray(
|
||||
np.empty((0, 4), dtype=np.float32),
|
||||
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["detection", "feature"],
|
||||
name="features",
|
||||
)
|
||||
return xr.Dataset(
|
||||
{
|
||||
"scores": scores,
|
||||
"dimensions": dimensions,
|
||||
"classes": classes,
|
||||
"features": features,
|
||||
},
|
||||
coords=detection_coords,
|
||||
)
|
||||
|
||||
|
||||
# Fixture for sample RawPrediction objects (using the actual type)
|
||||
@pytest.fixture
|
||||
def sample_raw_predictions() -> List[RawPrediction]:
|
||||
"""Manually crafted RawPrediction objects using the actual type."""
|
||||
# Corresponds roughly to sample_detection_dataset after geometry building
|
||||
# Det 1: t=20, f=300, s=0.9, w=7, h=16, classes=[0.43, 0.85], feats=[7, 16, 25, 34]
|
||||
# Det 2: t=10, f=200, s=0.8, w=3, h=12, classes=[0.24, 0.66], feats=[ 3, 12, 21, 30]
|
||||
pred1_classes = xr.DataArray(
|
||||
[0.43, 0.85], coords={"category": ["bat", "noise"]}, dims=["category"]
|
||||
)
|
||||
pred1_features = xr.DataArray(
|
||||
[7.0, 16.0, 25.0, 34.0],
|
||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
pred1 = RawPrediction( # Use RawPrediction directly
|
||||
detection_score=0.9,
|
||||
start_time=20 - 7 / 2,
|
||||
end_time=20 + 7 / 2, # 16.5, 23.5
|
||||
low_freq=300 - 16 / 2,
|
||||
high_freq=300 + 16 / 2, # 292, 308
|
||||
class_scores=pred1_classes,
|
||||
features=pred1_features,
|
||||
)
|
||||
|
||||
pred2_classes = xr.DataArray(
|
||||
[0.24, 0.66], coords={"category": ["bat", "noise"]}, dims=["category"]
|
||||
)
|
||||
pred2_features = xr.DataArray(
|
||||
[3.0, 12.0, 21.0, 30.0],
|
||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
pred2 = RawPrediction( # Use RawPrediction directly
|
||||
detection_score=0.8,
|
||||
start_time=10 - 3 / 2,
|
||||
end_time=10 + 3 / 2, # 8.5, 11.5
|
||||
low_freq=200 - 12 / 2,
|
||||
high_freq=200 + 12 / 2, # 194, 206
|
||||
class_scores=pred2_classes,
|
||||
features=pred2_features,
|
||||
)
|
||||
|
||||
pred3_classes = xr.DataArray(
|
||||
[0.05, 0.02], coords={"category": ["bat", "noise"]}, dims=["category"]
|
||||
) # Below default threshold
|
||||
pred3_features = xr.DataArray(
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
coords={"feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["feature"],
|
||||
)
|
||||
pred3 = RawPrediction( # Use RawPrediction directly
|
||||
detection_score=0.15,
|
||||
start_time=5.0,
|
||||
end_time=6.0,
|
||||
low_freq=50.0,
|
||||
high_freq=60.0,
|
||||
class_scores=pred3_classes,
|
||||
features=pred3_features,
|
||||
)
|
||||
return [pred1, pred2, pred3]
|
||||
|
||||
|
||||
# --- Tests for convert_xr_dataset_to_raw_prediction ---
|
||||
|
||||
|
||||
def test_convert_xr_dataset_basic(
|
||||
sample_detection_dataset, dummy_geometry_builder
|
||||
):
|
||||
"""Test basic conversion of a dataset to RawPrediction list."""
|
||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
||||
sample_detection_dataset, dummy_geometry_builder
|
||||
)
|
||||
|
||||
assert isinstance(raw_predictions, list)
|
||||
assert len(raw_predictions) == 2
|
||||
|
||||
# Check first prediction (score=0.9)
|
||||
pred1 = raw_predictions[0]
|
||||
assert isinstance(pred1, RawPrediction) # Check against the actual type
|
||||
assert pred1.detection_score == pytest.approx(0.9)
|
||||
# Check bounds derived from dummy_geometry_builder (center pos assumed)
|
||||
# t=20, f=300, w=7, h=16
|
||||
assert pred1.start_time == pytest.approx(20 - 7 / 2)
|
||||
assert pred1.end_time == pytest.approx(20 + 7 / 2)
|
||||
assert pred1.low_freq == pytest.approx(300 - 16 / 2)
|
||||
assert pred1.high_freq == pytest.approx(300 + 16 / 2)
|
||||
xr.testing.assert_allclose(
|
||||
pred1.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=0),
|
||||
)
|
||||
xr.testing.assert_allclose(
|
||||
pred1.features, sample_detection_dataset["features"].sel(detection=0)
|
||||
)
|
||||
|
||||
# Check second prediction (score=0.8)
|
||||
pred2 = raw_predictions[1]
|
||||
assert isinstance(pred2, RawPrediction) # Check against the actual type
|
||||
assert pred2.detection_score == pytest.approx(0.8)
|
||||
# t=10, f=200, w=3, h=12
|
||||
assert pred2.start_time == pytest.approx(10 - 3 / 2)
|
||||
assert pred2.end_time == pytest.approx(10 + 3 / 2)
|
||||
assert pred2.low_freq == pytest.approx(200 - 12 / 2)
|
||||
assert pred2.high_freq == pytest.approx(200 + 12 / 2)
|
||||
xr.testing.assert_allclose(
|
||||
pred2.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=1),
|
||||
)
|
||||
xr.testing.assert_allclose(
|
||||
pred2.features, sample_detection_dataset["features"].sel(detection=1)
|
||||
)
|
||||
|
||||
|
||||
# ...(rest of the tests remain unchanged as they accessed attributes correctly)...
|
||||
|
||||
|
||||
def test_convert_xr_dataset_empty(
|
||||
empty_detection_dataset, dummy_geometry_builder
|
||||
):
|
||||
"""Test conversion of an empty dataset."""
|
||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
||||
empty_detection_dataset, dummy_geometry_builder
|
||||
)
|
||||
assert isinstance(raw_predictions, list)
|
||||
assert len(raw_predictions) == 0
|
||||
|
||||
|
||||
# --- Tests for convert_raw_prediction_to_sound_event_prediction ---
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_basic(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
):
|
||||
"""Test basic conversion, default threshold, multi-label."""
|
||||
# score=0.9, classes=[0.43(bat), 0.85(noise)]
|
||||
raw_pred = sample_raw_predictions[0]
|
||||
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
# classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD (0.1),
|
||||
# top_class_only=False,
|
||||
)
|
||||
|
||||
assert isinstance(se_pred, data.SoundEventPrediction)
|
||||
assert se_pred.score == pytest.approx(raw_pred.detection_score)
|
||||
|
||||
# Check SoundEvent
|
||||
se = se_pred.sound_event
|
||||
assert isinstance(se, data.SoundEvent)
|
||||
assert se.recording == sample_recording
|
||||
assert isinstance(se.geometry, data.BoundingBox)
|
||||
np.testing.assert_allclose(
|
||||
se.geometry.coordinates,
|
||||
[
|
||||
raw_pred.start_time,
|
||||
raw_pred.low_freq,
|
||||
raw_pred.end_time,
|
||||
raw_pred.high_freq,
|
||||
],
|
||||
)
|
||||
assert len(se.features) == len(raw_pred.features)
|
||||
# Simple check for feature presence and value type
|
||||
feat_dict = {f.term.name: f.value for f in se.features}
|
||||
assert "batdetect2:f0" in feat_dict and isinstance(
|
||||
feat_dict["batdetect2:f0"], float
|
||||
)
|
||||
assert feat_dict["batdetect2:f0"] == pytest.approx(7.0)
|
||||
|
||||
# Check Tags
|
||||
# Expected: Generic(0.9), Noise(0.85), Bat(0.43)
|
||||
# Note: Order might depend on sortby implementation detail, compare as sets
|
||||
expected_tags = {
|
||||
# Generic Tag
|
||||
(generic_tags[0].key, generic_tags[0].value, 0.9),
|
||||
# Noise Tag (score 0.85 > 0.1)
|
||||
("category", "noise", 0.85),
|
||||
# Bat Tag (score 0.43 > 0.1)
|
||||
("species", "Myotis", 0.43),
|
||||
}
|
||||
print("expected", expected_tags)
|
||||
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
|
||||
print("actual", actual_tags)
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_thresholding(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
):
|
||||
"""Test effect of classification threshold."""
|
||||
raw_pred = sample_raw_predictions[
|
||||
0
|
||||
] # score=0.9, classes=[0.43(bat), 0.85(noise)]
|
||||
high_threshold = 0.5
|
||||
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=high_threshold, # Only noise should pass
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
# Expected: Generic(0.9), Noise(0.85) - Bat (0.43) is below threshold
|
||||
expected_tags = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
|
||||
("category", "noise", pytest.approx(0.85)),
|
||||
}
|
||||
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_no_threshold(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
):
|
||||
"""Test when classification_threshold is None."""
|
||||
raw_pred = sample_raw_predictions[
|
||||
2
|
||||
] # score=0.15, classes=[0.05(bat), 0.02(noise)]
|
||||
# Both classes are below default threshold, but should be included if None
|
||||
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=None, # No thresholding
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
# Expected: Generic(0.15), Bat(0.05), Noise(0.02)
|
||||
expected_tags = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
|
||||
("species", "Myotis", pytest.approx(0.05)),
|
||||
("category", "noise", pytest.approx(0.02)),
|
||||
}
|
||||
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_top_class(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
):
|
||||
"""Test top_class_only=True behavior."""
|
||||
raw_pred = sample_raw_predictions[
|
||||
0
|
||||
] # score=0.9, classes=[0.43(bat), 0.85(noise)]
|
||||
# Highest score is noise (0.85)
|
||||
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=True, # Only include top class (noise)
|
||||
)
|
||||
|
||||
# Expected: Generic(0.9), Noise(0.85)
|
||||
expected_tags = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
|
||||
("category", "noise", pytest.approx(0.85)),
|
||||
}
|
||||
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_all_below_threshold(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
):
|
||||
"""Test when all class scores are below the default threshold."""
|
||||
raw_pred = sample_raw_predictions[
|
||||
2
|
||||
] # score=0.15, classes=[0.05(bat), 0.02(noise)]
|
||||
|
||||
se_pred = convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction=raw_pred,
|
||||
recording=sample_recording,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, # 0.1
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
# Expected: Only Generic(0.15) tag, as others are below threshold
|
||||
expected_tags = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
|
||||
}
|
||||
actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags}
|
||||
assert actual_tags == expected_tags
|
||||
|
||||
|
||||
# --- Tests for convert_raw_predictions_to_clip_prediction ---
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_basic(
|
||||
sample_raw_predictions,
|
||||
sample_clip,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
):
|
||||
"""Test converting a list of RawPredictions to a ClipPrediction."""
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions=sample_raw_predictions,
|
||||
clip=sample_clip,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=False,
|
||||
)
|
||||
|
||||
assert isinstance(clip_pred, data.ClipPrediction)
|
||||
assert clip_pred.clip == sample_clip
|
||||
assert len(clip_pred.sound_events) == len(sample_raw_predictions)
|
||||
|
||||
# Check if the contained sound events seem correct (basic check)
|
||||
assert clip_pred.sound_events[0].score == pytest.approx(
|
||||
sample_raw_predictions[0].detection_score
|
||||
)
|
||||
assert clip_pred.sound_events[1].score == pytest.approx(
|
||||
sample_raw_predictions[1].detection_score
|
||||
)
|
||||
assert clip_pred.sound_events[2].score == pytest.approx(
|
||||
sample_raw_predictions[2].detection_score
|
||||
)
|
||||
|
||||
# Check if tags were generated correctly for one event (e.g., the last one)
|
||||
# Pred 3 has score 0.15, classes [0.05, 0.02]. Only generic tag expected.
|
||||
se_pred3_tags = {
|
||||
(pt.tag.key, pt.tag.value, pt.score)
|
||||
for pt in clip_pred.sound_events[2].tags
|
||||
}
|
||||
expected_tags3 = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)),
|
||||
}
|
||||
assert se_pred3_tags == expected_tags3
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_empty(
|
||||
sample_clip,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
):
|
||||
"""Test converting an empty list of RawPredictions."""
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions=[],
|
||||
clip=sample_clip,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
)
|
||||
|
||||
assert isinstance(clip_pred, data.ClipPrediction)
|
||||
assert clip_pred.clip == sample_clip
|
||||
assert len(clip_pred.sound_events) == 0
|
||||
|
||||
|
||||
def test_convert_raw_list_to_clip_passes_args(
|
||||
sample_raw_predictions,
|
||||
sample_clip,
|
||||
dummy_sound_event_decoder,
|
||||
generic_tags,
|
||||
):
|
||||
"""Test that arguments like top_class_only are passed through."""
|
||||
# Use top_class_only = True
|
||||
clip_pred = convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions=sample_raw_predictions,
|
||||
clip=sample_clip,
|
||||
sound_event_decoder=dummy_sound_event_decoder,
|
||||
generic_class_tags=generic_tags,
|
||||
classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only=True, # <<-- Argument being tested
|
||||
)
|
||||
|
||||
assert len(clip_pred.sound_events) == 3
|
||||
|
||||
# Check tags for the first prediction (score=0.9, classes=[0.43(bat), 0.85(noise)])
|
||||
# With top_class_only=True, expect Generic(0.9) and Noise(0.85) only
|
||||
se_pred1_tags = {
|
||||
(pt.tag.key, pt.tag.value, pt.score)
|
||||
for pt in clip_pred.sound_events[0].tags
|
||||
}
|
||||
expected_tags1 = {
|
||||
(generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)),
|
||||
("category", "noise", pytest.approx(0.85)),
|
||||
}
|
||||
assert se_pred1_tags == expected_tags1
|
214
tests/test_postprocessing/test_detection.py
Normal file
214
tests/test_postprocessing/test_detection.py
Normal file
@ -0,0 +1,214 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.postprocess.detection import extract_detections_from_array
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data_array():
|
||||
"""Provides a basic 3x3 DataArray.
|
||||
Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30)
|
||||
"""
|
||||
array = xr.DataArray(
|
||||
np.zeros([3, 3]),
|
||||
coords={
|
||||
Dimensions.frequency.value: [100, 200, 300],
|
||||
Dimensions.time.value: [10, 20, 30],
|
||||
},
|
||||
dims=[
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
)
|
||||
|
||||
array.loc[dict(time=10, frequency=100)] = 0.005
|
||||
array.loc[dict(time=10, frequency=200)] = 0.5
|
||||
array.loc[dict(time=10, frequency=300)] = 0.03
|
||||
array.loc[dict(time=20, frequency=100)] = 0.8
|
||||
array.loc[dict(time=20, frequency=200)] = 0.02
|
||||
array.loc[dict(time=20, frequency=300)] = 0.6
|
||||
array.loc[dict(time=30, frequency=100)] = 0.04
|
||||
array.loc[dict(time=30, frequency=200)] = 0.9
|
||||
array.loc[dict(time=30, frequency=300)] = 0.7
|
||||
return array
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_array_with_nans(sample_data_array: xr.DataArray):
|
||||
"""Provides a 2D DataArray containing NaN values."""
|
||||
array = sample_data_array.copy()
|
||||
array.loc[dict(time=10, frequency=300)] = np.nan
|
||||
array.loc[dict(time=30, frequency=100)] = np.nan
|
||||
return array
|
||||
|
||||
|
||||
def test_basic_extraction(sample_data_array: xr.DataArray):
|
||||
threshold = 0.1
|
||||
max_detections = 3
|
||||
|
||||
actual_result = extract_detections_from_array(
|
||||
sample_data_array,
|
||||
threshold=threshold,
|
||||
max_detections=max_detections,
|
||||
)
|
||||
|
||||
expected_values = np.array([0.9, 0.8, 0.7])
|
||||
expected_times = np.array([30, 20, 30])
|
||||
expected_freqs = np.array([200, 100, 300])
|
||||
expected_coords = {
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
expected_values,
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="score",
|
||||
)
|
||||
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
|
||||
|
||||
def test_threshold_only(sample_data_array):
|
||||
input_array = sample_data_array
|
||||
threshold = 0.5
|
||||
actual_result = extract_detections_from_array(
|
||||
input_array, threshold=threshold
|
||||
)
|
||||
expected_values = np.array([0.9, 0.8, 0.7, 0.6])
|
||||
expected_times = np.array([30, 20, 30, 20])
|
||||
expected_freqs = np.array([200, 100, 300, 300])
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
expected_values,
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
|
||||
|
||||
def test_max_detections_only(sample_data_array):
|
||||
input_array = sample_data_array
|
||||
max_detections = 4
|
||||
actual_result = extract_detections_from_array(
|
||||
input_array, max_detections=max_detections
|
||||
)
|
||||
expected_values = np.array([0.9, 0.8, 0.7, 0.6])
|
||||
expected_times = np.array([30, 20, 30, 20])
|
||||
expected_freqs = np.array([200, 100, 300, 300])
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
expected_values,
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
|
||||
|
||||
def test_no_optional_args(sample_data_array):
|
||||
input_array = sample_data_array
|
||||
actual_result = extract_detections_from_array(input_array)
|
||||
expected_values = np.array([0.9, 0.8, 0.7, 0.6, 0.5, 0.04, 0.03, 0.02])
|
||||
expected_times = np.array([30, 20, 30, 20, 10, 30, 10, 20])
|
||||
expected_freqs = np.array([200, 100, 300, 300, 200, 100, 300, 200])
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
expected_values,
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
|
||||
|
||||
def test_no_values_above_threshold(sample_data_array):
|
||||
input_array = sample_data_array
|
||||
threshold = 1.0
|
||||
actual_result = extract_detections_from_array(
|
||||
input_array, threshold=threshold
|
||||
)
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", np.array([], dtype=np.int64)),
|
||||
Dimensions.frequency.value: (
|
||||
"detection",
|
||||
np.array([], dtype=np.int64),
|
||||
),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
np.array([], dtype=np.float64),
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
assert actual_result.sizes["detection"] == 0
|
||||
|
||||
|
||||
def test_max_detections_zero(sample_data_array):
|
||||
input_array = sample_data_array
|
||||
max_detections = 0
|
||||
with pytest.raises(ValueError):
|
||||
extract_detections_from_array(
|
||||
input_array,
|
||||
max_detections=max_detections,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_input_array():
|
||||
empty_array = xr.DataArray(
|
||||
np.empty((0, 0)),
|
||||
coords={Dimensions.time.value: [], Dimensions.frequency.value: []},
|
||||
dims=[Dimensions.time.value, Dimensions.frequency.value],
|
||||
)
|
||||
actual_result = extract_detections_from_array(empty_array)
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", np.array([], dtype=np.int64)),
|
||||
Dimensions.frequency.value: (
|
||||
"detection",
|
||||
np.array([], dtype=np.int64),
|
||||
),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
np.array([], dtype=np.float64),
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
assert actual_result.sizes["detection"] == 0
|
||||
|
||||
|
||||
def test_nan_handling(data_array_with_nans):
|
||||
input_array = data_array_with_nans
|
||||
threshold = 0.1
|
||||
max_detections = 3
|
||||
actual_result = extract_detections_from_array(
|
||||
input_array, threshold=threshold, max_detections=max_detections
|
||||
)
|
||||
expected_values = np.array([0.9, 0.8, 0.7])
|
||||
expected_times = np.array([30, 20, 30])
|
||||
expected_freqs = np.array([200, 100, 300])
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
expected_values,
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
433
tests/test_postprocessing/test_extraction.py
Normal file
433
tests/test_postprocessing/test_extraction.py
Normal file
@ -0,0 +1,433 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.postprocess.detection import extract_detections_from_array
|
||||
from batdetect2.postprocess.extraction import (
|
||||
extract_detection_xr_dataset,
|
||||
extract_values_at_positions,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data_array():
|
||||
"""Provides a basic 3x3 DataArray.
|
||||
Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30)
|
||||
"""
|
||||
coords = {
|
||||
Dimensions.frequency.value: [100, 200, 300],
|
||||
Dimensions.time.value: [10, 20, 30],
|
||||
}
|
||||
array = xr.DataArray(
|
||||
np.zeros([3, 3]),
|
||||
coords=coords,
|
||||
dims=[
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
)
|
||||
|
||||
array.loc[dict(time=10, frequency=100)] = 0.005
|
||||
array.loc[dict(time=10, frequency=200)] = 0.5
|
||||
array.loc[dict(time=10, frequency=300)] = 0.03
|
||||
array.loc[dict(time=20, frequency=100)] = 0.8
|
||||
array.loc[dict(time=20, frequency=200)] = 0.02
|
||||
array.loc[dict(time=20, frequency=300)] = 0.6
|
||||
array.loc[dict(time=30, frequency=100)] = 0.04
|
||||
array.loc[dict(time=30, frequency=200)] = 0.9
|
||||
array.loc[dict(time=30, frequency=300)] = 0.7
|
||||
return array
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array_for_extraction():
|
||||
"""Provides a simple array (1-9) for value extraction tests."""
|
||||
data = np.arange(1, 10).reshape(3, 3)
|
||||
coords = {
|
||||
Dimensions.frequency.value: [100, 200, 300],
|
||||
Dimensions.time.value: [10, 20, 30],
|
||||
}
|
||||
return xr.DataArray(
|
||||
data,
|
||||
coords=coords,
|
||||
dims=[
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
name="test_values",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_positions_top3(sample_data_array):
|
||||
"""Get top 3 detection positions from sample_data_array."""
|
||||
# Expected: (f=300, t=20, s=0.9), (f=200, t=10, s=0.8), (f=300, t=30, s=0.7)
|
||||
return extract_detections_from_array(
|
||||
sample_data_array, max_detections=3, threshold=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_positions_top2(sample_data_array):
|
||||
"""Get top 2 detection positions from sample_data_array."""
|
||||
# Expected: (f=300, t=20, s=0.9), (f=200, t=10, s=0.8)
|
||||
return extract_detections_from_array(
|
||||
sample_data_array, max_detections=2, threshold=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_positions(sample_data_array):
|
||||
"""Get an empty positions array (high threshold)."""
|
||||
return extract_detections_from_array(
|
||||
sample_data_array,
|
||||
threshold=1.0, # No values > 1.0
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sizes_array(sample_data_array):
|
||||
"""Provides a sample sizes array matching sample_data_array coords."""
|
||||
coords = sample_data_array.coords
|
||||
# Data: [[0, 1, 2], [3, 4, 5]] # Dim 0 (width)
|
||||
# [[9,10,11], [12,13,14]] # Dim 1 (height)
|
||||
# Reshaped: (2, 3, 3) -> (dim, freq, time)
|
||||
data = np.array(
|
||||
[
|
||||
[
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8],
|
||||
], # width (freq increases down, time across)
|
||||
[[9, 10, 11], [12, 13, 14], [15, 16, 17]], # height
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data,
|
||||
coords={
|
||||
"dimension": ["width", "height"],
|
||||
Dimensions.frequency.value: coords[Dimensions.frequency.value],
|
||||
Dimensions.time.value: coords[Dimensions.time.value],
|
||||
},
|
||||
dims=["dimension", Dimensions.frequency.value, Dimensions.time.value],
|
||||
name="sizes",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_classes_array(sample_data_array):
|
||||
"""Provides a sample classes array matching sample_data_array coords."""
|
||||
coords = sample_data_array.coords
|
||||
# Example: (2 cats, 3 freqs, 3 times)
|
||||
data = np.linspace(0.1, 0.9, 18, dtype=np.float32).reshape(2, 3, 3)
|
||||
# data[0, 2, 1] -> cat=0, f=300, t=20 -> val for 0.9 detection
|
||||
# data[0, 1, 0] -> cat=0, f=200, t=10 -> val for 0.8 detection
|
||||
return xr.DataArray(
|
||||
data,
|
||||
coords={
|
||||
"category": ["bat", "noise"],
|
||||
Dimensions.frequency.value: coords[Dimensions.frequency.value],
|
||||
Dimensions.time.value: coords[Dimensions.time.value],
|
||||
},
|
||||
dims=["category", Dimensions.frequency.value, Dimensions.time.value],
|
||||
name="class_scores",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_features_array(sample_data_array):
|
||||
"""Provides a sample features array matching sample_data_array coords."""
|
||||
coords = sample_data_array.coords
|
||||
# Example: (4 features, 3 freqs, 3 times)
|
||||
data = np.arange(0, 36, dtype=np.float32).reshape(4, 3, 3)
|
||||
# data[:, 2, 1] -> feats, f=300, t=20 -> vals for 0.9 detection
|
||||
# data[:, 1, 0] -> feats, f=200, t=10 -> vals for 0.8 detection
|
||||
return xr.DataArray(
|
||||
data,
|
||||
coords={
|
||||
"feature": ["f0", "f1", "f2", "f3"],
|
||||
Dimensions.frequency.value: coords[Dimensions.frequency.value],
|
||||
Dimensions.time.value: coords[Dimensions.time.value],
|
||||
},
|
||||
dims=["feature", Dimensions.frequency.value, Dimensions.time.value],
|
||||
name="features",
|
||||
)
|
||||
|
||||
|
||||
# --- Tests for extract_values_at_positions ---
|
||||
|
||||
|
||||
def test_extract_values_at_positions_correct(
|
||||
sample_array_for_extraction, sample_positions_top3
|
||||
):
|
||||
"""Verify correct values are extracted based on positions coords."""
|
||||
# Positions: (f=300, t=20), (f=200, t=10), (f=300, t=30)
|
||||
# Corresponding values in sample_array_for_extraction (1-9):
|
||||
# f=300, t=20 -> index (2, 1) -> value 8
|
||||
# f=200, t=10 -> index (1, 0) -> value 4
|
||||
# f=300, t=30 -> index (2, 2) -> value 9
|
||||
expected_values = np.array([8, 4, 9])
|
||||
|
||||
print(sample_positions_top3)
|
||||
|
||||
expected = xr.DataArray(
|
||||
expected_values,
|
||||
coords=sample_positions_top3.coords, # Should inherit coords
|
||||
dims="detection",
|
||||
name="test_values", # Should inherit name
|
||||
)
|
||||
|
||||
extracted = extract_values_at_positions(
|
||||
sample_array_for_extraction, sample_positions_top3
|
||||
)
|
||||
|
||||
xr.testing.assert_allclose(extracted, expected)
|
||||
|
||||
|
||||
def test_extract_values_at_positions_extra_dims(
|
||||
sample_sizes_array, sample_positions_top2
|
||||
):
|
||||
"""Test extraction preserves other dimensions in the source array."""
|
||||
# Positions: (f=300, t=20), (f=200, t=10)
|
||||
# Extract from sample_sizes_array (dim, freq, time)
|
||||
# Det 1 (f=300, t=20) -> index (:, 2, 1) -> values [7, 16]
|
||||
# Det 2 (f=200, t=10) -> index (:, 1, 0) -> values [3, 12]
|
||||
# Expected shape: (dimension, detection)
|
||||
expected_values = np.array([[7.0, 3.0], [16.0, 12.0]], dtype=np.float32)
|
||||
|
||||
expected = xr.DataArray(
|
||||
expected_values,
|
||||
coords={
|
||||
"dimension": ["width", "height"],
|
||||
Dimensions.frequency.value: sample_positions_top2.coords[
|
||||
Dimensions.frequency.value
|
||||
],
|
||||
Dimensions.time.value: sample_positions_top2.coords[
|
||||
Dimensions.time.value
|
||||
],
|
||||
},
|
||||
dims=["dimension", "detection"],
|
||||
name="sizes", # Inherits name
|
||||
)
|
||||
|
||||
extracted = extract_values_at_positions(
|
||||
sample_sizes_array, sample_positions_top2
|
||||
)
|
||||
xr.testing.assert_allclose(extracted, expected)
|
||||
|
||||
|
||||
def test_extract_values_at_positions_empty(
|
||||
sample_array_for_extraction, empty_positions
|
||||
):
|
||||
"""Test extraction with empty positions returns empty array."""
|
||||
extracted = extract_values_at_positions(
|
||||
sample_array_for_extraction, empty_positions
|
||||
)
|
||||
assert extracted.sizes["detection"] == 0
|
||||
# Check coordinates are also empty but defined
|
||||
assert Dimensions.time.value in extracted.coords
|
||||
assert Dimensions.frequency.value in extracted.coords
|
||||
assert extracted.coords[Dimensions.time.value].size == 0
|
||||
assert extracted.coords[Dimensions.frequency.value].size == 0
|
||||
assert extracted.name == sample_array_for_extraction.name
|
||||
|
||||
|
||||
def test_extract_values_at_positions_missing_coord_in_array(
|
||||
sample_array_for_extraction, sample_positions_top2
|
||||
):
|
||||
"""Test error if source array misses required coordinates."""
|
||||
array_no_time = sample_array_for_extraction.copy()
|
||||
del array_no_time.coords[Dimensions.time.value]
|
||||
with pytest.raises(IndexError):
|
||||
extract_values_at_positions(array_no_time, sample_positions_top2)
|
||||
|
||||
array_no_freq = sample_array_for_extraction.copy()
|
||||
del array_no_freq.coords[Dimensions.frequency.value]
|
||||
with pytest.raises(IndexError):
|
||||
extract_values_at_positions(array_no_freq, sample_positions_top2)
|
||||
|
||||
|
||||
def test_extract_values_at_positions_missing_coord_in_positions(
|
||||
sample_array_for_extraction, sample_positions_top2
|
||||
):
|
||||
"""Test error if positions array misses required coordinates."""
|
||||
positions_no_time = sample_positions_top2.copy()
|
||||
del positions_no_time.coords[Dimensions.time.value]
|
||||
with pytest.raises(KeyError):
|
||||
extract_values_at_positions(
|
||||
sample_array_for_extraction, positions_no_time
|
||||
)
|
||||
|
||||
positions_no_freq = sample_positions_top2.copy()
|
||||
del positions_no_freq.coords[Dimensions.frequency.value]
|
||||
with pytest.raises(KeyError):
|
||||
extract_values_at_positions(
|
||||
sample_array_for_extraction, positions_no_freq
|
||||
)
|
||||
|
||||
|
||||
def test_extract_values_at_positions_mismatched_coords(
|
||||
sample_array_for_extraction, sample_positions_top2
|
||||
):
|
||||
"""Test error if positions requests coords not in source array."""
|
||||
# Create positions requesting a time=40 not present in sample_array
|
||||
bad_positions = sample_positions_top2.copy()
|
||||
bad_positions.coords[Dimensions.time.value] = (
|
||||
"detection",
|
||||
np.array([40, 10]), # First time is invalid
|
||||
)
|
||||
with pytest.raises(
|
||||
KeyError
|
||||
): # xarray.sel raises KeyError for missing labels
|
||||
extract_values_at_positions(sample_array_for_extraction, bad_positions)
|
||||
|
||||
|
||||
# --- Tests for extract_detection_xr_dataset ---
|
||||
|
||||
|
||||
def test_extract_detection_xr_dataset_correct(
|
||||
sample_positions_top2,
|
||||
sample_sizes_array,
|
||||
sample_classes_array,
|
||||
sample_features_array,
|
||||
):
|
||||
"""Tests extracting and bundling info for top 2 detections."""
|
||||
actual_dataset = extract_detection_xr_dataset(
|
||||
sample_positions_top2,
|
||||
sample_sizes_array,
|
||||
sample_classes_array,
|
||||
sample_features_array,
|
||||
)
|
||||
|
||||
# Expected positions (top 2):
|
||||
# 1. Score 0.9, Time 20, Freq 300. Indices (freq=2, time=1)
|
||||
# 2. Score 0.8, Time 10, Freq 200. Indices (freq=1, time=0)
|
||||
expected_times = np.array([20, 10])
|
||||
expected_freqs = np.array([300, 200])
|
||||
detection_coords = {
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
}
|
||||
|
||||
# --- Manually Calculate Expected Data ---
|
||||
|
||||
# Scores (already correct in sample_positions_top2)
|
||||
expected_score = sample_positions_top2.rename(
|
||||
"scores"
|
||||
) # Rename to match output
|
||||
|
||||
# Dimensions Data (width, height) -> Transposed to (detection, dimension)
|
||||
# sample_sizes_array data: (dim, freq, time)
|
||||
# Det 1 (f=300, t=20): index (:, 2, 1) -> values [ 7., 16.]
|
||||
# Det 2 (f=200, t=10): index (:, 1, 0) -> values [ 3., 12.]
|
||||
expected_dimensions_data = np.array(
|
||||
[
|
||||
[7.0, 16.0], # Detection 1 [width, height]
|
||||
[3.0, 12.0],
|
||||
], # Detection 2 [width, height]
|
||||
dtype=np.float32,
|
||||
)
|
||||
expected_dimensions = xr.DataArray(
|
||||
expected_dimensions_data,
|
||||
coords={**detection_coords, "dimension": ["width", "height"]},
|
||||
dims=["detection", "dimension"],
|
||||
name="dimensions",
|
||||
)
|
||||
|
||||
# Classes Data (bat, noise) -> Transposed to (detection, category)
|
||||
# sample_classes_array data: np.linspace(0.1, 0.9, 18).reshape(2, 3, 3)
|
||||
# linspace vals: [0.1, 0.147, 0.194, 0.241, 0.288, 0.335, 0.382, 0.429, 0.476, # cat 0
|
||||
# 0.523, 0.570, 0.617, 0.664, 0.711, 0.758, 0.805, 0.852, 0.9] # cat 1
|
||||
# Det 1 (cat, f=2, t=1): index (:, 2, 1) -> values [idx 7=0.429, idx 16=0.852]
|
||||
# Det 2 (cat, f=1, t=0): index (:, 1, 0) -> values [idx 3=0.241, idx 12=0.664]
|
||||
expected_classes_data = np.array(
|
||||
[
|
||||
[0.42941177, 0.85294118], # Detection 1 [bat_prob, noise_prob]
|
||||
[0.24117647, 0.66470588],
|
||||
], # Detection 2 [bat_prob, noise_prob]
|
||||
dtype=np.float32,
|
||||
)
|
||||
expected_classes = xr.DataArray(
|
||||
expected_classes_data,
|
||||
coords={**detection_coords, "category": ["bat", "noise"]},
|
||||
dims=["detection", "category"],
|
||||
name="classes",
|
||||
)
|
||||
|
||||
# Features Data (f0..f3) -> Transposed to (detection, feature)
|
||||
# sample_features_array data: np.arange(36).reshape(4, 3, 3)
|
||||
# Det 1 (feat, f=2, t=1): index (:, 2, 1) -> values [ 7, 16, 25, 34]
|
||||
# Det 2 (feat, f=1, t=0): index (:, 1, 0) -> values [ 3, 12, 21, 30]
|
||||
expected_features_data = np.array(
|
||||
[
|
||||
[7.0, 16.0, 25.0, 34.0], # Detection 1 [f0, f1, f2, f3]
|
||||
[3.0, 12.0, 21.0, 30.0],
|
||||
], # Detection 2 [f0, f1, f2, f3]
|
||||
dtype=np.float32,
|
||||
)
|
||||
expected_features = xr.DataArray(
|
||||
expected_features_data,
|
||||
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["detection", "feature"],
|
||||
name="features",
|
||||
)
|
||||
|
||||
# Construct Expected Dataset
|
||||
expected_dataset = xr.Dataset(
|
||||
{
|
||||
"scores": expected_score,
|
||||
"dimensions": expected_dimensions,
|
||||
"classes": expected_classes,
|
||||
"features": expected_features,
|
||||
}
|
||||
)
|
||||
# Add coords explicitly to ensure they match
|
||||
expected_dataset = expected_dataset.assign_coords(detection_coords)
|
||||
|
||||
# --- Assert Equality ---
|
||||
xr.testing.assert_allclose(actual_dataset, expected_dataset)
|
||||
|
||||
|
||||
def test_extract_detection_xr_dataset_empty(
|
||||
empty_positions,
|
||||
sample_sizes_array,
|
||||
sample_classes_array,
|
||||
sample_features_array,
|
||||
):
|
||||
"""Test extraction with empty positions yields an empty dataset."""
|
||||
actual_dataset = extract_detection_xr_dataset(
|
||||
empty_positions,
|
||||
sample_sizes_array,
|
||||
sample_classes_array,
|
||||
sample_features_array,
|
||||
)
|
||||
|
||||
assert isinstance(actual_dataset, xr.Dataset)
|
||||
assert "detection" in actual_dataset.dims
|
||||
assert actual_dataset.dims["detection"] == 0
|
||||
|
||||
# Check variables exist and have 0 size along detection dim
|
||||
assert "scores" in actual_dataset
|
||||
assert actual_dataset["scores"].dims == ("detection",)
|
||||
assert actual_dataset["scores"].size == 0
|
||||
|
||||
assert "dimensions" in actual_dataset
|
||||
assert actual_dataset["dimensions"].dims == ("detection", "dimension")
|
||||
assert actual_dataset["dimensions"].shape == (0, 2) # Check both dims
|
||||
|
||||
assert "classes" in actual_dataset
|
||||
assert actual_dataset["classes"].dims == ("detection", "category")
|
||||
assert actual_dataset["classes"].shape == (0, 2)
|
||||
|
||||
assert "features" in actual_dataset
|
||||
assert actual_dataset["features"].dims == ("detection", "feature")
|
||||
assert actual_dataset["features"].shape == (0, 4)
|
||||
|
||||
# Check coordinates exist and are empty
|
||||
assert Dimensions.time.value in actual_dataset.coords
|
||||
assert Dimensions.frequency.value in actual_dataset.coords
|
||||
assert actual_dataset.coords[Dimensions.time.value].size == 0
|
||||
assert actual_dataset.coords[Dimensions.frequency.value].size == 0
|
0
tests/test_postprocessing/test_remapping.py
Normal file
0
tests/test_postprocessing/test_remapping.py
Normal file
Loading…
Reference in New Issue
Block a user