mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
fix: move tensors to cpu when converting to numpy
This commit is contained in:
parent
68d3842931
commit
57f2b50239
@ -19,13 +19,15 @@ def x_coords_to_time(
|
||||
) -> float:
|
||||
"""Convert x coordinates of spectrogram to time in seconds.
|
||||
|
||||
Args:
|
||||
Parameters
|
||||
----------
|
||||
x_pos: X position of the detection in pixels.
|
||||
sampling_rate: Sampling rate of the audio in Hz.
|
||||
fft_win_length: Length of the FFT window in seconds.
|
||||
fft_overlap: Overlap of the FFT windows in seconds.
|
||||
|
||||
Returns:
|
||||
Returns
|
||||
-------
|
||||
Time in seconds.
|
||||
"""
|
||||
nfft = int(fft_win_length * sampling_rate)
|
||||
@ -134,12 +136,12 @@ def run_nms(
|
||||
y_pos[num_detection, valid_inds],
|
||||
x_pos[num_detection, valid_inds],
|
||||
].transpose(0, 1)
|
||||
feat = feat.detach().numpy().astype(np.float32)
|
||||
feat = feat.detach().cpu().numpy().astype(np.float32)
|
||||
feats.append(feat)
|
||||
|
||||
# convert to numpy
|
||||
for key, value in pred.items():
|
||||
pred[key] = value.detach().numpy().astype(np.float32)
|
||||
pred[key] = value.detach().cpu().numpy().astype(np.float32)
|
||||
|
||||
preds.append(pred) # type: ignore
|
||||
|
||||
|
@ -265,7 +265,7 @@ def detection(
|
||||
# Add class label
|
||||
txt = " ".join([sp[:3] for sp in det["class"].split(" ")])
|
||||
font_info = {
|
||||
"color": "white",
|
||||
"color": edgecolor,
|
||||
"size": 10,
|
||||
"weight": "bold",
|
||||
"alpha": rect.get_alpha(),
|
||||
|
@ -7,15 +7,14 @@ import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.detector import models
|
||||
from batdetect2.detector import parameters
|
||||
from batdetect2.train import losses
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.train.audio_dataloader as adl
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.train_split as ts
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2.detector import models, parameters
|
||||
from batdetect2.train import losses
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
@ -84,7 +83,6 @@ def save_image(
|
||||
def loss_fun(
|
||||
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
|
||||
):
|
||||
|
||||
# detection loss
|
||||
loss = params["det_loss_weight"] * det_criterion(
|
||||
outputs["pred_det"], gt_det
|
||||
@ -108,7 +106,6 @@ def loss_fun(
|
||||
def train(
|
||||
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
||||
):
|
||||
|
||||
model.train()
|
||||
|
||||
train_loss = tu.AverageMeter()
|
||||
@ -119,7 +116,6 @@ def train(
|
||||
|
||||
print("\nEpoch", epoch)
|
||||
for batch_idx, inputs in enumerate(data_loader):
|
||||
|
||||
data = inputs["spec"].to(params["device"])
|
||||
gt_det = inputs["y_2d_det"].to(params["device"])
|
||||
gt_size = inputs["y_2d_size"].to(params["device"])
|
||||
@ -172,7 +168,6 @@ def test(model, epoch, data_loader, det_criterion, params):
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, inputs in enumerate(data_loader):
|
||||
|
||||
data = inputs["spec"].to(params["device"])
|
||||
gt_det = inputs["y_2d_det"].to(params["device"])
|
||||
gt_size = inputs["y_2d_size"].to(params["device"])
|
||||
@ -279,7 +274,7 @@ def parse_gt_data(inputs):
|
||||
is_valid = inputs["is_valid"][ind] == 1
|
||||
gt = {}
|
||||
for kk in keys:
|
||||
gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32)
|
||||
gt[kk] = inputs[kk][ind][is_valid].cpu().numpy().astype(np.float32)
|
||||
gt["duration"] = inputs["duration"][ind].item()
|
||||
gt["file_id"] = inputs["file_id"][ind].item()
|
||||
gt["class_id_file"] = inputs["class_id_file"][ind].item()
|
||||
@ -318,8 +313,7 @@ def select_model(params):
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
plt.close("all")
|
||||
|
||||
params = parameters.get_params(True)
|
||||
@ -501,7 +495,6 @@ if __name__ == "__main__":
|
||||
#
|
||||
# main train loop
|
||||
for epoch in range(0, params["num_epochs"] + 1):
|
||||
|
||||
train_loss = train(
|
||||
model,
|
||||
epoch,
|
||||
@ -550,3 +543,7 @@ if __name__ == "__main__":
|
||||
# TODO: args variable does not exist
|
||||
# if not args["do_not_save_images"]:
|
||||
# save_images_batch(model, test_loader, params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -731,6 +731,7 @@ def process_file(
|
||||
spec_slices = []
|
||||
|
||||
# load audio file
|
||||
print("time_exp_fact", config.get("time_expansion", 1) or 1)
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
audio_file,
|
||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||
|
Loading…
Reference in New Issue
Block a user