From 9d4a9fc35cdff8f4a8378a27c69c4c2d6c7b3990 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 25 Aug 2025 23:08:49 +0100 Subject: [PATCH] Device fixing #3 --- src/batdetect2/postprocess/extraction.py | 4 ++-- src/batdetect2/train/labels.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/batdetect2/postprocess/extraction.py b/src/batdetect2/postprocess/extraction.py index bb4d1ae..f0635e3 100644 --- a/src/batdetect2/postprocess/extraction.py +++ b/src/batdetect2/postprocess/extraction.py @@ -47,8 +47,8 @@ def extract_prediction_tensor( indexing="ij", ) - freqs = freqs.flatten().to(detection_heatmap) - times = times.flatten().to(detection_heatmap) + freqs = freqs.flatten().to(detection_heatmap.device) + times = times.flatten().to(detection_heatmap.device) output_size_preds = output.size_preds.detach() output_features = output.features.detach() diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index c5af1f5..3dd9da6 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -210,8 +210,8 @@ def generate_heatmaps( indexing="ij", ) - freqs = freqs.to(spec) - times = times.to(spec) + freqs = freqs.to(spec.device) + times = times.to(spec.device) for sound_event_annotation in clip_annotation.sound_events: geom = sound_event_annotation.sound_event.geometry