diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index 499b2de..c1a283c 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -34,12 +34,12 @@ def convert_detections_to_raw_predictions( predictions = [] for score, class_scores, time, freq, dims, feats in zip( - detections.scores, - detections.class_scores, - detections.times, - detections.frequencies, - detections.sizes, - detections.features, + detections.scores.cpu().numpy(), + detections.class_scores.cpu().numpy(), + detections.times.cpu().numpy(), + detections.frequencies.cpu().numpy(), + detections.sizes.cpu().numpy(), + detections.features.cpu().numpy(), ): highest_scoring_class = targets.class_names[class_scores.argmax()]