From a267db290cc140460b2b18d951a444962aaf4d3e Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 25 Aug 2025 23:04:13 +0100 Subject: [PATCH] Device fixing --- src/batdetect2/postprocess/extraction.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/batdetect2/postprocess/extraction.py b/src/batdetect2/postprocess/extraction.py index 361d936..416824c 100644 --- a/src/batdetect2/postprocess/extraction.py +++ b/src/batdetect2/postprocess/extraction.py @@ -34,7 +34,7 @@ def extract_prediction_tensor( nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, ) -> List[Detections]: detection_heatmap = non_max_suppression( - output.detection_probs, + output.detection_probs.detach(), kernel_size=nms_kernel_size, ) @@ -50,17 +50,22 @@ def extract_prediction_tensor( freqs = freqs.flatten() times = times.flatten() + output_size_preds = output.size_preds.detach() + output_features = output.features.detach() + output_class_probs = output.class_probs.detach() + predictions = [] for idx, item in enumerate(detection_heatmap): item = item.squeeze().flatten() # Remove channel dim indices = torch.argsort(item, descending=True)[:max_detections] + indices.to(detection_heatmap) detection_scores = item.take(indices) detection_freqs = freqs.take(indices) detection_times = times.take(indices) - sizes = output.size_preds[idx, :, detection_freqs, detection_times].T - features = output.features[idx, :, detection_freqs, detection_times].T - class_scores = output.class_probs[ + sizes = output_size_preds[idx, :, detection_freqs, detection_times].T + features = output_features[idx, :, detection_freqs, detection_times].T + class_scores = output_class_probs[ idx, :, detection_freqs, detection_times ].T