mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Device fixing #2
This commit is contained in:
parent
a267db290c
commit
d0bab60bf3
@ -47,8 +47,8 @@ def extract_prediction_tensor(
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
freqs = freqs.flatten()
|
||||
times = times.flatten()
|
||||
freqs = freqs.flatten().to(detection_heatmap)
|
||||
times = times.flatten().to(detection_heatmap)
|
||||
|
||||
output_size_preds = output.size_preds.detach()
|
||||
output_features = output.features.detach()
|
||||
@ -58,7 +58,6 @@ def extract_prediction_tensor(
|
||||
for idx, item in enumerate(detection_heatmap):
|
||||
item = item.squeeze().flatten() # Remove channel dim
|
||||
indices = torch.argsort(item, descending=True)[:max_detections]
|
||||
indices.to(detection_heatmap)
|
||||
|
||||
detection_scores = item.take(indices)
|
||||
detection_freqs = freqs.take(indices)
|
||||
|
||||
@ -210,6 +210,9 @@ def generate_heatmaps(
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
freqs = freqs.to(spec)
|
||||
times = times.to(spec)
|
||||
|
||||
for sound_event_annotation in clip_annotation.sound_events:
|
||||
geom = sound_event_annotation.sound_event.geometry
|
||||
if geom is None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user