mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Device fixing
This commit is contained in:
parent
441ccb3382
commit
a267db290c
@ -34,7 +34,7 @@ def extract_prediction_tensor(
|
||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
) -> List[Detections]:
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs,
|
||||
output.detection_probs.detach(),
|
||||
kernel_size=nms_kernel_size,
|
||||
)
|
||||
|
||||
@ -50,17 +50,22 @@ def extract_prediction_tensor(
|
||||
freqs = freqs.flatten()
|
||||
times = times.flatten()
|
||||
|
||||
output_size_preds = output.size_preds.detach()
|
||||
output_features = output.features.detach()
|
||||
output_class_probs = output.class_probs.detach()
|
||||
|
||||
predictions = []
|
||||
for idx, item in enumerate(detection_heatmap):
|
||||
item = item.squeeze().flatten() # Remove channel dim
|
||||
indices = torch.argsort(item, descending=True)[:max_detections]
|
||||
indices.to(detection_heatmap)
|
||||
|
||||
detection_scores = item.take(indices)
|
||||
detection_freqs = freqs.take(indices)
|
||||
detection_times = times.take(indices)
|
||||
sizes = output.size_preds[idx, :, detection_freqs, detection_times].T
|
||||
features = output.features[idx, :, detection_freqs, detection_times].T
|
||||
class_scores = output.class_probs[
|
||||
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
||||
features = output_features[idx, :, detection_freqs, detection_times].T
|
||||
class_scores = output_class_probs[
|
||||
idx, :, detection_freqs, detection_times
|
||||
].T
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user