Merge pull request #11 from macaodha/fix/GH-10-merge-results-index-error

fix: the case of no detections is now handled better
This commit is contained in:
Santiago Martinez Balvanera 2023-05-11 14:05:57 +01:00 committed by GitHub
commit e5370e98db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 12 deletions

View File

@ -143,7 +143,19 @@ def load_model(
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m = {}
predictions_m = {
"det_probs": np.array([]),
"x_pos": np.array([]),
"y_pos": np.array([]),
"bb_widths": np.array([]),
"bb_heights": np.array([]),
"start_times": np.array([]),
"end_times": np.array([]),
"low_freqs": np.array([]),
"high_freqs": np.array([]),
"class_probs": np.array([]),
}
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
if num_preds > 0:
@ -151,10 +163,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m[key] = np.hstack(
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
)
else:
# hack in case where no detected calls as we need some of the key
# names in dict
predictions_m = predictions[0]
if len(spec_feats) > 0:
spec_feats = np.vstack(spec_feats)
@ -226,11 +234,19 @@ def format_single_result(
Returns:
dict: Results in the format expected by the annotation tool.
"""
# Get a single class prediction for the file
class_overall = pp.overall_class_pred(
predictions["det_probs"],
predictions["class_probs"],
)
try:
# Get a single class prediction for the file
class_overall = pp.overall_class_pred(
predictions["det_probs"],
predictions["class_probs"],
)
class_name = class_names[np.argmax(class_overall)]
annotations = get_annotations_from_preds(predictions, class_names)
except (np.AxisError, ValueError):
# No detections
class_overall = np.zeros(len(class_names))
class_name = "None"
annotations = []
return {
"id": file_id,
@ -239,8 +255,8 @@ def format_single_result(
"notes": "Automatically generated.",
"time_exp": time_exp,
"duration": round(float(duration), 4),
"annotation": get_annotations_from_preds(predictions, class_names),
"class_name": class_names[np.argmax(class_overall)],
"annotation": annotations,
"class_name": class_name,
}

View File

@ -1,11 +1,14 @@
"""Test bat detect module API."""
from pathlib import Path
import os
from glob import glob
import numpy as np
import torch
from torch import nn
import soundfile as sf
from batdetect2 import api
@ -262,3 +265,20 @@ def test_process_file_with_spec_slices():
assert "spec_slices" in results
assert isinstance(results["spec_slices"], list)
assert len(results["spec_slices"]) == len(detections)
def test_process_file_with_empty_predictions_does_not_fail(
tmp_path: Path,
):
"""Test process file with empty predictions does not fail."""
# Create empty file
empty_file = tmp_path / "empty.wav"
empty_wav = np.zeros((0, 1), dtype=np.float32)
sf.write(empty_file, empty_wav, 256000)
# Process file
results = api.process_file(str(empty_file))
assert results is not None
assert len(results["pred_dict"]["annotation"]) == 0