fix: move tensors to cpu when converting to numpy

This commit is contained in:
Santiago Martinez 2023-04-29 15:30:25 +01:00
parent 68d3842931
commit 57f2b50239
4 changed files with 16 additions and 16 deletions

View File

@ -19,13 +19,15 @@ def x_coords_to_time(
) -> float: ) -> float:
"""Convert x coordinates of spectrogram to time in seconds. """Convert x coordinates of spectrogram to time in seconds.
Args: Parameters
----------
x_pos: X position of the detection in pixels. x_pos: X position of the detection in pixels.
sampling_rate: Sampling rate of the audio in Hz. sampling_rate: Sampling rate of the audio in Hz.
fft_win_length: Length of the FFT window in seconds. fft_win_length: Length of the FFT window in seconds.
fft_overlap: Overlap of the FFT windows in seconds. fft_overlap: Overlap of the FFT windows in seconds.
Returns: Returns
-------
Time in seconds. Time in seconds.
""" """
nfft = int(fft_win_length * sampling_rate) nfft = int(fft_win_length * sampling_rate)
@ -134,12 +136,12 @@ def run_nms(
y_pos[num_detection, valid_inds], y_pos[num_detection, valid_inds],
x_pos[num_detection, valid_inds], x_pos[num_detection, valid_inds],
].transpose(0, 1) ].transpose(0, 1)
feat = feat.detach().numpy().astype(np.float32) feat = feat.detach().cpu().numpy().astype(np.float32)
feats.append(feat) feats.append(feat)
# convert to numpy # convert to numpy
for key, value in pred.items(): for key, value in pred.items():
pred[key] = value.detach().numpy().astype(np.float32) pred[key] = value.detach().cpu().numpy().astype(np.float32)
preds.append(pred) # type: ignore preds.append(pred) # type: ignore

View File

@ -265,7 +265,7 @@ def detection(
# Add class label # Add class label
txt = " ".join([sp[:3] for sp in det["class"].split(" ")]) txt = " ".join([sp[:3] for sp in det["class"].split(" ")])
font_info = { font_info = {
"color": "white", "color": edgecolor,
"size": 10, "size": 10,
"weight": "bold", "weight": "bold",
"alpha": rect.get_alpha(), "alpha": rect.get_alpha(),

View File

@ -7,15 +7,14 @@ import numpy as np
import torch import torch
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.detector import models
from batdetect2.detector import parameters
from batdetect2.train import losses
import batdetect2.detector.post_process as pp import batdetect2.detector.post_process as pp
import batdetect2.train.audio_dataloader as adl import batdetect2.train.audio_dataloader as adl
import batdetect2.train.evaluate as evl import batdetect2.train.evaluate as evl
import batdetect2.train.train_split as ts import batdetect2.train.train_split as ts
import batdetect2.train.train_utils as tu import batdetect2.train.train_utils as tu
import batdetect2.utils.plot_utils as pu import batdetect2.utils.plot_utils as pu
from batdetect2.detector import models, parameters
from batdetect2.train import losses
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
@ -84,7 +83,6 @@ def save_image(
def loss_fun( def loss_fun(
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
): ):
# detection loss # detection loss
loss = params["det_loss_weight"] * det_criterion( loss = params["det_loss_weight"] * det_criterion(
outputs["pred_det"], gt_det outputs["pred_det"], gt_det
@ -108,7 +106,6 @@ def loss_fun(
def train( def train(
model, epoch, data_loader, det_criterion, optimizer, scheduler, params model, epoch, data_loader, det_criterion, optimizer, scheduler, params
): ):
model.train() model.train()
train_loss = tu.AverageMeter() train_loss = tu.AverageMeter()
@ -119,7 +116,6 @@ def train(
print("\nEpoch", epoch) print("\nEpoch", epoch)
for batch_idx, inputs in enumerate(data_loader): for batch_idx, inputs in enumerate(data_loader):
data = inputs["spec"].to(params["device"]) data = inputs["spec"].to(params["device"])
gt_det = inputs["y_2d_det"].to(params["device"]) gt_det = inputs["y_2d_det"].to(params["device"])
gt_size = inputs["y_2d_size"].to(params["device"]) gt_size = inputs["y_2d_size"].to(params["device"])
@ -172,7 +168,6 @@ def test(model, epoch, data_loader, det_criterion, params):
with torch.no_grad(): with torch.no_grad():
for batch_idx, inputs in enumerate(data_loader): for batch_idx, inputs in enumerate(data_loader):
data = inputs["spec"].to(params["device"]) data = inputs["spec"].to(params["device"])
gt_det = inputs["y_2d_det"].to(params["device"]) gt_det = inputs["y_2d_det"].to(params["device"])
gt_size = inputs["y_2d_size"].to(params["device"]) gt_size = inputs["y_2d_size"].to(params["device"])
@ -279,7 +274,7 @@ def parse_gt_data(inputs):
is_valid = inputs["is_valid"][ind] == 1 is_valid = inputs["is_valid"][ind] == 1
gt = {} gt = {}
for kk in keys: for kk in keys:
gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32) gt[kk] = inputs[kk][ind][is_valid].cpu().numpy().astype(np.float32)
gt["duration"] = inputs["duration"][ind].item() gt["duration"] = inputs["duration"][ind].item()
gt["file_id"] = inputs["file_id"][ind].item() gt["file_id"] = inputs["file_id"][ind].item()
gt["class_id_file"] = inputs["class_id_file"][ind].item() gt["class_id_file"] = inputs["class_id_file"][ind].item()
@ -318,8 +313,7 @@ def select_model(params):
return model return model
if __name__ == "__main__": def main():
plt.close("all") plt.close("all")
params = parameters.get_params(True) params = parameters.get_params(True)
@ -501,7 +495,6 @@ if __name__ == "__main__":
# #
# main train loop # main train loop
for epoch in range(0, params["num_epochs"] + 1): for epoch in range(0, params["num_epochs"] + 1):
train_loss = train( train_loss = train(
model, model,
epoch, epoch,
@ -550,3 +543,7 @@ if __name__ == "__main__":
# TODO: args variable does not exist # TODO: args variable does not exist
# if not args["do_not_save_images"]: # if not args["do_not_save_images"]:
# save_images_batch(model, test_loader, params) # save_images_batch(model, test_loader, params)
if __name__ == "__main__":
main()

View File

@ -731,6 +731,7 @@ def process_file(
spec_slices = [] spec_slices = []
# load audio file # load audio file
print("time_exp_fact", config.get("time_expansion", 1) or 1)
sampling_rate, audio_full = au.load_audio( sampling_rate, audio_full = au.load_audio(
audio_file, audio_file,
time_exp_fact=config.get("time_expansion", 1) or 1, time_exp_fact=config.get("time_expansion", 1) or 1,