From 57f2b5023904d9eb10db38c118b0dfda9ec886b5 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Sat, 29 Apr 2023 15:30:25 +0100 Subject: [PATCH] fix: move tensors to cpu when converting to numpy --- batdetect2/detector/post_process.py | 10 ++++++---- batdetect2/plot.py | 2 +- batdetect2/train/train_model.py | 19 ++++++++----------- batdetect2/utils/detector_utils.py | 1 + 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/batdetect2/detector/post_process.py b/batdetect2/detector/post_process.py index a2ba353..b47eec6 100644 --- a/batdetect2/detector/post_process.py +++ b/batdetect2/detector/post_process.py @@ -19,13 +19,15 @@ def x_coords_to_time( ) -> float: """Convert x coordinates of spectrogram to time in seconds. - Args: + Parameters + ---------- x_pos: X position of the detection in pixels. sampling_rate: Sampling rate of the audio in Hz. fft_win_length: Length of the FFT window in seconds. fft_overlap: Overlap of the FFT windows in seconds. - Returns: + Returns + ------- Time in seconds. """ nfft = int(fft_win_length * sampling_rate) @@ -134,12 +136,12 @@ def run_nms( y_pos[num_detection, valid_inds], x_pos[num_detection, valid_inds], ].transpose(0, 1) - feat = feat.detach().numpy().astype(np.float32) + feat = feat.detach().cpu().numpy().astype(np.float32) feats.append(feat) # convert to numpy for key, value in pred.items(): - pred[key] = value.detach().numpy().astype(np.float32) + pred[key] = value.detach().cpu().numpy().astype(np.float32) preds.append(pred) # type: ignore diff --git a/batdetect2/plot.py b/batdetect2/plot.py index cdcdbd8..62c4919 100644 --- a/batdetect2/plot.py +++ b/batdetect2/plot.py @@ -265,7 +265,7 @@ def detection( # Add class label txt = " ".join([sp[:3] for sp in det["class"].split(" ")]) font_info = { - "color": "white", + "color": edgecolor, "size": 10, "weight": "bold", "alpha": rect.get_alpha(), diff --git a/batdetect2/train/train_model.py b/batdetect2/train/train_model.py index 759c2d7..e38de39 100644 --- a/batdetect2/train/train_model.py +++ b/batdetect2/train/train_model.py @@ -7,15 +7,14 @@ import numpy as np import torch from torch.optim.lr_scheduler import CosineAnnealingLR -from batdetect2.detector import models -from batdetect2.detector import parameters -from batdetect2.train import losses import batdetect2.detector.post_process as pp import batdetect2.train.audio_dataloader as adl import batdetect2.train.evaluate as evl import batdetect2.train.train_split as ts import batdetect2.train.train_utils as tu import batdetect2.utils.plot_utils as pu +from batdetect2.detector import models, parameters +from batdetect2.train import losses warnings.filterwarnings("ignore", category=UserWarning) @@ -84,7 +83,6 @@ def save_image( def loss_fun( outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq ): - # detection loss loss = params["det_loss_weight"] * det_criterion( outputs["pred_det"], gt_det @@ -108,7 +106,6 @@ def loss_fun( def train( model, epoch, data_loader, det_criterion, optimizer, scheduler, params ): - model.train() train_loss = tu.AverageMeter() @@ -119,7 +116,6 @@ def train( print("\nEpoch", epoch) for batch_idx, inputs in enumerate(data_loader): - data = inputs["spec"].to(params["device"]) gt_det = inputs["y_2d_det"].to(params["device"]) gt_size = inputs["y_2d_size"].to(params["device"]) @@ -172,7 +168,6 @@ def test(model, epoch, data_loader, det_criterion, params): with torch.no_grad(): for batch_idx, inputs in enumerate(data_loader): - data = inputs["spec"].to(params["device"]) gt_det = inputs["y_2d_det"].to(params["device"]) gt_size = inputs["y_2d_size"].to(params["device"]) @@ -279,7 +274,7 @@ def parse_gt_data(inputs): is_valid = inputs["is_valid"][ind] == 1 gt = {} for kk in keys: - gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32) + gt[kk] = inputs[kk][ind][is_valid].cpu().numpy().astype(np.float32) gt["duration"] = inputs["duration"][ind].item() gt["file_id"] = inputs["file_id"][ind].item() gt["class_id_file"] = inputs["class_id_file"][ind].item() @@ -318,8 +313,7 @@ def select_model(params): return model -if __name__ == "__main__": - +def main(): plt.close("all") params = parameters.get_params(True) @@ -501,7 +495,6 @@ if __name__ == "__main__": # # main train loop for epoch in range(0, params["num_epochs"] + 1): - train_loss = train( model, epoch, @@ -550,3 +543,7 @@ if __name__ == "__main__": # TODO: args variable does not exist # if not args["do_not_save_images"]: # save_images_batch(model, test_loader, params) + + +if __name__ == "__main__": + main() diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index d6d2b13..dd010f9 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -731,6 +731,7 @@ def process_file( spec_slices = [] # load audio file + print("time_exp_fact", config.get("time_expansion", 1) or 1) sampling_rate, audio_full = au.load_audio( audio_file, time_exp_fact=config.get("time_expansion", 1) or 1,