Device fixing #3

This commit is contained in:
mbsantiago 2025-08-25 23:08:49 +01:00
parent d0bab60bf3
commit 9d4a9fc35c
2 changed files with 4 additions and 4 deletions

View File

@ -47,8 +47,8 @@ def extract_prediction_tensor(
indexing="ij", indexing="ij",
) )
freqs = freqs.flatten().to(detection_heatmap) freqs = freqs.flatten().to(detection_heatmap.device)
times = times.flatten().to(detection_heatmap) times = times.flatten().to(detection_heatmap.device)
output_size_preds = output.size_preds.detach() output_size_preds = output.size_preds.detach()
output_features = output.features.detach() output_features = output.features.detach()

View File

@ -210,8 +210,8 @@ def generate_heatmaps(
indexing="ij", indexing="ij",
) )
freqs = freqs.to(spec) freqs = freqs.to(spec.device)
times = times.to(spec) times = times.to(spec.device)
for sound_event_annotation in clip_annotation.sound_events: for sound_event_annotation in clip_annotation.sound_events:
geom = sound_event_annotation.sound_event.geometry geom = sound_event_annotation.sound_event.geometry