mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
fix: move tensors to cpu when converting to numpy
This commit is contained in:
parent
68d3842931
commit
57f2b50239
@ -19,13 +19,15 @@ def x_coords_to_time(
|
|||||||
) -> float:
|
) -> float:
|
||||||
"""Convert x coordinates of spectrogram to time in seconds.
|
"""Convert x coordinates of spectrogram to time in seconds.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
|
----------
|
||||||
x_pos: X position of the detection in pixels.
|
x_pos: X position of the detection in pixels.
|
||||||
sampling_rate: Sampling rate of the audio in Hz.
|
sampling_rate: Sampling rate of the audio in Hz.
|
||||||
fft_win_length: Length of the FFT window in seconds.
|
fft_win_length: Length of the FFT window in seconds.
|
||||||
fft_overlap: Overlap of the FFT windows in seconds.
|
fft_overlap: Overlap of the FFT windows in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
|
-------
|
||||||
Time in seconds.
|
Time in seconds.
|
||||||
"""
|
"""
|
||||||
nfft = int(fft_win_length * sampling_rate)
|
nfft = int(fft_win_length * sampling_rate)
|
||||||
@ -134,12 +136,12 @@ def run_nms(
|
|||||||
y_pos[num_detection, valid_inds],
|
y_pos[num_detection, valid_inds],
|
||||||
x_pos[num_detection, valid_inds],
|
x_pos[num_detection, valid_inds],
|
||||||
].transpose(0, 1)
|
].transpose(0, 1)
|
||||||
feat = feat.detach().numpy().astype(np.float32)
|
feat = feat.detach().cpu().numpy().astype(np.float32)
|
||||||
feats.append(feat)
|
feats.append(feat)
|
||||||
|
|
||||||
# convert to numpy
|
# convert to numpy
|
||||||
for key, value in pred.items():
|
for key, value in pred.items():
|
||||||
pred[key] = value.detach().numpy().astype(np.float32)
|
pred[key] = value.detach().cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
preds.append(pred) # type: ignore
|
preds.append(pred) # type: ignore
|
||||||
|
|
||||||
|
@ -265,7 +265,7 @@ def detection(
|
|||||||
# Add class label
|
# Add class label
|
||||||
txt = " ".join([sp[:3] for sp in det["class"].split(" ")])
|
txt = " ".join([sp[:3] for sp in det["class"].split(" ")])
|
||||||
font_info = {
|
font_info = {
|
||||||
"color": "white",
|
"color": edgecolor,
|
||||||
"size": 10,
|
"size": 10,
|
||||||
"weight": "bold",
|
"weight": "bold",
|
||||||
"alpha": rect.get_alpha(),
|
"alpha": rect.get_alpha(),
|
||||||
|
@ -7,15 +7,14 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.detector import models
|
|
||||||
from batdetect2.detector import parameters
|
|
||||||
from batdetect2.train import losses
|
|
||||||
import batdetect2.detector.post_process as pp
|
import batdetect2.detector.post_process as pp
|
||||||
import batdetect2.train.audio_dataloader as adl
|
import batdetect2.train.audio_dataloader as adl
|
||||||
import batdetect2.train.evaluate as evl
|
import batdetect2.train.evaluate as evl
|
||||||
import batdetect2.train.train_split as ts
|
import batdetect2.train.train_split as ts
|
||||||
import batdetect2.train.train_utils as tu
|
import batdetect2.train.train_utils as tu
|
||||||
import batdetect2.utils.plot_utils as pu
|
import batdetect2.utils.plot_utils as pu
|
||||||
|
from batdetect2.detector import models, parameters
|
||||||
|
from batdetect2.train import losses
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
@ -84,7 +83,6 @@ def save_image(
|
|||||||
def loss_fun(
|
def loss_fun(
|
||||||
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
|
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
|
||||||
):
|
):
|
||||||
|
|
||||||
# detection loss
|
# detection loss
|
||||||
loss = params["det_loss_weight"] * det_criterion(
|
loss = params["det_loss_weight"] * det_criterion(
|
||||||
outputs["pred_det"], gt_det
|
outputs["pred_det"], gt_det
|
||||||
@ -108,7 +106,6 @@ def loss_fun(
|
|||||||
def train(
|
def train(
|
||||||
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
||||||
):
|
):
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
train_loss = tu.AverageMeter()
|
train_loss = tu.AverageMeter()
|
||||||
@ -119,7 +116,6 @@ def train(
|
|||||||
|
|
||||||
print("\nEpoch", epoch)
|
print("\nEpoch", epoch)
|
||||||
for batch_idx, inputs in enumerate(data_loader):
|
for batch_idx, inputs in enumerate(data_loader):
|
||||||
|
|
||||||
data = inputs["spec"].to(params["device"])
|
data = inputs["spec"].to(params["device"])
|
||||||
gt_det = inputs["y_2d_det"].to(params["device"])
|
gt_det = inputs["y_2d_det"].to(params["device"])
|
||||||
gt_size = inputs["y_2d_size"].to(params["device"])
|
gt_size = inputs["y_2d_size"].to(params["device"])
|
||||||
@ -172,7 +168,6 @@ def test(model, epoch, data_loader, det_criterion, params):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, inputs in enumerate(data_loader):
|
for batch_idx, inputs in enumerate(data_loader):
|
||||||
|
|
||||||
data = inputs["spec"].to(params["device"])
|
data = inputs["spec"].to(params["device"])
|
||||||
gt_det = inputs["y_2d_det"].to(params["device"])
|
gt_det = inputs["y_2d_det"].to(params["device"])
|
||||||
gt_size = inputs["y_2d_size"].to(params["device"])
|
gt_size = inputs["y_2d_size"].to(params["device"])
|
||||||
@ -279,7 +274,7 @@ def parse_gt_data(inputs):
|
|||||||
is_valid = inputs["is_valid"][ind] == 1
|
is_valid = inputs["is_valid"][ind] == 1
|
||||||
gt = {}
|
gt = {}
|
||||||
for kk in keys:
|
for kk in keys:
|
||||||
gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32)
|
gt[kk] = inputs[kk][ind][is_valid].cpu().numpy().astype(np.float32)
|
||||||
gt["duration"] = inputs["duration"][ind].item()
|
gt["duration"] = inputs["duration"][ind].item()
|
||||||
gt["file_id"] = inputs["file_id"][ind].item()
|
gt["file_id"] = inputs["file_id"][ind].item()
|
||||||
gt["class_id_file"] = inputs["class_id_file"][ind].item()
|
gt["class_id_file"] = inputs["class_id_file"][ind].item()
|
||||||
@ -318,8 +313,7 @@ def select_model(params):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
|
|
||||||
plt.close("all")
|
plt.close("all")
|
||||||
|
|
||||||
params = parameters.get_params(True)
|
params = parameters.get_params(True)
|
||||||
@ -501,7 +495,6 @@ if __name__ == "__main__":
|
|||||||
#
|
#
|
||||||
# main train loop
|
# main train loop
|
||||||
for epoch in range(0, params["num_epochs"] + 1):
|
for epoch in range(0, params["num_epochs"] + 1):
|
||||||
|
|
||||||
train_loss = train(
|
train_loss = train(
|
||||||
model,
|
model,
|
||||||
epoch,
|
epoch,
|
||||||
@ -550,3 +543,7 @@ if __name__ == "__main__":
|
|||||||
# TODO: args variable does not exist
|
# TODO: args variable does not exist
|
||||||
# if not args["do_not_save_images"]:
|
# if not args["do_not_save_images"]:
|
||||||
# save_images_batch(model, test_loader, params)
|
# save_images_batch(model, test_loader, params)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -731,6 +731,7 @@ def process_file(
|
|||||||
spec_slices = []
|
spec_slices = []
|
||||||
|
|
||||||
# load audio file
|
# load audio file
|
||||||
|
print("time_exp_fact", config.get("time_expansion", 1) or 1)
|
||||||
sampling_rate, audio_full = au.load_audio(
|
sampling_rate, audio_full = au.load_audio(
|
||||||
audio_file,
|
audio_file,
|
||||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||||
|
Loading…
Reference in New Issue
Block a user