mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
fix: the case of no detections is now handled better
This commit is contained in:
parent
10d090cc14
commit
04af74228b
@ -143,7 +143,19 @@ def load_model(
|
||||
|
||||
|
||||
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
predictions_m = {}
|
||||
predictions_m = {
|
||||
"det_probs": np.array([]),
|
||||
"x_pos": np.array([]),
|
||||
"y_pos": np.array([]),
|
||||
"bb_widths": np.array([]),
|
||||
"bb_heights": np.array([]),
|
||||
"start_times": np.array([]),
|
||||
"end_times": np.array([]),
|
||||
"low_freqs": np.array([]),
|
||||
"high_freqs": np.array([]),
|
||||
"class_probs": np.array([]),
|
||||
}
|
||||
|
||||
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
|
||||
|
||||
if num_preds > 0:
|
||||
@ -151,10 +163,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
predictions_m[key] = np.hstack(
|
||||
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
||||
)
|
||||
else:
|
||||
# hack in case where no detected calls as we need some of the key
|
||||
# names in dict
|
||||
predictions_m = predictions[0]
|
||||
|
||||
if len(spec_feats) > 0:
|
||||
spec_feats = np.vstack(spec_feats)
|
||||
@ -226,11 +234,19 @@ def format_single_result(
|
||||
Returns:
|
||||
dict: Results in the format expected by the annotation tool.
|
||||
"""
|
||||
try:
|
||||
# Get a single class prediction for the file
|
||||
class_overall = pp.overall_class_pred(
|
||||
predictions["det_probs"],
|
||||
predictions["class_probs"],
|
||||
)
|
||||
class_name = class_names[np.argmax(class_overall)]
|
||||
annotations = get_annotations_from_preds(predictions, class_names)
|
||||
except (np.AxisError, ValueError):
|
||||
# No detections
|
||||
class_overall = np.zeros(len(class_names))
|
||||
class_name = "None"
|
||||
annotations = []
|
||||
|
||||
return {
|
||||
"id": file_id,
|
||||
@ -239,8 +255,8 @@ def format_single_result(
|
||||
"notes": "Automatically generated.",
|
||||
"time_exp": time_exp,
|
||||
"duration": round(float(duration), 4),
|
||||
"annotation": get_annotations_from_preds(predictions, class_names),
|
||||
"class_name": class_names[np.argmax(class_overall)],
|
||||
"annotation": annotations,
|
||||
"class_name": class_name,
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,11 +1,14 @@
|
||||
"""Test bat detect module API."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import soundfile as sf
|
||||
|
||||
from batdetect2 import api
|
||||
|
||||
@ -262,3 +265,20 @@ def test_process_file_with_spec_slices():
|
||||
assert "spec_slices" in results
|
||||
assert isinstance(results["spec_slices"], list)
|
||||
assert len(results["spec_slices"]) == len(detections)
|
||||
|
||||
|
||||
|
||||
def test_process_file_with_empty_predictions_does_not_fail(
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""Test process file with empty predictions does not fail."""
|
||||
# Create empty file
|
||||
empty_file = tmp_path / "empty.wav"
|
||||
empty_wav = np.zeros((0, 1), dtype=np.float32)
|
||||
sf.write(empty_file, empty_wav, 256000)
|
||||
|
||||
# Process file
|
||||
results = api.process_file(str(empty_file))
|
||||
|
||||
assert results is not None
|
||||
assert len(results["pred_dict"]["annotation"]) == 0
|
||||
|
Loading…
Reference in New Issue
Block a user