Device fixing #2

This commit is contained in:
mbsantiago 2025-08-25 23:07:09 +01:00
parent a267db290c
commit d0bab60bf3
2 changed files with 5 additions and 3 deletions

View File

@ -47,8 +47,8 @@ def extract_prediction_tensor(
indexing="ij", indexing="ij",
) )
freqs = freqs.flatten() freqs = freqs.flatten().to(detection_heatmap)
times = times.flatten() times = times.flatten().to(detection_heatmap)
output_size_preds = output.size_preds.detach() output_size_preds = output.size_preds.detach()
output_features = output.features.detach() output_features = output.features.detach()
@ -58,7 +58,6 @@ def extract_prediction_tensor(
for idx, item in enumerate(detection_heatmap): for idx, item in enumerate(detection_heatmap):
item = item.squeeze().flatten() # Remove channel dim item = item.squeeze().flatten() # Remove channel dim
indices = torch.argsort(item, descending=True)[:max_detections] indices = torch.argsort(item, descending=True)[:max_detections]
indices.to(detection_heatmap)
detection_scores = item.take(indices) detection_scores = item.take(indices)
detection_freqs = freqs.take(indices) detection_freqs = freqs.take(indices)

View File

@ -210,6 +210,9 @@ def generate_heatmaps(
indexing="ij", indexing="ij",
) )
freqs = freqs.to(spec)
times = times.to(spec)
for sound_event_annotation in clip_annotation.sound_events: for sound_event_annotation in clip_annotation.sound_events:
geom = sound_event_annotation.sound_event.geometry geom = sound_event_annotation.sound_event.geometry
if geom is None: if geom is None: