fix: the case of no detections is now handled better

This commit is contained in:
Santiago Martinez 2023-05-11 13:59:20 +01:00
parent 10d090cc14
commit 04af74228b
2 changed files with 48 additions and 12 deletions

View File

@ -143,7 +143,19 @@ def load_model(
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m = {}
predictions_m = {
"det_probs": np.array([]),
"x_pos": np.array([]),
"y_pos": np.array([]),
"bb_widths": np.array([]),
"bb_heights": np.array([]),
"start_times": np.array([]),
"end_times": np.array([]),
"low_freqs": np.array([]),
"high_freqs": np.array([]),
"class_probs": np.array([]),
}
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
if num_preds > 0:
@ -151,10 +163,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m[key] = np.hstack(
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
)
else:
# hack in case where no detected calls as we need some of the key
# names in dict
predictions_m = predictions[0]
if len(spec_feats) > 0:
spec_feats = np.vstack(spec_feats)
@ -226,11 +234,19 @@ def format_single_result(
Returns:
dict: Results in the format expected by the annotation tool.
"""
try:
# Get a single class prediction for the file
class_overall = pp.overall_class_pred(
predictions["det_probs"],
predictions["class_probs"],
)
class_name = class_names[np.argmax(class_overall)]
annotations = get_annotations_from_preds(predictions, class_names)
except (np.AxisError, ValueError):
# No detections
class_overall = np.zeros(len(class_names))
class_name = "None"
annotations = []
return {
"id": file_id,
@ -239,8 +255,8 @@ def format_single_result(
"notes": "Automatically generated.",
"time_exp": time_exp,
"duration": round(float(duration), 4),
"annotation": get_annotations_from_preds(predictions, class_names),
"class_name": class_names[np.argmax(class_overall)],
"annotation": annotations,
"class_name": class_name,
}

View File

@ -1,11 +1,14 @@
"""Test bat detect module API."""
from pathlib import Path
import os
from glob import glob
import numpy as np
import torch
from torch import nn
import soundfile as sf
from batdetect2 import api
@ -262,3 +265,20 @@ def test_process_file_with_spec_slices():
assert "spec_slices" in results
assert isinstance(results["spec_slices"], list)
assert len(results["spec_slices"]) == len(detections)
def test_process_file_with_empty_predictions_does_not_fail(
tmp_path: Path,
):
"""Test process file with empty predictions does not fail."""
# Create empty file
empty_file = tmp_path / "empty.wav"
empty_wav = np.zeros((0, 1), dtype=np.float32)
sf.write(empty_file, empty_wav, 256000)
# Process file
results = api.process_file(str(empty_file))
assert results is not None
assert len(results["pred_dict"]["annotation"]) == 0