mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Device fixing
This commit is contained in:
parent
441ccb3382
commit
a267db290c
@ -34,7 +34,7 @@ def extract_prediction_tensor(
|
|||||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||||
) -> List[Detections]:
|
) -> List[Detections]:
|
||||||
detection_heatmap = non_max_suppression(
|
detection_heatmap = non_max_suppression(
|
||||||
output.detection_probs,
|
output.detection_probs.detach(),
|
||||||
kernel_size=nms_kernel_size,
|
kernel_size=nms_kernel_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -50,17 +50,22 @@ def extract_prediction_tensor(
|
|||||||
freqs = freqs.flatten()
|
freqs = freqs.flatten()
|
||||||
times = times.flatten()
|
times = times.flatten()
|
||||||
|
|
||||||
|
output_size_preds = output.size_preds.detach()
|
||||||
|
output_features = output.features.detach()
|
||||||
|
output_class_probs = output.class_probs.detach()
|
||||||
|
|
||||||
predictions = []
|
predictions = []
|
||||||
for idx, item in enumerate(detection_heatmap):
|
for idx, item in enumerate(detection_heatmap):
|
||||||
item = item.squeeze().flatten() # Remove channel dim
|
item = item.squeeze().flatten() # Remove channel dim
|
||||||
indices = torch.argsort(item, descending=True)[:max_detections]
|
indices = torch.argsort(item, descending=True)[:max_detections]
|
||||||
|
indices.to(detection_heatmap)
|
||||||
|
|
||||||
detection_scores = item.take(indices)
|
detection_scores = item.take(indices)
|
||||||
detection_freqs = freqs.take(indices)
|
detection_freqs = freqs.take(indices)
|
||||||
detection_times = times.take(indices)
|
detection_times = times.take(indices)
|
||||||
sizes = output.size_preds[idx, :, detection_freqs, detection_times].T
|
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
||||||
features = output.features[idx, :, detection_freqs, detection_times].T
|
features = output_features[idx, :, detection_freqs, detection_times].T
|
||||||
class_scores = output.class_probs[
|
class_scores = output_class_probs[
|
||||||
idx, :, detection_freqs, detection_times
|
idx, :, detection_freqs, detection_times
|
||||||
].T
|
].T
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user