mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
326 lines
9.7 KiB
Python
326 lines
9.7 KiB
Python
import argparse
|
|
import glob
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
|
|
sys.path.append(os.path.join("..", ".."))
|
|
import bat_detect.detector.models as models
|
|
import bat_detect.detector.parameters as parameters
|
|
import bat_detect.detector.post_process as pp
|
|
import bat_detect.train.audio_dataloader as adl
|
|
import bat_detect.train.evaluate as evl
|
|
import bat_detect.train.losses as losses
|
|
import bat_detect.train.train_model as tm
|
|
import bat_detect.train.train_utils as tu
|
|
import bat_detect.utils.detector_utils as du
|
|
import bat_detect.utils.plot_utils as pu
|
|
|
|
if __name__ == "__main__":
|
|
|
|
info_str = "\nBatDetect - Finetune Model\n"
|
|
|
|
print(info_str)
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"audio_path", type=str, help="Input directory for audio"
|
|
)
|
|
parser.add_argument(
|
|
"train_ann_path",
|
|
type=str,
|
|
help="Path to where train annotation file is stored",
|
|
)
|
|
parser.add_argument(
|
|
"test_ann_path",
|
|
type=str,
|
|
help="Path to where test annotation file is stored",
|
|
)
|
|
parser.add_argument(
|
|
"model_path", type=str, help="Path to pretrained model"
|
|
)
|
|
parser.add_argument(
|
|
"--op_model_name",
|
|
type=str,
|
|
default="",
|
|
help="Path and name for finetuned model",
|
|
)
|
|
parser.add_argument(
|
|
"--num_epochs",
|
|
type=int,
|
|
default=200,
|
|
dest="num_epochs",
|
|
help="Number of finetuning epochs",
|
|
)
|
|
parser.add_argument(
|
|
"--finetune_only_last_layer",
|
|
action="store_true",
|
|
help="Only train final layers",
|
|
)
|
|
parser.add_argument(
|
|
"--train_from_scratch",
|
|
action="store_true",
|
|
help="Do not use pretrained weights",
|
|
)
|
|
parser.add_argument(
|
|
"--do_not_save_images",
|
|
action="store_false",
|
|
help="Do not save images at the end of training",
|
|
)
|
|
parser.add_argument(
|
|
"--notes", type=str, default="", help="Notes to save in text file"
|
|
)
|
|
args = vars(parser.parse_args())
|
|
|
|
params = parameters.get_params(True, "../../experiments/")
|
|
if torch.cuda.is_available():
|
|
params["device"] = "cuda"
|
|
else:
|
|
params["device"] = "cpu"
|
|
print(
|
|
"\nNote, this will be a lot faster if you use computer with a GPU.\n"
|
|
)
|
|
|
|
print("\nAudio directory: " + args["audio_path"])
|
|
print("Train file: " + args["train_ann_path"])
|
|
print("Test file: " + args["test_ann_path"])
|
|
print("Loading model: " + args["model_path"])
|
|
|
|
dataset_name = (
|
|
os.path.basename(args["train_ann_path"])
|
|
.replace(".json", "")
|
|
.replace("_TRAIN", "")
|
|
)
|
|
|
|
if args["train_from_scratch"]:
|
|
print(
|
|
"\nTraining model from scratch i.e. not using pretrained weights"
|
|
)
|
|
model, params_train = du.load_model(args["model_path"], False)
|
|
else:
|
|
model, params_train = du.load_model(args["model_path"], True)
|
|
model.to(params["device"])
|
|
|
|
params["num_epochs"] = args["num_epochs"]
|
|
if args["op_model_name"] != "":
|
|
params["model_file_name"] = args["op_model_name"]
|
|
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
|
|
|
# save notes file
|
|
params["notes"] = args["notes"]
|
|
if args["notes"] != "":
|
|
tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"])
|
|
|
|
# load train annotations
|
|
train_sets = []
|
|
train_sets.append(
|
|
tu.get_blank_dataset_dict(
|
|
dataset_name, False, args["train_ann_path"], args["audio_path"]
|
|
)
|
|
)
|
|
params["train_sets"] = [
|
|
tu.get_blank_dataset_dict(
|
|
dataset_name,
|
|
False,
|
|
os.path.basename(args["train_ann_path"]),
|
|
args["audio_path"],
|
|
)
|
|
]
|
|
|
|
print("\nTrain set:")
|
|
(
|
|
data_train,
|
|
params["class_names"],
|
|
params["class_inv_freq"],
|
|
) = tu.load_set_of_anns(
|
|
train_sets, classes_to_ignore, params["events_of_interest"]
|
|
)
|
|
print("Number of files", len(data_train))
|
|
|
|
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
|
|
params["class_names"]
|
|
)
|
|
params["class_names_short"] = tu.get_short_class_names(
|
|
params["class_names"]
|
|
)
|
|
|
|
# load test annotations
|
|
test_sets = []
|
|
test_sets.append(
|
|
tu.get_blank_dataset_dict(
|
|
dataset_name, True, args["test_ann_path"], args["audio_path"]
|
|
)
|
|
)
|
|
params["test_sets"] = [
|
|
tu.get_blank_dataset_dict(
|
|
dataset_name,
|
|
True,
|
|
os.path.basename(args["test_ann_path"]),
|
|
args["audio_path"],
|
|
)
|
|
]
|
|
|
|
print("\nTest set:")
|
|
data_test, _, _ = tu.load_set_of_anns(
|
|
test_sets, classes_to_ignore, params["events_of_interest"]
|
|
)
|
|
print("Number of files", len(data_test))
|
|
|
|
# train loader
|
|
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_dataset,
|
|
batch_size=params["batch_size"],
|
|
shuffle=True,
|
|
num_workers=params["num_workers"],
|
|
pin_memory=True,
|
|
)
|
|
|
|
# test loader - batch size of one because of variable file length
|
|
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
|
|
test_loader = torch.utils.data.DataLoader(
|
|
test_dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
num_workers=params["num_workers"],
|
|
pin_memory=True,
|
|
)
|
|
|
|
inputs_train = next(iter(train_loader))
|
|
params["ip_height"] = inputs_train["spec"].shape[2]
|
|
print("\ntrain batch size :", inputs_train["spec"].shape)
|
|
|
|
assert params_train["model_name"] == "Net2DFast"
|
|
print(
|
|
"\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n"
|
|
)
|
|
|
|
# set the number of output classes
|
|
num_filts = model.conv_classes_op.in_channels
|
|
k_size = model.conv_classes_op.kernel_size
|
|
pad = model.conv_classes_op.padding
|
|
model.conv_classes_op = torch.nn.Conv2d(
|
|
num_filts,
|
|
len(params["class_names"]) + 1,
|
|
kernel_size=k_size,
|
|
padding=pad,
|
|
)
|
|
model.conv_classes_op.to(params["device"])
|
|
|
|
if args["finetune_only_last_layer"]:
|
|
print("\nOnly finetuning the final layers.\n")
|
|
train_layers_i = [
|
|
"conv_classes",
|
|
"conv_classes_op",
|
|
"conv_size",
|
|
"conv_size_op",
|
|
]
|
|
train_layers = [tt + ".weight" for tt in train_layers_i] + [
|
|
tt + ".bias" for tt in train_layers_i
|
|
]
|
|
for name, param in model.named_parameters():
|
|
if name in train_layers:
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
|
|
scheduler = CosineAnnealingLR(
|
|
optimizer, params["num_epochs"] * len(train_loader)
|
|
)
|
|
if params["train_loss"] == "mse":
|
|
det_criterion = losses.mse_loss
|
|
elif params["train_loss"] == "focal":
|
|
det_criterion = losses.focal_loss
|
|
|
|
# plotting
|
|
train_plt_ls = pu.LossPlotter(
|
|
params["experiment"] + "train_loss.png",
|
|
params["num_epochs"] + 1,
|
|
["train_loss"],
|
|
None,
|
|
None,
|
|
["epoch", "train_loss"],
|
|
logy=True,
|
|
)
|
|
test_plt_ls = pu.LossPlotter(
|
|
params["experiment"] + "test_loss.png",
|
|
params["num_epochs"] + 1,
|
|
["test_loss"],
|
|
None,
|
|
None,
|
|
["epoch", "test_loss"],
|
|
logy=True,
|
|
)
|
|
test_plt = pu.LossPlotter(
|
|
params["experiment"] + "test.png",
|
|
params["num_epochs"] + 1,
|
|
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
|
|
[0, 1],
|
|
None,
|
|
["epoch", ""],
|
|
)
|
|
test_plt_class = pu.LossPlotter(
|
|
params["experiment"] + "test_avg_prec.png",
|
|
params["num_epochs"] + 1,
|
|
params["class_names_short"],
|
|
[0, 1],
|
|
params["class_names_short"],
|
|
["epoch", "avg_prec"],
|
|
)
|
|
|
|
# main train loop
|
|
for epoch in range(0, params["num_epochs"] + 1):
|
|
|
|
train_loss = tm.train(
|
|
model,
|
|
epoch,
|
|
train_loader,
|
|
det_criterion,
|
|
optimizer,
|
|
scheduler,
|
|
params,
|
|
)
|
|
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
|
|
|
|
if epoch % params["num_eval_epochs"] == 0:
|
|
# detection accuracy on test set
|
|
test_res, test_loss = tm.test(
|
|
model, epoch, test_loader, det_criterion, params
|
|
)
|
|
test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
|
|
test_plt.update_and_save(
|
|
epoch,
|
|
[
|
|
test_res["avg_prec"],
|
|
test_res["rec_at_x"],
|
|
test_res["avg_prec_class"],
|
|
test_res["file_acc"],
|
|
test_res["top_class"]["avg_prec"],
|
|
],
|
|
)
|
|
test_plt_class.update_and_save(
|
|
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
|
|
)
|
|
pu.plot_pr_curve_class(
|
|
params["experiment"], "test_pr", "test_pr", test_res
|
|
)
|
|
|
|
# save finetuned model
|
|
print("saving model to: " + params["model_file_name"])
|
|
op_state = {
|
|
"epoch": epoch + 1,
|
|
"state_dict": model.state_dict(),
|
|
"params": params,
|
|
}
|
|
torch.save(op_state, params["model_file_name"])
|
|
|
|
# save an image with associated prediction for each batch in the test set
|
|
if not args["do_not_save_images"]:
|
|
tm.save_images_batch(model, test_loader, params)
|