mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Device fixing #2
This commit is contained in:
parent
a267db290c
commit
d0bab60bf3
@ -47,8 +47,8 @@ def extract_prediction_tensor(
|
|||||||
indexing="ij",
|
indexing="ij",
|
||||||
)
|
)
|
||||||
|
|
||||||
freqs = freqs.flatten()
|
freqs = freqs.flatten().to(detection_heatmap)
|
||||||
times = times.flatten()
|
times = times.flatten().to(detection_heatmap)
|
||||||
|
|
||||||
output_size_preds = output.size_preds.detach()
|
output_size_preds = output.size_preds.detach()
|
||||||
output_features = output.features.detach()
|
output_features = output.features.detach()
|
||||||
@ -58,7 +58,6 @@ def extract_prediction_tensor(
|
|||||||
for idx, item in enumerate(detection_heatmap):
|
for idx, item in enumerate(detection_heatmap):
|
||||||
item = item.squeeze().flatten() # Remove channel dim
|
item = item.squeeze().flatten() # Remove channel dim
|
||||||
indices = torch.argsort(item, descending=True)[:max_detections]
|
indices = torch.argsort(item, descending=True)[:max_detections]
|
||||||
indices.to(detection_heatmap)
|
|
||||||
|
|
||||||
detection_scores = item.take(indices)
|
detection_scores = item.take(indices)
|
||||||
detection_freqs = freqs.take(indices)
|
detection_freqs = freqs.take(indices)
|
||||||
|
|||||||
@ -210,6 +210,9 @@ def generate_heatmaps(
|
|||||||
indexing="ij",
|
indexing="ij",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
freqs = freqs.to(spec)
|
||||||
|
times = times.to(spec)
|
||||||
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
for sound_event_annotation in clip_annotation.sound_events:
|
||||||
geom = sound_event_annotation.sound_event.geometry
|
geom = sound_event_annotation.sound_event.geometry
|
||||||
if geom is None:
|
if geom is None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user