From 1f4454693e15d9e150822e4c7f64718733612d53 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 20 Apr 2025 15:52:25 +0100 Subject: [PATCH] Working on postprocess tests --- batdetect2/postprocess/decoding.py | 33 +- batdetect2/postprocess/detection.py | 2 +- tests/test_postprocessing/test_arrays.py | 43 -- tests/test_postprocessing/test_decoding.py | 604 +++++++++++++++++++ tests/test_postprocessing/test_detection.py | 214 +++++++ tests/test_postprocessing/test_extraction.py | 433 +++++++++++++ tests/test_postprocessing/test_remapping.py | 0 7 files changed, 1280 insertions(+), 49 deletions(-) delete mode 100644 tests/test_postprocessing/test_arrays.py create mode 100644 tests/test_postprocessing/test_decoding.py create mode 100644 tests/test_postprocessing/test_detection.py create mode 100644 tests/test_postprocessing/test_extraction.py create mode 100644 tests/test_postprocessing/test_remapping.py diff --git a/batdetect2/postprocess/decoding.py b/batdetect2/postprocess/decoding.py index 09537f7..3741030 100644 --- a/batdetect2/postprocess/decoding.py +++ b/batdetect2/postprocess/decoding.py @@ -29,6 +29,7 @@ The process involves: from typing import List, Optional +import numpy as np import xarray as xr from soundevent import data from soundevent.geometry import compute_bounds @@ -256,8 +257,17 @@ def convert_raw_prediction_to_sound_event_prediction( ] ), features=[ - data.Feature(term=data.term_from_key(feat_name), value=value) - for feat_name, value in raw_prediction.features + data.Feature( + term=data.Term( + name=f"batdetect2:{feat_name}", + label=feat_name, + definition="Automatically extracted features by BatDetect2", + ), + value=value, + ) + for feat_name, value in _iterate_over_array( + raw_prediction.features + ) ], ) @@ -274,9 +284,7 @@ def convert_raw_prediction_to_sound_event_prediction( drop=True, ) - for class_name, score in class_scores.sortby( - class_scores, ascending=False - ): + for class_name, score in _iterate_sorted(class_scores): class_tags = sound_event_decoder(class_name) for tag in class_tags: @@ -295,3 +303,18 @@ def convert_raw_prediction_to_sound_event_prediction( score=raw_prediction.detection_score, tags=tags, ) + + +def _iterate_over_array(array: xr.DataArray): + dim_name = array.dims[0] + coords = array.coords[dim_name] + for value, coord in zip(array.values, coords.values): + yield coord, float(value) + + +def _iterate_sorted(array: xr.DataArray): + dim_name = array.dims[0] + coords = array.coords[dim_name] + indices = np.argsort(coords.values) + for index in indices: + yield str(coords[index]), coords.values[index] diff --git a/batdetect2/postprocess/detection.py b/batdetect2/postprocess/detection.py index 96ac430..78e7003 100644 --- a/batdetect2/postprocess/detection.py +++ b/batdetect2/postprocess/detection.py @@ -103,7 +103,7 @@ def extract_detections_from_array( top_values = top_values[mask] top_sorted_indices = top_sorted_indices[mask] - time_indices, freq_indices = np.unravel_index( + freq_indices, time_indices = np.unravel_index( top_sorted_indices, detection_array.shape, ) diff --git a/tests/test_postprocessing/test_arrays.py b/tests/test_postprocessing/test_arrays.py deleted file mode 100644 index c94f64e..0000000 --- a/tests/test_postprocessing/test_arrays.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import List - -import numpy as np -import torch -import xarray as xr -from soundevent import data - -from batdetect2.modules import DetectorModel -from batdetect2.postprocess.arrays import to_xarray -from batdetect2.preprocess import preprocess_audio_clip - - -def test_this(clip: data.Clip, class_names: List[str]): - spec = xr.DataArray( - data=np.random.rand(100, 100), - dims=["time", "frequency"], - coords={ - "time": np.linspace(0, 100, 100, endpoint=False), - "frequency": np.linspace(0, 100, 100, endpoint=False), - }, - ) - - model = DetectorModel() - - spec = preprocess_audio_clip( - clip, - config=model.config.preprocessing, - ) - - tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0) - - outputs = model(tensor) - - arrays = to_xarray( - outputs, - start_time=clip.start_time, - end_time=clip.end_time, - class_names=class_names, - ) - - print(arrays) - - assert False diff --git a/tests/test_postprocessing/test_decoding.py b/tests/test_postprocessing/test_decoding.py new file mode 100644 index 0000000..0c0136e --- /dev/null +++ b/tests/test_postprocessing/test_decoding.py @@ -0,0 +1,604 @@ +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import pytest +import xarray as xr + +# Removed dataclass import as MockRawPrediction is replaced +from soundevent import data + +# Import functions to test +from batdetect2.postprocess.decoding import ( + DEFAULT_CLASSIFICATION_THRESHOLD, + convert_raw_prediction_to_sound_event_prediction, + convert_raw_predictions_to_clip_prediction, + convert_xr_dataset_to_raw_prediction, +) +from batdetect2.postprocess.types import RawPrediction + + +# Dummy GeometryBuilder function fixture +@pytest.fixture +def dummy_geometry_builder(): + """A simple GeometryBuilder that creates a BBox around the point.""" + + def _builder( + position: Tuple[float, float], + dimensions: xr.DataArray, + ) -> data.BoundingBox: + time, freq = position + width = dimensions.sel(dimension="width").item() + height = dimensions.sel(dimension="height").item() + # Assume position is the center + return data.BoundingBox( + coordinates=[ + time - width / 2, + freq - height / 2, + time + width / 2, + freq + height / 2, + ] + ) + + return _builder + + +# Dummy SoundEventDecoder function fixture +@pytest.fixture +def dummy_sound_event_decoder(): + """A simple SoundEventDecoder mapping names to tags.""" + tag_map = { + "bat": [ + data.Tag(term=data.term_from_key(key="species"), value="Myotis") + ], + "noise": [ + data.Tag(term=data.term_from_key(key="category"), value="noise") + ], + "unknown": [ + data.Tag(term=data.term_from_key(key="status"), value="uncertain") + ], + } + + def _decoder(class_name: str) -> List[data.Tag]: + return tag_map.get(class_name.lower(), []) + + return _decoder + + +@pytest.fixture +def generic_tags() -> List[data.Tag]: + """Sample generic tags.""" + return [ + data.Tag(term=data.term_from_key(key="detector"), value="batdetect2") + ] + + +@pytest.fixture +def sample_recording() -> data.Recording: + """A sample soundevent Recording.""" + return data.Recording( + path=Path("/path/to/recording.wav"), + duration=60.0, + channels=1, + samplerate=192000, + ) + + +@pytest.fixture +def sample_clip(sample_recording) -> data.Clip: + """A sample soundevent Clip.""" + return data.Clip( + recording=sample_recording, + start_time=10.0, + end_time=20.0, + ) + + +# Fixture for a detection dataset (adapted from test_extraction) +@pytest.fixture +def sample_detection_dataset() -> xr.Dataset: + """Creates a sample detection dataset suitable for decoding.""" + # Based on test_extraction's corrected expectations + # Detections: (t=20, f=300, s=0.9), (t=10, f=200, s=0.8) + expected_times = np.array([20, 10]) + expected_freqs = np.array([300, 200]) + detection_coords = { + "time": ("detection", expected_times), + "freq": ("detection", expected_freqs), + } + + scores_data = np.array([0.9, 0.8], dtype=np.float64) + scores = xr.DataArray( + scores_data, + coords=detection_coords, + dims=["detection"], + name="scores", + ) + + dimensions_data = np.array([[7.0, 16.0], [3.0, 12.0]], dtype=np.float32) + dimensions = xr.DataArray( + dimensions_data, + coords={**detection_coords, "dimension": ["width", "height"]}, + dims=["detection", "dimension"], + name="dimensions", + ) + + classes_data = np.array( + [[0.43, 0.85], [0.24, 0.66]], + dtype=np.float32, # Simplified values + ) + classes = xr.DataArray( + classes_data, + coords={**detection_coords, "category": ["bat", "noise"]}, + dims=["detection", "category"], + name="classes", + ) + + features_data = np.array( + [[7.0, 16.0, 25.0, 34.0], [3.0, 12.0, 21.0, 30.0]], dtype=np.float32 + ) + features = xr.DataArray( + features_data, + coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]}, + dims=["detection", "feature"], + name="features", + ) + + ds = xr.Dataset( + { + "score": scores, + "dimensions": dimensions, + "classes": classes, + "features": features, + }, + coords=detection_coords, + ) + return ds + + +@pytest.fixture +def empty_detection_dataset() -> xr.Dataset: + """Creates an empty detection dataset with correct structure.""" + detection_coords = { + "time": ("detection", np.array([], dtype=np.float64)), + "freq": ("detection", np.array([], dtype=np.float64)), + } + scores = xr.DataArray( + np.array([], dtype=np.float64), + coords=detection_coords, + dims=["detection"], + name="scores", + ) + dimensions = xr.DataArray( + np.empty((0, 2), dtype=np.float32), + coords={**detection_coords, "dimension": ["width", "height"]}, + dims=["detection", "dimension"], + name="dimensions", + ) + classes = xr.DataArray( + np.empty((0, 2), dtype=np.float32), + coords={**detection_coords, "category": ["bat", "noise"]}, + dims=["detection", "category"], + name="classes", + ) + features = xr.DataArray( + np.empty((0, 4), dtype=np.float32), + coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]}, + dims=["detection", "feature"], + name="features", + ) + return xr.Dataset( + { + "scores": scores, + "dimensions": dimensions, + "classes": classes, + "features": features, + }, + coords=detection_coords, + ) + + +# Fixture for sample RawPrediction objects (using the actual type) +@pytest.fixture +def sample_raw_predictions() -> List[RawPrediction]: + """Manually crafted RawPrediction objects using the actual type.""" + # Corresponds roughly to sample_detection_dataset after geometry building + # Det 1: t=20, f=300, s=0.9, w=7, h=16, classes=[0.43, 0.85], feats=[7, 16, 25, 34] + # Det 2: t=10, f=200, s=0.8, w=3, h=12, classes=[0.24, 0.66], feats=[ 3, 12, 21, 30] + pred1_classes = xr.DataArray( + [0.43, 0.85], coords={"category": ["bat", "noise"]}, dims=["category"] + ) + pred1_features = xr.DataArray( + [7.0, 16.0, 25.0, 34.0], + coords={"feature": ["f0", "f1", "f2", "f3"]}, + dims=["feature"], + ) + pred1 = RawPrediction( # Use RawPrediction directly + detection_score=0.9, + start_time=20 - 7 / 2, + end_time=20 + 7 / 2, # 16.5, 23.5 + low_freq=300 - 16 / 2, + high_freq=300 + 16 / 2, # 292, 308 + class_scores=pred1_classes, + features=pred1_features, + ) + + pred2_classes = xr.DataArray( + [0.24, 0.66], coords={"category": ["bat", "noise"]}, dims=["category"] + ) + pred2_features = xr.DataArray( + [3.0, 12.0, 21.0, 30.0], + coords={"feature": ["f0", "f1", "f2", "f3"]}, + dims=["feature"], + ) + pred2 = RawPrediction( # Use RawPrediction directly + detection_score=0.8, + start_time=10 - 3 / 2, + end_time=10 + 3 / 2, # 8.5, 11.5 + low_freq=200 - 12 / 2, + high_freq=200 + 12 / 2, # 194, 206 + class_scores=pred2_classes, + features=pred2_features, + ) + + pred3_classes = xr.DataArray( + [0.05, 0.02], coords={"category": ["bat", "noise"]}, dims=["category"] + ) # Below default threshold + pred3_features = xr.DataArray( + [1.0, 2.0, 3.0, 4.0], + coords={"feature": ["f0", "f1", "f2", "f3"]}, + dims=["feature"], + ) + pred3 = RawPrediction( # Use RawPrediction directly + detection_score=0.15, + start_time=5.0, + end_time=6.0, + low_freq=50.0, + high_freq=60.0, + class_scores=pred3_classes, + features=pred3_features, + ) + return [pred1, pred2, pred3] + + +# --- Tests for convert_xr_dataset_to_raw_prediction --- + + +def test_convert_xr_dataset_basic( + sample_detection_dataset, dummy_geometry_builder +): + """Test basic conversion of a dataset to RawPrediction list.""" + raw_predictions = convert_xr_dataset_to_raw_prediction( + sample_detection_dataset, dummy_geometry_builder + ) + + assert isinstance(raw_predictions, list) + assert len(raw_predictions) == 2 + + # Check first prediction (score=0.9) + pred1 = raw_predictions[0] + assert isinstance(pred1, RawPrediction) # Check against the actual type + assert pred1.detection_score == pytest.approx(0.9) + # Check bounds derived from dummy_geometry_builder (center pos assumed) + # t=20, f=300, w=7, h=16 + assert pred1.start_time == pytest.approx(20 - 7 / 2) + assert pred1.end_time == pytest.approx(20 + 7 / 2) + assert pred1.low_freq == pytest.approx(300 - 16 / 2) + assert pred1.high_freq == pytest.approx(300 + 16 / 2) + xr.testing.assert_allclose( + pred1.class_scores, + sample_detection_dataset["classes"].sel(detection=0), + ) + xr.testing.assert_allclose( + pred1.features, sample_detection_dataset["features"].sel(detection=0) + ) + + # Check second prediction (score=0.8) + pred2 = raw_predictions[1] + assert isinstance(pred2, RawPrediction) # Check against the actual type + assert pred2.detection_score == pytest.approx(0.8) + # t=10, f=200, w=3, h=12 + assert pred2.start_time == pytest.approx(10 - 3 / 2) + assert pred2.end_time == pytest.approx(10 + 3 / 2) + assert pred2.low_freq == pytest.approx(200 - 12 / 2) + assert pred2.high_freq == pytest.approx(200 + 12 / 2) + xr.testing.assert_allclose( + pred2.class_scores, + sample_detection_dataset["classes"].sel(detection=1), + ) + xr.testing.assert_allclose( + pred2.features, sample_detection_dataset["features"].sel(detection=1) + ) + + +# ...(rest of the tests remain unchanged as they accessed attributes correctly)... + + +def test_convert_xr_dataset_empty( + empty_detection_dataset, dummy_geometry_builder +): + """Test conversion of an empty dataset.""" + raw_predictions = convert_xr_dataset_to_raw_prediction( + empty_detection_dataset, dummy_geometry_builder + ) + assert isinstance(raw_predictions, list) + assert len(raw_predictions) == 0 + + +# --- Tests for convert_raw_prediction_to_sound_event_prediction --- + + +def test_convert_raw_to_sound_event_basic( + sample_raw_predictions, + sample_recording, + dummy_sound_event_decoder, + generic_tags, +): + """Test basic conversion, default threshold, multi-label.""" + # score=0.9, classes=[0.43(bat), 0.85(noise)] + raw_pred = sample_raw_predictions[0] + + se_pred = convert_raw_prediction_to_sound_event_prediction( + raw_prediction=raw_pred, + recording=sample_recording, + sound_event_decoder=dummy_sound_event_decoder, + generic_class_tags=generic_tags, + # classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD (0.1), + # top_class_only=False, + ) + + assert isinstance(se_pred, data.SoundEventPrediction) + assert se_pred.score == pytest.approx(raw_pred.detection_score) + + # Check SoundEvent + se = se_pred.sound_event + assert isinstance(se, data.SoundEvent) + assert se.recording == sample_recording + assert isinstance(se.geometry, data.BoundingBox) + np.testing.assert_allclose( + se.geometry.coordinates, + [ + raw_pred.start_time, + raw_pred.low_freq, + raw_pred.end_time, + raw_pred.high_freq, + ], + ) + assert len(se.features) == len(raw_pred.features) + # Simple check for feature presence and value type + feat_dict = {f.term.name: f.value for f in se.features} + assert "batdetect2:f0" in feat_dict and isinstance( + feat_dict["batdetect2:f0"], float + ) + assert feat_dict["batdetect2:f0"] == pytest.approx(7.0) + + # Check Tags + # Expected: Generic(0.9), Noise(0.85), Bat(0.43) + # Note: Order might depend on sortby implementation detail, compare as sets + expected_tags = { + # Generic Tag + (generic_tags[0].key, generic_tags[0].value, 0.9), + # Noise Tag (score 0.85 > 0.1) + ("category", "noise", 0.85), + # Bat Tag (score 0.43 > 0.1) + ("species", "Myotis", 0.43), + } + print("expected", expected_tags) + actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags} + print("actual", actual_tags) + assert actual_tags == expected_tags + + +def test_convert_raw_to_sound_event_thresholding( + sample_raw_predictions, + sample_recording, + dummy_sound_event_decoder, + generic_tags, +): + """Test effect of classification threshold.""" + raw_pred = sample_raw_predictions[ + 0 + ] # score=0.9, classes=[0.43(bat), 0.85(noise)] + high_threshold = 0.5 + + se_pred = convert_raw_prediction_to_sound_event_prediction( + raw_prediction=raw_pred, + recording=sample_recording, + sound_event_decoder=dummy_sound_event_decoder, + generic_class_tags=generic_tags, + classification_threshold=high_threshold, # Only noise should pass + top_class_only=False, + ) + + # Expected: Generic(0.9), Noise(0.85) - Bat (0.43) is below threshold + expected_tags = { + (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)), + ("category", "noise", pytest.approx(0.85)), + } + actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags} + assert actual_tags == expected_tags + + +def test_convert_raw_to_sound_event_no_threshold( + sample_raw_predictions, + sample_recording, + dummy_sound_event_decoder, + generic_tags, +): + """Test when classification_threshold is None.""" + raw_pred = sample_raw_predictions[ + 2 + ] # score=0.15, classes=[0.05(bat), 0.02(noise)] + # Both classes are below default threshold, but should be included if None + + se_pred = convert_raw_prediction_to_sound_event_prediction( + raw_prediction=raw_pred, + recording=sample_recording, + sound_event_decoder=dummy_sound_event_decoder, + generic_class_tags=generic_tags, + classification_threshold=None, # No thresholding + top_class_only=False, + ) + + # Expected: Generic(0.15), Bat(0.05), Noise(0.02) + expected_tags = { + (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)), + ("species", "Myotis", pytest.approx(0.05)), + ("category", "noise", pytest.approx(0.02)), + } + actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags} + assert actual_tags == expected_tags + + +def test_convert_raw_to_sound_event_top_class( + sample_raw_predictions, + sample_recording, + dummy_sound_event_decoder, + generic_tags, +): + """Test top_class_only=True behavior.""" + raw_pred = sample_raw_predictions[ + 0 + ] # score=0.9, classes=[0.43(bat), 0.85(noise)] + # Highest score is noise (0.85) + + se_pred = convert_raw_prediction_to_sound_event_prediction( + raw_prediction=raw_pred, + recording=sample_recording, + sound_event_decoder=dummy_sound_event_decoder, + generic_class_tags=generic_tags, + classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, + top_class_only=True, # Only include top class (noise) + ) + + # Expected: Generic(0.9), Noise(0.85) + expected_tags = { + (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)), + ("category", "noise", pytest.approx(0.85)), + } + actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags} + assert actual_tags == expected_tags + + +def test_convert_raw_to_sound_event_all_below_threshold( + sample_raw_predictions, + sample_recording, + dummy_sound_event_decoder, + generic_tags, +): + """Test when all class scores are below the default threshold.""" + raw_pred = sample_raw_predictions[ + 2 + ] # score=0.15, classes=[0.05(bat), 0.02(noise)] + + se_pred = convert_raw_prediction_to_sound_event_prediction( + raw_prediction=raw_pred, + recording=sample_recording, + sound_event_decoder=dummy_sound_event_decoder, + generic_class_tags=generic_tags, + classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, # 0.1 + top_class_only=False, + ) + + # Expected: Only Generic(0.15) tag, as others are below threshold + expected_tags = { + (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)), + } + actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags} + assert actual_tags == expected_tags + + +# --- Tests for convert_raw_predictions_to_clip_prediction --- + + +def test_convert_raw_list_to_clip_basic( + sample_raw_predictions, + sample_clip, + dummy_sound_event_decoder, + generic_tags, +): + """Test converting a list of RawPredictions to a ClipPrediction.""" + clip_pred = convert_raw_predictions_to_clip_prediction( + raw_predictions=sample_raw_predictions, + clip=sample_clip, + sound_event_decoder=dummy_sound_event_decoder, + generic_class_tags=generic_tags, + classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, + top_class_only=False, + ) + + assert isinstance(clip_pred, data.ClipPrediction) + assert clip_pred.clip == sample_clip + assert len(clip_pred.sound_events) == len(sample_raw_predictions) + + # Check if the contained sound events seem correct (basic check) + assert clip_pred.sound_events[0].score == pytest.approx( + sample_raw_predictions[0].detection_score + ) + assert clip_pred.sound_events[1].score == pytest.approx( + sample_raw_predictions[1].detection_score + ) + assert clip_pred.sound_events[2].score == pytest.approx( + sample_raw_predictions[2].detection_score + ) + + # Check if tags were generated correctly for one event (e.g., the last one) + # Pred 3 has score 0.15, classes [0.05, 0.02]. Only generic tag expected. + se_pred3_tags = { + (pt.tag.key, pt.tag.value, pt.score) + for pt in clip_pred.sound_events[2].tags + } + expected_tags3 = { + (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)), + } + assert se_pred3_tags == expected_tags3 + + +def test_convert_raw_list_to_clip_empty( + sample_clip, + dummy_sound_event_decoder, + generic_tags, +): + """Test converting an empty list of RawPredictions.""" + clip_pred = convert_raw_predictions_to_clip_prediction( + raw_predictions=[], + clip=sample_clip, + sound_event_decoder=dummy_sound_event_decoder, + generic_class_tags=generic_tags, + ) + + assert isinstance(clip_pred, data.ClipPrediction) + assert clip_pred.clip == sample_clip + assert len(clip_pred.sound_events) == 0 + + +def test_convert_raw_list_to_clip_passes_args( + sample_raw_predictions, + sample_clip, + dummy_sound_event_decoder, + generic_tags, +): + """Test that arguments like top_class_only are passed through.""" + # Use top_class_only = True + clip_pred = convert_raw_predictions_to_clip_prediction( + raw_predictions=sample_raw_predictions, + clip=sample_clip, + sound_event_decoder=dummy_sound_event_decoder, + generic_class_tags=generic_tags, + classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, + top_class_only=True, # <<-- Argument being tested + ) + + assert len(clip_pred.sound_events) == 3 + + # Check tags for the first prediction (score=0.9, classes=[0.43(bat), 0.85(noise)]) + # With top_class_only=True, expect Generic(0.9) and Noise(0.85) only + se_pred1_tags = { + (pt.tag.key, pt.tag.value, pt.score) + for pt in clip_pred.sound_events[0].tags + } + expected_tags1 = { + (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)), + ("category", "noise", pytest.approx(0.85)), + } + assert se_pred1_tags == expected_tags1 diff --git a/tests/test_postprocessing/test_detection.py b/tests/test_postprocessing/test_detection.py new file mode 100644 index 0000000..65aaad1 --- /dev/null +++ b/tests/test_postprocessing/test_detection.py @@ -0,0 +1,214 @@ +import numpy as np +import pytest +import xarray as xr +from soundevent.arrays import Dimensions + +from batdetect2.postprocess.detection import extract_detections_from_array + + +@pytest.fixture +def sample_data_array(): + """Provides a basic 3x3 DataArray. + Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30) + """ + array = xr.DataArray( + np.zeros([3, 3]), + coords={ + Dimensions.frequency.value: [100, 200, 300], + Dimensions.time.value: [10, 20, 30], + }, + dims=[ + Dimensions.frequency.value, + Dimensions.time.value, + ], + ) + + array.loc[dict(time=10, frequency=100)] = 0.005 + array.loc[dict(time=10, frequency=200)] = 0.5 + array.loc[dict(time=10, frequency=300)] = 0.03 + array.loc[dict(time=20, frequency=100)] = 0.8 + array.loc[dict(time=20, frequency=200)] = 0.02 + array.loc[dict(time=20, frequency=300)] = 0.6 + array.loc[dict(time=30, frequency=100)] = 0.04 + array.loc[dict(time=30, frequency=200)] = 0.9 + array.loc[dict(time=30, frequency=300)] = 0.7 + return array + + +@pytest.fixture +def data_array_with_nans(sample_data_array: xr.DataArray): + """Provides a 2D DataArray containing NaN values.""" + array = sample_data_array.copy() + array.loc[dict(time=10, frequency=300)] = np.nan + array.loc[dict(time=30, frequency=100)] = np.nan + return array + + +def test_basic_extraction(sample_data_array: xr.DataArray): + threshold = 0.1 + max_detections = 3 + + actual_result = extract_detections_from_array( + sample_data_array, + threshold=threshold, + max_detections=max_detections, + ) + + expected_values = np.array([0.9, 0.8, 0.7]) + expected_times = np.array([30, 20, 30]) + expected_freqs = np.array([200, 100, 300]) + expected_coords = { + Dimensions.frequency.value: ("detection", expected_freqs), + Dimensions.time.value: ("detection", expected_times), + } + expected_result = xr.DataArray( + expected_values, + coords=expected_coords, + dims="detection", + name="score", + ) + + xr.testing.assert_equal(actual_result, expected_result) + + +def test_threshold_only(sample_data_array): + input_array = sample_data_array + threshold = 0.5 + actual_result = extract_detections_from_array( + input_array, threshold=threshold + ) + expected_values = np.array([0.9, 0.8, 0.7, 0.6]) + expected_times = np.array([30, 20, 30, 20]) + expected_freqs = np.array([200, 100, 300, 300]) + expected_coords = { + Dimensions.time.value: ("detection", expected_times), + Dimensions.frequency.value: ("detection", expected_freqs), + } + expected_result = xr.DataArray( + expected_values, + coords=expected_coords, + dims="detection", + name="detection_value", + ) + xr.testing.assert_equal(actual_result, expected_result) + + +def test_max_detections_only(sample_data_array): + input_array = sample_data_array + max_detections = 4 + actual_result = extract_detections_from_array( + input_array, max_detections=max_detections + ) + expected_values = np.array([0.9, 0.8, 0.7, 0.6]) + expected_times = np.array([30, 20, 30, 20]) + expected_freqs = np.array([200, 100, 300, 300]) + expected_coords = { + Dimensions.time.value: ("detection", expected_times), + Dimensions.frequency.value: ("detection", expected_freqs), + } + expected_result = xr.DataArray( + expected_values, + coords=expected_coords, + dims="detection", + name="detection_value", + ) + xr.testing.assert_equal(actual_result, expected_result) + + +def test_no_optional_args(sample_data_array): + input_array = sample_data_array + actual_result = extract_detections_from_array(input_array) + expected_values = np.array([0.9, 0.8, 0.7, 0.6, 0.5, 0.04, 0.03, 0.02]) + expected_times = np.array([30, 20, 30, 20, 10, 30, 10, 20]) + expected_freqs = np.array([200, 100, 300, 300, 200, 100, 300, 200]) + expected_coords = { + Dimensions.time.value: ("detection", expected_times), + Dimensions.frequency.value: ("detection", expected_freqs), + } + expected_result = xr.DataArray( + expected_values, + coords=expected_coords, + dims="detection", + name="detection_value", + ) + xr.testing.assert_equal(actual_result, expected_result) + + +def test_no_values_above_threshold(sample_data_array): + input_array = sample_data_array + threshold = 1.0 + actual_result = extract_detections_from_array( + input_array, threshold=threshold + ) + expected_coords = { + Dimensions.time.value: ("detection", np.array([], dtype=np.int64)), + Dimensions.frequency.value: ( + "detection", + np.array([], dtype=np.int64), + ), + } + expected_result = xr.DataArray( + np.array([], dtype=np.float64), + coords=expected_coords, + dims="detection", + name="detection_value", + ) + xr.testing.assert_equal(actual_result, expected_result) + assert actual_result.sizes["detection"] == 0 + + +def test_max_detections_zero(sample_data_array): + input_array = sample_data_array + max_detections = 0 + with pytest.raises(ValueError): + extract_detections_from_array( + input_array, + max_detections=max_detections, + ) + + +def test_empty_input_array(): + empty_array = xr.DataArray( + np.empty((0, 0)), + coords={Dimensions.time.value: [], Dimensions.frequency.value: []}, + dims=[Dimensions.time.value, Dimensions.frequency.value], + ) + actual_result = extract_detections_from_array(empty_array) + expected_coords = { + Dimensions.time.value: ("detection", np.array([], dtype=np.int64)), + Dimensions.frequency.value: ( + "detection", + np.array([], dtype=np.int64), + ), + } + expected_result = xr.DataArray( + np.array([], dtype=np.float64), + coords=expected_coords, + dims="detection", + name="detection_value", + ) + xr.testing.assert_equal(actual_result, expected_result) + assert actual_result.sizes["detection"] == 0 + + +def test_nan_handling(data_array_with_nans): + input_array = data_array_with_nans + threshold = 0.1 + max_detections = 3 + actual_result = extract_detections_from_array( + input_array, threshold=threshold, max_detections=max_detections + ) + expected_values = np.array([0.9, 0.8, 0.7]) + expected_times = np.array([30, 20, 30]) + expected_freqs = np.array([200, 100, 300]) + expected_coords = { + Dimensions.time.value: ("detection", expected_times), + Dimensions.frequency.value: ("detection", expected_freqs), + } + expected_result = xr.DataArray( + expected_values, + coords=expected_coords, + dims="detection", + name="detection_value", + ) + xr.testing.assert_equal(actual_result, expected_result) diff --git a/tests/test_postprocessing/test_extraction.py b/tests/test_postprocessing/test_extraction.py new file mode 100644 index 0000000..7b6b7c3 --- /dev/null +++ b/tests/test_postprocessing/test_extraction.py @@ -0,0 +1,433 @@ +import numpy as np +import pytest +import xarray as xr +from soundevent.arrays import Dimensions + +from batdetect2.postprocess.detection import extract_detections_from_array +from batdetect2.postprocess.extraction import ( + extract_detection_xr_dataset, + extract_values_at_positions, +) + + +@pytest.fixture +def sample_data_array(): + """Provides a basic 3x3 DataArray. + Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30) + """ + coords = { + Dimensions.frequency.value: [100, 200, 300], + Dimensions.time.value: [10, 20, 30], + } + array = xr.DataArray( + np.zeros([3, 3]), + coords=coords, + dims=[ + Dimensions.frequency.value, + Dimensions.time.value, + ], + ) + + array.loc[dict(time=10, frequency=100)] = 0.005 + array.loc[dict(time=10, frequency=200)] = 0.5 + array.loc[dict(time=10, frequency=300)] = 0.03 + array.loc[dict(time=20, frequency=100)] = 0.8 + array.loc[dict(time=20, frequency=200)] = 0.02 + array.loc[dict(time=20, frequency=300)] = 0.6 + array.loc[dict(time=30, frequency=100)] = 0.04 + array.loc[dict(time=30, frequency=200)] = 0.9 + array.loc[dict(time=30, frequency=300)] = 0.7 + return array + + +@pytest.fixture +def sample_array_for_extraction(): + """Provides a simple array (1-9) for value extraction tests.""" + data = np.arange(1, 10).reshape(3, 3) + coords = { + Dimensions.frequency.value: [100, 200, 300], + Dimensions.time.value: [10, 20, 30], + } + return xr.DataArray( + data, + coords=coords, + dims=[ + Dimensions.frequency.value, + Dimensions.time.value, + ], + name="test_values", + ) + + +@pytest.fixture +def sample_positions_top3(sample_data_array): + """Get top 3 detection positions from sample_data_array.""" + # Expected: (f=300, t=20, s=0.9), (f=200, t=10, s=0.8), (f=300, t=30, s=0.7) + return extract_detections_from_array( + sample_data_array, max_detections=3, threshold=None + ) + + +@pytest.fixture +def sample_positions_top2(sample_data_array): + """Get top 2 detection positions from sample_data_array.""" + # Expected: (f=300, t=20, s=0.9), (f=200, t=10, s=0.8) + return extract_detections_from_array( + sample_data_array, max_detections=2, threshold=None + ) + + +@pytest.fixture +def empty_positions(sample_data_array): + """Get an empty positions array (high threshold).""" + return extract_detections_from_array( + sample_data_array, + threshold=1.0, # No values > 1.0 + ) + + +@pytest.fixture +def sample_sizes_array(sample_data_array): + """Provides a sample sizes array matching sample_data_array coords.""" + coords = sample_data_array.coords + # Data: [[0, 1, 2], [3, 4, 5]] # Dim 0 (width) + # [[9,10,11], [12,13,14]] # Dim 1 (height) + # Reshaped: (2, 3, 3) -> (dim, freq, time) + data = np.array( + [ + [ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + ], # width (freq increases down, time across) + [[9, 10, 11], [12, 13, 14], [15, 16, 17]], # height + ], + dtype=np.float32, + ) + + return xr.DataArray( + data, + coords={ + "dimension": ["width", "height"], + Dimensions.frequency.value: coords[Dimensions.frequency.value], + Dimensions.time.value: coords[Dimensions.time.value], + }, + dims=["dimension", Dimensions.frequency.value, Dimensions.time.value], + name="sizes", + ) + + +@pytest.fixture +def sample_classes_array(sample_data_array): + """Provides a sample classes array matching sample_data_array coords.""" + coords = sample_data_array.coords + # Example: (2 cats, 3 freqs, 3 times) + data = np.linspace(0.1, 0.9, 18, dtype=np.float32).reshape(2, 3, 3) + # data[0, 2, 1] -> cat=0, f=300, t=20 -> val for 0.9 detection + # data[0, 1, 0] -> cat=0, f=200, t=10 -> val for 0.8 detection + return xr.DataArray( + data, + coords={ + "category": ["bat", "noise"], + Dimensions.frequency.value: coords[Dimensions.frequency.value], + Dimensions.time.value: coords[Dimensions.time.value], + }, + dims=["category", Dimensions.frequency.value, Dimensions.time.value], + name="class_scores", + ) + + +@pytest.fixture +def sample_features_array(sample_data_array): + """Provides a sample features array matching sample_data_array coords.""" + coords = sample_data_array.coords + # Example: (4 features, 3 freqs, 3 times) + data = np.arange(0, 36, dtype=np.float32).reshape(4, 3, 3) + # data[:, 2, 1] -> feats, f=300, t=20 -> vals for 0.9 detection + # data[:, 1, 0] -> feats, f=200, t=10 -> vals for 0.8 detection + return xr.DataArray( + data, + coords={ + "feature": ["f0", "f1", "f2", "f3"], + Dimensions.frequency.value: coords[Dimensions.frequency.value], + Dimensions.time.value: coords[Dimensions.time.value], + }, + dims=["feature", Dimensions.frequency.value, Dimensions.time.value], + name="features", + ) + + +# --- Tests for extract_values_at_positions --- + + +def test_extract_values_at_positions_correct( + sample_array_for_extraction, sample_positions_top3 +): + """Verify correct values are extracted based on positions coords.""" + # Positions: (f=300, t=20), (f=200, t=10), (f=300, t=30) + # Corresponding values in sample_array_for_extraction (1-9): + # f=300, t=20 -> index (2, 1) -> value 8 + # f=200, t=10 -> index (1, 0) -> value 4 + # f=300, t=30 -> index (2, 2) -> value 9 + expected_values = np.array([8, 4, 9]) + + print(sample_positions_top3) + + expected = xr.DataArray( + expected_values, + coords=sample_positions_top3.coords, # Should inherit coords + dims="detection", + name="test_values", # Should inherit name + ) + + extracted = extract_values_at_positions( + sample_array_for_extraction, sample_positions_top3 + ) + + xr.testing.assert_allclose(extracted, expected) + + +def test_extract_values_at_positions_extra_dims( + sample_sizes_array, sample_positions_top2 +): + """Test extraction preserves other dimensions in the source array.""" + # Positions: (f=300, t=20), (f=200, t=10) + # Extract from sample_sizes_array (dim, freq, time) + # Det 1 (f=300, t=20) -> index (:, 2, 1) -> values [7, 16] + # Det 2 (f=200, t=10) -> index (:, 1, 0) -> values [3, 12] + # Expected shape: (dimension, detection) + expected_values = np.array([[7.0, 3.0], [16.0, 12.0]], dtype=np.float32) + + expected = xr.DataArray( + expected_values, + coords={ + "dimension": ["width", "height"], + Dimensions.frequency.value: sample_positions_top2.coords[ + Dimensions.frequency.value + ], + Dimensions.time.value: sample_positions_top2.coords[ + Dimensions.time.value + ], + }, + dims=["dimension", "detection"], + name="sizes", # Inherits name + ) + + extracted = extract_values_at_positions( + sample_sizes_array, sample_positions_top2 + ) + xr.testing.assert_allclose(extracted, expected) + + +def test_extract_values_at_positions_empty( + sample_array_for_extraction, empty_positions +): + """Test extraction with empty positions returns empty array.""" + extracted = extract_values_at_positions( + sample_array_for_extraction, empty_positions + ) + assert extracted.sizes["detection"] == 0 + # Check coordinates are also empty but defined + assert Dimensions.time.value in extracted.coords + assert Dimensions.frequency.value in extracted.coords + assert extracted.coords[Dimensions.time.value].size == 0 + assert extracted.coords[Dimensions.frequency.value].size == 0 + assert extracted.name == sample_array_for_extraction.name + + +def test_extract_values_at_positions_missing_coord_in_array( + sample_array_for_extraction, sample_positions_top2 +): + """Test error if source array misses required coordinates.""" + array_no_time = sample_array_for_extraction.copy() + del array_no_time.coords[Dimensions.time.value] + with pytest.raises(IndexError): + extract_values_at_positions(array_no_time, sample_positions_top2) + + array_no_freq = sample_array_for_extraction.copy() + del array_no_freq.coords[Dimensions.frequency.value] + with pytest.raises(IndexError): + extract_values_at_positions(array_no_freq, sample_positions_top2) + + +def test_extract_values_at_positions_missing_coord_in_positions( + sample_array_for_extraction, sample_positions_top2 +): + """Test error if positions array misses required coordinates.""" + positions_no_time = sample_positions_top2.copy() + del positions_no_time.coords[Dimensions.time.value] + with pytest.raises(KeyError): + extract_values_at_positions( + sample_array_for_extraction, positions_no_time + ) + + positions_no_freq = sample_positions_top2.copy() + del positions_no_freq.coords[Dimensions.frequency.value] + with pytest.raises(KeyError): + extract_values_at_positions( + sample_array_for_extraction, positions_no_freq + ) + + +def test_extract_values_at_positions_mismatched_coords( + sample_array_for_extraction, sample_positions_top2 +): + """Test error if positions requests coords not in source array.""" + # Create positions requesting a time=40 not present in sample_array + bad_positions = sample_positions_top2.copy() + bad_positions.coords[Dimensions.time.value] = ( + "detection", + np.array([40, 10]), # First time is invalid + ) + with pytest.raises( + KeyError + ): # xarray.sel raises KeyError for missing labels + extract_values_at_positions(sample_array_for_extraction, bad_positions) + + +# --- Tests for extract_detection_xr_dataset --- + + +def test_extract_detection_xr_dataset_correct( + sample_positions_top2, + sample_sizes_array, + sample_classes_array, + sample_features_array, +): + """Tests extracting and bundling info for top 2 detections.""" + actual_dataset = extract_detection_xr_dataset( + sample_positions_top2, + sample_sizes_array, + sample_classes_array, + sample_features_array, + ) + + # Expected positions (top 2): + # 1. Score 0.9, Time 20, Freq 300. Indices (freq=2, time=1) + # 2. Score 0.8, Time 10, Freq 200. Indices (freq=1, time=0) + expected_times = np.array([20, 10]) + expected_freqs = np.array([300, 200]) + detection_coords = { + Dimensions.time.value: ("detection", expected_times), + Dimensions.frequency.value: ("detection", expected_freqs), + } + + # --- Manually Calculate Expected Data --- + + # Scores (already correct in sample_positions_top2) + expected_score = sample_positions_top2.rename( + "scores" + ) # Rename to match output + + # Dimensions Data (width, height) -> Transposed to (detection, dimension) + # sample_sizes_array data: (dim, freq, time) + # Det 1 (f=300, t=20): index (:, 2, 1) -> values [ 7., 16.] + # Det 2 (f=200, t=10): index (:, 1, 0) -> values [ 3., 12.] + expected_dimensions_data = np.array( + [ + [7.0, 16.0], # Detection 1 [width, height] + [3.0, 12.0], + ], # Detection 2 [width, height] + dtype=np.float32, + ) + expected_dimensions = xr.DataArray( + expected_dimensions_data, + coords={**detection_coords, "dimension": ["width", "height"]}, + dims=["detection", "dimension"], + name="dimensions", + ) + + # Classes Data (bat, noise) -> Transposed to (detection, category) + # sample_classes_array data: np.linspace(0.1, 0.9, 18).reshape(2, 3, 3) + # linspace vals: [0.1, 0.147, 0.194, 0.241, 0.288, 0.335, 0.382, 0.429, 0.476, # cat 0 + # 0.523, 0.570, 0.617, 0.664, 0.711, 0.758, 0.805, 0.852, 0.9] # cat 1 + # Det 1 (cat, f=2, t=1): index (:, 2, 1) -> values [idx 7=0.429, idx 16=0.852] + # Det 2 (cat, f=1, t=0): index (:, 1, 0) -> values [idx 3=0.241, idx 12=0.664] + expected_classes_data = np.array( + [ + [0.42941177, 0.85294118], # Detection 1 [bat_prob, noise_prob] + [0.24117647, 0.66470588], + ], # Detection 2 [bat_prob, noise_prob] + dtype=np.float32, + ) + expected_classes = xr.DataArray( + expected_classes_data, + coords={**detection_coords, "category": ["bat", "noise"]}, + dims=["detection", "category"], + name="classes", + ) + + # Features Data (f0..f3) -> Transposed to (detection, feature) + # sample_features_array data: np.arange(36).reshape(4, 3, 3) + # Det 1 (feat, f=2, t=1): index (:, 2, 1) -> values [ 7, 16, 25, 34] + # Det 2 (feat, f=1, t=0): index (:, 1, 0) -> values [ 3, 12, 21, 30] + expected_features_data = np.array( + [ + [7.0, 16.0, 25.0, 34.0], # Detection 1 [f0, f1, f2, f3] + [3.0, 12.0, 21.0, 30.0], + ], # Detection 2 [f0, f1, f2, f3] + dtype=np.float32, + ) + expected_features = xr.DataArray( + expected_features_data, + coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]}, + dims=["detection", "feature"], + name="features", + ) + + # Construct Expected Dataset + expected_dataset = xr.Dataset( + { + "scores": expected_score, + "dimensions": expected_dimensions, + "classes": expected_classes, + "features": expected_features, + } + ) + # Add coords explicitly to ensure they match + expected_dataset = expected_dataset.assign_coords(detection_coords) + + # --- Assert Equality --- + xr.testing.assert_allclose(actual_dataset, expected_dataset) + + +def test_extract_detection_xr_dataset_empty( + empty_positions, + sample_sizes_array, + sample_classes_array, + sample_features_array, +): + """Test extraction with empty positions yields an empty dataset.""" + actual_dataset = extract_detection_xr_dataset( + empty_positions, + sample_sizes_array, + sample_classes_array, + sample_features_array, + ) + + assert isinstance(actual_dataset, xr.Dataset) + assert "detection" in actual_dataset.dims + assert actual_dataset.dims["detection"] == 0 + + # Check variables exist and have 0 size along detection dim + assert "scores" in actual_dataset + assert actual_dataset["scores"].dims == ("detection",) + assert actual_dataset["scores"].size == 0 + + assert "dimensions" in actual_dataset + assert actual_dataset["dimensions"].dims == ("detection", "dimension") + assert actual_dataset["dimensions"].shape == (0, 2) # Check both dims + + assert "classes" in actual_dataset + assert actual_dataset["classes"].dims == ("detection", "category") + assert actual_dataset["classes"].shape == (0, 2) + + assert "features" in actual_dataset + assert actual_dataset["features"].dims == ("detection", "feature") + assert actual_dataset["features"].shape == (0, 4) + + # Check coordinates exist and are empty + assert Dimensions.time.value in actual_dataset.coords + assert Dimensions.frequency.value in actual_dataset.coords + assert actual_dataset.coords[Dimensions.time.value].size == 0 + assert actual_dataset.coords[Dimensions.frequency.value].size == 0 diff --git a/tests/test_postprocessing/test_remapping.py b/tests/test_postprocessing/test_remapping.py new file mode 100644 index 0000000..e69de29