fix: the case of no detections is now handled better

This commit is contained in:
Santiago Martinez 2023-05-11 13:59:20 +01:00
parent 10d090cc14
commit 04af74228b
2 changed files with 48 additions and 12 deletions

View File

@ -143,7 +143,19 @@ def load_model(
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m = {} predictions_m = {
"det_probs": np.array([]),
"x_pos": np.array([]),
"y_pos": np.array([]),
"bb_widths": np.array([]),
"bb_heights": np.array([]),
"start_times": np.array([]),
"end_times": np.array([]),
"low_freqs": np.array([]),
"high_freqs": np.array([]),
"class_probs": np.array([]),
}
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions]) num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
if num_preds > 0: if num_preds > 0:
@ -151,10 +163,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m[key] = np.hstack( predictions_m[key] = np.hstack(
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0] [pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
) )
else:
# hack in case where no detected calls as we need some of the key
# names in dict
predictions_m = predictions[0]
if len(spec_feats) > 0: if len(spec_feats) > 0:
spec_feats = np.vstack(spec_feats) spec_feats = np.vstack(spec_feats)
@ -226,11 +234,19 @@ def format_single_result(
Returns: Returns:
dict: Results in the format expected by the annotation tool. dict: Results in the format expected by the annotation tool.
""" """
# Get a single class prediction for the file try:
class_overall = pp.overall_class_pred( # Get a single class prediction for the file
predictions["det_probs"], class_overall = pp.overall_class_pred(
predictions["class_probs"], predictions["det_probs"],
) predictions["class_probs"],
)
class_name = class_names[np.argmax(class_overall)]
annotations = get_annotations_from_preds(predictions, class_names)
except (np.AxisError, ValueError):
# No detections
class_overall = np.zeros(len(class_names))
class_name = "None"
annotations = []
return { return {
"id": file_id, "id": file_id,
@ -239,8 +255,8 @@ def format_single_result(
"notes": "Automatically generated.", "notes": "Automatically generated.",
"time_exp": time_exp, "time_exp": time_exp,
"duration": round(float(duration), 4), "duration": round(float(duration), 4),
"annotation": get_annotations_from_preds(predictions, class_names), "annotation": annotations,
"class_name": class_names[np.argmax(class_overall)], "class_name": class_name,
} }

View File

@ -1,11 +1,14 @@
"""Test bat detect module API.""" """Test bat detect module API."""
from pathlib import Path
import os import os
from glob import glob from glob import glob
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
import soundfile as sf
from batdetect2 import api from batdetect2 import api
@ -262,3 +265,20 @@ def test_process_file_with_spec_slices():
assert "spec_slices" in results assert "spec_slices" in results
assert isinstance(results["spec_slices"], list) assert isinstance(results["spec_slices"], list)
assert len(results["spec_slices"]) == len(detections) assert len(results["spec_slices"]) == len(detections)
def test_process_file_with_empty_predictions_does_not_fail(
tmp_path: Path,
):
"""Test process file with empty predictions does not fail."""
# Create empty file
empty_file = tmp_path / "empty.wav"
empty_wav = np.zeros((0, 1), dtype=np.float32)
sf.write(empty_file, empty_wav, 256000)
# Process file
results = api.process_file(str(empty_file))
assert results is not None
assert len(results["pred_dict"]["annotation"]) == 0