mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
fix: the case of no detections is now handled better
This commit is contained in:
parent
10d090cc14
commit
04af74228b
@ -143,7 +143,19 @@ def load_model(
|
|||||||
|
|
||||||
|
|
||||||
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||||
predictions_m = {}
|
predictions_m = {
|
||||||
|
"det_probs": np.array([]),
|
||||||
|
"x_pos": np.array([]),
|
||||||
|
"y_pos": np.array([]),
|
||||||
|
"bb_widths": np.array([]),
|
||||||
|
"bb_heights": np.array([]),
|
||||||
|
"start_times": np.array([]),
|
||||||
|
"end_times": np.array([]),
|
||||||
|
"low_freqs": np.array([]),
|
||||||
|
"high_freqs": np.array([]),
|
||||||
|
"class_probs": np.array([]),
|
||||||
|
}
|
||||||
|
|
||||||
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
|
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
|
||||||
|
|
||||||
if num_preds > 0:
|
if num_preds > 0:
|
||||||
@ -151,10 +163,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
|||||||
predictions_m[key] = np.hstack(
|
predictions_m[key] = np.hstack(
|
||||||
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# hack in case where no detected calls as we need some of the key
|
|
||||||
# names in dict
|
|
||||||
predictions_m = predictions[0]
|
|
||||||
|
|
||||||
if len(spec_feats) > 0:
|
if len(spec_feats) > 0:
|
||||||
spec_feats = np.vstack(spec_feats)
|
spec_feats = np.vstack(spec_feats)
|
||||||
@ -226,11 +234,19 @@ def format_single_result(
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Results in the format expected by the annotation tool.
|
dict: Results in the format expected by the annotation tool.
|
||||||
"""
|
"""
|
||||||
# Get a single class prediction for the file
|
try:
|
||||||
class_overall = pp.overall_class_pred(
|
# Get a single class prediction for the file
|
||||||
predictions["det_probs"],
|
class_overall = pp.overall_class_pred(
|
||||||
predictions["class_probs"],
|
predictions["det_probs"],
|
||||||
)
|
predictions["class_probs"],
|
||||||
|
)
|
||||||
|
class_name = class_names[np.argmax(class_overall)]
|
||||||
|
annotations = get_annotations_from_preds(predictions, class_names)
|
||||||
|
except (np.AxisError, ValueError):
|
||||||
|
# No detections
|
||||||
|
class_overall = np.zeros(len(class_names))
|
||||||
|
class_name = "None"
|
||||||
|
annotations = []
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": file_id,
|
"id": file_id,
|
||||||
@ -239,8 +255,8 @@ def format_single_result(
|
|||||||
"notes": "Automatically generated.",
|
"notes": "Automatically generated.",
|
||||||
"time_exp": time_exp,
|
"time_exp": time_exp,
|
||||||
"duration": round(float(duration), 4),
|
"duration": round(float(duration), 4),
|
||||||
"annotation": get_annotations_from_preds(predictions, class_names),
|
"annotation": annotations,
|
||||||
"class_name": class_names[np.argmax(class_overall)],
|
"class_name": class_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
"""Test bat detect module API."""
|
"""Test bat detect module API."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
|
|
||||||
@ -262,3 +265,20 @@ def test_process_file_with_spec_slices():
|
|||||||
assert "spec_slices" in results
|
assert "spec_slices" in results
|
||||||
assert isinstance(results["spec_slices"], list)
|
assert isinstance(results["spec_slices"], list)
|
||||||
assert len(results["spec_slices"]) == len(detections)
|
assert len(results["spec_slices"]) == len(detections)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_file_with_empty_predictions_does_not_fail(
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
"""Test process file with empty predictions does not fail."""
|
||||||
|
# Create empty file
|
||||||
|
empty_file = tmp_path / "empty.wav"
|
||||||
|
empty_wav = np.zeros((0, 1), dtype=np.float32)
|
||||||
|
sf.write(empty_file, empty_wav, 256000)
|
||||||
|
|
||||||
|
# Process file
|
||||||
|
results = api.process_file(str(empty_file))
|
||||||
|
|
||||||
|
assert results is not None
|
||||||
|
assert len(results["pred_dict"]["annotation"]) == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user