diff --git a/src/batdetect2/postprocess/extraction.py b/src/batdetect2/postprocess/extraction.py index 416824c..bb4d1ae 100644 --- a/src/batdetect2/postprocess/extraction.py +++ b/src/batdetect2/postprocess/extraction.py @@ -47,8 +47,8 @@ def extract_prediction_tensor( indexing="ij", ) - freqs = freqs.flatten() - times = times.flatten() + freqs = freqs.flatten().to(detection_heatmap) + times = times.flatten().to(detection_heatmap) output_size_preds = output.size_preds.detach() output_features = output.features.detach() @@ -58,7 +58,6 @@ def extract_prediction_tensor( for idx, item in enumerate(detection_heatmap): item = item.squeeze().flatten() # Remove channel dim indices = torch.argsort(item, descending=True)[:max_detections] - indices.to(detection_heatmap) detection_scores = item.take(indices) detection_freqs = freqs.take(indices) diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 9a668db..c5af1f5 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -210,6 +210,9 @@ def generate_heatmaps( indexing="ij", ) + freqs = freqs.to(spec) + times = times.to(spec) + for sound_event_annotation in clip_annotation.sound_events: geom = sound_event_annotation.sound_event.geometry if geom is None: