diff --git a/src/batdetect2/postprocess/extraction.py b/src/batdetect2/postprocess/extraction.py index bb4d1ae..f0635e3 100644 --- a/src/batdetect2/postprocess/extraction.py +++ b/src/batdetect2/postprocess/extraction.py @@ -47,8 +47,8 @@ def extract_prediction_tensor( indexing="ij", ) - freqs = freqs.flatten().to(detection_heatmap) - times = times.flatten().to(detection_heatmap) + freqs = freqs.flatten().to(detection_heatmap.device) + times = times.flatten().to(detection_heatmap.device) output_size_preds = output.size_preds.detach() output_features = output.features.detach() diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index c5af1f5..3dd9da6 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -210,8 +210,8 @@ def generate_heatmaps( indexing="ij", ) - freqs = freqs.to(spec) - times = times.to(spec) + freqs = freqs.to(spec.device) + times = times.to(spec.device) for sound_event_annotation in clip_annotation.sound_events: geom = sound_event_annotation.sound_event.geometry