Device fixing

This commit is contained in:
mbsantiago 2025-08-25 23:04:13 +01:00
parent 441ccb3382
commit a267db290c

View File

@ -34,7 +34,7 @@ def extract_prediction_tensor(
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> List[Detections]: ) -> List[Detections]:
detection_heatmap = non_max_suppression( detection_heatmap = non_max_suppression(
output.detection_probs, output.detection_probs.detach(),
kernel_size=nms_kernel_size, kernel_size=nms_kernel_size,
) )
@ -50,17 +50,22 @@ def extract_prediction_tensor(
freqs = freqs.flatten() freqs = freqs.flatten()
times = times.flatten() times = times.flatten()
output_size_preds = output.size_preds.detach()
output_features = output.features.detach()
output_class_probs = output.class_probs.detach()
predictions = [] predictions = []
for idx, item in enumerate(detection_heatmap): for idx, item in enumerate(detection_heatmap):
item = item.squeeze().flatten() # Remove channel dim item = item.squeeze().flatten() # Remove channel dim
indices = torch.argsort(item, descending=True)[:max_detections] indices = torch.argsort(item, descending=True)[:max_detections]
indices.to(detection_heatmap)
detection_scores = item.take(indices) detection_scores = item.take(indices)
detection_freqs = freqs.take(indices) detection_freqs = freqs.take(indices)
detection_times = times.take(indices) detection_times = times.take(indices)
sizes = output.size_preds[idx, :, detection_freqs, detection_times].T sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
features = output.features[idx, :, detection_freqs, detection_times].T features = output_features[idx, :, detection_freqs, detection_times].T
class_scores = output.class_probs[ class_scores = output_class_probs[
idx, :, detection_freqs, detection_times idx, :, detection_freqs, detection_times
].T ].T