diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index d6d2b13..8074d80 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -143,7 +143,19 @@ def load_model( def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): - predictions_m = {} + predictions_m = { + "det_probs": np.array([]), + "x_pos": np.array([]), + "y_pos": np.array([]), + "bb_widths": np.array([]), + "bb_heights": np.array([]), + "start_times": np.array([]), + "end_times": np.array([]), + "low_freqs": np.array([]), + "high_freqs": np.array([]), + "class_probs": np.array([]), + } + num_preds = np.sum([len(pp["det_probs"]) for pp in predictions]) if num_preds > 0: @@ -151,10 +163,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): predictions_m[key] = np.hstack( [pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0] ) - else: - # hack in case where no detected calls as we need some of the key - # names in dict - predictions_m = predictions[0] if len(spec_feats) > 0: spec_feats = np.vstack(spec_feats) @@ -226,11 +234,19 @@ def format_single_result( Returns: dict: Results in the format expected by the annotation tool. """ - # Get a single class prediction for the file - class_overall = pp.overall_class_pred( - predictions["det_probs"], - predictions["class_probs"], - ) + try: + # Get a single class prediction for the file + class_overall = pp.overall_class_pred( + predictions["det_probs"], + predictions["class_probs"], + ) + class_name = class_names[np.argmax(class_overall)] + annotations = get_annotations_from_preds(predictions, class_names) + except (np.AxisError, ValueError): + # No detections + class_overall = np.zeros(len(class_names)) + class_name = "None" + annotations = [] return { "id": file_id, @@ -239,8 +255,8 @@ def format_single_result( "notes": "Automatically generated.", "time_exp": time_exp, "duration": round(float(duration), 4), - "annotation": get_annotations_from_preds(predictions, class_names), - "class_name": class_names[np.argmax(class_overall)], + "annotation": annotations, + "class_name": class_name, } diff --git a/tests/test_api.py b/tests/test_api.py index 942a1f1..d28c733 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,11 +1,14 @@ """Test bat detect module API.""" +from pathlib import Path + import os from glob import glob import numpy as np import torch from torch import nn +import soundfile as sf from batdetect2 import api @@ -262,3 +265,20 @@ def test_process_file_with_spec_slices(): assert "spec_slices" in results assert isinstance(results["spec_slices"], list) assert len(results["spec_slices"]) == len(detections) + + + +def test_process_file_with_empty_predictions_does_not_fail( + tmp_path: Path, +): + """Test process file with empty predictions does not fail.""" + # Create empty file + empty_file = tmp_path / "empty.wav" + empty_wav = np.zeros((0, 1), dtype=np.float32) + sf.write(empty_file, empty_wav, 256000) + + # Process file + results = api.process_file(str(empty_file)) + + assert results is not None + assert len(results["pred_dict"]["annotation"]) == 0