diff --git a/batdetect2/postprocess/decoding.py b/batdetect2/postprocess/decoding.py index 3741030..8e1f570 100644 --- a/batdetect2/postprocess/decoding.py +++ b/batdetect2/postprocess/decoding.py @@ -102,10 +102,6 @@ def convert_xr_dataset_to_raw_prediction( ) start_time, low_freq, end_time, high_freq = compute_bounds(geom) - - classes = det_info.classes - features = det_info.features - detections.append( RawPrediction( detection_score=det_info.score, @@ -113,8 +109,8 @@ def convert_xr_dataset_to_raw_prediction( end_time=end_time, low_freq=low_freq, high_freq=high_freq, - class_scores=classes, - features=features, + class_scores=det_info.classes, + features=det_info.features, ) ) @@ -256,33 +252,130 @@ def convert_raw_prediction_to_sound_event_prediction( raw_prediction.high_freq, ] ), - features=[ - data.Feature( - term=data.Term( - name=f"batdetect2:{feat_name}", - label=feat_name, - definition="Automatically extracted features by BatDetect2", - ), - value=value, - ) - for feat_name, value in _iterate_over_array( - raw_prediction.features - ) - ], + features=get_prediction_features(raw_prediction.features), ) tags = [ - data.PredictedTag(tag=tag, score=raw_prediction.detection_score) + *get_generic_tags( + raw_prediction.detection_score, + generic_class_tags=generic_class_tags, + ), + *get_class_tags( + raw_prediction.class_scores, + sound_event_decoder, + top_class_only=top_class_only, + threshold=classification_threshold, + ), + ] + + return data.SoundEventPrediction( + sound_event=sound_event, + score=raw_prediction.detection_score, + tags=tags, + ) + + +def get_generic_tags( + detection_score: float, + generic_class_tags: List[data.Tag], +) -> List[data.PredictedTag]: + """Create PredictedTag objects for the generic category. + + Takes the base list of generic tags and assigns the overall detection + score to each one, wrapping them in `PredictedTag` objects. + + Parameters + ---------- + detection_score : float + The overall confidence score of the detection event. + generic_class_tags : List[data.Tag] + The list of base `soundevent.data.Tag` objects that define the + generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']). + + Returns + ------- + List[data.PredictedTag] + A list of `PredictedTag` objects for the generic category, each + assigned the `detection_score`. + """ + return [ + data.PredictedTag(tag=tag, score=detection_score) for tag in generic_class_tags ] - class_scores = raw_prediction.class_scores - if classification_threshold is not None: - class_scores = class_scores.where( - class_scores > classification_threshold, - drop=True, +def get_prediction_features(features: xr.DataArray) -> List[data.Feature]: + """Convert an extracted feature vector DataArray into soundevent Features. + + Parameters + ---------- + features : xr.DataArray + A 1D xarray DataArray containing feature values, indexed by a coordinate + named 'feature' which holds the feature names (e.g., output of selecting + features for one detection from `extract_detection_xr_dataset`). + + Returns + ------- + List[data.Feature] + A list of `soundevent.data.Feature` objects. + + Notes + ----- + - This function creates basic `Term` objects using the feature coordinate + names with a "batdetect2:" prefix. + """ + return [ + data.Feature( + term=data.Term( + name=f"batdetect2:{feat_name}", + label=feat_name, + definition="Automatically extracted features by BatDetect2", + ), + value=value, ) + for feat_name, value in _iterate_over_array(features) + ] + + +def get_class_tags( + class_scores: xr.DataArray, + sound_event_decoder: SoundEventDecoder, + top_class_only: bool = False, + threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD, +) -> List[data.PredictedTag]: + """Generate specific PredictedTags based on class scores and decoder. + + Filters class scores by the threshold, sorts remaining scores descending, + decodes the class name(s) into base tags using the `sound_event_decoder`, + and creates `PredictedTag` objects associating the class score. Stops after + the first (top) class if `top_class_only` is True. + + Parameters + ---------- + class_scores : xr.DataArray + A 1D xarray DataArray containing class probabilities/scores, indexed + by a 'category' coordinate holding the class names. + sound_event_decoder : SoundEventDecoder + Function to map a class name string to a list of base `data.Tag` + objects. + top_class_only : bool, default=False + If True, only generate tags for the single highest-scoring class above + the threshold. + threshold : float, optional + Minimum score for a class to be considered. If None, all classes are + processed (or top-1 if `top_class_only` is True). Defaults to + `DEFAULT_CLASSIFICATION_THRESHOLD`. + + Returns + ------- + List[data.PredictedTag] + A list of `PredictedTag` objects for the class(es) that passed the + threshold, ordered by score if `top_class_only` is False. + """ + tags = [] + + if threshold is not None: + class_scores = class_scores.where(class_scores > threshold, drop=True) for class_name, score in _iterate_sorted(class_scores): class_tags = sound_event_decoder(class_name) @@ -298,11 +391,7 @@ def convert_raw_prediction_to_sound_event_prediction( if top_class_only: break - return data.SoundEventPrediction( - sound_event=sound_event, - score=raw_prediction.detection_score, - tags=tags, - ) + return tags def _iterate_over_array(array: xr.DataArray): @@ -314,7 +403,7 @@ def _iterate_over_array(array: xr.DataArray): def _iterate_sorted(array: xr.DataArray): dim_name = array.dims[0] - coords = array.coords[dim_name] - indices = np.argsort(coords.values) + coords = array.coords[dim_name].values + indices = np.argsort(-array.values) for index in indices: - yield str(coords[index]), coords.values[index] + yield str(coords[index]), float(array.values[index]) diff --git a/tests/test_postprocessing/test_decoding.py b/tests/test_postprocessing/test_decoding.py index 0c0136e..4580772 100644 --- a/tests/test_postprocessing/test_decoding.py +++ b/tests/test_postprocessing/test_decoding.py @@ -4,21 +4,20 @@ from typing import List, Tuple import numpy as np import pytest import xarray as xr - -# Removed dataclass import as MockRawPrediction is replaced from soundevent import data -# Import functions to test from batdetect2.postprocess.decoding import ( DEFAULT_CLASSIFICATION_THRESHOLD, convert_raw_prediction_to_sound_event_prediction, convert_raw_predictions_to_clip_prediction, convert_xr_dataset_to_raw_prediction, + get_class_tags, + get_generic_tags, + get_prediction_features, ) from batdetect2.postprocess.types import RawPrediction -# Dummy GeometryBuilder function fixture @pytest.fixture def dummy_geometry_builder(): """A simple GeometryBuilder that creates a BBox around the point.""" @@ -30,7 +29,6 @@ def dummy_geometry_builder(): time, freq = position width = dimensions.sel(dimension="width").item() height = dimensions.sel(dimension="height").item() - # Assume position is the center return data.BoundingBox( coordinates=[ time - width / 2, @@ -43,7 +41,6 @@ def dummy_geometry_builder(): return _builder -# Dummy SoundEventDecoder function fixture @pytest.fixture def dummy_sound_event_decoder(): """A simple SoundEventDecoder mapping names to tags.""" @@ -94,12 +91,9 @@ def sample_clip(sample_recording) -> data.Clip: ) -# Fixture for a detection dataset (adapted from test_extraction) @pytest.fixture def sample_detection_dataset() -> xr.Dataset: """Creates a sample detection dataset suitable for decoding.""" - # Based on test_extraction's corrected expectations - # Detections: (t=20, f=300, s=0.9), (t=10, f=200, s=0.8) expected_times = np.array([20, 10]) expected_freqs = np.array([300, 200]) detection_coords = { @@ -125,7 +119,7 @@ def sample_detection_dataset() -> xr.Dataset: classes_data = np.array( [[0.43, 0.85], [0.24, 0.66]], - dtype=np.float32, # Simplified values + dtype=np.float32, ) classes = xr.DataArray( classes_data, @@ -198,13 +192,10 @@ def empty_detection_dataset() -> xr.Dataset: ) -# Fixture for sample RawPrediction objects (using the actual type) @pytest.fixture def sample_raw_predictions() -> List[RawPrediction]: """Manually crafted RawPrediction objects using the actual type.""" - # Corresponds roughly to sample_detection_dataset after geometry building - # Det 1: t=20, f=300, s=0.9, w=7, h=16, classes=[0.43, 0.85], feats=[7, 16, 25, 34] - # Det 2: t=10, f=200, s=0.8, w=3, h=12, classes=[0.24, 0.66], feats=[ 3, 12, 21, 30] + pred1_classes = xr.DataArray( [0.43, 0.85], coords={"category": ["bat", "noise"]}, dims=["category"] ) @@ -213,12 +204,12 @@ def sample_raw_predictions() -> List[RawPrediction]: coords={"feature": ["f0", "f1", "f2", "f3"]}, dims=["feature"], ) - pred1 = RawPrediction( # Use RawPrediction directly + pred1 = RawPrediction( detection_score=0.9, start_time=20 - 7 / 2, - end_time=20 + 7 / 2, # 16.5, 23.5 + end_time=20 + 7 / 2, low_freq=300 - 16 / 2, - high_freq=300 + 16 / 2, # 292, 308 + high_freq=300 + 16 / 2, class_scores=pred1_classes, features=pred1_features, ) @@ -231,25 +222,25 @@ def sample_raw_predictions() -> List[RawPrediction]: coords={"feature": ["f0", "f1", "f2", "f3"]}, dims=["feature"], ) - pred2 = RawPrediction( # Use RawPrediction directly + pred2 = RawPrediction( detection_score=0.8, start_time=10 - 3 / 2, - end_time=10 + 3 / 2, # 8.5, 11.5 + end_time=10 + 3 / 2, low_freq=200 - 12 / 2, - high_freq=200 + 12 / 2, # 194, 206 + high_freq=200 + 12 / 2, class_scores=pred2_classes, features=pred2_features, ) pred3_classes = xr.DataArray( [0.05, 0.02], coords={"category": ["bat", "noise"]}, dims=["category"] - ) # Below default threshold + ) pred3_features = xr.DataArray( [1.0, 2.0, 3.0, 4.0], coords={"feature": ["f0", "f1", "f2", "f3"]}, dims=["feature"], ) - pred3 = RawPrediction( # Use RawPrediction directly + pred3 = RawPrediction( detection_score=0.15, start_time=5.0, end_time=6.0, @@ -261,9 +252,6 @@ def sample_raw_predictions() -> List[RawPrediction]: return [pred1, pred2, pred3] -# --- Tests for convert_xr_dataset_to_raw_prediction --- - - def test_convert_xr_dataset_basic( sample_detection_dataset, dummy_geometry_builder ): @@ -275,16 +263,14 @@ def test_convert_xr_dataset_basic( assert isinstance(raw_predictions, list) assert len(raw_predictions) == 2 - # Check first prediction (score=0.9) pred1 = raw_predictions[0] - assert isinstance(pred1, RawPrediction) # Check against the actual type - assert pred1.detection_score == pytest.approx(0.9) - # Check bounds derived from dummy_geometry_builder (center pos assumed) - # t=20, f=300, w=7, h=16 - assert pred1.start_time == pytest.approx(20 - 7 / 2) - assert pred1.end_time == pytest.approx(20 + 7 / 2) - assert pred1.low_freq == pytest.approx(300 - 16 / 2) - assert pred1.high_freq == pytest.approx(300 + 16 / 2) + assert isinstance(pred1, RawPrediction) + assert pred1.detection_score == 0.9 + + assert pred1.start_time == 20 - 7 / 2 + assert pred1.end_time == 20 + 7 / 2 + assert pred1.low_freq == 300 - 16 / 2 + assert pred1.high_freq == 300 + 16 / 2 xr.testing.assert_allclose( pred1.class_scores, sample_detection_dataset["classes"].sel(detection=0), @@ -293,15 +279,14 @@ def test_convert_xr_dataset_basic( pred1.features, sample_detection_dataset["features"].sel(detection=0) ) - # Check second prediction (score=0.8) pred2 = raw_predictions[1] - assert isinstance(pred2, RawPrediction) # Check against the actual type - assert pred2.detection_score == pytest.approx(0.8) - # t=10, f=200, w=3, h=12 - assert pred2.start_time == pytest.approx(10 - 3 / 2) - assert pred2.end_time == pytest.approx(10 + 3 / 2) - assert pred2.low_freq == pytest.approx(200 - 12 / 2) - assert pred2.high_freq == pytest.approx(200 + 12 / 2) + assert isinstance(pred2, RawPrediction) + assert pred2.detection_score == 0.8 + + assert pred2.start_time == 10 - 3 / 2 + assert pred2.end_time == 10 + 3 / 2 + assert pred2.low_freq == 200 - 12 / 2 + assert pred2.high_freq == 200 + 12 / 2 xr.testing.assert_allclose( pred2.class_scores, sample_detection_dataset["classes"].sel(detection=1), @@ -311,9 +296,6 @@ def test_convert_xr_dataset_basic( ) -# ...(rest of the tests remain unchanged as they accessed attributes correctly)... - - def test_convert_xr_dataset_empty( empty_detection_dataset, dummy_geometry_builder ): @@ -325,9 +307,6 @@ def test_convert_xr_dataset_empty( assert len(raw_predictions) == 0 -# --- Tests for convert_raw_prediction_to_sound_event_prediction --- - - def test_convert_raw_to_sound_event_basic( sample_raw_predictions, sample_recording, @@ -335,7 +314,7 @@ def test_convert_raw_to_sound_event_basic( generic_tags, ): """Test basic conversion, default threshold, multi-label.""" - # score=0.9, classes=[0.43(bat), 0.85(noise)] + raw_pred = sample_raw_predictions[0] se_pred = convert_raw_prediction_to_sound_event_prediction( @@ -343,14 +322,11 @@ def test_convert_raw_to_sound_event_basic( recording=sample_recording, sound_event_decoder=dummy_sound_event_decoder, generic_class_tags=generic_tags, - # classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD (0.1), - # top_class_only=False, ) assert isinstance(se_pred, data.SoundEventPrediction) - assert se_pred.score == pytest.approx(raw_pred.detection_score) + assert se_pred.score == raw_pred.detection_score - # Check SoundEvent se = se_pred.sound_event assert isinstance(se, data.SoundEvent) assert se.recording == sample_recording @@ -365,27 +341,21 @@ def test_convert_raw_to_sound_event_basic( ], ) assert len(se.features) == len(raw_pred.features) - # Simple check for feature presence and value type + feat_dict = {f.term.name: f.value for f in se.features} assert "batdetect2:f0" in feat_dict and isinstance( feat_dict["batdetect2:f0"], float ) - assert feat_dict["batdetect2:f0"] == pytest.approx(7.0) + assert feat_dict["batdetect2:f0"] == 7.0 - # Check Tags - # Expected: Generic(0.9), Noise(0.85), Bat(0.43) - # Note: Order might depend on sortby implementation detail, compare as sets expected_tags = { - # Generic Tag - (generic_tags[0].key, generic_tags[0].value, 0.9), - # Noise Tag (score 0.85 > 0.1) - ("category", "noise", 0.85), - # Bat Tag (score 0.43 > 0.1) - ("species", "Myotis", 0.43), + (generic_tags[0].term.name, generic_tags[0].value, 0.9), + ("soundevent:category", "noise", 0.85), + ("soundevent:species", "Myotis", 0.43), + } + actual_tags = { + (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags } - print("expected", expected_tags) - actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags} - print("actual", actual_tags) assert actual_tags == expected_tags @@ -396,9 +366,7 @@ def test_convert_raw_to_sound_event_thresholding( generic_tags, ): """Test effect of classification threshold.""" - raw_pred = sample_raw_predictions[ - 0 - ] # score=0.9, classes=[0.43(bat), 0.85(noise)] + raw_pred = sample_raw_predictions[0] high_threshold = 0.5 se_pred = convert_raw_prediction_to_sound_event_prediction( @@ -406,16 +374,17 @@ def test_convert_raw_to_sound_event_thresholding( recording=sample_recording, sound_event_decoder=dummy_sound_event_decoder, generic_class_tags=generic_tags, - classification_threshold=high_threshold, # Only noise should pass + classification_threshold=high_threshold, top_class_only=False, ) - # Expected: Generic(0.9), Noise(0.85) - Bat (0.43) is below threshold expected_tags = { - (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)), - ("category", "noise", pytest.approx(0.85)), + (generic_tags[0].term.name, generic_tags[0].value, 0.9), + ("soundevent:category", "noise", 0.85), + } + actual_tags = { + (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags } - actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags} assert actual_tags == expected_tags @@ -426,27 +395,25 @@ def test_convert_raw_to_sound_event_no_threshold( generic_tags, ): """Test when classification_threshold is None.""" - raw_pred = sample_raw_predictions[ - 2 - ] # score=0.15, classes=[0.05(bat), 0.02(noise)] - # Both classes are below default threshold, but should be included if None + raw_pred = sample_raw_predictions[2] se_pred = convert_raw_prediction_to_sound_event_prediction( raw_prediction=raw_pred, recording=sample_recording, sound_event_decoder=dummy_sound_event_decoder, generic_class_tags=generic_tags, - classification_threshold=None, # No thresholding + classification_threshold=None, top_class_only=False, ) - # Expected: Generic(0.15), Bat(0.05), Noise(0.02) expected_tags = { - (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)), - ("species", "Myotis", pytest.approx(0.05)), - ("category", "noise", pytest.approx(0.02)), + (generic_tags[0].term.name, generic_tags[0].value, 0.15), + ("soundevent:species", "Myotis", 0.05), + ("soundevent:category", "noise", 0.02), + } + actual_tags = { + (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags } - actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags} assert actual_tags == expected_tags @@ -457,10 +424,7 @@ def test_convert_raw_to_sound_event_top_class( generic_tags, ): """Test top_class_only=True behavior.""" - raw_pred = sample_raw_predictions[ - 0 - ] # score=0.9, classes=[0.43(bat), 0.85(noise)] - # Highest score is noise (0.85) + raw_pred = sample_raw_predictions[0] se_pred = convert_raw_prediction_to_sound_event_prediction( raw_prediction=raw_pred, @@ -468,15 +432,16 @@ def test_convert_raw_to_sound_event_top_class( sound_event_decoder=dummy_sound_event_decoder, generic_class_tags=generic_tags, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, - top_class_only=True, # Only include top class (noise) + top_class_only=True, ) - # Expected: Generic(0.9), Noise(0.85) expected_tags = { - (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)), - ("category", "noise", pytest.approx(0.85)), + (generic_tags[0].term.name, generic_tags[0].value, 0.9), + ("soundevent:category", "noise", 0.85), + } + actual_tags = { + (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags } - actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags} assert actual_tags == expected_tags @@ -487,30 +452,26 @@ def test_convert_raw_to_sound_event_all_below_threshold( generic_tags, ): """Test when all class scores are below the default threshold.""" - raw_pred = sample_raw_predictions[ - 2 - ] # score=0.15, classes=[0.05(bat), 0.02(noise)] + raw_pred = sample_raw_predictions[2] se_pred = convert_raw_prediction_to_sound_event_prediction( raw_prediction=raw_pred, recording=sample_recording, sound_event_decoder=dummy_sound_event_decoder, generic_class_tags=generic_tags, - classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, # 0.1 + classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only=False, ) - # Expected: Only Generic(0.15) tag, as others are below threshold expected_tags = { - (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)), + (generic_tags[0].term.name, generic_tags[0].value, 0.15), + } + actual_tags = { + (pt.tag.term.name, pt.tag.value, pt.score) for pt in se_pred.tags } - actual_tags = {(pt.tag.key, pt.tag.value, pt.score) for pt in se_pred.tags} assert actual_tags == expected_tags -# --- Tests for convert_raw_predictions_to_clip_prediction --- - - def test_convert_raw_list_to_clip_basic( sample_raw_predictions, sample_clip, @@ -531,25 +492,22 @@ def test_convert_raw_list_to_clip_basic( assert clip_pred.clip == sample_clip assert len(clip_pred.sound_events) == len(sample_raw_predictions) - # Check if the contained sound events seem correct (basic check) - assert clip_pred.sound_events[0].score == pytest.approx( + assert clip_pred.sound_events[0].score == ( sample_raw_predictions[0].detection_score ) - assert clip_pred.sound_events[1].score == pytest.approx( + assert clip_pred.sound_events[1].score == ( sample_raw_predictions[1].detection_score ) - assert clip_pred.sound_events[2].score == pytest.approx( + assert clip_pred.sound_events[2].score == ( sample_raw_predictions[2].detection_score ) - # Check if tags were generated correctly for one event (e.g., the last one) - # Pred 3 has score 0.15, classes [0.05, 0.02]. Only generic tag expected. se_pred3_tags = { - (pt.tag.key, pt.tag.value, pt.score) + (pt.tag.term.name, pt.tag.value, pt.score) for pt in clip_pred.sound_events[2].tags } expected_tags3 = { - (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.15)), + (generic_tags[0].term.name, generic_tags[0].value, 0.15), } assert se_pred3_tags == expected_tags3 @@ -579,26 +537,126 @@ def test_convert_raw_list_to_clip_passes_args( generic_tags, ): """Test that arguments like top_class_only are passed through.""" - # Use top_class_only = True + clip_pred = convert_raw_predictions_to_clip_prediction( raw_predictions=sample_raw_predictions, clip=sample_clip, sound_event_decoder=dummy_sound_event_decoder, generic_class_tags=generic_tags, classification_threshold=DEFAULT_CLASSIFICATION_THRESHOLD, - top_class_only=True, # <<-- Argument being tested + top_class_only=True, ) assert len(clip_pred.sound_events) == 3 - # Check tags for the first prediction (score=0.9, classes=[0.43(bat), 0.85(noise)]) - # With top_class_only=True, expect Generic(0.9) and Noise(0.85) only se_pred1_tags = { - (pt.tag.key, pt.tag.value, pt.score) + (pt.tag.term.name, pt.tag.value, pt.score) for pt in clip_pred.sound_events[0].tags } expected_tags1 = { - (generic_tags[0].key, generic_tags[0].value, pytest.approx(0.9)), - ("category", "noise", pytest.approx(0.85)), + (generic_tags[0].term.name, generic_tags[0].value, 0.9), + ("soundevent:category", "noise", 0.85), } assert se_pred1_tags == expected_tags1 + + +def test_get_generic_tags_basic(generic_tags): + """Test creation of generic tags with score.""" + detection_score = 0.75 + predicted_tags = get_generic_tags( + detection_score=detection_score, generic_class_tags=generic_tags + ) + assert len(predicted_tags) == len(generic_tags) + for predicted_tag in predicted_tags: + assert isinstance(predicted_tag, data.PredictedTag) + assert predicted_tag.score == detection_score + assert predicted_tag.tag in generic_tags + + +def test_get_prediction_features_basic(): + """Test conversion of feature DataArray to list of Features.""" + feature_data = xr.DataArray( + [1.1, 2.2, 3.3], + coords={"feature": ["feat1", "feat2", "feat3"]}, + dims=["feature"], + ) + features = get_prediction_features(feature_data) + assert len(features) == 3 + for feature, feat_name, feat_value in zip( + features, ["feat1", "feat2", "feat3"], [1.1, 2.2, 3.3] + ): + assert isinstance(feature, data.Feature) + assert feature.term.name == f"batdetect2:{feat_name}" + assert feature.value == feat_value + + +def test_get_class_tags_basic(dummy_sound_event_decoder): + """Test creation of class tags based on scores and decoder.""" + class_scores = xr.DataArray( + [0.6, 0.2, 0.9], + coords={"category": ["bat", "noise", "unknown"]}, + dims=["category"], + ) + predicted_tags = get_class_tags( + class_scores=class_scores, + sound_event_decoder=dummy_sound_event_decoder, + ) + assert len(predicted_tags) == 3 + tag_values = [pt.tag.value for pt in predicted_tags] + tag_scores = [pt.score for pt in predicted_tags] + + assert "Myotis" in tag_values + assert "noise" in tag_values + assert "uncertain" in tag_values + assert 0.6 in tag_scores + assert 0.2 in tag_scores + assert 0.9 in tag_scores + + +def test_get_class_tags_thresholding(dummy_sound_event_decoder): + """Test class tag creation with a threshold.""" + class_scores = xr.DataArray( + [0.6, 0.2, 0.9], + coords={"category": ["bat", "noise", "unknown"]}, + dims=["category"], + ) + threshold = 0.5 + predicted_tags = get_class_tags( + class_scores=class_scores, + sound_event_decoder=dummy_sound_event_decoder, + threshold=threshold, + ) + + assert len(predicted_tags) == 2 + tag_values = [pt.tag.value for pt in predicted_tags] + assert "Myotis" in tag_values + assert "noise" not in tag_values + assert "uncertain" in tag_values + + +def test_get_class_tags_top_class_only(dummy_sound_event_decoder): + """Test class tag creation with top_class_only.""" + class_scores = xr.DataArray( + [0.6, 0.2, 0.9], + coords={"category": ["bat", "noise", "unknown"]}, + dims=["category"], + ) + predicted_tags = get_class_tags( + class_scores=class_scores, + sound_event_decoder=dummy_sound_event_decoder, + top_class_only=True, + ) + + assert len(predicted_tags) == 1 + assert predicted_tags[0].tag.value == "uncertain" + assert predicted_tags[0].score == 0.9 + + +def test_get_class_tags_empty(dummy_sound_event_decoder): + """Test with empty class scores.""" + class_scores = xr.DataArray([], coords={"category": []}, dims=["category"]) + predicted_tags = get_class_tags( + class_scores=class_scores, + sound_event_decoder=dummy_sound_event_decoder, + ) + assert len(predicted_tags) == 0