import argparse import json import os import sys import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F from torch.optim.lr_scheduler import CosineAnnealingLR sys.path.append(os.path.join("..", "..")) import warnings import bat_detect.detector.models as models import bat_detect.detector.parameters as parameters import bat_detect.detector.post_process as pp import bat_detect.train.audio_dataloader as adl import bat_detect.train.evaluate as evl import bat_detect.train.losses as losses import bat_detect.train.train_split as ts import bat_detect.train.train_utils as tu import bat_detect.utils.plot_utils as pu warnings.filterwarnings("ignore", category=UserWarning) def save_images_batch(model, data_loader, params): print("\nsaving images ...") is_train_state = data_loader.dataset.is_train data_loader.dataset.is_train = False data_loader.dataset.return_spec_for_viz = True model.eval() ind = 0 # first image in each batch with torch.no_grad(): for batch_idx, inputs in enumerate(data_loader): data = inputs["spec"].to(params["device"]) outputs = model(data) spec_viz = inputs["spec_for_viz"].data.cpu().numpy() orig_index = inputs["file_id"][ind] plot_title = data_loader.dataset.data_anns[orig_index]["id"] op_file_name = ( params["op_im_dir_test"] + data_loader.dataset.data_anns[orig_index]["id"] + ".jpg" ) save_image( spec_viz, outputs, ind, inputs, params, op_file_name, plot_title, ) data_loader.dataset.is_train = is_train_state data_loader.dataset.return_spec_for_viz = False def save_image( spec_viz, outputs, ind, inputs, params, op_file_name, plot_title ): pred_nms, _ = pp.run_nms(outputs, params, inputs["sampling_rate"].float()) pred_hm = outputs["pred_det"][ind, 0, :].data.cpu().numpy() spec_viz = spec_viz[ind, 0, :] gt = parse_gt_data(inputs)[ind] sampling_rate = inputs["sampling_rate"][ind].item() duration = inputs["duration"][ind].item() pu.plot_spec( spec_viz, sampling_rate, duration, gt, pred_nms[ind], params, plot_title, op_file_name, pred_hm, plot_boxes=True, fixed_aspect=False, ) def loss_fun( outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq ): # detection loss loss = params["det_loss_weight"] * det_criterion( outputs["pred_det"], gt_det ) # bounding box size loss loss += params["size_loss_weight"] * losses.bbox_size_loss( outputs["pred_size"], gt_size ) # classification loss valid_mask = (gt_class[:, :-1, :, :].sum(1) > 0).float().unsqueeze(1) p_class = outputs["pred_class"][:, :-1, :] loss += params["class_loss_weight"] * det_criterion( p_class, gt_class[:, :-1, :], valid_mask=valid_mask ) return loss def train( model, epoch, data_loader, det_criterion, optimizer, scheduler, params ): model.train() train_loss = tu.AverageMeter() class_inv_freq = torch.from_numpy( np.array(params["class_inv_freq"], dtype=np.float32) ).to(params["device"]) class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2) print("\nEpoch", epoch) for batch_idx, inputs in enumerate(data_loader): data = inputs["spec"].to(params["device"]) gt_det = inputs["y_2d_det"].to(params["device"]) gt_size = inputs["y_2d_size"].to(params["device"]) gt_class = inputs["y_2d_classes"].to(params["device"]) optimizer.zero_grad() outputs = model(data) loss = loss_fun( outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq, ) train_loss.update(loss.item(), data.shape[0]) loss.backward() optimizer.step() scheduler.step() if batch_idx % 50 == 0 and batch_idx != 0: print( "[{}/{}]\tLoss: {:.4f}".format( batch_idx * len(data), len(data_loader.dataset), train_loss.avg, ) ) print("Train loss : {:.4f}".format(train_loss.avg)) res = {} res["train_loss"] = float(train_loss.avg) return res def test(model, epoch, data_loader, det_criterion, params): model.eval() predictions = [] ground_truths = [] test_loss = tu.AverageMeter() class_inv_freq = torch.from_numpy( np.array(params["class_inv_freq"], dtype=np.float32) ).to(params["device"]) class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2) with torch.no_grad(): for batch_idx, inputs in enumerate(data_loader): data = inputs["spec"].to(params["device"]) gt_det = inputs["y_2d_det"].to(params["device"]) gt_size = inputs["y_2d_size"].to(params["device"]) gt_class = inputs["y_2d_classes"].to(params["device"]) outputs = model(data) # if the model needs a fixed sized intput run this # data = torch.cat(torch.split(data, int(params['spec_train_width']*params['resize_factor']), 3), 0) # outputs = model(data) # for kk in ['pred_det', 'pred_size', 'pred_class']: # outputs[kk] = torch.cat([oo for oo in outputs[kk]], 2).unsqueeze(0) if params["save_test_image_during_train"] and batch_idx == 0: # for visualization - save the first prediction ind = 0 orig_index = inputs["file_id"][ind] plot_title = data_loader.dataset.data_anns[orig_index]["id"] op_file_name = ( params["op_im_dir"] + str(orig_index.item()).zfill(4) + "_" + str(epoch).zfill(4) + "_pred.jpg" ) save_image( data, outputs, ind, inputs, params, op_file_name, plot_title, ) loss = loss_fun( outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq, ) test_loss.update(loss.item(), data.shape[0]) # do NMS pred_nms, _ = pp.run_nms( outputs, params, inputs["sampling_rate"].float() ) predictions.extend(pred_nms) ground_truths.extend(parse_gt_data(inputs)) res_det = evl.evaluate_predictions( ground_truths, predictions, params["class_names"], params["detection_overlap"], params["ignore_start_end"], ) print("\nTest loss : {:.4f}".format(test_loss.avg)) print("Rec at 0.95 (det) : {:.4f}".format(res_det["rec_at_x"])) print("Avg prec (cls) : {:.4f}".format(res_det["avg_prec"])) print( "File acc (cls) : {:.2f} - for {} out of {}".format( res_det["file_acc"], res_det["num_valid_files"], res_det["num_total_files"], ) ) print("Cls Avg prec (cls) : {:.4f}".format(res_det["avg_prec_class"])) print("\nPer class average precision") str_len = np.max([len(rs["name"]) for rs in res_det["class_pr"]]) + 5 for cc, rs in enumerate(res_det["class_pr"]): if rs["num_gt"] > 0: print( str(cc).ljust(5) + rs["name"].ljust(str_len) + "{:.4f}".format(rs["avg_prec"]) ) res = {} res["test_loss"] = float(test_loss.avg) return res_det, res def parse_gt_data(inputs): # reads the torch arrays into a dictionary of numpy arrays, taking care to # remove padding data i.e. not valid ones keys = [ "start_times", "end_times", "low_freqs", "high_freqs", "class_ids", "individual_ids", ] batch_data = [] for ind in range(inputs["start_times"].shape[0]): is_valid = inputs["is_valid"][ind] == 1 gt = {} for kk in keys: gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32) gt["duration"] = inputs["duration"][ind].item() gt["file_id"] = inputs["file_id"][ind].item() gt["class_id_file"] = inputs["class_id_file"][ind].item() batch_data.append(gt) return batch_data def select_model(params): num_classes = len(params["class_names"]) if params["model_name"] == "Net2DFast": model = models.Net2DFast( params["num_filters"], num_classes=num_classes, emb_dim=params["emb_dim"], ip_height=params["ip_height"], resize_factor=params["resize_factor"], ) elif params["model_name"] == "Net2DFastNoAttn": model = models.Net2DFastNoAttn( params["num_filters"], num_classes=num_classes, emb_dim=params["emb_dim"], ip_height=params["ip_height"], resize_factor=params["resize_factor"], ) elif params["model_name"] == "Net2DFastNoCoordConv": model = models.Net2DFastNoCoordConv( params["num_filters"], num_classes=num_classes, emb_dim=params["emb_dim"], ip_height=params["ip_height"], resize_factor=params["resize_factor"], ) else: print("No valid network specified") return model if __name__ == "__main__": plt.close("all") params = parameters.get_params(True) if torch.cuda.is_available(): params["device"] = "cuda" else: params["device"] = "cpu" # setup arg parser and populate it with exiting parameters - will not work with lists parser = argparse.ArgumentParser() parser.add_argument("data_dir", type=str, help="Path to root of datasets") parser.add_argument( "ann_dir", type=str, help="Path to extracted annotations" ) parser.add_argument( "--train_split", type=str, default="diff", # diff, same help="Which train split to use", ) parser.add_argument( "--notes", type=str, default="", help="Notes to save in text file" ) parser.add_argument( "--do_not_save_images", action="store_false", help="Do not save images at the end of training", ) parser.add_argument( "--standardize_classs_names_ip", type=str, default="Rhinolophus ferrumequinum;Rhinolophus hipposideros", help='Will set low and high frequency the same for these classes. Separate names with ";"', ) for key, val in params.items(): parser.add_argument("--" + key, type=type(val), default=val) params = vars(parser.parse_args()) # save notes file if params["notes"] != "": tu.write_notes_file( params["experiment"] + "notes.txt", params["notes"] ) # load the training and test meta data - there are different splits defined train_sets, test_sets = ts.get_train_test_data( params["ann_dir"], params["data_dir"], params["train_split"] ) train_sets_no_path, test_sets_no_path = ts.get_train_test_data( "", "", params["train_split"] ) # keep track of what we have trained on params["train_sets"] = train_sets_no_path params["test_sets"] = test_sets_no_path # load train annotations - merge them all together print("\nTraining on:") for tt in train_sets: print(tt["ann_path"]) classes_to_ignore = params["classes_to_ignore"] + params["generic_class"] ( data_train, params["class_names"], params["class_inv_freq"], ) = tu.load_set_of_anns( train_sets, classes_to_ignore, params["events_of_interest"], params["convert_to_genus"], ) params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping( params["class_names"] ) params["class_names_short"] = tu.get_short_class_names( params["class_names"] ) # standardize the low and high frequency value for specified classes params["standardize_classs_names"] = params[ "standardize_classs_names_ip" ].split(";") for cc in params["standardize_classs_names"]: if cc in params["class_names"]: data_train = tu.standardize_low_freq(data_train, cc) else: print(cc, "not found") # train loader train_dataset = adl.AudioLoader(data_train, params, is_train=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=params["num_workers"], pin_memory=True, ) # test set print("\nTesting on:") for tt in test_sets: print(tt["ann_path"]) data_test, _, _ = tu.load_set_of_anns( test_sets, classes_to_ignore, params["events_of_interest"], params["convert_to_genus"], ) data_train = tu.remove_dupes(data_train, data_test) test_dataset = adl.AudioLoader(data_test, params, is_train=False) # batch size of 1 because of variable file length test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=1, shuffle=False, num_workers=params["num_workers"], pin_memory=True, ) inputs_train = next(iter(train_loader)) # TODO remove params['ip_height'], this is just legacy params["ip_height"] = int(params["spec_height"] * params["resize_factor"]) print("\ntrain batch spec size :", inputs_train["spec"].shape) print("class target size :", inputs_train["y_2d_classes"].shape) # select network model = select_model(params) model = model.to(params["device"]) optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"]) # optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9) scheduler = CosineAnnealingLR( optimizer, params["num_epochs"] * len(train_loader) ) if params["train_loss"] == "mse": det_criterion = losses.mse_loss elif params["train_loss"] == "focal": det_criterion = losses.focal_loss # save parameters to file with open(params["experiment"] + "params.json", "w") as da: json.dump(params, da, indent=2, sort_keys=True) # plotting train_plt_ls = pu.LossPlotter( params["experiment"] + "train_loss.png", params["num_epochs"] + 1, ["train_loss"], None, None, ["epoch", "train_loss"], logy=True, ) test_plt_ls = pu.LossPlotter( params["experiment"] + "test_loss.png", params["num_epochs"] + 1, ["test_loss"], None, None, ["epoch", "test_loss"], logy=True, ) test_plt = pu.LossPlotter( params["experiment"] + "test.png", params["num_epochs"] + 1, ["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"], [0, 1], None, ["epoch", ""], ) test_plt_class = pu.LossPlotter( params["experiment"] + "test_avg_prec.png", params["num_epochs"] + 1, params["class_names_short"], [0, 1], params["class_names_short"], ["epoch", "avg_prec"], ) # # main train loop for epoch in range(0, params["num_epochs"] + 1): train_loss = train( model, epoch, train_loader, det_criterion, optimizer, scheduler, params, ) train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]]) if epoch % params["num_eval_epochs"] == 0: # detection accuracy on test set test_res, test_loss = test( model, epoch, test_loader, det_criterion, params ) test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]]) test_plt.update_and_save( epoch, [ test_res["avg_prec"], test_res["rec_at_x"], test_res["avg_prec_class"], test_res["file_acc"], test_res["top_class"]["avg_prec"], ], ) test_plt_class.update_and_save( epoch, [rs["avg_prec"] for rs in test_res["class_pr"]] ) pu.plot_pr_curve_class( params["experiment"], "test_pr", "test_pr", test_res ) # save trained model print("saving model to: " + params["model_file_name"]) op_state = { "epoch": epoch + 1, "state_dict": model.state_dict(), #'optimizer' : optimizer.state_dict(), "params": params, } torch.save(op_state, params["model_file_name"]) # save an image with associated prediction for each batch in the test set if not args["do_not_save_images"]: save_images_batch(model, test_loader, params)