diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..0fd091c --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Format code with Black and isort +3c17a2337166245de8df778fe174aad997e14e8f diff --git a/.ropeproject/autoimport.db b/.ropeproject/autoimport.db new file mode 100644 index 0000000..585d383 Binary files /dev/null and b/.ropeproject/autoimport.db differ diff --git a/app.py b/app.py index ae44690..1c884f0 100644 --- a/app.py +++ b/app.py @@ -1,84 +1,121 @@ -import gradio as gr import os + +import gradio as gr import matplotlib.pyplot as plt -import pandas as pd import numpy as np +import pandas as pd -import bat_detect.utils.detector_utils as du import bat_detect.utils.audio_utils as au +import bat_detect.utils.detector_utils as du import bat_detect.utils.plot_utils as viz - # setup the arguments args = {} args = du.get_default_bd_args() -args['detection_threshold'] = 0.3 -args['time_expansion_factor'] = 1 -args['model_path'] = 'models/Net2DFast_UK_same.pth.tar' +args["detection_threshold"] = 0.3 +args["time_expansion_factor"] = 1 +args["model_path"] = "models/Net2DFast_UK_same.pth.tar" max_duration = 2.0 # load the model -model, params = du.load_model(args['model_path']) +model, params = du.load_model(args["model_path"]) -df = gr.Dataframe( - headers=["species", "time", "detection_prob", "species_prob"], - datatype=["str", "str", "str", "str"], - row_count=1, - col_count=(4, "fixed"), - label='Predictions' - ) - -examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3], - ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3], - ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]] +df = gr.Dataframe( + headers=["species", "time", "detection_prob", "species_prob"], + datatype=["str", "str", "str", "str"], + row_count=1, + col_count=(4, "fixed"), + label="Predictions", +) + +examples = [ + ["example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav", 0.3], + ["example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav", 0.3], + ["example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav", 0.3], +] def make_prediction(file_name=None, detection_threshold=0.3): - + if file_name is not None: audio_file = file_name else: return "You must provide an input audio file." - - if detection_threshold is not None and detection_threshold != '': - args['detection_threshold'] = float(detection_threshold) - + + if detection_threshold is not None and detection_threshold != "": + args["detection_threshold"] = float(detection_threshold) + # process the file to generate predictions - results = du.process_file(audio_file, model, params, args, max_duration=max_duration) - - anns = [ann for ann in results['pred_dict']['annotation']] - clss = [aa['class'] for aa in anns] - st_time = [aa['start_time'] for aa in anns] - cls_prob = [aa['class_prob'] for aa in anns] - det_prob = [aa['det_prob'] for aa in anns] - data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob} - + results = du.process_file( + audio_file, model, params, args, max_duration=max_duration + ) + + anns = [ann for ann in results["pred_dict"]["annotation"]] + clss = [aa["class"] for aa in anns] + st_time = [aa["start_time"] for aa in anns] + cls_prob = [aa["class_prob"] for aa in anns] + det_prob = [aa["det_prob"] for aa in anns] + data = { + "species": clss, + "time": st_time, + "detection_prob": det_prob, + "species_prob": cls_prob, + } + df = pd.DataFrame(data=data) im = generate_results_image(audio_file, anns) - + return [df, im] -def generate_results_image(audio_file, anns): - +def generate_results_image(audio_file, anns): + # load audio - sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], - params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration) + sampling_rate, audio = au.load_audio_file( + audio_file, + args["time_expansion_factor"], + params["target_samp_rate"], + params["scale_raw_audio"], + max_duration=max_duration, + ) duration = audio.shape[0] / sampling_rate - + # generate spec - spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False) + spec, spec_viz = au.generate_spectrogram( + audio, sampling_rate, params, True, False + ) # create fig - plt.close('all') - fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False) - spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) - viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True) - plt.ylabel('Freq - kHz') - plt.xlabel('Time - secs') + plt.close("all") + fig = plt.figure( + 1, + figsize=(spec.shape[1] / 100, spec.shape[0] / 100), + dpi=100, + frameon=False, + ) + spec_duration = au.x_coords_to_time( + spec.shape[1], + sampling_rate, + params["fft_win_length"], + params["fft_overlap"], + ) + viz.create_box_image( + spec, + fig, + anns, + 0, + spec_duration, + spec_duration, + params, + spec.max() * 1.1, + False, + True, + ) + plt.ylabel("Freq - kHz") + plt.xlabel("Time - secs") plt.tight_layout() - + # convert fig to image fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) @@ -88,21 +125,23 @@ def generate_results_image(audio_file, anns): return im -descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \ - "
This model is only trained on bat species from the UK. If the input " \ - "file is longer than 2 seconds, only the first 2 seconds will be processed." \ - "
Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)." +descr_txt = ( + "Demo of BatDetect2 deep learning-based bat echolocation call detection. " + "
This model is only trained on bat species from the UK. If the input " + "file is longer than 2 seconds, only the first 2 seconds will be processed." + "
Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)." +) gr.Interface( - fn = make_prediction, - inputs = [gr.Audio(source="upload", type="filepath", optional=True), - gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])], - outputs = [df, gr.Image(label="Visualisation")], - theme = "huggingface", - title = "BatDetect2 Demo", - description = descr_txt, - examples = examples, - allow_flagging = 'never', + fn=make_prediction, + inputs=[ + gr.Audio(source="upload", type="filepath", optional=True), + gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]), + ], + outputs=[df, gr.Image(label="Visualisation")], + theme="huggingface", + title="BatDetect2 Demo", + description=descr_txt, + examples=examples, + allow_flagging="never", ).launch() - - diff --git a/bat_detect/detector/compute_features.py b/bat_detect/detector/compute_features.py index b6f4b8b..368c2db 100644 --- a/bat_detect/detector/compute_features.py +++ b/bat_detect/detector/compute_features.py @@ -2,8 +2,10 @@ import numpy as np def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq): - spec_ind = spec_height-spec_ind - return round((spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2) + spec_ind = spec_height - spec_ind + return round( + (spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2 + ) def extract_spec_slices(spec, pred_nms, params): @@ -11,28 +13,40 @@ def extract_spec_slices(spec, pred_nms, params): Extracts spectrogram slices from spectrogram based on detected call locations. """ - x_pos = pred_nms['x_pos'] - y_pos = pred_nms['y_pos'] - bb_width = pred_nms['bb_width'] - bb_height = pred_nms['bb_height'] - slices = [] + x_pos = pred_nms["x_pos"] + y_pos = pred_nms["y_pos"] + bb_width = pred_nms["bb_width"] + bb_height = pred_nms["bb_height"] + slices = [] # add 20% padding either side of call - pad = bb_width*0.2 - x_pos_pad = x_pos - pad - bb_width_pad = bb_width + 2*pad + pad = bb_width * 0.2 + x_pos_pad = x_pos - pad + bb_width_pad = bb_width + 2 * pad - for ff in range(len(pred_nms['det_probs'])): + for ff in range(len(pred_nms["det_probs"])): x_start = int(np.maximum(0, x_pos_pad[ff])) - x_end = int(np.minimum(spec.shape[1]-1, np.round(x_pos_pad[ff] + bb_width_pad[ff]))) + x_end = int( + np.minimum( + spec.shape[1] - 1, np.round(x_pos_pad[ff] + bb_width_pad[ff]) + ) + ) slices.append(spec[:, x_start:x_end].astype(np.float16)) return slices def get_feature_names(): - feature_names = ['duration', 'low_freq_bb', 'high_freq_bb', 'bandwidth', - 'max_power_bb', 'max_power', 'max_power_first', - 'max_power_second', 'call_interval'] + feature_names = [ + "duration", + "low_freq_bb", + "high_freq_bb", + "bandwidth", + "max_power_bb", + "max_power", + "max_power_first", + "max_power_second", + "call_interval", + ] return feature_names @@ -45,40 +59,76 @@ def get_feats(spec, pred_nms, params): https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt """ - x_pos = pred_nms['x_pos'] - y_pos = pred_nms['y_pos'] - bb_width = pred_nms['bb_width'] - bb_height = pred_nms['bb_height'] + x_pos = pred_nms["x_pos"] + y_pos = pred_nms["y_pos"] + bb_width = pred_nms["bb_width"] + bb_height = pred_nms["bb_height"] - feature_names = get_feature_names() - num_detections = len(pred_nms['det_probs']) - features = np.ones((num_detections, len(feature_names)), dtype=np.float32)*-1 + feature_names = get_feature_names() + num_detections = len(pred_nms["det_probs"]) + features = ( + np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1 + ) for ff in range(num_detections): x_start = int(np.maximum(0, x_pos[ff])) - x_end = int(np.minimum(spec.shape[1]-1, np.round(x_pos[ff] + bb_width[ff]))) + x_end = int( + np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff])) + ) # y low is the lowest freq but it will have a higher value due to array starting at 0 at top - y_low = int(np.minimum(spec.shape[0]-1, y_pos[ff])) - y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff]))) + y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff])) + y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff]))) spec_slice = spec[:, x_start:x_end] if spec_slice.shape[1] > 1: - features[ff, 0] = round(pred_nms['end_times'][ff] - pred_nms['start_times'][ff], 5) - features[ff, 1] = int(pred_nms['low_freqs'][ff]) - features[ff, 2] = int(pred_nms['high_freqs'][ff]) - features[ff, 3] = int(pred_nms['high_freqs'][ff] - pred_nms['low_freqs'][ff]) - features[ff, 4] = int(convert_int_to_freq(y_high+spec_slice[y_high:y_low, :].sum(1).argmax(), - spec.shape[0], params['min_freq'], params['max_freq'])) - features[ff, 5] = int(convert_int_to_freq(spec_slice.sum(1).argmax(), - spec.shape[0], params['min_freq'], params['max_freq'])) - hlf_val = spec_slice.shape[1]//2 + features[ff, 0] = round( + pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5 + ) + features[ff, 1] = int(pred_nms["low_freqs"][ff]) + features[ff, 2] = int(pred_nms["high_freqs"][ff]) + features[ff, 3] = int( + pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff] + ) + features[ff, 4] = int( + convert_int_to_freq( + y_high + spec_slice[y_high:y_low, :].sum(1).argmax(), + spec.shape[0], + params["min_freq"], + params["max_freq"], + ) + ) + features[ff, 5] = int( + convert_int_to_freq( + spec_slice.sum(1).argmax(), + spec.shape[0], + params["min_freq"], + params["max_freq"], + ) + ) + hlf_val = spec_slice.shape[1] // 2 - features[ff, 6] = int(convert_int_to_freq(spec_slice[:, :hlf_val].sum(1).argmax(), - spec.shape[0], params['min_freq'], params['max_freq'])) - features[ff, 7] = int(convert_int_to_freq(spec_slice[:, hlf_val:].sum(1).argmax(), - spec.shape[0], params['min_freq'], params['max_freq'])) + features[ff, 6] = int( + convert_int_to_freq( + spec_slice[:, :hlf_val].sum(1).argmax(), + spec.shape[0], + params["min_freq"], + params["max_freq"], + ) + ) + features[ff, 7] = int( + convert_int_to_freq( + spec_slice[:, hlf_val:].sum(1).argmax(), + spec.shape[0], + params["min_freq"], + params["max_freq"], + ) + ) if ff > 0: - features[ff, 8] = round(pred_nms['start_times'][ff] - pred_nms['start_times'][ff-1], 5) + features[ff, 8] = round( + pred_nms["start_times"][ff] + - pred_nms["start_times"][ff - 1], + 5, + ) return features diff --git a/bat_detect/detector/model_helpers.py b/bat_detect/detector/model_helpers.py index c91ef04..e237f7c 100644 --- a/bat_detect/detector/model_helpers.py +++ b/bat_detect/detector/model_helpers.py @@ -1,47 +1,71 @@ -import torch.nn as nn -import torch -import torch.nn.functional as F -import numpy as np import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + class SelfAttention(nn.Module): def __init__(self, ip_dim, att_dim): super(SelfAttention, self).__init__() # Note, does not encode position information (absolute or realtive) self.temperature = 1.0 - self.att_dim = att_dim + self.att_dim = att_dim self.key_fun = nn.Linear(ip_dim, att_dim) self.val_fun = nn.Linear(ip_dim, att_dim) self.que_fun = nn.Linear(ip_dim, att_dim) self.pro_fun = nn.Linear(att_dim, ip_dim) def forward(self, x): - x = x.squeeze(2).permute(0,2,1) + x = x.squeeze(2).permute(0, 2, 1) - kk = torch.matmul(x, self.key_fun.weight.T) + self.key_fun.bias.unsqueeze(0).unsqueeze(0) - qq = torch.matmul(x, self.que_fun.weight.T) + self.que_fun.bias.unsqueeze(0).unsqueeze(0) - vv = torch.matmul(x, self.val_fun.weight.T) + self.val_fun.bias.unsqueeze(0).unsqueeze(0) + kk = torch.matmul( + x, self.key_fun.weight.T + ) + self.key_fun.bias.unsqueeze(0).unsqueeze(0) + qq = torch.matmul( + x, self.que_fun.weight.T + ) + self.que_fun.bias.unsqueeze(0).unsqueeze(0) + vv = torch.matmul( + x, self.val_fun.weight.T + ) + self.val_fun.bias.unsqueeze(0).unsqueeze(0) - kk_qq = torch.bmm(kk, qq.permute(0,2,1)) / (self.temperature*self.att_dim) - att_weights = F.softmax(kk_qq, 1) # each col of each attention matrix sums to 1 - att = torch.bmm(vv.permute(0,2,1), att_weights) + kk_qq = torch.bmm(kk, qq.permute(0, 2, 1)) / ( + self.temperature * self.att_dim + ) + att_weights = F.softmax( + kk_qq, 1 + ) # each col of each attention matrix sums to 1 + att = torch.bmm(vv.permute(0, 2, 1), att_weights) - op = torch.matmul(att.permute(0,2,1), self.pro_fun.weight.T) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0) - op = op.permute(0,2,1).unsqueeze(2) + op = torch.matmul( + att.permute(0, 2, 1), self.pro_fun.weight.T + ) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0) + op = op.permute(0, 2, 1).unsqueeze(2) return op class ConvBlockDownCoordF(nn.Module): - def __init__(self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1): + def __init__( + self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1 + ): super(ConvBlockDownCoordF, self).__init__() - self.coords = nn.Parameter(torch.linspace(-1, 1, ip_height)[None, None, ..., None], requires_grad=False) - self.conv = nn.Conv2d(in_chn+1, out_chn, kernel_size=k_size, padding=pad_size, stride=stride) + self.coords = nn.Parameter( + torch.linspace(-1, 1, ip_height)[None, None, ..., None], + requires_grad=False, + ) + self.conv = nn.Conv2d( + in_chn + 1, + out_chn, + kernel_size=k_size, + padding=pad_size, + stride=stride, + ) self.conv_bn = nn.BatchNorm2d(out_chn) def forward(self, x): - freq_info = self.coords.repeat(x.shape[0],1,1,x.shape[3]) + freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3]) x = torch.cat((x, freq_info), 1) x = F.max_pool2d(self.conv(x), 2, 2) x = F.relu(self.conv_bn(x), inplace=True) @@ -49,9 +73,17 @@ class ConvBlockDownCoordF(nn.Module): class ConvBlockDownStandard(nn.Module): - def __init__(self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1): + def __init__( + self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1 + ): super(ConvBlockDownStandard, self).__init__() - self.conv = nn.Conv2d(in_chn, out_chn, kernel_size=k_size, padding=pad_size, stride=stride) + self.conv = nn.Conv2d( + in_chn, + out_chn, + kernel_size=k_size, + padding=pad_size, + stride=stride, + ) self.conv_bn = nn.BatchNorm2d(out_chn) def forward(self, x): @@ -61,17 +93,41 @@ class ConvBlockDownStandard(nn.Module): class ConvBlockUpF(nn.Module): - def __init__(self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, up_mode='bilinear', up_scale=(2,2)): + def __init__( + self, + in_chn, + out_chn, + ip_height, + k_size=3, + pad_size=1, + up_mode="bilinear", + up_scale=(2, 2), + ): super(ConvBlockUpF, self).__init__() self.up_scale = up_scale self.up_mode = up_mode - self.coords = nn.Parameter(torch.linspace(-1, 1, ip_height*up_scale[0])[None, None, ..., None], requires_grad=False) - self.conv = nn.Conv2d(in_chn+1, out_chn, kernel_size=k_size, padding=pad_size) + self.coords = nn.Parameter( + torch.linspace(-1, 1, ip_height * up_scale[0])[ + None, None, ..., None + ], + requires_grad=False, + ) + self.conv = nn.Conv2d( + in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size + ) self.conv_bn = nn.BatchNorm2d(out_chn) def forward(self, x): - op = F.interpolate(x, size=(x.shape[-2]*self.up_scale[0], x.shape[-1]*self.up_scale[1]), mode=self.up_mode, align_corners=False) - freq_info = self.coords.repeat(op.shape[0],1,1,op.shape[3]) + op = F.interpolate( + x, + size=( + x.shape[-2] * self.up_scale[0], + x.shape[-1] * self.up_scale[1], + ), + mode=self.up_mode, + align_corners=False, + ) + freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3]) op = torch.cat((op, freq_info), 1) op = self.conv(op) op = F.relu(self.conv_bn(op), inplace=True) @@ -79,15 +135,34 @@ class ConvBlockUpF(nn.Module): class ConvBlockUpStandard(nn.Module): - def __init__(self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, up_mode='bilinear', up_scale=(2,2)): + def __init__( + self, + in_chn, + out_chn, + ip_height=None, + k_size=3, + pad_size=1, + up_mode="bilinear", + up_scale=(2, 2), + ): super(ConvBlockUpStandard, self).__init__() self.up_scale = up_scale self.up_mode = up_mode - self.conv = nn.Conv2d(in_chn, out_chn, kernel_size=k_size, padding=pad_size) + self.conv = nn.Conv2d( + in_chn, out_chn, kernel_size=k_size, padding=pad_size + ) self.conv_bn = nn.BatchNorm2d(out_chn) def forward(self, x): - op = F.interpolate(x, size=(x.shape[-2]*self.up_scale[0], x.shape[-1]*self.up_scale[1]), mode=self.up_mode, align_corners=False) + op = F.interpolate( + x, + size=( + x.shape[-2] * self.up_scale[0], + x.shape[-1] * self.up_scale[1], + ), + mode=self.up_mode, + align_corners=False, + ) op = self.conv(op) op = F.relu(self.conv_bn(op), inplace=True) return op diff --git a/bat_detect/detector/models.py b/bat_detect/detector/models.py index fc7b5b4..b39cbf4 100644 --- a/bat_detect/detector/models.py +++ b/bat_detect/detector/models.py @@ -1,52 +1,97 @@ -import torch.nn as nn -import torch -import torch.nn.functional as F import numpy as np -from .model_helpers import * - -import torchvision - +import torch import torch.fft +import torch.nn as nn +import torch.nn.functional as F +import torchvision from torch import nn +from .model_helpers import * + class Net2DFast(nn.Module): - def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5): + def __init__( + self, + num_filts, + num_classes=0, + emb_dim=0, + ip_height=128, + resize_factor=0.5, + ): super(Net2DFast, self).__init__() self.num_classes = num_classes self.emb_dim = emb_dim self.num_filts = num_filts self.resize_factor = resize_factor self.ip_height_rs = ip_height - self.bneck_height = self.ip_height_rs//32 + self.bneck_height = self.ip_height_rs // 32 # encoder - self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1) - self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1) - self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1) - self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1) - self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2) + self.conv_dn_0 = ConvBlockDownCoordF( + 1, + num_filts // 4, + self.ip_height_rs, + k_size=3, + pad_size=1, + stride=1, + ) + self.conv_dn_1 = ConvBlockDownCoordF( + num_filts // 4, + num_filts // 2, + self.ip_height_rs // 2, + k_size=3, + pad_size=1, + stride=1, + ) + self.conv_dn_2 = ConvBlockDownCoordF( + num_filts // 2, + num_filts, + self.ip_height_rs // 4, + k_size=3, + pad_size=1, + stride=1, + ) + self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1) + self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2) # bottleneck - self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0) - self.conv_1d_bn = nn.BatchNorm2d(num_filts*2) - self.att = SelfAttention(num_filts*2, num_filts*2) + self.conv_1d = nn.Conv2d( + num_filts * 2, + num_filts * 2, + (self.ip_height_rs // 8, 1), + padding=0, + ) + self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2) + self.att = SelfAttention(num_filts * 2, num_filts * 2) # decoder - self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8) - self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4) - self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2) + self.conv_up_2 = ConvBlockUpF( + num_filts * 2, num_filts // 2, self.ip_height_rs // 8 + ) + self.conv_up_3 = ConvBlockUpF( + num_filts // 2, num_filts // 4, self.ip_height_rs // 4 + ) + self.conv_up_4 = ConvBlockUpF( + num_filts // 4, num_filts // 4, self.ip_height_rs // 2 + ) # output # +1 to include background class for class output - self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1) - self.conv_op_bn = nn.BatchNorm2d(num_filts//4) - self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0) - self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0) + self.conv_op = nn.Conv2d( + num_filts // 4, num_filts // 4, kernel_size=3, padding=1 + ) + self.conv_op_bn = nn.BatchNorm2d(num_filts // 4) + self.conv_size_op = nn.Conv2d( + num_filts // 4, 2, kernel_size=1, padding=0 + ) + self.conv_classes_op = nn.Conv2d( + num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0 + ) if self.emb_dim > 0: - self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0) - + self.conv_emb = nn.Conv2d( + num_filts, self.emb_dim, kernel_size=1, padding=0 + ) def forward(self, ip, return_feats=False): @@ -59,33 +104,40 @@ class Net2DFast(nn.Module): # bottleneck x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = self.att(x) - x = x.repeat([1,1,self.bneck_height*4,1]) + x = x.repeat([1, 1, self.bneck_height * 4, 1]) # decoder - x = self.conv_up_2(x+x3) - x = self.conv_up_3(x+x2) - x = self.conv_up_4(x+x1) + x = self.conv_up_2(x + x3) + x = self.conv_up_3(x + x2) + x = self.conv_up_4(x + x1) # output x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) - cls = self.conv_classes_op(x) + cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) op = {} - op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1) - op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True) - op['pred_class'] = comb - op['pred_class_un_norm'] = cls + op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1) + op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True) + op["pred_class"] = comb + op["pred_class_un_norm"] = cls if self.emb_dim > 0: - op['pred_emb'] = self.conv_emb(x) + op["pred_emb"] = self.conv_emb(x) if return_feats: - op['features'] = x + op["features"] = x return op class Net2DFastNoAttn(nn.Module): - def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5): + def __init__( + self, + num_filts, + num_classes=0, + emb_dim=0, + ip_height=128, + resize_factor=0.5, + ): super(Net2DFastNoAttn, self).__init__() self.num_classes = num_classes @@ -93,31 +145,70 @@ class Net2DFastNoAttn(nn.Module): self.num_filts = num_filts self.resize_factor = resize_factor self.ip_height_rs = ip_height - self.bneck_height = self.ip_height_rs//32 + self.bneck_height = self.ip_height_rs // 32 - self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1) - self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1) - self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1) - self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1) - self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2) + self.conv_dn_0 = ConvBlockDownCoordF( + 1, + num_filts // 4, + self.ip_height_rs, + k_size=3, + pad_size=1, + stride=1, + ) + self.conv_dn_1 = ConvBlockDownCoordF( + num_filts // 4, + num_filts // 2, + self.ip_height_rs // 2, + k_size=3, + pad_size=1, + stride=1, + ) + self.conv_dn_2 = ConvBlockDownCoordF( + num_filts // 2, + num_filts, + self.ip_height_rs // 4, + k_size=3, + pad_size=1, + stride=1, + ) + self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1) + self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2) - self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0) - self.conv_1d_bn = nn.BatchNorm2d(num_filts*2) + self.conv_1d = nn.Conv2d( + num_filts * 2, + num_filts * 2, + (self.ip_height_rs // 8, 1), + padding=0, + ) + self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2) - - self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8) - self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4) - self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2) + self.conv_up_2 = ConvBlockUpF( + num_filts * 2, num_filts // 2, self.ip_height_rs // 8 + ) + self.conv_up_3 = ConvBlockUpF( + num_filts // 2, num_filts // 4, self.ip_height_rs // 4 + ) + self.conv_up_4 = ConvBlockUpF( + num_filts // 4, num_filts // 4, self.ip_height_rs // 2 + ) # output # +1 to include background class for class output - self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1) - self.conv_op_bn = nn.BatchNorm2d(num_filts//4) - self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0) - self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0) + self.conv_op = nn.Conv2d( + num_filts // 4, num_filts // 4, kernel_size=3, padding=1 + ) + self.conv_op_bn = nn.BatchNorm2d(num_filts // 4) + self.conv_size_op = nn.Conv2d( + num_filts // 4, 2, kernel_size=1, padding=0 + ) + self.conv_classes_op = nn.Conv2d( + num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0 + ) if self.emb_dim > 0: - self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0) + self.conv_emb = nn.Conv2d( + num_filts, self.emb_dim, kernel_size=1, padding=0 + ) def forward(self, ip, return_feats=False): @@ -127,31 +218,38 @@ class Net2DFastNoAttn(nn.Module): x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) - x = x.repeat([1,1,self.bneck_height*4,1]) + x = x.repeat([1, 1, self.bneck_height * 4, 1]) - x = self.conv_up_2(x+x3) - x = self.conv_up_3(x+x2) - x = self.conv_up_4(x+x1) + x = self.conv_up_2(x + x3) + x = self.conv_up_3(x + x2) + x = self.conv_up_4(x + x1) x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) - cls = self.conv_classes_op(x) + cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) op = {} - op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1) - op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True) - op['pred_class'] = comb - op['pred_class_un_norm'] = cls + op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1) + op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True) + op["pred_class"] = comb + op["pred_class_un_norm"] = cls if self.emb_dim > 0: - op['pred_emb'] = self.conv_emb(x) + op["pred_emb"] = self.conv_emb(x) if return_feats: - op['features'] = x + op["features"] = x return op class Net2DFastNoCoordConv(nn.Module): - def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5): + def __init__( + self, + num_filts, + num_classes=0, + emb_dim=0, + ip_height=128, + resize_factor=0.5, + ): super(Net2DFastNoCoordConv, self).__init__() self.num_classes = num_classes @@ -159,32 +257,72 @@ class Net2DFastNoCoordConv(nn.Module): self.num_filts = num_filts self.resize_factor = resize_factor self.ip_height_rs = ip_height - self.bneck_height = self.ip_height_rs//32 + self.bneck_height = self.ip_height_rs // 32 - self.conv_dn_0 = ConvBlockDownStandard(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1) - self.conv_dn_1 = ConvBlockDownStandard(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1) - self.conv_dn_2 = ConvBlockDownStandard(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1) - self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1) - self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2) + self.conv_dn_0 = ConvBlockDownStandard( + 1, + num_filts // 4, + self.ip_height_rs, + k_size=3, + pad_size=1, + stride=1, + ) + self.conv_dn_1 = ConvBlockDownStandard( + num_filts // 4, + num_filts // 2, + self.ip_height_rs // 2, + k_size=3, + pad_size=1, + stride=1, + ) + self.conv_dn_2 = ConvBlockDownStandard( + num_filts // 2, + num_filts, + self.ip_height_rs // 4, + k_size=3, + pad_size=1, + stride=1, + ) + self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1) + self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2) - self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0) - self.conv_1d_bn = nn.BatchNorm2d(num_filts*2) + self.conv_1d = nn.Conv2d( + num_filts * 2, + num_filts * 2, + (self.ip_height_rs // 8, 1), + padding=0, + ) + self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2) - self.att = SelfAttention(num_filts*2, num_filts*2) + self.att = SelfAttention(num_filts * 2, num_filts * 2) - self.conv_up_2 = ConvBlockUpStandard(num_filts*2, num_filts//2, self.ip_height_rs//8) - self.conv_up_3 = ConvBlockUpStandard(num_filts//2, num_filts//4, self.ip_height_rs//4) - self.conv_up_4 = ConvBlockUpStandard(num_filts//4, num_filts//4, self.ip_height_rs//2) + self.conv_up_2 = ConvBlockUpStandard( + num_filts * 2, num_filts // 2, self.ip_height_rs // 8 + ) + self.conv_up_3 = ConvBlockUpStandard( + num_filts // 2, num_filts // 4, self.ip_height_rs // 4 + ) + self.conv_up_4 = ConvBlockUpStandard( + num_filts // 4, num_filts // 4, self.ip_height_rs // 2 + ) # output # +1 to include background class for class output - self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1) - self.conv_op_bn = nn.BatchNorm2d(num_filts//4) - self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0) - self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0) + self.conv_op = nn.Conv2d( + num_filts // 4, num_filts // 4, kernel_size=3, padding=1 + ) + self.conv_op_bn = nn.BatchNorm2d(num_filts // 4) + self.conv_size_op = nn.Conv2d( + num_filts // 4, 2, kernel_size=1, padding=0 + ) + self.conv_classes_op = nn.Conv2d( + num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0 + ) if self.emb_dim > 0: - self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0) + self.conv_emb = nn.Conv2d( + num_filts, self.emb_dim, kernel_size=1, padding=0 + ) def forward(self, ip, return_feats=False): @@ -195,24 +333,24 @@ class Net2DFastNoCoordConv(nn.Module): x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) x = self.att(x) - x = x.repeat([1,1,self.bneck_height*4,1]) + x = x.repeat([1, 1, self.bneck_height * 4, 1]) - x = self.conv_up_2(x+x3) - x = self.conv_up_3(x+x2) - x = self.conv_up_4(x+x1) + x = self.conv_up_2(x + x3) + x = self.conv_up_3(x + x2) + x = self.conv_up_4(x + x1) x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) - cls = self.conv_classes_op(x) + cls = self.conv_classes_op(x) comb = torch.softmax(cls, 1) op = {} - op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1) - op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True) - op['pred_class'] = comb - op['pred_class_un_norm'] = cls + op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1) + op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True) + op["pred_class"] = comb + op["pred_class_un_norm"] = cls if self.emb_dim > 0: - op['pred_emb'] = self.conv_emb(x) + op["pred_emb"] = self.conv_emb(x) if return_feats: - op['features'] = x + op["features"] = x return op diff --git a/bat_detect/detector/parameters.py b/bat_detect/detector/parameters.py index 10276eb..d93ac8c 100644 --- a/bat_detect/detector/parameters.py +++ b/bat_detect/detector/parameters.py @@ -1,108 +1,164 @@ -import numpy as np -import os import datetime +import os + +import numpy as np def mk_dir(path): if not os.path.isdir(path): os.makedirs(path) - - -def get_params(make_dirs=False, exps_dir='../../experiments/'): + + +def get_params(make_dirs=False, exps_dir="../../experiments/"): params = {} - params['model_name'] = 'Net2DFast' # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN - params['num_filters'] = 128 + params[ + "model_name" + ] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN + params["num_filters"] = 128 now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S") - model_name = now_str + '.pth.tar' - params['experiment'] = os.path.join(exps_dir, now_str, '') - params['model_file_name'] = os.path.join(params['experiment'], model_name) - params['op_im_dir'] = os.path.join(params['experiment'], 'op_ims', '') - params['op_im_dir_test'] = os.path.join(params['experiment'], 'op_ims_test', '') - #params['notes'] = '' # can save notes about an experiment here - + model_name = now_str + ".pth.tar" + params["experiment"] = os.path.join(exps_dir, now_str, "") + params["model_file_name"] = os.path.join(params["experiment"], model_name) + params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "") + params["op_im_dir_test"] = os.path.join( + params["experiment"], "op_ims_test", "" + ) + # params['notes'] = '' # can save notes about an experiment here # spec parameters - params['target_samp_rate'] = 256000 # resamples all audio so that it is at this rate - params['fft_win_length'] = 512 / 256000.0 # in milliseconds, amount of time per stft time step - params['fft_overlap'] = 0.75 # stft window overlap + params[ + "target_samp_rate" + ] = 256000 # resamples all audio so that it is at this rate + params["fft_win_length"] = ( + 512 / 256000.0 + ) # in milliseconds, amount of time per stft time step + params["fft_overlap"] = 0.75 # stft window overlap - params['max_freq'] = 120000 # in Hz, everything above this will be discarded - params['min_freq'] = 10000 # in Hz, everything below this will be discarded + params[ + "max_freq" + ] = 120000 # in Hz, everything above this will be discarded + params[ + "min_freq" + ] = 10000 # in Hz, everything below this will be discarded - params['resize_factor'] = 0.5 # resize so the spectrogram at the input of the network - params['spec_height'] = 256 # units are number of frequency bins (before resizing is performed) - params['spec_train_width'] = 512 # units are number of time steps (before resizing is performed) - params['spec_divide_factor'] = 32 # spectrogram should be divisible by this amount in width and height + params[ + "resize_factor" + ] = 0.5 # resize so the spectrogram at the input of the network + params[ + "spec_height" + ] = 256 # units are number of frequency bins (before resizing is performed) + params[ + "spec_train_width" + ] = 512 # units are number of time steps (before resizing is performed) + params[ + "spec_divide_factor" + ] = 32 # spectrogram should be divisible by this amount in width and height # spec processing params - params['denoise_spec_avg'] = True # removes the mean for each frequency band - params['scale_raw_audio'] = False # scales the raw audio to [-1, 1] - params['max_scale_spec'] = False # scales the spectrogram so that it is max 1 - params['spec_scale'] = 'pcen' # 'log', 'pcen', 'none' + params[ + "denoise_spec_avg" + ] = True # removes the mean for each frequency band + params["scale_raw_audio"] = False # scales the raw audio to [-1, 1] + params[ + "max_scale_spec" + ] = False # scales the spectrogram so that it is max 1 + params["spec_scale"] = "pcen" # 'log', 'pcen', 'none' # detection params - params['detection_overlap'] = 0.01 # has to be within this number of ms to count as detection - params['ignore_start_end'] = 0.01 # if start of GT calls are within this time from the start/end of file ignore - params['detection_threshold'] = 0.01 # the smaller this is the better the recall will be - params['nms_kernel_size'] = 9 - params['nms_top_k_per_sec'] = 200 # keep top K highest predictions per second of audio - params['target_sigma'] = 2.0 + params[ + "detection_overlap" + ] = 0.01 # has to be within this number of ms to count as detection + params[ + "ignore_start_end" + ] = 0.01 # if start of GT calls are within this time from the start/end of file ignore + params[ + "detection_threshold" + ] = 0.01 # the smaller this is the better the recall will be + params["nms_kernel_size"] = 9 + params[ + "nms_top_k_per_sec" + ] = 200 # keep top K highest predictions per second of audio + params["target_sigma"] = 2.0 # augmentation params - params['aug_prob'] = 0.20 # augmentations will be performed with this probability - params['augment_at_train'] = True - params['augment_at_train_combine'] = True - params['echo_max_delay'] = 0.005 # simulate echo by adding copy of raw audio - params['stretch_squeeze_delta'] = 0.04 # stretch or squeeze spec - params['mask_max_time_perc'] = 0.05 # max mask size - here percentage, not ideal - params['mask_max_freq_perc'] = 0.10 # max mask size - here percentage, not ideal - params['spec_amp_scaling'] = 2.0 # multiply the "volume" by 0:X times current amount - params['aug_sampling_rates'] = [220500, 256000, 300000, 312500, 384000, 441000, 500000] + params[ + "aug_prob" + ] = 0.20 # augmentations will be performed with this probability + params["augment_at_train"] = True + params["augment_at_train_combine"] = True + params[ + "echo_max_delay" + ] = 0.005 # simulate echo by adding copy of raw audio + params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec + params[ + "mask_max_time_perc" + ] = 0.05 # max mask size - here percentage, not ideal + params[ + "mask_max_freq_perc" + ] = 0.10 # max mask size - here percentage, not ideal + params[ + "spec_amp_scaling" + ] = 2.0 # multiply the "volume" by 0:X times current amount + params["aug_sampling_rates"] = [ + 220500, + 256000, + 300000, + 312500, + 384000, + 441000, + 500000, + ] # loss params - params['train_loss'] = 'focal' # mse or focal - params['det_loss_weight'] = 1.0 # weight for the detection part of the loss - params['size_loss_weight'] = 0.1 # weight for the bbox size loss - params['class_loss_weight'] = 2.0 # weight for the classification loss - params['individual_loss_weight'] = 0.0 # not used - if params['individual_loss_weight'] == 0.0: - params['emb_dim'] = 0 # number of dimensions used for individual id embedding + params["train_loss"] = "focal" # mse or focal + params[ + "det_loss_weight" + ] = 1.0 # weight for the detection part of the loss + params["size_loss_weight"] = 0.1 # weight for the bbox size loss + params["class_loss_weight"] = 2.0 # weight for the classification loss + params["individual_loss_weight"] = 0.0 # not used + if params["individual_loss_weight"] == 0.0: + params[ + "emb_dim" + ] = 0 # number of dimensions used for individual id embedding else: - params['emb_dim'] = 3 + params["emb_dim"] = 3 # train params - params['lr'] = 0.001 - params['batch_size'] = 8 - params['num_workers'] = 4 - params['num_epochs'] = 200 - params['num_eval_epochs'] = 5 # run evaluation every X epochs - params['device'] = 'cuda' - params['save_test_image_during_train'] = False - params['save_test_image_after_train'] = True + params["lr"] = 0.001 + params["batch_size"] = 8 + params["num_workers"] = 4 + params["num_epochs"] = 200 + params["num_eval_epochs"] = 5 # run evaluation every X epochs + params["device"] = "cuda" + params["save_test_image_during_train"] = False + params["save_test_image_after_train"] = True - params['convert_to_genus'] = False - params['genus_mapping'] = [] - params['class_names'] = [] - params['classes_to_ignore'] = ['', ' ', 'Unknown', 'Not Bat'] - params['generic_class'] = ['Bat'] - params['events_of_interest'] = ['Echolocation'] # will ignore all other types of events e.g. social calls + params["convert_to_genus"] = False + params["genus_mapping"] = [] + params["class_names"] = [] + params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"] + params["generic_class"] = ["Bat"] + params["events_of_interest"] = [ + "Echolocation" + ] # will ignore all other types of events e.g. social calls # the classes in this list are standardized during training so that the same low and high freq are used - params['standardize_classs_names'] = [] + params["standardize_classs_names"] = [] # create directories if make_dirs: - print('Model name : ' + params['model_name']) - print('Model file : ' + params['model_file_name']) - print('Experiment : ' + params['experiment']) + print("Model name : " + params["model_name"]) + print("Model file : " + params["model_file_name"]) + print("Experiment : " + params["experiment"]) - mk_dir(params['experiment']) - if params['save_test_image_during_train']: - mk_dir(params['op_im_dir']) - if params['save_test_image_after_train']: - mk_dir(params['op_im_dir_test']) - mk_dir(os.path.dirname(params['model_file_name'])) + mk_dir(params["experiment"]) + if params["save_test_image_during_train"]: + mk_dir(params["op_im_dir"]) + if params["save_test_image_after_train"]: + mk_dir(params["op_im_dir_test"]) + mk_dir(os.path.dirname(params["model_file_name"])) return params diff --git a/bat_detect/detector/post_process.py b/bat_detect/detector/post_process.py index 757831f..2745cdf 100644 --- a/bat_detect/detector/post_process.py +++ b/bat_detect/detector/post_process.py @@ -1,35 +1,42 @@ +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np -np.seterr(divide='ignore', invalid='ignore') + +np.seterr(divide="ignore", invalid="ignore") def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): - nfft = int(fft_win_length*sampling_rate) - noverlap = int(fft_overlap*nfft) - return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate - #return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window + nfft = int(fft_win_length * sampling_rate) + noverlap = int(fft_overlap * nfft) + return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate + # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window def overall_class_pred(det_prob, class_prob): - weighted_pred = (class_prob*det_prob).sum(1) + weighted_pred = (class_prob * det_prob).sum(1) return weighted_pred / weighted_pred.sum() def run_nms(outputs, params, sampling_rate): - pred_det = outputs['pred_det'] # probability of box - pred_size = outputs['pred_size'] # box size + pred_det = outputs["pred_det"] # probability of box + pred_size = outputs["pred_size"] # box size - pred_det_nms = non_max_suppression(pred_det, params['nms_kernel_size']) - freq_rescale = (params['max_freq'] - params['min_freq']) /pred_det.shape[-2] + pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"]) + freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[ + -2 + ] # NOTE there will be small differences depending on which sampling rate is chosen # as we are choosing the same sampling rate for the entire batch - duration = x_coords_to_time(pred_det.shape[-1], sampling_rate[0].item(), - params['fft_win_length'], params['fft_overlap']) - top_k = int(duration * params['nms_top_k_per_sec']) + duration = x_coords_to_time( + pred_det.shape[-1], + sampling_rate[0].item(), + params["fft_win_length"], + params["fft_overlap"], + ) + top_k = int(duration * params["nms_top_k_per_sec"]) scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) # loop over batch to save outputs @@ -38,30 +45,47 @@ def run_nms(outputs, params, sampling_rate): for ii in range(pred_det_nms.shape[0]): # get valid indices inds_ord = torch.argsort(x_pos[ii, :]) - valid_inds = scores[ii, inds_ord] > params['detection_threshold'] + valid_inds = scores[ii, inds_ord] > params["detection_threshold"] valid_inds = inds_ord[valid_inds] # create result dictionary pred = {} - pred['det_probs'] = scores[ii, valid_inds] - pred['x_pos'] = x_pos[ii, valid_inds] - pred['y_pos'] = y_pos[ii, valid_inds] - pred['bb_width'] = pred_size[ii, 0, pred['y_pos'], pred['x_pos']] - pred['bb_height'] = pred_size[ii, 1, pred['y_pos'], pred['x_pos']] - pred['start_times'] = x_coords_to_time(pred['x_pos'].float() / params['resize_factor'], - sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap']) - pred['end_times'] = x_coords_to_time((pred['x_pos'].float()+pred['bb_width']) / params['resize_factor'], - sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap']) - pred['low_freqs'] = (pred_size[ii].shape[1] - pred['y_pos'].float())*freq_rescale + params['min_freq'] - pred['high_freqs'] = pred['low_freqs'] + pred['bb_height']*freq_rescale + pred["det_probs"] = scores[ii, valid_inds] + pred["x_pos"] = x_pos[ii, valid_inds] + pred["y_pos"] = y_pos[ii, valid_inds] + pred["bb_width"] = pred_size[ii, 0, pred["y_pos"], pred["x_pos"]] + pred["bb_height"] = pred_size[ii, 1, pred["y_pos"], pred["x_pos"]] + pred["start_times"] = x_coords_to_time( + pred["x_pos"].float() / params["resize_factor"], + sampling_rate[ii].item(), + params["fft_win_length"], + params["fft_overlap"], + ) + pred["end_times"] = x_coords_to_time( + (pred["x_pos"].float() + pred["bb_width"]) + / params["resize_factor"], + sampling_rate[ii].item(), + params["fft_win_length"], + params["fft_overlap"], + ) + pred["low_freqs"] = ( + pred_size[ii].shape[1] - pred["y_pos"].float() + ) * freq_rescale + params["min_freq"] + pred["high_freqs"] = ( + pred["low_freqs"] + pred["bb_height"] * freq_rescale + ) # extract the per class votes - if 'pred_class' in outputs: - pred['class_probs'] = outputs['pred_class'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]] + if "pred_class" in outputs: + pred["class_probs"] = outputs["pred_class"][ + ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds] + ] # extract the model features - if 'features' in outputs: - feat = outputs['features'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]].transpose(0, 1) + if "features" in outputs: + feat = outputs["features"][ + ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds] + ].transpose(0, 1) feat = feat.cpu().numpy().astype(np.float32) feats.append(feat) @@ -82,7 +106,9 @@ def non_max_suppression(heat, kernel_size): pad_h = (kernel_size_h - 1) // 2 pad_w = (kernel_size_w - 1) // 2 - hmax = nn.functional.max_pool2d(heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w)) + hmax = nn.functional.max_pool2d( + heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w) + ) keep = (hmax == heat).float() return heat * keep @@ -94,7 +120,7 @@ def get_topk_scores(scores, K): topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K) topk_inds = topk_inds % (height * width) - topk_ys = torch.div(topk_inds, width, rounding_mode='floor').long() - topk_xs = (topk_inds % width).long() + topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long() + topk_xs = (topk_inds % width).long() return topk_scores, topk_ys, topk_xs diff --git a/bat_detect/evaluate/evaluate_models.py b/bat_detect/evaluate/evaluate_models.py index 0fc8ae9..6b7c460 100644 --- a/bat_detect/evaluate/evaluate_models.py +++ b/bat_detect/evaluate/evaluate_models.py @@ -2,67 +2,76 @@ Evaluates trained model on test set and generates plots. """ -import numpy as np -import sys -import os +import argparse import copy import json +import os +import sys + +import numpy as np import pandas as pd from sklearn.ensemble import RandomForestClassifier -import argparse -sys.path.append('../../') -import bat_detect.utils.detector_utils as du -import bat_detect.train.train_utils as tu +sys.path.append("../../") import bat_detect.detector.parameters as parameters import bat_detect.train.evaluate as evl +import bat_detect.train.train_utils as tu +import bat_detect.utils.detector_utils as du import bat_detect.utils.plot_utils as pu def get_blank_annotation(ip_str): res = {} - res['class_name'] = '' - res['duration'] = -1 - res['id'] = ''# fileName - res['issues'] = False - res['notes'] = ip_str - res['time_exp'] = 1 - res['annotated'] = False - res['annotation'] = [] + res["class_name"] = "" + res["duration"] = -1 + res["id"] = "" # fileName + res["issues"] = False + res["notes"] = ip_str + res["time_exp"] = 1 + res["annotated"] = False + res["annotation"] = [] ann = {} - ann['class'] = '' - ann['event'] = 'Echolocation' - ann['individual'] = -1 - ann['start_time'] = -1 - ann['end_time'] = -1 - ann['low_freq'] = -1 - ann['high_freq'] = -1 - ann['confidence'] = -1 + ann["class"] = "" + ann["event"] = "Echolocation" + ann["individual"] = -1 + ann["start_time"] = -1 + ann["end_time"] = -1 + ann["low_freq"] = -1 + ann["high_freq"] = -1 + ann["confidence"] = -1 return copy.deepcopy(res), copy.deepcopy(ann) def create_genus_mapping(gt_test, preds, class_names): # rolls the per class predictions and ground truth back up to genus level - class_names_genus, cls_to_genus = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True) - genus_to_cls_map = [np.where(np.array(cls_to_genus) == cc)[0] for cc in range(len(class_names_genus))] + class_names_genus, cls_to_genus = np.unique( + [cc.split(" ")[0] for cc in class_names], return_inverse=True + ) + genus_to_cls_map = [ + np.where(np.array(cls_to_genus) == cc)[0] + for cc in range(len(class_names_genus)) + ] gt_test_g = [] for gg in gt_test: gg_g = copy.deepcopy(gg) - inds = np.where(gg_g['class_ids']!=-1)[0] - gg_g['class_ids'][inds] = cls_to_genus[gg_g['class_ids'][inds]] + inds = np.where(gg_g["class_ids"] != -1)[0] + gg_g["class_ids"][inds] = cls_to_genus[gg_g["class_ids"][inds]] gt_test_g.append(gg_g) # note, will have entries geater than one as we are summing across the respective classes preds_g = [] for pp in preds: pp_g = copy.deepcopy(pp) - pp_g['class_probs'] = np.zeros((len(class_names_genus), pp_g['class_probs'].shape[1]), dtype=np.float32) + pp_g["class_probs"] = np.zeros( + (len(class_names_genus), pp_g["class_probs"].shape[1]), + dtype=np.float32, + ) for cc, inds in enumerate(genus_to_cls_map): - pp_g['class_probs'][cc, :] = pp['class_probs'][inds, :].sum(0) + pp_g["class_probs"][cc, :] = pp["class_probs"][inds, :].sum(0) preds_g.append(pp_g) return class_names_genus, preds_g, gt_test_g @@ -70,56 +79,70 @@ def create_genus_mapping(gt_test, preds, class_names): def load_tadarida_pred(ip_dir, dataset, file_of_interest): - res, ann = get_blank_annotation('Generated by Tadarida') + res, ann = get_blank_annotation("Generated by Tadarida") # create the annotations in the correct format - da_c = pd.read_csv(ip_dir + dataset + '/' + file_of_interest.replace('.wav', '.ta').replace('.WAV', '.ta'), sep='\t') + da_c = pd.read_csv( + ip_dir + + dataset + + "/" + + file_of_interest.replace(".wav", ".ta").replace(".WAV", ".ta"), + sep="\t", + ) res_c = copy.deepcopy(res) - res_c['id'] = file_of_interest - res_c['dataset'] = dataset - res_c['feats'] = da_c.iloc[:, 6:].values.astype(np.float32) + res_c["id"] = file_of_interest + res_c["dataset"] = dataset + res_c["feats"] = da_c.iloc[:, 6:].values.astype(np.float32) if da_c.shape[0] > 0: - res_c['class_name'] = '' - res_c['class_prob'] = 0.0 + res_c["class_name"] = "" + res_c["class_prob"] = 0.0 for aa in range(da_c.shape[0]): ann_c = copy.deepcopy(ann) - ann_c['class'] = 'Not Bat' # will assign to class later - ann_c['start_time'] = np.round(da_c.iloc[aa]['StTime']/1000.0 ,5) - ann_c['end_time'] = np.round((da_c.iloc[aa]['StTime'] + da_c.iloc[aa]['Dur'])/1000.0, 5) - ann_c['low_freq'] = np.round(da_c.iloc[aa]['Fmin'] * 1000.0, 2) - ann_c['high_freq'] = np.round(da_c.iloc[aa]['Fmax'] * 1000.0, 2) - ann_c['det_prob'] = 0.0 - res_c['annotation'].append(ann_c) + ann_c["class"] = "Not Bat" # will assign to class later + ann_c["start_time"] = np.round(da_c.iloc[aa]["StTime"] / 1000.0, 5) + ann_c["end_time"] = np.round( + (da_c.iloc[aa]["StTime"] + da_c.iloc[aa]["Dur"]) / 1000.0, 5 + ) + ann_c["low_freq"] = np.round(da_c.iloc[aa]["Fmin"] * 1000.0, 2) + ann_c["high_freq"] = np.round(da_c.iloc[aa]["Fmax"] * 1000.0, 2) + ann_c["det_prob"] = 0.0 + res_c["annotation"].append(ann_c) return res_c -def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_accepted_species=True): +def load_sonobat_meta( + ip_dir, + datasets, + region_classifier, + class_names, + only_accepted_species=True, +): sp_dict = {} for ss in class_names: - sp_key = ss.split(' ')[0][:3] + ss.split(' ')[1][:3] + sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3] sp_dict[sp_key] = ss - sp_dict['x'] = '' # not bat - sp_dict['Bat'] = 'Bat' + sp_dict["x"] = "" # not bat + sp_dict["Bat"] = "Bat" sonobat_meta = {} for tt in datasets: - dataset = tt['dataset_name'] - sb_ip_dir = ip_dir + dataset + '/' + region_classifier + '/' + dataset = tt["dataset_name"] + sb_ip_dir = ip_dir + dataset + "/" + region_classifier + "/" # load the call level predictions - ip_file_p = sb_ip_dir + dataset + '_Parameters_v4.5.0.txt' - #ip_file_p = sb_ip_dir + 'audio_SonoBatch_v30.0 beta.txt' - da = pd.read_csv(ip_file_p, sep='\t') + ip_file_p = sb_ip_dir + dataset + "_Parameters_v4.5.0.txt" + # ip_file_p = sb_ip_dir + 'audio_SonoBatch_v30.0 beta.txt' + da = pd.read_csv(ip_file_p, sep="\t") # load the file level predictions - ip_file_b = sb_ip_dir + dataset + '_SonoBatch_v4.5.0.txt' - #ip_file_b = sb_ip_dir + 'audio_CumulativeParameters_v30.0 beta.txt' + ip_file_b = sb_ip_dir + dataset + "_SonoBatch_v4.5.0.txt" + # ip_file_b = sb_ip_dir + 'audio_CumulativeParameters_v30.0 beta.txt' with open(ip_file_b) as f: lines = f.readlines() @@ -129,7 +152,7 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc file_res = {} for ll in lines: # note this does not seem to parse the file very well - ll_data = ll.split('\t') + ll_data = ll.split("\t") # there are sometimes many different species names per file if only_accepted_species: @@ -137,20 +160,24 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc ind = 4 else: # choosing ""~Spp" if "SppAccp" does not exist - if ll_data[4] != 'x': - ind = 4 # choosing "SppAccp", along with "Prob" here + if ll_data[4] != "x": + ind = 4 # choosing "SppAccp", along with "Prob" here else: ind = 8 # choosing "~Spp", along with "~Prob" here sp_name_1 = sp_dict[ll_data[ind]] - prob_1 = ll_data[ind+1] - if prob_1 == 'x': + prob_1 = ll_data[ind + 1] + if prob_1 == "x": prob_1 = 0.0 - file_res[ll_data[1]] = {'id':ll_data[1], 'species_1':sp_name_1, 'prob_1':prob_1} + file_res[ll_data[1]] = { + "id": ll_data[1], + "species_1": sp_name_1, + "prob_1": prob_1, + } sonobat_meta[dataset] = {} - sonobat_meta[dataset]['file_res'] = file_res - sonobat_meta[dataset]['call_info'] = da + sonobat_meta[dataset]["file_res"] = file_res + sonobat_meta[dataset]["call_info"] = da return sonobat_meta @@ -158,34 +185,38 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None): # create the annotations in the correct format - res, ann = get_blank_annotation('Generated by Sonobat') + res, ann = get_blank_annotation("Generated by Sonobat") res_c = copy.deepcopy(res) - res_c['id'] = id - res_c['dataset'] = dataset + res_c["id"] = id + res_c["dataset"] = dataset - da = sb_meta[dataset]['call_info'] - da_c = da[da['Filename'] == id] + da = sb_meta[dataset]["call_info"] + da_c = da[da["Filename"] == id] - file_res = sb_meta[dataset]['file_res'] - res_c['feats'] = np.zeros((0,0)) + file_res = sb_meta[dataset]["file_res"] + res_c["feats"] = np.zeros((0, 0)) if da_c.shape[0] > 0: - res_c['class_name'] = file_res[id]['species_1'] - res_c['class_prob'] = file_res[id]['prob_1'] - res_c['feats'] = da_c.iloc[:, 3:105].values.astype(np.float32) + res_c["class_name"] = file_res[id]["species_1"] + res_c["class_prob"] = file_res[id]["prob_1"] + res_c["feats"] = da_c.iloc[:, 3:105].values.astype(np.float32) for aa in range(da_c.shape[0]): ann_c = copy.deepcopy(ann) if set_class_name is None: - ann_c['class'] = file_res[id]['species_1'] + ann_c["class"] = file_res[id]["species_1"] else: - ann_c['class'] = set_class_name - ann_c['start_time'] = np.round(da_c.iloc[aa]['TimeInFile'] / 1000.0 ,5) - ann_c['end_time'] = np.round(ann_c['start_time'] + da_c.iloc[aa]['CallDuration']/1000.0, 5) - ann_c['low_freq'] = np.round(da_c.iloc[aa]['LowFreq'] * 1000.0, 2) - ann_c['high_freq'] = np.round(da_c.iloc[aa]['HiFreq'] * 1000.0, 2) - ann_c['det_prob'] = np.round(da_c.iloc[aa]['Quality'], 3) - res_c['annotation'].append(ann_c) + ann_c["class"] = set_class_name + ann_c["start_time"] = np.round( + da_c.iloc[aa]["TimeInFile"] / 1000.0, 5 + ) + ann_c["end_time"] = np.round( + ann_c["start_time"] + da_c.iloc[aa]["CallDuration"] / 1000.0, 5 + ) + ann_c["low_freq"] = np.round(da_c.iloc[aa]["LowFreq"] * 1000.0, 2) + ann_c["high_freq"] = np.round(da_c.iloc[aa]["HiFreq"] * 1000.0, 2) + ann_c["det_prob"] = np.round(da_c.iloc[aa]["Quality"], 3) + res_c["annotation"].append(ann_c) return res_c @@ -193,8 +224,18 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None): def bb_overlap(bb_g_in, bb_p_in): freq_scale = 10000000.0 # ensure that both axis are roughly the same range - bb_g = [bb_g_in['start_time'], bb_g_in['low_freq']/freq_scale, bb_g_in['end_time'], bb_g_in['high_freq']/freq_scale] - bb_p = [bb_p_in['start_time'], bb_p_in['low_freq']/freq_scale, bb_p_in['end_time'], bb_p_in['high_freq']/freq_scale] + bb_g = [ + bb_g_in["start_time"], + bb_g_in["low_freq"] / freq_scale, + bb_g_in["end_time"], + bb_g_in["high_freq"] / freq_scale, + ] + bb_p = [ + bb_p_in["start_time"], + bb_p_in["low_freq"] / freq_scale, + bb_p_in["end_time"], + bb_p_in["high_freq"] / freq_scale, + ] xA = max(bb_g[0], bb_p[0]) yA = max(bb_g[1], bb_p[1]) @@ -220,13 +261,15 @@ def bb_overlap(bb_g_in, bb_p_in): def assign_to_gt(gt, pred, iou_thresh): # this will edit pred in place - num_preds = len(pred['annotation']) - num_gts = len(gt['annotation']) + num_preds = len(pred["annotation"]) + num_gts = len(gt["annotation"]) if num_preds > 0 and num_gts > 0: iou_m = np.zeros((num_preds, num_gts)) for ii in range(num_preds): for jj in range(num_gts): - iou_m[ii, jj] = bb_overlap(gt['annotation'][jj], pred['annotation'][ii]) + iou_m[ii, jj] = bb_overlap( + gt["annotation"][jj], pred["annotation"][ii] + ) # greedily assign detections to ground truths # needs to be greater than some threshold and we cannot assign GT @@ -235,7 +278,9 @@ def assign_to_gt(gt, pred, iou_thresh): for jj in range(num_gts): max_iou = np.argmax(iou_m[:, jj]) if iou_m[max_iou, jj] > iou_thresh: - pred['annotation'][max_iou]['class'] = gt['annotation'][jj]['class'] + pred["annotation"][max_iou]["class"] = gt["annotation"][jj][ + "class" + ] iou_m[max_iou, :] = -1.0 return pred @@ -244,27 +289,39 @@ def assign_to_gt(gt, pred, iou_thresh): def parse_data(data, class_names, non_event_classes, is_pred=False): class_names_all = class_names + non_event_classes - data['class_names'] = np.array([aa['class'] for aa in data['annotation']]) - data['start_times'] = np.array([aa['start_time'] for aa in data['annotation']]) - data['end_times'] = np.array([aa['end_time'] for aa in data['annotation']]) - data['high_freqs'] = np.array([float(aa['high_freq']) for aa in data['annotation']]) - data['low_freqs'] = np.array([float(aa['low_freq']) for aa in data['annotation']]) + data["class_names"] = np.array([aa["class"] for aa in data["annotation"]]) + data["start_times"] = np.array( + [aa["start_time"] for aa in data["annotation"]] + ) + data["end_times"] = np.array([aa["end_time"] for aa in data["annotation"]]) + data["high_freqs"] = np.array( + [float(aa["high_freq"]) for aa in data["annotation"]] + ) + data["low_freqs"] = np.array( + [float(aa["low_freq"]) for aa in data["annotation"]] + ) if is_pred: # when loading predictions - data['det_probs'] = np.array([float(aa['det_prob']) for aa in data['annotation']]) - data['class_probs'] = np.zeros((len(class_names)+1, len(data['annotation']))) - data['class_ids'] = np.array([class_names_all.index(aa['class']) for aa in data['annotation']]).astype(np.int32) + data["det_probs"] = np.array( + [float(aa["det_prob"]) for aa in data["annotation"]] + ) + data["class_probs"] = np.zeros( + (len(class_names) + 1, len(data["annotation"])) + ) + data["class_ids"] = np.array( + [class_names_all.index(aa["class"]) for aa in data["annotation"]] + ).astype(np.int32) else: # when loading ground truth # if the class label is not in the set of interest then set to -1 labels = [] - for aa in data['annotation']: - if aa['class'] in class_names: - labels.append(class_names_all.index(aa['class'])) + for aa in data["annotation"]: + if aa["class"] in class_names: + labels.append(class_names_all.index(aa["class"])) else: labels.append(-1) - data['class_ids'] = np.array(labels).astype(np.int32) + data["class_ids"] = np.array(labels).astype(np.int32) return data @@ -272,12 +329,17 @@ def parse_data(data, class_names, non_event_classes, is_pred=False): def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore): gt_data = [] for dd in datasets: - print('\n' + dd['dataset_name']) - gt_dataset = tu.load_set_of_anns([dd], events_of_interest=events_of_interest, verbose=True) - gt_dataset = [parse_data(gg, class_names, classes_to_ignore, False) for gg in gt_dataset] + print("\n" + dd["dataset_name"]) + gt_dataset = tu.load_set_of_anns( + [dd], events_of_interest=events_of_interest, verbose=True + ) + gt_dataset = [ + parse_data(gg, class_names, classes_to_ignore, False) + for gg in gt_dataset + ] for gt in gt_dataset: - gt['dataset_name'] = dd['dataset_name'] + gt["dataset_name"] = dd["dataset_name"] gt_data.extend(gt_dataset) @@ -300,69 +362,103 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001): clf = RandomForestClassifier(random_state=seed, n_jobs=-1) clf.fit(x_train, y_train) y_pred = clf.predict(x_train) - tr_acc = (y_pred==y_train).mean() - #print('Train acc', round(tr_acc*100, 2)) + tr_acc = (y_pred == y_train).mean() + # print('Train acc', round(tr_acc*100, 2)) return clf, un_train_class def eval_rf_model(clf, pred, un_train_class, num_classes): # stores the prediction in place - if pred['feats'].shape[0] > 0: - pred['class_probs'] = np.zeros((num_classes, pred['feats'].shape[0])) - pred['class_probs'][un_train_class, :] = clf.predict_proba(pred['feats']).T - pred['det_probs'] = pred['class_probs'][:-1, :].sum(0) + if pred["feats"].shape[0] > 0: + pred["class_probs"] = np.zeros((num_classes, pred["feats"].shape[0])) + pred["class_probs"][un_train_class, :] = clf.predict_proba( + pred["feats"] + ).T + pred["det_probs"] = pred["class_probs"][:-1, :].sum(0) else: - pred['class_probs'] = np.zeros((num_classes, 0)) - pred['det_probs'] = np.zeros(0) + pred["class_probs"] = np.zeros((num_classes, 0)) + pred["det_probs"] = np.zeros(0) return pred def save_summary_to_json(op_dir, mod_name, results): op = {} - op['avg_prec'] = round(results['avg_prec'], 3) - op['avg_prec_class'] = round(results['avg_prec_class'], 3) - op['top_class'] = round(results['top_class']['avg_prec'], 3) - op['file_acc'] = round(results['file_acc'], 3) - op['model'] = mod_name + op["avg_prec"] = round(results["avg_prec"], 3) + op["avg_prec_class"] = round(results["avg_prec_class"], 3) + op["top_class"] = round(results["top_class"]["avg_prec"], 3) + op["file_acc"] = round(results["file_acc"], 3) + op["model"] = mod_name - op['per_class'] = {} - for cc in results['class_pr']: - op['per_class'][cc['name']] = cc['avg_prec'] + op["per_class"] = {} + for cc in results["class_pr"]: + op["per_class"][cc["name"]] = cc["avg_prec"] - op_file_name = os.path.join(op_dir, mod_name + '_results.json') - with open(op_file_name, 'w') as da: + op_file_name = os.path.join(op_dir, mod_name + "_results.json") + with open(op_file_name, "w") as da: json.dump(op, da, indent=2) -def print_results(model_name, mod_str, results, op_dir, class_names, file_type, title_text=''): - print('\nResults - ' + model_name) - print('avg_prec ', round(results['avg_prec'], 3)) - print('avg_prec_class', round(results['avg_prec_class'], 3)) - print('top_class ', round(results['top_class']['avg_prec'], 3)) - print('file_acc ', round(results['file_acc'], 3)) +def print_results( + model_name, mod_str, results, op_dir, class_names, file_type, title_text="" +): + print("\nResults - " + model_name) + print("avg_prec ", round(results["avg_prec"], 3)) + print("avg_prec_class", round(results["avg_prec_class"], 3)) + print("top_class ", round(results["top_class"]["avg_prec"], 3)) + print("file_acc ", round(results["file_acc"], 3)) - print('\nSaving ' + model_name + ' results to: ' + op_dir) + print("\nSaving " + model_name + " results to: " + op_dir) save_summary_to_json(op_dir, mod_str, results) - pu.plot_pr_curve(op_dir, mod_str+'_test_all_det', mod_str+'_test_all_det', results, file_type, title_text + 'Detection PR') - pu.plot_pr_curve(op_dir, mod_str+'_test_all_top_class', mod_str+'_test_all_top_class', results['top_class'], file_type, title_text + 'Top Class') - pu.plot_pr_curve_class(op_dir, mod_str+'_test_all_class', mod_str+'_test_all_class', results, file_type, title_text + 'Per-Class PR') - pu.plot_confusion_matrix(op_dir, mod_str+'_confusion', results['gt_valid_file'], results['pred_valid_file'], - results['file_acc'], class_names, True, file_type, title_text + 'Confusion Matrix') + pu.plot_pr_curve( + op_dir, + mod_str + "_test_all_det", + mod_str + "_test_all_det", + results, + file_type, + title_text + "Detection PR", + ) + pu.plot_pr_curve( + op_dir, + mod_str + "_test_all_top_class", + mod_str + "_test_all_top_class", + results["top_class"], + file_type, + title_text + "Top Class", + ) + pu.plot_pr_curve_class( + op_dir, + mod_str + "_test_all_class", + mod_str + "_test_all_class", + results, + file_type, + title_text + "Per-Class PR", + ) + pu.plot_confusion_matrix( + op_dir, + mod_str + "_confusion", + results["gt_valid_file"], + results["pred_valid_file"], + results["file_acc"], + class_names, + True, + file_type, + title_text + "Confusion Matrix", + ) def add_root_path_back(data_sets, ann_path, wav_path): for dd in data_sets: - dd['ann_path'] = os.path.join(ann_path, dd['ann_path']) - dd['wav_path'] = os.path.join(wav_path, dd['wav_path']) + dd["ann_path"] = os.path.join(ann_path, dd["ann_path"]) + dd["wav_path"] = os.path.join(wav_path, dd["wav_path"]) return data_sets def check_classes_in_train(gt_list, class_names): - num_gt_total = np.sum([gg['start_times'].shape[0] for gg in gt_list]) + num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list]) num_with_no_class = 0 for gt in gt_list: - for cc in gt['class_names']: + for cc in gt["class_names"]: if cc not in class_names: num_with_no_class += 1 return num_with_no_class @@ -371,195 +467,335 @@ def check_classes_in_train(gt_list, class_names): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('op_dir', type=str, default='plots/results_compare/', - help='Output directory for plots') - parser.add_argument('data_dir', type=str, - help='Path to root of datasets') - parser.add_argument('ann_dir', type=str, - help='Path to extracted annotations') - parser.add_argument('bd_model_path', type=str, - help='Path to BatDetect model') - parser.add_argument('--test_file', type=str, default='', - help='Path to json file used for evaluation.') - parser.add_argument('--sb_ip_dir', type=str, default='', - help='Path to sonobat predictions') - parser.add_argument('--sb_region_classifier', type=str, default='south', - help='Path to sonobat predictions') - parser.add_argument('--td_ip_dir', type=str, default='', - help='Path to tadarida_D predictions') - parser.add_argument('--iou_thresh', type=float, default=0.01, - help='IOU threshold for assigning predictions to ground truth') - parser.add_argument('--file_type', type=str, default='png', - help='Type of image to save - png or pdf') - parser.add_argument('--title_text', type=str, default='', - help='Text to add as title of plots') - parser.add_argument('--rand_seed', type=int, default=2001, - help='Random seed') + parser.add_argument( + "op_dir", + type=str, + default="plots/results_compare/", + help="Output directory for plots", + ) + parser.add_argument("data_dir", type=str, help="Path to root of datasets") + parser.add_argument( + "ann_dir", type=str, help="Path to extracted annotations" + ) + parser.add_argument( + "bd_model_path", type=str, help="Path to BatDetect model" + ) + parser.add_argument( + "--test_file", + type=str, + default="", + help="Path to json file used for evaluation.", + ) + parser.add_argument( + "--sb_ip_dir", type=str, default="", help="Path to sonobat predictions" + ) + parser.add_argument( + "--sb_region_classifier", + type=str, + default="south", + help="Path to sonobat predictions", + ) + parser.add_argument( + "--td_ip_dir", + type=str, + default="", + help="Path to tadarida_D predictions", + ) + parser.add_argument( + "--iou_thresh", + type=float, + default=0.01, + help="IOU threshold for assigning predictions to ground truth", + ) + parser.add_argument( + "--file_type", + type=str, + default="png", + help="Type of image to save - png or pdf", + ) + parser.add_argument( + "--title_text", + type=str, + default="", + help="Text to add as title of plots", + ) + parser.add_argument( + "--rand_seed", type=int, default=2001, help="Random seed" + ) args = vars(parser.parse_args()) - np.random.seed(args['rand_seed']) - - if not os.path.isdir(args['op_dir']): - os.makedirs(args['op_dir']) + np.random.seed(args["rand_seed"]) + if not os.path.isdir(args["op_dir"]): + os.makedirs(args["op_dir"]) # load the model params_eval = parameters.get_params(False) - _, params_bd = du.load_model(args['bd_model_path']) + _, params_bd = du.load_model(args["bd_model_path"]) - class_names = params_bd['class_names'] + class_names = params_bd["class_names"] num_classes = len(class_names) + 1 # num classes plus background class - classes_to_ignore = ['Not Bat', 'Bat', 'Unknown'] - events_of_interest = ['Echolocation'] + classes_to_ignore = ["Not Bat", "Bat", "Unknown"] + events_of_interest = ["Echolocation"] # load test data - if args['test_file'] == '': + if args["test_file"] == "": # load the test files of interest from the trained model - test_sets = add_root_path_back(params_bd['test_sets'], args['ann_dir'], args['data_dir']) - test_sets = [dd for dd in test_sets if not dd['is_binary']] # exclude bat/not datasets + test_sets = add_root_path_back( + params_bd["test_sets"], args["ann_dir"], args["data_dir"] + ) + test_sets = [ + dd for dd in test_sets if not dd["is_binary"] + ] # exclude bat/not datasets else: # user specified annotation file to evaluate test_dict = {} - test_dict['dataset_name'] = args['test_file'].replace('.json', '') - test_dict['is_test'] = True - test_dict['is_binary'] = True - test_dict['ann_path'] = os.path.join(args['ann_dir'], args['test_file']) - test_dict['wav_path'] = args['data_dir'] + test_dict["dataset_name"] = args["test_file"].replace(".json", "") + test_dict["is_test"] = True + test_dict["is_binary"] = True + test_dict["ann_path"] = os.path.join( + args["ann_dir"], args["test_file"] + ) + test_dict["wav_path"] = args["data_dir"] test_sets = [test_dict] # load the gt for the test set - gt_test = load_gt_data(test_sets, events_of_interest, class_names, classes_to_ignore) - total_num_calls = np.sum([gg['start_times'].shape[0] for gg in gt_test]) - print('\nTotal number of test files:', len(gt_test)) - print('Total number of test calls:', np.sum([gg['start_times'].shape[0] for gg in gt_test])) + gt_test = load_gt_data( + test_sets, events_of_interest, class_names, classes_to_ignore + ) + total_num_calls = np.sum([gg["start_times"].shape[0] for gg in gt_test]) + print("\nTotal number of test files:", len(gt_test)) + print( + "Total number of test calls:", + np.sum([gg["start_times"].shape[0] for gg in gt_test]), + ) # check if test contains classes not in the train set num_with_no_class = check_classes_in_train(gt_test, class_names) if total_num_calls == num_with_no_class: - print('Classes from the test set are not in the train set.') + print("Classes from the test set are not in the train set.") assert False # only need the train data if evaluating Sonobat or Tadarida - if args['sb_ip_dir'] != '' or args['td_ip_dir'] != '': - train_sets = add_root_path_back(params_bd['train_sets'], args['ann_dir'], args['data_dir']) - train_sets = [dd for dd in train_sets if not dd['is_binary']] # exclude bat/not datasets - gt_train = load_gt_data(train_sets, events_of_interest, class_names, classes_to_ignore) - + if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "": + train_sets = add_root_path_back( + params_bd["train_sets"], args["ann_dir"], args["data_dir"] + ) + train_sets = [ + dd for dd in train_sets if not dd["is_binary"] + ] # exclude bat/not datasets + gt_train = load_gt_data( + train_sets, events_of_interest, class_names, classes_to_ignore + ) # # evaluate Sonobat by training random forest classifier # # NOTE: Sonobat may only make predictions for a subset of the files # - if args['sb_ip_dir'] != '': - sb_meta = load_sonobat_meta(args['sb_ip_dir'], train_sets + test_sets, args['sb_region_classifier'], class_names) + if args["sb_ip_dir"] != "": + sb_meta = load_sonobat_meta( + args["sb_ip_dir"], + train_sets + test_sets, + args["sb_region_classifier"], + class_names, + ) preds_sb = [] keep_inds_sb = [] for ii, gt in enumerate(gt_test): - sb_pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta) - if sb_pred['class_name'] != '': - sb_pred = parse_data(sb_pred, class_names, classes_to_ignore, True) - sb_pred['class_probs'][sb_pred['class_ids'], np.arange(sb_pred['class_probs'].shape[1])] = sb_pred['det_probs'] + sb_pred = load_sonobat_preds(gt["dataset_name"], gt["id"], sb_meta) + if sb_pred["class_name"] != "": + sb_pred = parse_data( + sb_pred, class_names, classes_to_ignore, True + ) + sb_pred["class_probs"][ + sb_pred["class_ids"], + np.arange(sb_pred["class_probs"].shape[1]), + ] = sb_pred["det_probs"] preds_sb.append(sb_pred) keep_inds_sb.append(ii) - results_sb = evl.evaluate_predictions([gt_test[ii] for ii in keep_inds_sb], preds_sb, class_names, - params_eval['detection_overlap'], params_eval['ignore_start_end']) - print_results('Sonobat', 'sb', results_sb, args['op_dir'], class_names, - args['file_type'], args['title_text'] + ' - Species - ') - print('Only reporting results for', len(keep_inds_sb), 'files, out of', len(gt_test)) - + results_sb = evl.evaluate_predictions( + [gt_test[ii] for ii in keep_inds_sb], + preds_sb, + class_names, + params_eval["detection_overlap"], + params_eval["ignore_start_end"], + ) + print_results( + "Sonobat", + "sb", + results_sb, + args["op_dir"], + class_names, + args["file_type"], + args["title_text"] + " - Species - ", + ) + print( + "Only reporting results for", + len(keep_inds_sb), + "files, out of", + len(gt_test), + ) # train our own random forest on sonobat features x_train = [] y_train = [] for gt in gt_train: - pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta, 'Not Bat') + pred = load_sonobat_preds( + gt["dataset_name"], gt["id"], sb_meta, "Not Bat" + ) - if len(pred['annotation']) > 0: + if len(pred["annotation"]) > 0: # compute detection overlap with ground truth to determine which are the TP detections - assign_to_gt(gt, pred, args['iou_thresh']) + assign_to_gt(gt, pred, args["iou_thresh"]) pred = parse_data(pred, class_names, classes_to_ignore, True) - x_train.append(pred['feats']) - y_train.append(pred['class_ids']) + x_train.append(pred["feats"]) + y_train.append(pred["class_ids"]) # train random forest on tadarida predictions - clf_sb, un_train_class = train_rf_model(x_train, y_train, num_classes, args['rand_seed']) + clf_sb, un_train_class = train_rf_model( + x_train, y_train, num_classes, args["rand_seed"] + ) # run the model on the test set preds_sb_rf = [] for gt in gt_test: - pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta, 'Not Bat') + pred = load_sonobat_preds( + gt["dataset_name"], gt["id"], sb_meta, "Not Bat" + ) pred = parse_data(pred, class_names, classes_to_ignore, True) pred = eval_rf_model(clf_sb, pred, un_train_class, num_classes) preds_sb_rf.append(pred) - results_sb_rf = evl.evaluate_predictions(gt_test, preds_sb_rf, class_names, - params_eval['detection_overlap'], params_eval['ignore_start_end']) - print_results('Sonobat RF', 'sb_rf', results_sb_rf, args['op_dir'], class_names, - args['file_type'], args['title_text'] + ' - Species - ') - print('\n\nWARNING\nThis is evaluating on the full test set, but there is only dections for a subset of files\n\n') - + results_sb_rf = evl.evaluate_predictions( + gt_test, + preds_sb_rf, + class_names, + params_eval["detection_overlap"], + params_eval["ignore_start_end"], + ) + print_results( + "Sonobat RF", + "sb_rf", + results_sb_rf, + args["op_dir"], + class_names, + args["file_type"], + args["title_text"] + " - Species - ", + ) + print( + "\n\nWARNING\nThis is evaluating on the full test set, but there is only dections for a subset of files\n\n" + ) # # evaluate Tadarida-D by training random forest classifier # - if args['td_ip_dir'] != '': + if args["td_ip_dir"] != "": x_train = [] y_train = [] for gt in gt_train: - pred = load_tadarida_pred(args['td_ip_dir'], gt['dataset_name'], gt['id']) + pred = load_tadarida_pred( + args["td_ip_dir"], gt["dataset_name"], gt["id"] + ) # compute detection overlap with ground truth to determine which are the TP detections - assign_to_gt(gt, pred, args['iou_thresh']) + assign_to_gt(gt, pred, args["iou_thresh"]) pred = parse_data(pred, class_names, classes_to_ignore, True) - x_train.append(pred['feats']) - y_train.append(pred['class_ids']) + x_train.append(pred["feats"]) + y_train.append(pred["class_ids"]) # train random forest on Tadarida-D predictions - clf_td, un_train_class = train_rf_model(x_train, y_train, num_classes, args['rand_seed']) + clf_td, un_train_class = train_rf_model( + x_train, y_train, num_classes, args["rand_seed"] + ) # run the model on the test set preds_td = [] for gt in gt_test: - pred = load_tadarida_pred(args['td_ip_dir'], gt['dataset_name'], gt['id']) + pred = load_tadarida_pred( + args["td_ip_dir"], gt["dataset_name"], gt["id"] + ) pred = parse_data(pred, class_names, classes_to_ignore, True) pred = eval_rf_model(clf_td, pred, un_train_class, num_classes) preds_td.append(pred) - results_td = evl.evaluate_predictions(gt_test, preds_td, class_names, - params_eval['detection_overlap'], params_eval['ignore_start_end']) - print_results('Tadarida', 'td_rf', results_td, args['op_dir'], class_names, - args['file_type'], args['title_text'] + ' - Species - ') - + results_td = evl.evaluate_predictions( + gt_test, + preds_td, + class_names, + params_eval["detection_overlap"], + params_eval["ignore_start_end"], + ) + print_results( + "Tadarida", + "td_rf", + results_td, + args["op_dir"], + class_names, + args["file_type"], + args["title_text"] + " - Species - ", + ) # # evaluate BatDetect # - if args['bd_model_path'] != '': + if args["bd_model_path"] != "": # load model bd_args = du.get_default_bd_args() - model, params_bd = du.load_model(args['bd_model_path']) + model, params_bd = du.load_model(args["bd_model_path"]) # check if the class names are the same - if params_bd['class_names'] != class_names: - print('Warning: Class names are not the same as the trained model') + if params_bd["class_names"] != class_names: + print("Warning: Class names are not the same as the trained model") assert False preds_bd = [] for ii, gg in enumerate(gt_test): - pred = du.process_file(gg['file_path'], model, params_bd, bd_args, return_raw_preds=True) + pred = du.process_file( + gg["file_path"], + model, + params_bd, + bd_args, + return_raw_preds=True, + ) preds_bd.append(pred) - results_bd = evl.evaluate_predictions(gt_test, preds_bd, class_names, - params_eval['detection_overlap'], params_eval['ignore_start_end']) - print_results('BatDetect', 'bd', results_bd, args['op_dir'], - class_names, args['file_type'], args['title_text'] + ' - Species - ') + results_bd = evl.evaluate_predictions( + gt_test, + preds_bd, + class_names, + params_eval["detection_overlap"], + params_eval["ignore_start_end"], + ) + print_results( + "BatDetect", + "bd", + results_bd, + args["op_dir"], + class_names, + args["file_type"], + args["title_text"] + " - Species - ", + ) # evaluate genus level - class_names_genus, preds_bd_g, gt_test_g = create_genus_mapping(gt_test, preds_bd, class_names) - results_bd_genus = evl.evaluate_predictions(gt_test_g, preds_bd_g, class_names_genus, - params_eval['detection_overlap'], params_eval['ignore_start_end']) - print_results('BatDetect Genus', 'bd_genus', results_bd_genus, args['op_dir'], - class_names_genus, args['file_type'], args['title_text'] + ' - Genus - ') + class_names_genus, preds_bd_g, gt_test_g = create_genus_mapping( + gt_test, preds_bd, class_names + ) + results_bd_genus = evl.evaluate_predictions( + gt_test_g, + preds_bd_g, + class_names_genus, + params_eval["detection_overlap"], + params_eval["ignore_start_end"], + ) + print_results( + "BatDetect Genus", + "bd_genus", + results_bd_genus, + args["op_dir"], + class_names_genus, + args["file_type"], + args["title_text"] + " - Genus - ", + ) diff --git a/bat_detect/finetune/finetune_model.py b/bat_detect/finetune/finetune_model.py index 4fecc48..8c20e22 100644 --- a/bat_detect/finetune/finetune_model.py +++ b/bat_detect/finetune/finetune_model.py @@ -1,183 +1,325 @@ -import numpy as np -import matplotlib.pyplot as plt +import argparse +import glob +import json import os +import sys + +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn.functional as F from torch.optim.lr_scheduler import CosineAnnealingLR -import json -import argparse -import glob -import sys -sys.path.append(os.path.join('..', '..')) -import bat_detect.train.train_model as tm +sys.path.append(os.path.join("..", "..")) +import bat_detect.detector.models as models +import bat_detect.detector.parameters as parameters +import bat_detect.detector.post_process as pp import bat_detect.train.audio_dataloader as adl import bat_detect.train.evaluate as evl -import bat_detect.train.train_utils as tu import bat_detect.train.losses as losses - -import bat_detect.detector.parameters as parameters -import bat_detect.detector.models as models -import bat_detect.detector.post_process as pp -import bat_detect.utils.plot_utils as pu +import bat_detect.train.train_model as tm +import bat_detect.train.train_utils as tu import bat_detect.utils.detector_utils as du - +import bat_detect.utils.plot_utils as pu if __name__ == "__main__": - info_str = '\nBatDetect - Finetune Model\n' + info_str = "\nBatDetect - Finetune Model\n" print(info_str) parser = argparse.ArgumentParser() - parser.add_argument('audio_path', type=str, help='Input directory for audio') - parser.add_argument('train_ann_path', type=str, - help='Path to where train annotation file is stored') - parser.add_argument('test_ann_path', type=str, - help='Path to where test annotation file is stored') - parser.add_argument('model_path', type=str, - help='Path to pretrained model') - parser.add_argument('--op_model_name', type=str, default='', - help='Path and name for finetuned model') - parser.add_argument('--num_epochs', type=int, default=200, dest='num_epochs', - help='Number of finetuning epochs') - parser.add_argument('--finetune_only_last_layer', action='store_true', - help='Only train final layers') - parser.add_argument('--train_from_scratch', action='store_true', - help='Do not use pretrained weights') - parser.add_argument('--do_not_save_images', action='store_false', - help='Do not save images at the end of training') - parser.add_argument('--notes', type=str, default='', - help='Notes to save in text file') + parser.add_argument( + "audio_path", type=str, help="Input directory for audio" + ) + parser.add_argument( + "train_ann_path", + type=str, + help="Path to where train annotation file is stored", + ) + parser.add_argument( + "test_ann_path", + type=str, + help="Path to where test annotation file is stored", + ) + parser.add_argument( + "model_path", type=str, help="Path to pretrained model" + ) + parser.add_argument( + "--op_model_name", + type=str, + default="", + help="Path and name for finetuned model", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=200, + dest="num_epochs", + help="Number of finetuning epochs", + ) + parser.add_argument( + "--finetune_only_last_layer", + action="store_true", + help="Only train final layers", + ) + parser.add_argument( + "--train_from_scratch", + action="store_true", + help="Do not use pretrained weights", + ) + parser.add_argument( + "--do_not_save_images", + action="store_false", + help="Do not save images at the end of training", + ) + parser.add_argument( + "--notes", type=str, default="", help="Notes to save in text file" + ) args = vars(parser.parse_args()) - params = parameters.get_params(True, '../../experiments/') + params = parameters.get_params(True, "../../experiments/") if torch.cuda.is_available(): - params['device'] = 'cuda' + params["device"] = "cuda" else: - params['device'] = 'cpu' - print('\nNote, this will be a lot faster if you use computer with a GPU.\n') + params["device"] = "cpu" + print( + "\nNote, this will be a lot faster if you use computer with a GPU.\n" + ) - print('\nAudio directory: ' + args['audio_path']) - print('Train file: ' + args['train_ann_path']) - print('Test file: ' + args['test_ann_path']) - print('Loading model: ' + args['model_path']) + print("\nAudio directory: " + args["audio_path"]) + print("Train file: " + args["train_ann_path"]) + print("Test file: " + args["test_ann_path"]) + print("Loading model: " + args["model_path"]) - dataset_name = os.path.basename(args['train_ann_path']).replace('.json', '').replace('_TRAIN', '') + dataset_name = ( + os.path.basename(args["train_ann_path"]) + .replace(".json", "") + .replace("_TRAIN", "") + ) - if args['train_from_scratch']: - print('\nTraining model from scratch i.e. not using pretrained weights') - model, params_train = du.load_model(args['model_path'], False) + if args["train_from_scratch"]: + print( + "\nTraining model from scratch i.e. not using pretrained weights" + ) + model, params_train = du.load_model(args["model_path"], False) else: - model, params_train = du.load_model(args['model_path'], True) - model.to(params['device']) + model, params_train = du.load_model(args["model_path"], True) + model.to(params["device"]) - params['num_epochs'] = args['num_epochs'] - if args['op_model_name'] != '': - params['model_file_name'] = args['op_model_name'] - classes_to_ignore = params['classes_to_ignore']+params['generic_class'] + params["num_epochs"] = args["num_epochs"] + if args["op_model_name"] != "": + params["model_file_name"] = args["op_model_name"] + classes_to_ignore = params["classes_to_ignore"] + params["generic_class"] # save notes file - params['notes'] = args['notes'] - if args['notes'] != '': - tu.write_notes_file(params['experiment'] + 'notes.txt', args['notes']) - + params["notes"] = args["notes"] + if args["notes"] != "": + tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"]) # load train annotations train_sets = [] - train_sets.append(tu.get_blank_dataset_dict(dataset_name, False, args['train_ann_path'], args['audio_path'])) - params['train_sets'] = [tu.get_blank_dataset_dict(dataset_name, False, os.path.basename(args['train_ann_path']), args['audio_path'])] + train_sets.append( + tu.get_blank_dataset_dict( + dataset_name, False, args["train_ann_path"], args["audio_path"] + ) + ) + params["train_sets"] = [ + tu.get_blank_dataset_dict( + dataset_name, + False, + os.path.basename(args["train_ann_path"]), + args["audio_path"], + ) + ] - print('\nTrain set:') - data_train, params['class_names'], params['class_inv_freq'] = \ - tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest']) - print('Number of files', len(data_train)) + print("\nTrain set:") + ( + data_train, + params["class_names"], + params["class_inv_freq"], + ) = tu.load_set_of_anns( + train_sets, classes_to_ignore, params["events_of_interest"] + ) + print("Number of files", len(data_train)) - params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names']) - params['class_names_short'] = tu.get_short_class_names(params['class_names']) + params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping( + params["class_names"] + ) + params["class_names_short"] = tu.get_short_class_names( + params["class_names"] + ) # load test annotations test_sets = [] - test_sets.append(tu.get_blank_dataset_dict(dataset_name, True, args['test_ann_path'], args['audio_path'])) - params['test_sets'] = [tu.get_blank_dataset_dict(dataset_name, True, os.path.basename(args['test_ann_path']), args['audio_path'])] + test_sets.append( + tu.get_blank_dataset_dict( + dataset_name, True, args["test_ann_path"], args["audio_path"] + ) + ) + params["test_sets"] = [ + tu.get_blank_dataset_dict( + dataset_name, + True, + os.path.basename(args["test_ann_path"]), + args["audio_path"], + ) + ] - print('\nTest set:') - data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest']) - print('Number of files', len(data_test)) + print("\nTest set:") + data_test, _, _ = tu.load_set_of_anns( + test_sets, classes_to_ignore, params["events_of_interest"] + ) + print("Number of files", len(data_test)) # train loader train_dataset = adl.AudioLoader(data_train, params, is_train=True) - train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'], - shuffle=True, num_workers=params['num_workers'], pin_memory=True) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=params["batch_size"], + shuffle=True, + num_workers=params["num_workers"], + pin_memory=True, + ) # test loader - batch size of one because of variable file length test_dataset = adl.AudioLoader(data_test, params, is_train=False) - test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, - shuffle=False, num_workers=params['num_workers'], pin_memory=True) + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=1, + shuffle=False, + num_workers=params["num_workers"], + pin_memory=True, + ) inputs_train = next(iter(train_loader)) - params['ip_height'] = inputs_train['spec'].shape[2] - print('\ntrain batch size :', inputs_train['spec'].shape) + params["ip_height"] = inputs_train["spec"].shape[2] + print("\ntrain batch size :", inputs_train["spec"].shape) - assert(params_train['model_name'] == 'Net2DFast') - print('\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n') + assert params_train["model_name"] == "Net2DFast" + print( + "\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n" + ) # set the number of output classes num_filts = model.conv_classes_op.in_channels k_size = model.conv_classes_op.kernel_size pad = model.conv_classes_op.padding - model.conv_classes_op = torch.nn.Conv2d(num_filts, len(params['class_names'])+1, kernel_size=k_size, padding=pad) - model.conv_classes_op.to(params['device']) + model.conv_classes_op = torch.nn.Conv2d( + num_filts, + len(params["class_names"]) + 1, + kernel_size=k_size, + padding=pad, + ) + model.conv_classes_op.to(params["device"]) - if args['finetune_only_last_layer']: - print('\nOnly finetuning the final layers.\n') - train_layers_i = ['conv_classes', 'conv_classes_op', 'conv_size', 'conv_size_op'] - train_layers = [tt + '.weight' for tt in train_layers_i] + [tt + '.bias' for tt in train_layers_i] + if args["finetune_only_last_layer"]: + print("\nOnly finetuning the final layers.\n") + train_layers_i = [ + "conv_classes", + "conv_classes_op", + "conv_size", + "conv_size_op", + ] + train_layers = [tt + ".weight" for tt in train_layers_i] + [ + tt + ".bias" for tt in train_layers_i + ] for name, param in model.named_parameters(): if name in train_layers: param.requires_grad = True else: param.requires_grad = False - optimizer = torch.optim.Adam(model.parameters(), lr=params['lr']) - scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader)) - if params['train_loss'] == 'mse': + optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"]) + scheduler = CosineAnnealingLR( + optimizer, params["num_epochs"] * len(train_loader) + ) + if params["train_loss"] == "mse": det_criterion = losses.mse_loss - elif params['train_loss'] == 'focal': + elif params["train_loss"] == "focal": det_criterion = losses.focal_loss # plotting - train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1, - ['train_loss'], None, None, ['epoch', 'train_loss'], logy=True) - test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1, - ['test_loss'], None, None, ['epoch', 'test_loss'], logy=True) - test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1, - ['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', '']) - test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1, - params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec']) + train_plt_ls = pu.LossPlotter( + params["experiment"] + "train_loss.png", + params["num_epochs"] + 1, + ["train_loss"], + None, + None, + ["epoch", "train_loss"], + logy=True, + ) + test_plt_ls = pu.LossPlotter( + params["experiment"] + "test_loss.png", + params["num_epochs"] + 1, + ["test_loss"], + None, + None, + ["epoch", "test_loss"], + logy=True, + ) + test_plt = pu.LossPlotter( + params["experiment"] + "test.png", + params["num_epochs"] + 1, + ["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"], + [0, 1], + None, + ["epoch", ""], + ) + test_plt_class = pu.LossPlotter( + params["experiment"] + "test_avg_prec.png", + params["num_epochs"] + 1, + params["class_names_short"], + [0, 1], + params["class_names_short"], + ["epoch", "avg_prec"], + ) # main train loop - for epoch in range(0, params['num_epochs']+1): + for epoch in range(0, params["num_epochs"] + 1): - train_loss = tm.train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params) - train_plt_ls.update_and_save(epoch, [train_loss['train_loss']]) + train_loss = tm.train( + model, + epoch, + train_loader, + det_criterion, + optimizer, + scheduler, + params, + ) + train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]]) - if epoch % params['num_eval_epochs'] == 0: + if epoch % params["num_eval_epochs"] == 0: # detection accuracy on test set - test_res, test_loss = tm.test(model, epoch, test_loader, det_criterion, params) - test_plt_ls.update_and_save(epoch, [test_loss['test_loss']]) - test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'], - test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']]) - test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']]) - pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res) + test_res, test_loss = tm.test( + model, epoch, test_loader, det_criterion, params + ) + test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]]) + test_plt.update_and_save( + epoch, + [ + test_res["avg_prec"], + test_res["rec_at_x"], + test_res["avg_prec_class"], + test_res["file_acc"], + test_res["top_class"]["avg_prec"], + ], + ) + test_plt_class.update_and_save( + epoch, [rs["avg_prec"] for rs in test_res["class_pr"]] + ) + pu.plot_pr_curve_class( + params["experiment"], "test_pr", "test_pr", test_res + ) # save finetuned model - print('saving model to: ' + params['model_file_name']) - op_state = {'epoch': epoch + 1, - 'state_dict': model.state_dict(), - 'params' : params} - torch.save(op_state, params['model_file_name']) - + print("saving model to: " + params["model_file_name"]) + op_state = { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "params": params, + } + torch.save(op_state, params["model_file_name"]) # save an image with associated prediction for each batch in the test set - if not args['do_not_save_images']: + if not args["do_not_save_images"]: tm.save_images_batch(model, test_loader, params) diff --git a/bat_detect/finetune/prep_data_finetune.py b/bat_detect/finetune/prep_data_finetune.py index 3e86cd4..bf86e97 100644 --- a/bat_detect/finetune/prep_data_finetune.py +++ b/bat_detect/finetune/prep_data_finetune.py @@ -1,32 +1,33 @@ -import numpy as np import argparse -import os import json - +import os import sys -sys.path.append(os.path.join('..', '..')) + +import numpy as np + +sys.path.append(os.path.join("..", "..")) import bat_detect.train.train_utils as tu def print_dataset_stats(data, split_name, classes_to_ignore): - print('\nSplit:', split_name) - print('Num files:', len(data)) + print("\nSplit:", split_name) + print("Num files:", len(data)) class_cnts = {} for dd in data: - for aa in dd['annotation']: - if aa['class'] not in classes_to_ignore: - if aa['class'] in class_cnts: - class_cnts[aa['class']] += 1 + for aa in dd["annotation"]: + if aa["class"] not in classes_to_ignore: + if aa["class"] in class_cnts: + class_cnts[aa["class"]] += 1 else: - class_cnts[aa['class']] = 1 + class_cnts[aa["class"]] = 1 if len(class_cnts) == 0: class_names = [] else: class_names = np.sort([*class_cnts]).tolist() - print('Class count:') + print("Class count:") str_len = np.max([len(cc) for cc in class_names]) + 5 for ii, cc in enumerate(class_names): @@ -41,111 +42,169 @@ def load_file_names(file_name): with open(file_name) as da: files = [line.rstrip() for line in da.readlines()] for ff in files: - if ff.lower()[-3:] != 'wav': - print('Error: Filenames need to end in .wav - ', ff) - assert(False) + if ff.lower()[-3:] != "wav": + print("Error: Filenames need to end in .wav - ", ff) + assert False else: - print('Error: Input file not found - ', file_name) - assert(False) + print("Error: Input file not found - ", file_name) + assert False return files if __name__ == "__main__": - info_str = '\nBatDetect - Prepare Data for Finetuning\n' + info_str = "\nBatDetect - Prepare Data for Finetuning\n" print(info_str) parser = argparse.ArgumentParser() - parser.add_argument('dataset_name', type=str, help='Name to call your dataset') - parser.add_argument('audio_dir', type=str, help='Input directory for audio') - parser.add_argument('ann_dir', type=str, help='Input directory for where the audio annotations are stored') - parser.add_argument('op_dir', type=str, help='Path where the train and test splits will be stored') - parser.add_argument('--percent_val', type=float, default=0.20, - help='Hold out this much data for validation. Should be number between 0 and 1') - parser.add_argument('--rand_seed', type=int, default=2001, - help='Random seed used for creating the validation split') - parser.add_argument('--train_file', type=str, default='', - help='Text file where each line is a wav file in train split') - parser.add_argument('--test_file', type=str, default='', - help='Text file where each line is a wav file in test split') - parser.add_argument('--input_class_names', type=str, default='', - help='Specify names of classes that you want to change. Separate with ";"') - parser.add_argument('--output_class_names', type=str, default='', - help='New class names to use instead. One to one mapping with "--input_class_names". \ - Separate with ";"') + parser.add_argument( + "dataset_name", type=str, help="Name to call your dataset" + ) + parser.add_argument( + "audio_dir", type=str, help="Input directory for audio" + ) + parser.add_argument( + "ann_dir", + type=str, + help="Input directory for where the audio annotations are stored", + ) + parser.add_argument( + "op_dir", + type=str, + help="Path where the train and test splits will be stored", + ) + parser.add_argument( + "--percent_val", + type=float, + default=0.20, + help="Hold out this much data for validation. Should be number between 0 and 1", + ) + parser.add_argument( + "--rand_seed", + type=int, + default=2001, + help="Random seed used for creating the validation split", + ) + parser.add_argument( + "--train_file", + type=str, + default="", + help="Text file where each line is a wav file in train split", + ) + parser.add_argument( + "--test_file", + type=str, + default="", + help="Text file where each line is a wav file in test split", + ) + parser.add_argument( + "--input_class_names", + type=str, + default="", + help='Specify names of classes that you want to change. Separate with ";"', + ) + parser.add_argument( + "--output_class_names", + type=str, + default="", + help='New class names to use instead. One to one mapping with "--input_class_names". \ + Separate with ";"', + ) args = vars(parser.parse_args()) + np.random.seed(args["rand_seed"]) - np.random.seed(args['rand_seed']) + classes_to_ignore = ["", " ", "Unknown", "Not Bat"] + generic_class = ["Bat"] + events_of_interest = ["Echolocation"] - classes_to_ignore = ['', ' ', 'Unknown', 'Not Bat'] - generic_class = ['Bat'] - events_of_interest = ['Echolocation'] - - if args['input_class_names'] != '' and args['output_class_names'] != '': + if args["input_class_names"] != "" and args["output_class_names"] != "": # change the names of the classes - ip_names = args['input_class_names'].split(';') - op_names = args['output_class_names'].split(';') + ip_names = args["input_class_names"].split(";") + op_names = args["output_class_names"].split(";") name_dict = dict(zip(ip_names, op_names)) else: name_dict = False # load annotations - data_all, _, _ = tu.load_set_of_anns({'ann_path': args['ann_dir'], 'wav_path': args['audio_dir']}, - classes_to_ignore, events_of_interest, False, False, - list_of_anns=True, filter_issues=True, name_replace=name_dict) + data_all, _, _ = tu.load_set_of_anns( + {"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]}, + classes_to_ignore, + events_of_interest, + False, + False, + list_of_anns=True, + filter_issues=True, + name_replace=name_dict, + ) - print('Dataset name: ' + args['dataset_name']) - print('Audio directory: ' + args['audio_dir']) - print('Annotation directory: ' + args['ann_dir']) - print('Ouput directory: ' + args['op_dir']) - print('Num annotated files: ' + str(len(data_all))) + print("Dataset name: " + args["dataset_name"]) + print("Audio directory: " + args["audio_dir"]) + print("Annotation directory: " + args["ann_dir"]) + print("Ouput directory: " + args["op_dir"]) + print("Num annotated files: " + str(len(data_all))) - if args['train_file'] != '' and args['test_file'] != '': + if args["train_file"] != "" and args["test_file"] != "": # user has specifed the train / test split - train_files = load_file_names(args['train_file']) - test_files = load_file_names(args['test_file']) - file_names_all = [dd['id'] for dd in data_all] - train_inds = [file_names_all.index(ff) for ff in train_files if ff in file_names_all] - test_inds = [file_names_all.index(ff) for ff in test_files if ff in file_names_all] + train_files = load_file_names(args["train_file"]) + test_files = load_file_names(args["test_file"]) + file_names_all = [dd["id"] for dd in data_all] + train_inds = [ + file_names_all.index(ff) + for ff in train_files + if ff in file_names_all + ] + test_inds = [ + file_names_all.index(ff) + for ff in test_files + if ff in file_names_all + ] else: # split the data into train and test at the file level num_exs = len(data_all) - test_inds = np.random.choice(np.arange(num_exs), int(num_exs*args['percent_val']), replace=False) + test_inds = np.random.choice( + np.arange(num_exs), + int(num_exs * args["percent_val"]), + replace=False, + ) test_inds = np.sort(test_inds) train_inds = np.setdiff1d(np.arange(num_exs), test_inds) data_train = [data_all[ii] for ii in train_inds] data_test = [data_all[ii] for ii in test_inds] - if not os.path.isdir(args['op_dir']): - os.makedirs(args['op_dir']) - op_name = os.path.join(args['op_dir'], args['dataset_name']) - op_name_train = op_name + '_TRAIN.json' - op_name_test = op_name + '_TEST.json' + if not os.path.isdir(args["op_dir"]): + os.makedirs(args["op_dir"]) + op_name = os.path.join(args["op_dir"], args["dataset_name"]) + op_name_train = op_name + "_TRAIN.json" + op_name_test = op_name + "_TEST.json" - class_un_train = print_dataset_stats(data_train, 'Train', classes_to_ignore) - class_un_test = print_dataset_stats(data_test, 'Test', classes_to_ignore) + class_un_train = print_dataset_stats( + data_train, "Train", classes_to_ignore + ) + class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore) if len(data_train) > 0 and len(data_test) > 0: if class_un_train != class_un_test: - print('\nError: some classes are not in both the training and test sets.\ - \nTry a different random seed "--rand_seed".') + print( + '\nError: some classes are not in both the training and test sets.\ + \nTry a different random seed "--rand_seed".' + ) assert False - print('\n') + print("\n") if len(data_train) == 0: - print('No train annotations to save') + print("No train annotations to save") else: - print('Saving: ', op_name_train) - with open(op_name_train, 'w') as da: + print("Saving: ", op_name_train) + with open(op_name_train, "w") as da: json.dump(data_train, da, indent=2) if len(data_test) == 0: - print('No test annotations to save') + print("No test annotations to save") else: - print('Saving: ', op_name_test) - with open(op_name_test, 'w') as da: + print("Saving: ", op_name_test) + with open(op_name_test, "w") as da: json.dump(data_test, da, indent=2) diff --git a/bat_detect/train/audio_dataloader.py b/bat_detect/train/audio_dataloader.py index a36ec0b..ffd8086 100644 --- a/bat_detect/train/audio_dataloader.py +++ b/bat_detect/train/audio_dataloader.py @@ -1,71 +1,101 @@ -import torch -import random -import numpy as np import copy +import os +import random +import sys + import librosa +import numpy as np +import torch import torch.nn.functional as F import torchaudio -import os -import sys -sys.path.append(os.path.join('..', '..')) +sys.path.append(os.path.join("..", "..")) import bat_detect.utils.audio_utils as au def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params): # spec may be resized on input into the network - num_classes = len(params['class_names']) - op_height = spec_op_shape[0] - op_width = spec_op_shape[1] - freq_per_bin = (params['max_freq'] - params['min_freq']) / op_height + num_classes = len(params["class_names"]) + op_height = spec_op_shape[0] + op_width = spec_op_shape[1] + freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height # start and end times - x_pos_start = au.time_to_x_coords(ann['start_times'], sampling_rate, - params['fft_win_length'], params['fft_overlap']) - x_pos_start = (params['resize_factor']*x_pos_start).astype(np.int) - x_pos_end = au.time_to_x_coords(ann['end_times'], sampling_rate, - params['fft_win_length'], params['fft_overlap']) - x_pos_end = (params['resize_factor']*x_pos_end).astype(np.int) + x_pos_start = au.time_to_x_coords( + ann["start_times"], + sampling_rate, + params["fft_win_length"], + params["fft_overlap"], + ) + x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int) + x_pos_end = au.time_to_x_coords( + ann["end_times"], + sampling_rate, + params["fft_win_length"], + params["fft_overlap"], + ) + x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int) # location on y axis i.e. frequency - y_pos_low = (ann['low_freqs'] - params['min_freq']) / freq_per_bin - y_pos_low = (op_height - y_pos_low).astype(np.int) - y_pos_high = (ann['high_freqs'] - params['min_freq']) / freq_per_bin + y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin + y_pos_low = (op_height - y_pos_low).astype(np.int) + y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin y_pos_high = (op_height - y_pos_high).astype(np.int) - bb_widths = x_pos_end - x_pos_start - bb_heights = (y_pos_low - y_pos_high) + bb_widths = x_pos_end - x_pos_start + bb_heights = y_pos_low - y_pos_high - valid_inds = np.where((x_pos_start >= 0) & (x_pos_start < op_width) & - (y_pos_low >= 0) & (y_pos_low < (op_height-1)))[0] + valid_inds = np.where( + (x_pos_start >= 0) + & (x_pos_start < op_width) + & (y_pos_low >= 0) + & (y_pos_low < (op_height - 1)) + )[0] ann_aug = {} - ann_aug['x_inds'] = x_pos_start[valid_inds] - ann_aug['y_inds'] = y_pos_low[valid_inds] - keys = ['start_times', 'end_times', 'high_freqs', 'low_freqs', 'class_ids', 'individual_ids'] + ann_aug["x_inds"] = x_pos_start[valid_inds] + ann_aug["y_inds"] = y_pos_low[valid_inds] + keys = [ + "start_times", + "end_times", + "high_freqs", + "low_freqs", + "class_ids", + "individual_ids", + ] for kk in keys: ann_aug[kk] = ann[kk][valid_inds] # if the number of calls is only 1, then it is unique # TODO would be better if we found these unique calls at the merging stage - if len(ann_aug['individual_ids']) == 1: - ann_aug['individual_ids'][0] = 0 + if len(ann_aug["individual_ids"]) == 1: + ann_aug["individual_ids"][0] = 0 - y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32) + y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32) y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32) # num classes and "background" class - y_2d_classes = np.zeros((num_classes+1, op_height, op_width), dtype=np.float32) + y_2d_classes = np.zeros( + (num_classes + 1, op_height, op_width), dtype=np.float32 + ) # create 2D ground truth heatmaps for ii in valid_inds: - draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma']) - #draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2) + draw_gaussian( + y_2d_det[0, :], + (x_pos_start[ii], y_pos_low[ii]), + params["target_sigma"], + ) + # draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2) y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii] y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii] - cls_id = ann['class_ids'][ii] + cls_id = ann["class_ids"][ii] if cls_id > -1: - draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma']) - #draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2) + draw_gaussian( + y_2d_classes[cls_id, :], + (x_pos_start[ii], y_pos_low[ii]), + params["target_sigma"], + ) + # draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2) # be careful as this will have a 1.0 places where we have event but dont know gt class # this will be masked in training anyway @@ -96,20 +126,24 @@ def draw_gaussian(heatmap, center, sigmax, sigmay=None): x = np.arange(0, size, 1, np.float32) y = x[:, np.newaxis] x0 = y0 = size // 2 - #g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) - g = np.exp(- ((x - x0) ** 2)/(2 * sigmax ** 2) - ((y - y0) ** 2)/(2 * sigmay ** 2)) + # g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) + g = np.exp( + -((x - x0) ** 2) / (2 * sigmax**2) + - ((y - y0) ** 2) / (2 * sigmay**2) + ) g_x = max(0, -ul[0]), min(br[0], h) - ul[0] g_y = max(0, -ul[1]), min(br[1], w) - ul[1] img_x = max(0, ul[0]), min(br[0], h) img_y = max(0, ul[1]), min(br[1], w) - heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum( - heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]], - g[g_y[0]:g_y[1], g_x[0]:g_x[1]]) + heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum( + heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]], + g[g_y[0] : g_y[1], g_x[0] : g_x[1]], + ) return True def pad_aray(ip_array, pad_size): - return np.hstack((ip_array, np.ones(pad_size, dtype=np.int)*-1)) + return np.hstack((ip_array, np.ones(pad_size, dtype=np.int) * -1)) def warp_spec_aug(spec, ann, return_spec_for_viz, params): @@ -121,24 +155,37 @@ def warp_spec_aug(spec, ann, return_spec_for_viz, params): if return_spec_for_viz: assert False - delta = params['stretch_squeeze_delta'] + delta = params["stretch_squeeze_delta"] op_size = (spec.shape[1], spec.shape[2]) - resize_fract_r = np.random.rand()*delta*2 - delta + 1.0 - resize_amt = int(spec.shape[2]*resize_fract_r) + resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0 + resize_amt = int(spec.shape[2] * resize_fract_r) if resize_amt >= spec.shape[2]: - spec_r = torch.cat((spec, torch.zeros((1, spec.shape[1], resize_amt-spec.shape[2]), dtype=spec.dtype)), 2) + spec_r = torch.cat( + ( + spec, + torch.zeros( + (1, spec.shape[1], resize_amt - spec.shape[2]), + dtype=spec.dtype, + ), + ), + 2, + ) else: spec_r = spec[:, :, :resize_amt] - spec = F.interpolate(spec_r.unsqueeze(0), size=op_size, mode='bilinear', align_corners=False).squeeze(0) - ann['start_times'] *= (1.0/resize_fract_r) - ann['end_times'] *= (1.0/resize_fract_r) + spec = F.interpolate( + spec_r.unsqueeze(0), size=op_size, mode="bilinear", align_corners=False + ).squeeze(0) + ann["start_times"] *= 1.0 / resize_fract_r + ann["end_times"] *= 1.0 / resize_fract_r return spec def mask_time_aug(spec, params): # Mask out a random block of time - repeat up to 3 times # SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition - fm = torchaudio.transforms.TimeMasking(int(spec.shape[1]*params['mask_max_time_perc'])) + fm = torchaudio.transforms.TimeMasking( + int(spec.shape[1] * params["mask_max_time_perc"]) + ) for ii in range(np.random.randint(1, 4)): spec = fm(spec) return spec @@ -147,40 +194,59 @@ def mask_time_aug(spec, params): def mask_freq_aug(spec, params): # Mask out a random frequncy range - repeat up to 3 times # SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition - fm = torchaudio.transforms.FrequencyMasking(int(spec.shape[1]*params['mask_max_freq_perc'])) + fm = torchaudio.transforms.FrequencyMasking( + int(spec.shape[1] * params["mask_max_freq_perc"]) + ) for ii in range(np.random.randint(1, 4)): spec = fm(spec) return spec def scale_vol_aug(spec, params): - return spec * np.random.random()*params['spec_amp_scaling'] + return spec * np.random.random() * params["spec_amp_scaling"] def echo_aug(audio, sampling_rate, params): - sample_offset = int(params['echo_max_delay']*np.random.random()*sampling_rate) + 1 - audio[:-sample_offset] += np.random.random()*audio[sample_offset:] + sample_offset = ( + int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1 + ) + audio[:-sample_offset] += np.random.random() * audio[sample_offset:] return audio def resample_aug(audio, sampling_rate, params): sampling_rate_old = sampling_rate - sampling_rate = np.random.choice(params['aug_sampling_rates']) - audio = librosa.resample(audio, sampling_rate_old, sampling_rate, res_type='polyphase') + sampling_rate = np.random.choice(params["aug_sampling_rates"]) + audio = librosa.resample( + audio, sampling_rate_old, sampling_rate, res_type="polyphase" + ) - audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], - params['fft_overlap'], params['resize_factor'], - params['spec_divide_factor'], params['spec_train_width']) + audio = au.pad_audio( + audio, + sampling_rate, + params["fft_win_length"], + params["fft_overlap"], + params["resize_factor"], + params["spec_divide_factor"], + params["spec_train_width"], + ) duration = audio.shape[0] / float(sampling_rate) return audio, sampling_rate, duration def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2): if sampling_rate != sampling_rate2: - audio2 = librosa.resample(audio2, sampling_rate2, sampling_rate, res_type='polyphase') + audio2 = librosa.resample( + audio2, sampling_rate2, sampling_rate, res_type="polyphase" + ) sampling_rate2 = sampling_rate if audio2.shape[0] < num_samples: - audio2 = np.hstack((audio2, np.zeros((num_samples-audio2.shape[0]), dtype=audio2.dtype))) + audio2 = np.hstack( + ( + audio2, + np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype), + ) + ) elif audio2.shape[0] > num_samples: audio2 = audio2[:num_samples] return audio2, sampling_rate2 @@ -189,33 +255,43 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2): def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2): # resample so they are the same - audio2, sampling_rate2 = resample_audio(audio.shape[0], sampling_rate, audio2, sampling_rate2) + audio2, sampling_rate2 = resample_audio( + audio.shape[0], sampling_rate, audio2, sampling_rate2 + ) # # set mean and std to be the same # audio2 = (audio2 - audio2.mean()) # audio2 = (audio2/audio2.std())*audio.std() # audio2 = audio2 + audio.mean() - if ann['annotated'] and (ann2['annotated']) and \ - (sampling_rate2 == sampling_rate) and (audio.shape[0] == audio2.shape[0]): - comb_weight = 0.3 + np.random.random()*0.4 - audio = comb_weight*audio + (1-comb_weight)*audio2 - inds = np.argsort(np.hstack((ann['start_times'], ann2['start_times']))) + if ( + ann["annotated"] + and (ann2["annotated"]) + and (sampling_rate2 == sampling_rate) + and (audio.shape[0] == audio2.shape[0]) + ): + comb_weight = 0.3 + np.random.random() * 0.4 + audio = comb_weight * audio + (1 - comb_weight) * audio2 + inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"]))) for kk in ann.keys(): # when combining calls from different files, assume they come from different individuals - if kk == 'individual_ids': - if (ann[kk]>-1).sum() > 0: - ann2[kk][ann2[kk]>-1] += np.max(ann[kk][ann[kk]>-1]) + 1 + if kk == "individual_ids": + if (ann[kk] > -1).sum() > 0: + ann2[kk][ann2[kk] > -1] += ( + np.max(ann[kk][ann[kk] > -1]) + 1 + ) - if (kk != 'class_id_file') and (kk != 'annotated'): + if (kk != "class_id_file") and (kk != "annotated"): ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds] return audio, ann class AudioLoader(torch.utils.data.Dataset): - def __init__(self, data_anns_ip, params, dataset_name=None, is_train=False): + def __init__( + self, data_anns_ip, params, dataset_name=None, is_train=False + ): self.data_anns = [] self.is_train = is_train @@ -227,53 +303,70 @@ class AudioLoader(torch.utils.data.Dataset): # filter out unused annotation here filtered_annotations = [] - for ii, aa in enumerate(dd['annotation']): + for ii, aa in enumerate(dd["annotation"]): - if 'individual' in aa.keys(): - aa['individual'] = int(aa['individual']) + if "individual" in aa.keys(): + aa["individual"] = int(aa["individual"]) # if only one call labeled it has to be from the same individual - if len(dd['annotation']) == 1: - aa['individual'] = 0 + if len(dd["annotation"]) == 1: + aa["individual"] = 0 # convert class name into class label - if aa['class'] in self.params['class_names']: - aa['class_id'] = self.params['class_names'].index(aa['class']) + if aa["class"] in self.params["class_names"]: + aa["class_id"] = self.params["class_names"].index( + aa["class"] + ) else: - aa['class_id'] = -1 + aa["class_id"] = -1 - if aa['class'] not in self.params['classes_to_ignore']: + if aa["class"] not in self.params["classes_to_ignore"]: filtered_annotations.append(aa) - dd['annotation'] = filtered_annotations - dd['start_times'] = np.array([aa['start_time'] for aa in dd['annotation']]) - dd['end_times'] = np.array([aa['end_time'] for aa in dd['annotation']]) - dd['high_freqs'] = np.array([float(aa['high_freq']) for aa in dd['annotation']]) - dd['low_freqs'] = np.array([float(aa['low_freq']) for aa in dd['annotation']]) - dd['class_ids'] = np.array([aa['class_id'] for aa in dd['annotation']]).astype(np.int) - dd['individual_ids'] = np.array([aa['individual'] for aa in dd['annotation']]).astype(np.int) + dd["annotation"] = filtered_annotations + dd["start_times"] = np.array( + [aa["start_time"] for aa in dd["annotation"]] + ) + dd["end_times"] = np.array( + [aa["end_time"] for aa in dd["annotation"]] + ) + dd["high_freqs"] = np.array( + [float(aa["high_freq"]) for aa in dd["annotation"]] + ) + dd["low_freqs"] = np.array( + [float(aa["low_freq"]) for aa in dd["annotation"]] + ) + dd["class_ids"] = np.array( + [aa["class_id"] for aa in dd["annotation"]] + ).astype(np.int) + dd["individual_ids"] = np.array( + [aa["individual"] for aa in dd["annotation"]] + ).astype(np.int) # file level class name - dd['class_id_file'] = -1 - if 'class_name' in dd.keys(): - if dd['class_name'] in self.params['class_names']: - dd['class_id_file'] = self.params['class_names'].index(dd['class_name']) + dd["class_id_file"] = -1 + if "class_name" in dd.keys(): + if dd["class_name"] in self.params["class_names"]: + dd["class_id_file"] = self.params["class_names"].index( + dd["class_name"] + ) self.data_anns.append(dd) - ann_cnt = [len(aa['annotation']) for aa in self.data_anns] - self.max_num_anns = 2*np.max(ann_cnt) # x2 because we may be combining files during training + ann_cnt = [len(aa["annotation"]) for aa in self.data_anns] + self.max_num_anns = 2 * np.max( + ann_cnt + ) # x2 because we may be combining files during training - print('\n') + print("\n") if dataset_name is not None: - print('Dataset : ' + dataset_name) + print("Dataset : " + dataset_name) if self.is_train: - print('Split type : train') + print("Split type : train") else: - print('Split type : test') - print('Num files : ' + str(len(self.data_anns))) - print('Num calls : ' + str(np.sum(ann_cnt))) - + print("Split type : test") + print("Num files : " + str(len(self.data_anns))) + print("Num calls : " + str(np.sum(ann_cnt))) def get_file_and_anns(self, index=None): @@ -281,110 +374,171 @@ class AudioLoader(torch.utils.data.Dataset): if index == None: index = np.random.randint(0, len(self.data_anns)) - audio_file = self.data_anns[index]['file_path'] - sampling_rate, audio_raw = au.load_audio_file(audio_file, self.data_anns[index]['time_exp'], - self.params['target_samp_rate'], self.params['scale_raw_audio']) + audio_file = self.data_anns[index]["file_path"] + sampling_rate, audio_raw = au.load_audio_file( + audio_file, + self.data_anns[index]["time_exp"], + self.params["target_samp_rate"], + self.params["scale_raw_audio"], + ) # copy annotation ann = {} - ann['annotated'] = self.data_anns[index]['annotated'] - ann['class_id_file'] = self.data_anns[index]['class_id_file'] - keys = ['start_times', 'end_times', 'high_freqs', 'low_freqs', 'class_ids', 'individual_ids'] + ann["annotated"] = self.data_anns[index]["annotated"] + ann["class_id_file"] = self.data_anns[index]["class_id_file"] + keys = [ + "start_times", + "end_times", + "high_freqs", + "low_freqs", + "class_ids", + "individual_ids", + ] for kk in keys: ann[kk] = self.data_anns[index][kk].copy() # if train then grab a random crop if self.is_train: - nfft = int(self.params['fft_win_length']*sampling_rate) - noverlap = int(self.params['fft_overlap']*nfft) - length_samples = self.params['spec_train_width']*(nfft - noverlap) + noverlap + nfft = int(self.params["fft_win_length"] * sampling_rate) + noverlap = int(self.params["fft_overlap"] * nfft) + length_samples = ( + self.params["spec_train_width"] * (nfft - noverlap) + noverlap + ) if audio_raw.shape[0] - length_samples > 0: - sample_crop = np.random.randint(audio_raw.shape[0] - length_samples) + sample_crop = np.random.randint( + audio_raw.shape[0] - length_samples + ) else: sample_crop = 0 - audio_raw = audio_raw[sample_crop:sample_crop+length_samples] - ann['start_times'] = ann['start_times'] - sample_crop/float(sampling_rate) - ann['end_times'] = ann['end_times'] - sample_crop/float(sampling_rate) + audio_raw = audio_raw[sample_crop : sample_crop + length_samples] + ann["start_times"] = ann["start_times"] - sample_crop / float( + sampling_rate + ) + ann["end_times"] = ann["end_times"] - sample_crop / float( + sampling_rate + ) # pad audio if self.is_train: - op_spec_target_size = self.params['spec_train_width'] + op_spec_target_size = self.params["spec_train_width"] else: op_spec_target_size = None - audio_raw = au.pad_audio(audio_raw, sampling_rate, self.params['fft_win_length'], - self.params['fft_overlap'], self.params['resize_factor'], - self.params['spec_divide_factor'], op_spec_target_size) + audio_raw = au.pad_audio( + audio_raw, + sampling_rate, + self.params["fft_win_length"], + self.params["fft_overlap"], + self.params["resize_factor"], + self.params["spec_divide_factor"], + op_spec_target_size, + ) duration = audio_raw.shape[0] / float(sampling_rate) # sort based on time - inds = np.argsort(ann['start_times']) + inds = np.argsort(ann["start_times"]) for kk in ann.keys(): - if (kk != 'class_id_file') and (kk != 'annotated'): + if (kk != "class_id_file") and (kk != "annotated"): ann[kk] = ann[kk][inds] return audio_raw, sampling_rate, duration, ann - def __getitem__(self, index): # load audio file audio, sampling_rate, duration, ann = self.get_file_and_anns(index) # augment on raw audio - if self.is_train and self.params['augment_at_train']: + if self.is_train and self.params["augment_at_train"]: # augment - combine with random audio file - if self.params['augment_at_train_combine'] and np.random.random() < self.params['aug_prob']: - audio2, sampling_rate2, duration2, ann2 = self.get_file_and_anns() - audio, ann = combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2) + if ( + self.params["augment_at_train_combine"] + and np.random.random() < self.params["aug_prob"] + ): + ( + audio2, + sampling_rate2, + duration2, + ann2, + ) = self.get_file_and_anns() + audio, ann = combine_audio_aug( + audio, sampling_rate, ann, audio2, sampling_rate2, ann2 + ) # simulate echo by adding delayed copy of the file - if np.random.random() < self.params['aug_prob']: + if np.random.random() < self.params["aug_prob"]: audio = echo_aug(audio, sampling_rate, self.params) # resample the audio - #if np.random.random() < self.params['aug_prob']: + # if np.random.random() < self.params['aug_prob']: # audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params) # create spectrogram - spec, spec_for_viz = au.generate_spectrogram(audio, sampling_rate, self.params, self.return_spec_for_viz) - rsf = self.params['resize_factor'] - spec_op_shape = (int(self.params['spec_height']*rsf), int(spec.shape[1]*rsf)) + spec, spec_for_viz = au.generate_spectrogram( + audio, sampling_rate, self.params, self.return_spec_for_viz + ) + rsf = self.params["resize_factor"] + spec_op_shape = ( + int(self.params["spec_height"] * rsf), + int(spec.shape[1] * rsf), + ) # resize the spec spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0) - spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False).squeeze(0) + spec = F.interpolate( + spec, size=spec_op_shape, mode="bilinear", align_corners=False + ).squeeze(0) # augment spectrogram - if self.is_train and self.params['augment_at_train']: + if self.is_train and self.params["augment_at_train"]: - if np.random.random() < self.params['aug_prob']: + if np.random.random() < self.params["aug_prob"]: spec = scale_vol_aug(spec, self.params) - if np.random.random() < self.params['aug_prob']: - spec = warp_spec_aug(spec, ann, self.return_spec_for_viz, self.params) + if np.random.random() < self.params["aug_prob"]: + spec = warp_spec_aug( + spec, ann, self.return_spec_for_viz, self.params + ) - if np.random.random() < self.params['aug_prob']: + if np.random.random() < self.params["aug_prob"]: spec = mask_time_aug(spec, self.params) - if np.random.random() < self.params['aug_prob']: + if np.random.random() < self.params["aug_prob"]: spec = mask_freq_aug(spec, self.params) outputs = {} - outputs['spec'] = spec + outputs["spec"] = spec if self.return_spec_for_viz: - outputs['spec_for_viz'] = torch.from_numpy(spec_for_viz).unsqueeze(0) + outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze( + 0 + ) # create ground truth heatmaps - outputs['y_2d_det'], outputs['y_2d_size'], outputs['y_2d_classes'], ann_aug =\ - generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params) + ( + outputs["y_2d_det"], + outputs["y_2d_size"], + outputs["y_2d_classes"], + ann_aug, + ) = generate_gt_heatmaps( + spec_op_shape, sampling_rate, ann, self.params + ) # hack to get around requirement that all vectors are the same length in # the output batch - pad_size = self.max_num_anns-len(ann_aug['individual_ids']) - outputs['is_valid'] = pad_aray(np.ones(len(ann_aug['individual_ids'])), pad_size) - keys = ['class_ids', 'individual_ids', 'x_inds', 'y_inds', - 'start_times', 'end_times', 'low_freqs', 'high_freqs'] + pad_size = self.max_num_anns - len(ann_aug["individual_ids"]) + outputs["is_valid"] = pad_aray( + np.ones(len(ann_aug["individual_ids"])), pad_size + ) + keys = [ + "class_ids", + "individual_ids", + "x_inds", + "y_inds", + "start_times", + "end_times", + "low_freqs", + "high_freqs", + ] for kk in keys: outputs[kk] = pad_aray(ann_aug[kk], pad_size) @@ -394,14 +548,13 @@ class AudioLoader(torch.utils.data.Dataset): outputs[kk] = torch.from_numpy(outputs[kk]) # scalars - outputs['class_id_file'] = ann['class_id_file'] - outputs['annotated'] = ann['annotated'] - outputs['duration'] = duration - outputs['sampling_rate'] = sampling_rate - outputs['file_id'] = index + outputs["class_id_file"] = ann["class_id_file"] + outputs["annotated"] = ann["annotated"] + outputs["duration"] = duration + outputs["sampling_rate"] = sampling_rate + outputs["file_id"] = index return outputs - def __len__(self): return len(self.data_anns) diff --git a/bat_detect/train/evaluate.py b/bat_detect/train/evaluate.py index b88719f..a926fbb 100755 --- a/bat_detect/train/evaluate.py +++ b/bat_detect/train/evaluate.py @@ -1,6 +1,10 @@ import numpy as np -from sklearn.metrics import roc_curve, auc -from sklearn.metrics import accuracy_score, balanced_accuracy_score +from sklearn.metrics import ( + accuracy_score, + auc, + balanced_accuracy_score, + roc_curve, +) def compute_error_auc(op_str, gt, pred, prob): @@ -13,8 +17,11 @@ def compute_error_auc(op_str, gt, pred, prob): fpr, tpr, thresholds = roc_curve(gt, pred) roc_auc = auc(fpr, tpr) - print(op_str + ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc)) - #return class_acc, roc_auc + print( + op_str + + ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc) + ) + # return class_acc, roc_auc def calc_average_precision(recall, precision): @@ -25,10 +32,10 @@ def calc_average_precision(recall, precision): # pascal 12 way mprec = np.hstack((0, precision, 0)) mrec = np.hstack((0, recall, 1)) - for ii in range(mprec.shape[0]-2, -1,-1): - mprec[ii] = np.maximum(mprec[ii], mprec[ii+1]) - inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0]+1 - ave_prec = ((mrec[inds] - mrec[inds-1])*mprec[inds]).sum() + for ii in range(mprec.shape[0] - 2, -1, -1): + mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1]) + inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1 + ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum() return float(ave_prec) @@ -37,7 +44,7 @@ def calc_recall_at_x(recall, precision, x=0.95): precision[np.isnan(precision)] = 0 recall[np.isnan(recall)] = 0 - inds = np.where(precision[::-1]>x)[0] + inds = np.where(precision[::-1] > x)[0] if len(inds) > 0: return float(recall[::-1][inds[0]]) else: @@ -51,7 +58,15 @@ def compute_affinity_1d(pred_box, gt_boxes, threshold): return valid_detection, np.argmin(score) -def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, threshold, ignore_start_end): +def compute_pre_rec( + gts, + preds, + eval_mode, + class_of_interest, + num_classes, + threshold, + ignore_start_end, +): """ Computes precision and recall. Assumes that each file has been exhaustively annotated. Will not count predicted detection with a start time that is within @@ -78,26 +93,40 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres for pid, pp in enumerate(preds): # filter predicted calls that are too near the start or end of the file - file_dur = gts[pid]['duration'] - valid_inds = (pp['start_times'] >= ignore_start_end) & (pp['start_times'] <= (file_dur - ignore_start_end)) + file_dur = gts[pid]["duration"] + valid_inds = (pp["start_times"] >= ignore_start_end) & ( + pp["start_times"] <= (file_dur - ignore_start_end) + ) - pred_boxes.append(np.vstack((pp['start_times'][valid_inds], pp['end_times'][valid_inds], - pp['low_freqs'][valid_inds], pp['high_freqs'][valid_inds])).T) + pred_boxes.append( + np.vstack( + ( + pp["start_times"][valid_inds], + pp["end_times"][valid_inds], + pp["low_freqs"][valid_inds], + pp["high_freqs"][valid_inds], + ) + ).T + ) - if eval_mode == 'detection': + if eval_mode == "detection": # overall detection - confidence.append(pp['det_probs'][valid_inds]) - elif eval_mode == 'per_class': + confidence.append(pp["det_probs"][valid_inds]) + elif eval_mode == "per_class": # per class - confidence.append(pp['class_probs'].T[valid_inds, class_of_interest]) - elif eval_mode == 'top_class': + confidence.append( + pp["class_probs"].T[valid_inds, class_of_interest] + ) + elif eval_mode == "top_class": # per class - note that sometimes 'class_probs' can be num_classes+1 in size - top_class = np.argmax(pp['class_probs'].T[valid_inds, :num_classes], 1) - confidence.append(pp['class_probs'].T[valid_inds, top_class]) + top_class = np.argmax( + pp["class_probs"].T[valid_inds, :num_classes], 1 + ) + confidence.append(pp["class_probs"].T[valid_inds, top_class]) pred_class.append(top_class) # be careful, assuming the order in the list is same as GT - file_ids.append([pid]*valid_inds.sum()) + file_ids.append([pid] * valid_inds.sum()) confidence = np.hstack(confidence) file_ids = np.hstack(file_ids).astype(np.int) @@ -105,7 +134,6 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres if len(pred_class) > 0: pred_class = np.hstack(pred_class) - # extract relevant ground truth boxes gt_boxes = [] gt_assigned = [] @@ -115,32 +143,42 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres for gg in gts: # filter ground truth calls that are too near the start or end of the file - file_dur = gg['duration'] - valid_inds = (gg['start_times'] >= ignore_start_end) & (gg['start_times'] <= (file_dur - ignore_start_end)) + file_dur = gg["duration"] + valid_inds = (gg["start_times"] >= ignore_start_end) & ( + gg["start_times"] <= (file_dur - ignore_start_end) + ) # note, files with the incorrect duration will cause a problem - if (gg['start_times'] > file_dur).sum() > 0: - print('Error: file duration incorrect for', gg['id']) - assert(False) + if (gg["start_times"] > file_dur).sum() > 0: + print("Error: file duration incorrect for", gg["id"]) + assert False - boxes = np.vstack((gg['start_times'][valid_inds], gg['end_times'][valid_inds], - gg['low_freqs'][valid_inds], gg['high_freqs'][valid_inds])).T - gen_class = gg['class_ids'][valid_inds] == -1 - class_ids = gg['class_ids'][valid_inds] + boxes = np.vstack( + ( + gg["start_times"][valid_inds], + gg["end_times"][valid_inds], + gg["low_freqs"][valid_inds], + gg["high_freqs"][valid_inds], + ) + ).T + gen_class = gg["class_ids"][valid_inds] == -1 + class_ids = gg["class_ids"][valid_inds] # keep track of the number of relevant ground truth calls - if eval_mode == 'detection': + if eval_mode == "detection": # all valid ones - num_positives += len(gg['start_times'][valid_inds]) - elif eval_mode == 'per_class': + num_positives += len(gg["start_times"][valid_inds]) + elif eval_mode == "per_class": # all valid ones with class of interest - num_positives += (gg['class_ids'][valid_inds] == class_of_interest).sum() - elif eval_mode == 'top_class': + num_positives += ( + gg["class_ids"][valid_inds] == class_of_interest + ).sum() + elif eval_mode == "top_class": # all valid ones with non generic class - num_positives += (gg['class_ids'][valid_inds] > -1).sum() + num_positives += (gg["class_ids"][valid_inds] > -1).sum() # find relevant classes (i.e. class_of_interest) and events without known class (i.e. generic class, -1) - if eval_mode == 'per_class': + if eval_mode == "per_class": class_inds = (class_ids == class_of_interest) | (class_ids == -1) boxes = boxes[class_inds, :] gen_class = gen_class[class_inds] @@ -151,25 +189,27 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres gt_generic_class.append(gen_class) gt_class.append(class_ids) - # loop through detections and keep track of those that have been assigned - true_pos = np.zeros(confidence.shape[0]) - valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True - sorted_inds = np.argsort(confidence)[::-1] # sort high to low + true_pos = np.zeros(confidence.shape[0]) + valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True + sorted_inds = np.argsort(confidence)[::-1] # sort high to low for ii, ind in enumerate(sorted_inds): gt_id = file_ids[ind] valid_det = False if gt_boxes[gt_id].shape[0] > 0: # compute overlap - valid_det, det_ind = compute_affinity_1d(pred_boxes[ind], gt_boxes[gt_id], - threshold) + valid_det, det_ind = compute_affinity_1d( + pred_boxes[ind], gt_boxes[gt_id], threshold + ) # valid detection that has not already been assigned if valid_det and (gt_assigned[gt_id][det_ind] == 0): count_as_true_pos = True - if eval_mode == 'top_class' and (gt_class[gt_id][det_ind] != pred_class[ind]): + if eval_mode == "top_class" and ( + gt_class[gt_id][det_ind] != pred_class[ind] + ): # needs to be the same class count_as_true_pos = False @@ -181,40 +221,43 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres # if event is generic class (i.e. gt_generic_class[gt_id][det_ind] is True) # and eval_mode != 'detection', then ignore it if gt_generic_class[gt_id][det_ind]: - if eval_mode == 'per_class' or eval_mode == 'top_class': + if eval_mode == "per_class" or eval_mode == "top_class": valid_inds[ii] = False - # store threshold values - used for plotting conf_sorted = np.sort(confidence)[::-1][valid_inds] thresholds = np.linspace(0.1, 0.9, 9) thresholds_inds = np.zeros(len(thresholds), dtype=np.int) for ii, tt in enumerate(thresholds): thresholds_inds[ii] = np.argmin(conf_sorted > tt) - thresholds_inds[thresholds_inds==0] = -1 + thresholds_inds[thresholds_inds == 0] = -1 # compute precision and recall - true_pos = true_pos[valid_inds] - false_pos_c = np.cumsum(1-true_pos) - true_pos_c = np.cumsum(true_pos) + true_pos = true_pos[valid_inds] + false_pos_c = np.cumsum(1 - true_pos) + true_pos_c = np.cumsum(true_pos) recall = true_pos_c / num_positives - precision = true_pos_c / np.maximum(true_pos_c + false_pos_c, np.finfo(np.float64).eps) + precision = true_pos_c / np.maximum( + true_pos_c + false_pos_c, np.finfo(np.float64).eps + ) results = {} - results['recall'] = recall - results['precision'] = precision - results['num_gt'] = num_positives + results["recall"] = recall + results["precision"] = precision + results["num_gt"] = num_positives - results['thresholds'] = thresholds - results['thresholds_inds'] = thresholds_inds + results["thresholds"] = thresholds + results["thresholds_inds"] = thresholds_inds if num_positives == 0: - results['avg_prec'] = np.nan - results['rec_at_x'] = np.nan + results["avg_prec"] = np.nan + results["rec_at_x"] = np.nan else: - results['avg_prec'] = np.round(calc_average_precision(recall, precision), 5) - results['rec_at_x'] = np.round(calc_recall_at_x(recall, precision), 5) + results["avg_prec"] = np.round( + calc_average_precision(recall, precision), 5 + ) + results["rec_at_x"] = np.round(calc_recall_at_x(recall, precision), 5) return results @@ -230,19 +273,19 @@ def compute_file_accuracy_simple(gts, preds, num_classes): gt_valid = [] pred_valid = [] for ii in range(len(gts)): - gt_class = np.unique(gts[ii]['class_ids']) + gt_class = np.unique(gts[ii]["class_ids"]) if len(gt_class) == 1 and gt_class[0] != -1: gt_valid.append(gt_class[0]) - pred = preds[ii]['class_probs'][:num_classes, :].T + pred = preds[ii]["class_probs"][:num_classes, :].T pred_valid.append(np.argmax(pred.mean(0))) acc = (np.array(gt_valid) == np.array(pred_valid)).mean() res = {} - res['num_valid_files'] = len(gt_valid) - res['num_total_files'] = len(gts) - res['gt_valid_file'] = gt_valid - res['pred_valid_file'] = pred_valid - res['file_acc'] = np.round(acc, 5) + res["num_valid_files"] = len(gt_valid) + res["num_total_files"] = len(gts) + res["gt_valid_file"] = gt_valid + res["pred_valid_file"] = pred_valid + res["file_acc"] = np.round(acc, 5) return res @@ -256,12 +299,20 @@ def compute_file_accuracy(gts, preds, num_classes): # compute min and max scoring range - then threshold min_val = 0 - mins = [pp['class_probs'].min() for pp in preds if pp['class_probs'].shape[1] > 0] + mins = [ + pp["class_probs"].min() + for pp in preds + if pp["class_probs"].shape[1] > 0 + ] if len(mins) > 0: min_val = np.min(mins) max_val = 1.0 - maxes = [pp['class_probs'].max() for pp in preds if pp['class_probs'].shape[1] > 0] + maxes = [ + pp["class_probs"].max() + for pp in preds + if pp["class_probs"].shape[1] > 0 + ] if len(maxes) > 0: max_val = np.max(maxes) @@ -272,33 +323,37 @@ def compute_file_accuracy(gts, preds, num_classes): gt_valid = [] pred_valid_all = [] for ii in range(len(gts)): - gt_class = np.unique(gts[ii]['class_ids']) + gt_class = np.unique(gts[ii]["class_ids"]) if len(gt_class) == 1 and gt_class[0] != -1: gt_valid.append(gt_class[0]) - pred = preds[ii]['class_probs'][:num_classes, :].T + pred = preds[ii]["class_probs"][:num_classes, :].T p_class = np.zeros(len(thresh)) for tt in range(len(thresh)): - p_class[tt] = (pred*(pred>=thresh[tt])).sum(0).argmax() + p_class[tt] = (pred * (pred >= thresh[tt])).sum(0).argmax() pred_valid_all.append(p_class) # pick the result corresponding to the overall best threshold pred_valid_all = np.vstack(pred_valid_all) - acc_per_thresh = (np.array(gt_valid)[..., np.newaxis] == pred_valid_all).mean(0) + acc_per_thresh = ( + np.array(gt_valid)[..., np.newaxis] == pred_valid_all + ).mean(0) best_thresh = np.argmax(acc_per_thresh) best_acc = acc_per_thresh[best_thresh] pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist() res = {} - res['num_valid_files'] = len(gt_valid) - res['num_total_files'] = len(gts) - res['gt_valid_file'] = gt_valid - res['pred_valid_file'] = pred_valid - res['file_acc'] = np.round(best_acc, 5) + res["num_valid_files"] = len(gt_valid) + res["num_total_files"] = len(gts) + res["gt_valid_file"] = gt_valid + res["pred_valid_file"] = pred_valid + res["file_acc"] = np.round(best_acc, 5) return res -def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_start_end=0.0): +def evaluate_predictions( + gts, preds, class_names, detection_overlap, ignore_start_end=0.0 +): """ Computes metrics derived from the precision and recall. Assumes that gts and preds are both lists of the same lengths, with ground @@ -307,24 +362,50 @@ def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_star Returns the overall detection results, and per class results """ - assert(len(gts) == len(preds)) + assert len(gts) == len(preds) num_classes = len(class_names) # evaluate detection on its own i.e. ignoring class - det_results = compute_pre_rec(gts, preds, 'detection', None, num_classes, detection_overlap, ignore_start_end) - top_class = compute_pre_rec(gts, preds, 'top_class', None, num_classes, detection_overlap, ignore_start_end) - det_results['top_class'] = top_class + det_results = compute_pre_rec( + gts, + preds, + "detection", + None, + num_classes, + detection_overlap, + ignore_start_end, + ) + top_class = compute_pre_rec( + gts, + preds, + "top_class", + None, + num_classes, + detection_overlap, + ignore_start_end, + ) + det_results["top_class"] = top_class # per class evaluation - det_results['class_pr'] = [] + det_results["class_pr"] = [] for cc in range(num_classes): - res = compute_pre_rec(gts, preds, 'per_class', cc, num_classes, detection_overlap, ignore_start_end) - res['name'] = class_names[cc] - det_results['class_pr'].append(res) + res = compute_pre_rec( + gts, + preds, + "per_class", + cc, + num_classes, + detection_overlap, + ignore_start_end, + ) + res["name"] = class_names[cc] + det_results["class_pr"].append(res) # ignores classes that are not present in the test set - det_results['avg_prec_class'] = np.mean([rs['avg_prec'] for rs in det_results['class_pr'] if rs['num_gt'] > 0]) - det_results['avg_prec_class'] = np.round(det_results['avg_prec_class'], 5) + det_results["avg_prec_class"] = np.mean( + [rs["avg_prec"] for rs in det_results["class_pr"] if rs["num_gt"] > 0] + ) + det_results["avg_prec_class"] = np.round(det_results["avg_prec_class"], 5) # file level evaluation res_file = compute_file_accuracy(gts, preds, num_classes) diff --git a/bat_detect/train/losses.py b/bat_detect/train/losses.py index aaef2c4..02bfdd6 100644 --- a/bat_detect/train/losses.py +++ b/bat_detect/train/losses.py @@ -7,7 +7,9 @@ def bbox_size_loss(pred_size, gt_size): Bounding box size loss. Only compute loss where there is a bounding box. """ gt_size_mask = (gt_size > 0).float() - return (F.l1_loss(pred_size*gt_size_mask, gt_size, reduction='sum') / (gt_size_mask.sum() + 1e-5)) + return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / ( + gt_size_mask.sum() + 1e-5 + ) def focal_loss(pred, gt, weights=None, valid_mask=None): @@ -24,20 +26,25 @@ def focal_loss(pred, gt, weights=None, valid_mask=None): neg_inds = gt.lt(1).float() pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds - neg_loss = torch.log(1 - pred + eps) * torch.pow(pred, alpha) * torch.pow(1 - gt, beta) * neg_inds + neg_loss = ( + torch.log(1 - pred + eps) + * torch.pow(pred, alpha) + * torch.pow(1 - gt, beta) + * neg_inds + ) if weights is not None: - pos_loss = pos_loss*weights - #neg_loss = neg_loss*weights + pos_loss = pos_loss * weights + # neg_loss = neg_loss*weights if valid_mask is not None: - pos_loss = pos_loss*valid_mask - neg_loss = neg_loss*valid_mask + pos_loss = pos_loss * valid_mask + neg_loss = neg_loss * valid_mask pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() - num_pos = pos_inds.float().sum() + num_pos = pos_inds.float().sum() if num_pos == 0: loss = -neg_loss else: @@ -47,10 +54,10 @@ def focal_loss(pred, gt, weights=None, valid_mask=None): def mse_loss(pred, gt, weights=None, valid_mask=None): """ - Mean squared error loss. + Mean squared error loss. """ if valid_mask is None: - op = ((gt-pred)**2).mean() + op = ((gt - pred) ** 2).mean() else: - op = (valid_mask*((gt-pred)**2)).sum() / valid_mask.sum() + op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum() return op diff --git a/bat_detect/train/train_model.py b/bat_detect/train/train_model.py index d955216..2fd33fe 100644 --- a/bat_detect/train/train_model.py +++ b/bat_detect/train/train_model.py @@ -1,32 +1,33 @@ -import numpy as np -import matplotlib.pyplot as plt +import argparse +import json import os +import sys + +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn.functional as F from torch.optim.lr_scheduler import CosineAnnealingLR -import json -import argparse -import sys -sys.path.append(os.path.join('..', '..')) - -import bat_detect.detector.parameters as parameters -import bat_detect.detector.models as models -import bat_detect.detector.post_process as pp -import bat_detect.utils.plot_utils as pu - -import bat_detect.train.audio_dataloader as adl -import bat_detect.train.evaluate as evl -import bat_detect.train.train_utils as tu -import bat_detect.train.train_split as ts -import bat_detect.train.losses as losses +sys.path.append(os.path.join("..", "..")) import warnings + +import bat_detect.detector.models as models +import bat_detect.detector.parameters as parameters +import bat_detect.detector.post_process as pp +import bat_detect.train.audio_dataloader as adl +import bat_detect.train.evaluate as evl +import bat_detect.train.losses as losses +import bat_detect.train.train_split as ts +import bat_detect.train.train_utils as tu +import bat_detect.utils.plot_utils as pu + warnings.filterwarnings("ignore", category=UserWarning) def save_images_batch(model, data_loader, params): - print('\nsaving images ...') + print("\nsaving images ...") is_train_state = data_loader.dataset.is_train data_loader.dataset.is_train = False @@ -36,67 +37,112 @@ def save_images_batch(model, data_loader, params): ind = 0 # first image in each batch with torch.no_grad(): for batch_idx, inputs in enumerate(data_loader): - data = inputs['spec'].to(params['device']) + data = inputs["spec"].to(params["device"]) outputs = model(data) - spec_viz = inputs['spec_for_viz'].data.cpu().numpy() - orig_index = inputs['file_id'][ind] - plot_title = data_loader.dataset.data_anns[orig_index]['id'] - op_file_name = params['op_im_dir_test'] + data_loader.dataset.data_anns[orig_index]['id'] + '.jpg' - save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title) + spec_viz = inputs["spec_for_viz"].data.cpu().numpy() + orig_index = inputs["file_id"][ind] + plot_title = data_loader.dataset.data_anns[orig_index]["id"] + op_file_name = ( + params["op_im_dir_test"] + + data_loader.dataset.data_anns[orig_index]["id"] + + ".jpg" + ) + save_image( + spec_viz, + outputs, + ind, + inputs, + params, + op_file_name, + plot_title, + ) data_loader.dataset.is_train = is_train_state data_loader.dataset.return_spec_for_viz = False -def save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title): - pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float()) - pred_hm = outputs['pred_det'][ind, 0, :].data.cpu().numpy() +def save_image( + spec_viz, outputs, ind, inputs, params, op_file_name, plot_title +): + pred_nms, _ = pp.run_nms(outputs, params, inputs["sampling_rate"].float()) + pred_hm = outputs["pred_det"][ind, 0, :].data.cpu().numpy() spec_viz = spec_viz[ind, 0, :] - gt = parse_gt_data(inputs)[ind] - sampling_rate = inputs['sampling_rate'][ind].item() - duration = inputs['duration'][ind].item() + gt = parse_gt_data(inputs)[ind] + sampling_rate = inputs["sampling_rate"][ind].item() + duration = inputs["duration"][ind].item() - pu.plot_spec(spec_viz, sampling_rate, duration, gt, pred_nms[ind], - params, plot_title, op_file_name, pred_hm, plot_boxes=True, fixed_aspect=False) + pu.plot_spec( + spec_viz, + sampling_rate, + duration, + gt, + pred_nms[ind], + params, + plot_title, + op_file_name, + pred_hm, + plot_boxes=True, + fixed_aspect=False, + ) -def loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq): +def loss_fun( + outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq +): # detection loss - loss = params['det_loss_weight']*det_criterion(outputs['pred_det'], gt_det) + loss = params["det_loss_weight"] * det_criterion( + outputs["pred_det"], gt_det + ) # bounding box size loss - loss += params['size_loss_weight']*losses.bbox_size_loss(outputs['pred_size'], gt_size) + loss += params["size_loss_weight"] * losses.bbox_size_loss( + outputs["pred_size"], gt_size + ) # classification loss valid_mask = (gt_class[:, :-1, :, :].sum(1) > 0).float().unsqueeze(1) - p_class = outputs['pred_class'][:, :-1, :] - loss += params['class_loss_weight']*det_criterion(p_class, gt_class[:, :-1, :], valid_mask=valid_mask) + p_class = outputs["pred_class"][:, :-1, :] + loss += params["class_loss_weight"] * det_criterion( + p_class, gt_class[:, :-1, :], valid_mask=valid_mask + ) return loss -def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params): +def train( + model, epoch, data_loader, det_criterion, optimizer, scheduler, params +): model.train() train_loss = tu.AverageMeter() - class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device']) + class_inv_freq = torch.from_numpy( + np.array(params["class_inv_freq"], dtype=np.float32) + ).to(params["device"]) class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2) - print('\nEpoch', epoch) + print("\nEpoch", epoch) for batch_idx, inputs in enumerate(data_loader): - data = inputs['spec'].to(params['device']) - gt_det = inputs['y_2d_det'].to(params['device']) - gt_size = inputs['y_2d_size'].to(params['device']) - gt_class = inputs['y_2d_classes'].to(params['device']) + data = inputs["spec"].to(params["device"]) + gt_det = inputs["y_2d_det"].to(params["device"]) + gt_size = inputs["y_2d_size"].to(params["device"]) + gt_class = inputs["y_2d_classes"].to(params["device"]) optimizer.zero_grad() outputs = model(data) - loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq) + loss = loss_fun( + outputs, + gt_det, + gt_size, + gt_class, + det_criterion, + params, + class_inv_freq, + ) train_loss.update(loss.item(), data.shape[0]) loss.backward() @@ -104,13 +150,18 @@ def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params scheduler.step() if batch_idx % 50 == 0 and batch_idx != 0: - print('[{}/{}]\tLoss: {:.4f}'.format( - batch_idx * len(data), len(data_loader.dataset), train_loss.avg)) + print( + "[{}/{}]\tLoss: {:.4f}".format( + batch_idx * len(data), + len(data_loader.dataset), + train_loss.avg, + ) + ) - print('Train loss : {:.4f}'.format(train_loss.avg)) + print("Train loss : {:.4f}".format(train_loss.avg)) res = {} - res['train_loss'] = float(train_loss.avg) + res["train_loss"] = float(train_loss.avg) return res @@ -120,16 +171,18 @@ def test(model, epoch, data_loader, det_criterion, params): ground_truths = [] test_loss = tu.AverageMeter() - class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device']) + class_inv_freq = torch.from_numpy( + np.array(params["class_inv_freq"], dtype=np.float32) + ).to(params["device"]) class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2) with torch.no_grad(): for batch_idx, inputs in enumerate(data_loader): - data = inputs['spec'].to(params['device']) - gt_det = inputs['y_2d_det'].to(params['device']) - gt_size = inputs['y_2d_size'].to(params['device']) - gt_class = inputs['y_2d_classes'].to(params['device']) + data = inputs["spec"].to(params["device"]) + gt_det = inputs["y_2d_det"].to(params["device"]) + gt_size = inputs["y_2d_size"].to(params["device"]) + gt_class = inputs["y_2d_classes"].to(params["device"]) outputs = model(data) @@ -139,41 +192,79 @@ def test(model, epoch, data_loader, det_criterion, params): # for kk in ['pred_det', 'pred_size', 'pred_class']: # outputs[kk] = torch.cat([oo for oo in outputs[kk]], 2).unsqueeze(0) - if params['save_test_image_during_train'] and batch_idx == 0: + if params["save_test_image_during_train"] and batch_idx == 0: # for visualization - save the first prediction ind = 0 - orig_index = inputs['file_id'][ind] - plot_title = data_loader.dataset.data_anns[orig_index]['id'] - op_file_name = params['op_im_dir'] + str(orig_index.item()).zfill(4) + '_' + str(epoch).zfill(4) + '_pred.jpg' - save_image(data, outputs, ind, inputs, params, op_file_name, plot_title) + orig_index = inputs["file_id"][ind] + plot_title = data_loader.dataset.data_anns[orig_index]["id"] + op_file_name = ( + params["op_im_dir"] + + str(orig_index.item()).zfill(4) + + "_" + + str(epoch).zfill(4) + + "_pred.jpg" + ) + save_image( + data, + outputs, + ind, + inputs, + params, + op_file_name, + plot_title, + ) - loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq) + loss = loss_fun( + outputs, + gt_det, + gt_size, + gt_class, + det_criterion, + params, + class_inv_freq, + ) test_loss.update(loss.item(), data.shape[0]) # do NMS - pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float()) + pred_nms, _ = pp.run_nms( + outputs, params, inputs["sampling_rate"].float() + ) predictions.extend(pred_nms) ground_truths.extend(parse_gt_data(inputs)) - res_det = evl.evaluate_predictions(ground_truths, predictions, params['class_names'], - params['detection_overlap'], params['ignore_start_end']) + res_det = evl.evaluate_predictions( + ground_truths, + predictions, + params["class_names"], + params["detection_overlap"], + params["ignore_start_end"], + ) - print('\nTest loss : {:.4f}'.format(test_loss.avg)) - print('Rec at 0.95 (det) : {:.4f}'.format(res_det['rec_at_x'])) - print('Avg prec (cls) : {:.4f}'.format(res_det['avg_prec'])) - print('File acc (cls) : {:.2f} - for {} out of {}'.format(res_det['file_acc'], - res_det['num_valid_files'], res_det['num_total_files'])) - print('Cls Avg prec (cls) : {:.4f}'.format(res_det['avg_prec_class'])) + print("\nTest loss : {:.4f}".format(test_loss.avg)) + print("Rec at 0.95 (det) : {:.4f}".format(res_det["rec_at_x"])) + print("Avg prec (cls) : {:.4f}".format(res_det["avg_prec"])) + print( + "File acc (cls) : {:.2f} - for {} out of {}".format( + res_det["file_acc"], + res_det["num_valid_files"], + res_det["num_total_files"], + ) + ) + print("Cls Avg prec (cls) : {:.4f}".format(res_det["avg_prec_class"])) - print('\nPer class average precision') - str_len = np.max([len(rs['name']) for rs in res_det['class_pr']]) + 5 - for cc, rs in enumerate(res_det['class_pr']): - if rs['num_gt'] > 0: - print(str(cc).ljust(5) + rs['name'].ljust(str_len) + '{:.4f}'.format(rs['avg_prec'])) + print("\nPer class average precision") + str_len = np.max([len(rs["name"]) for rs in res_det["class_pr"]]) + 5 + for cc, rs in enumerate(res_det["class_pr"]): + if rs["num_gt"] > 0: + print( + str(cc).ljust(5) + + rs["name"].ljust(str_len) + + "{:.4f}".format(rs["avg_prec"]) + ) res = {} - res['test_loss'] = float(test_loss.avg) + res["test_loss"] = float(test_loss.avg) return res_det, res @@ -181,176 +272,288 @@ def test(model, epoch, data_loader, det_criterion, params): def parse_gt_data(inputs): # reads the torch arrays into a dictionary of numpy arrays, taking care to # remove padding data i.e. not valid ones - keys = ['start_times', 'end_times', 'low_freqs', 'high_freqs', 'class_ids', 'individual_ids'] + keys = [ + "start_times", + "end_times", + "low_freqs", + "high_freqs", + "class_ids", + "individual_ids", + ] batch_data = [] - for ind in range(inputs['start_times'].shape[0]): - is_valid = inputs['is_valid'][ind]==1 + for ind in range(inputs["start_times"].shape[0]): + is_valid = inputs["is_valid"][ind] == 1 gt = {} for kk in keys: gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32) - gt['duration'] = inputs['duration'][ind].item() - gt['file_id'] = inputs['file_id'][ind].item() - gt['class_id_file'] = inputs['class_id_file'][ind].item() + gt["duration"] = inputs["duration"][ind].item() + gt["file_id"] = inputs["file_id"][ind].item() + gt["class_id_file"] = inputs["class_id_file"][ind].item() batch_data.append(gt) return batch_data def select_model(params): - num_classes = len(params['class_names']) - if params['model_name'] == 'Net2DFast': - model = models.Net2DFast(params['num_filters'], num_classes=num_classes, - emb_dim=params['emb_dim'], ip_height=params['ip_height'], - resize_factor=params['resize_factor']) - elif params['model_name'] == 'Net2DFastNoAttn': - model = models.Net2DFastNoAttn(params['num_filters'], num_classes=num_classes, - emb_dim=params['emb_dim'], ip_height=params['ip_height'], - resize_factor=params['resize_factor']) - elif params['model_name'] == 'Net2DFastNoCoordConv': - model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=num_classes, - emb_dim=params['emb_dim'], ip_height=params['ip_height'], - resize_factor=params['resize_factor']) + num_classes = len(params["class_names"]) + if params["model_name"] == "Net2DFast": + model = models.Net2DFast( + params["num_filters"], + num_classes=num_classes, + emb_dim=params["emb_dim"], + ip_height=params["ip_height"], + resize_factor=params["resize_factor"], + ) + elif params["model_name"] == "Net2DFastNoAttn": + model = models.Net2DFastNoAttn( + params["num_filters"], + num_classes=num_classes, + emb_dim=params["emb_dim"], + ip_height=params["ip_height"], + resize_factor=params["resize_factor"], + ) + elif params["model_name"] == "Net2DFastNoCoordConv": + model = models.Net2DFastNoCoordConv( + params["num_filters"], + num_classes=num_classes, + emb_dim=params["emb_dim"], + ip_height=params["ip_height"], + resize_factor=params["resize_factor"], + ) else: - print('No valid network specified') + print("No valid network specified") return model if __name__ == "__main__": - plt.close('all') + plt.close("all") params = parameters.get_params(True) if torch.cuda.is_available(): - params['device'] = 'cuda' + params["device"] = "cuda" else: - params['device'] = 'cpu' + params["device"] = "cpu" # setup arg parser and populate it with exiting parameters - will not work with lists parser = argparse.ArgumentParser() - parser.add_argument('data_dir', type=str, - help='Path to root of datasets') - parser.add_argument('ann_dir', type=str, - help='Path to extracted annotations') - parser.add_argument('--train_split', type=str, default='diff', # diff, same - help='Which train split to use') - parser.add_argument('--notes', type=str, default='', - help='Notes to save in text file') - parser.add_argument('--do_not_save_images', action='store_false', - help='Do not save images at the end of training') - parser.add_argument('--standardize_classs_names_ip', type=str, - default='Rhinolophus ferrumequinum;Rhinolophus hipposideros', - help='Will set low and high frequency the same for these classes. Separate names with ";"') + parser.add_argument("data_dir", type=str, help="Path to root of datasets") + parser.add_argument( + "ann_dir", type=str, help="Path to extracted annotations" + ) + parser.add_argument( + "--train_split", + type=str, + default="diff", # diff, same + help="Which train split to use", + ) + parser.add_argument( + "--notes", type=str, default="", help="Notes to save in text file" + ) + parser.add_argument( + "--do_not_save_images", + action="store_false", + help="Do not save images at the end of training", + ) + parser.add_argument( + "--standardize_classs_names_ip", + type=str, + default="Rhinolophus ferrumequinum;Rhinolophus hipposideros", + help='Will set low and high frequency the same for these classes. Separate names with ";"', + ) for key, val in params.items(): - parser.add_argument('--'+key, type=type(val), default=val) + parser.add_argument("--" + key, type=type(val), default=val) params = vars(parser.parse_args()) # save notes file - if params['notes'] != '': - tu.write_notes_file(params['experiment'] + 'notes.txt', params['notes']) + if params["notes"] != "": + tu.write_notes_file( + params["experiment"] + "notes.txt", params["notes"] + ) # load the training and test meta data - there are different splits defined - train_sets, test_sets = ts.get_train_test_data(params['ann_dir'], params['data_dir'], params['train_split']) - train_sets_no_path, test_sets_no_path = ts.get_train_test_data('', '', params['train_split']) + train_sets, test_sets = ts.get_train_test_data( + params["ann_dir"], params["data_dir"], params["train_split"] + ) + train_sets_no_path, test_sets_no_path = ts.get_train_test_data( + "", "", params["train_split"] + ) # keep track of what we have trained on - params['train_sets'] = train_sets_no_path - params['test_sets'] = test_sets_no_path + params["train_sets"] = train_sets_no_path + params["test_sets"] = test_sets_no_path # load train annotations - merge them all together - print('\nTraining on:') + print("\nTraining on:") for tt in train_sets: - print(tt['ann_path']) - classes_to_ignore = params['classes_to_ignore']+params['generic_class'] - data_train, params['class_names'], params['class_inv_freq'] = \ - tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus']) - params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names']) - params['class_names_short'] = tu.get_short_class_names(params['class_names']) + print(tt["ann_path"]) + classes_to_ignore = params["classes_to_ignore"] + params["generic_class"] + ( + data_train, + params["class_names"], + params["class_inv_freq"], + ) = tu.load_set_of_anns( + train_sets, + classes_to_ignore, + params["events_of_interest"], + params["convert_to_genus"], + ) + params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping( + params["class_names"] + ) + params["class_names_short"] = tu.get_short_class_names( + params["class_names"] + ) # standardize the low and high frequency value for specified classes - params['standardize_classs_names'] = params['standardize_classs_names_ip'].split(';') - for cc in params['standardize_classs_names']: - if cc in params['class_names']: + params["standardize_classs_names"] = params[ + "standardize_classs_names_ip" + ].split(";") + for cc in params["standardize_classs_names"]: + if cc in params["class_names"]: data_train = tu.standardize_low_freq(data_train, cc) else: - print(cc, 'not found') + print(cc, "not found") # train loader train_dataset = adl.AudioLoader(data_train, params, is_train=True) - train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'], - shuffle=True, num_workers=params['num_workers'], pin_memory=True) - + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=params["batch_size"], + shuffle=True, + num_workers=params["num_workers"], + pin_memory=True, + ) # test set - print('\nTesting on:') + print("\nTesting on:") for tt in test_sets: - print(tt['ann_path']) - data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus']) + print(tt["ann_path"]) + data_test, _, _ = tu.load_set_of_anns( + test_sets, + classes_to_ignore, + params["events_of_interest"], + params["convert_to_genus"], + ) data_train = tu.remove_dupes(data_train, data_test) test_dataset = adl.AudioLoader(data_test, params, is_train=False) # batch size of 1 because of variable file length - test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, - shuffle=False, num_workers=params['num_workers'], pin_memory=True) - + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=1, + shuffle=False, + num_workers=params["num_workers"], + pin_memory=True, + ) inputs_train = next(iter(train_loader)) # TODO remove params['ip_height'], this is just legacy - params['ip_height'] = int(params['spec_height']*params['resize_factor']) - print('\ntrain batch spec size :', inputs_train['spec'].shape) - print('class target size :', inputs_train['y_2d_classes'].shape) + params["ip_height"] = int(params["spec_height"] * params["resize_factor"]) + print("\ntrain batch spec size :", inputs_train["spec"].shape) + print("class target size :", inputs_train["y_2d_classes"].shape) # select network model = select_model(params) - model = model.to(params['device']) + model = model.to(params["device"]) - optimizer = torch.optim.Adam(model.parameters(), lr=params['lr']) - #optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9) - scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader)) - if params['train_loss'] == 'mse': + optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"]) + # optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9) + scheduler = CosineAnnealingLR( + optimizer, params["num_epochs"] * len(train_loader) + ) + if params["train_loss"] == "mse": det_criterion = losses.mse_loss - elif params['train_loss'] == 'focal': + elif params["train_loss"] == "focal": det_criterion = losses.focal_loss # save parameters to file - with open(params['experiment'] + 'params.json', 'w') as da: + with open(params["experiment"] + "params.json", "w") as da: json.dump(params, da, indent=2, sort_keys=True) # plotting - train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1, - ['train_loss'], None, None, ['epoch', 'train_loss'], logy=True) - test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1, - ['test_loss'], None, None, ['epoch', 'test_loss'], logy=True) - test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1, - ['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', '']) - test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1, - params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec']) - + train_plt_ls = pu.LossPlotter( + params["experiment"] + "train_loss.png", + params["num_epochs"] + 1, + ["train_loss"], + None, + None, + ["epoch", "train_loss"], + logy=True, + ) + test_plt_ls = pu.LossPlotter( + params["experiment"] + "test_loss.png", + params["num_epochs"] + 1, + ["test_loss"], + None, + None, + ["epoch", "test_loss"], + logy=True, + ) + test_plt = pu.LossPlotter( + params["experiment"] + "test.png", + params["num_epochs"] + 1, + ["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"], + [0, 1], + None, + ["epoch", ""], + ) + test_plt_class = pu.LossPlotter( + params["experiment"] + "test_avg_prec.png", + params["num_epochs"] + 1, + params["class_names_short"], + [0, 1], + params["class_names_short"], + ["epoch", "avg_prec"], + ) # # main train loop - for epoch in range(0, params['num_epochs']+1): + for epoch in range(0, params["num_epochs"] + 1): - train_loss = train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params) - train_plt_ls.update_and_save(epoch, [train_loss['train_loss']]) + train_loss = train( + model, + epoch, + train_loader, + det_criterion, + optimizer, + scheduler, + params, + ) + train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]]) - if epoch % params['num_eval_epochs'] == 0: + if epoch % params["num_eval_epochs"] == 0: # detection accuracy on test set - test_res, test_loss = test(model, epoch, test_loader, det_criterion, params) - test_plt_ls.update_and_save(epoch, [test_loss['test_loss']]) - test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'], - test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']]) - test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']]) - pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res) - + test_res, test_loss = test( + model, epoch, test_loader, det_criterion, params + ) + test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]]) + test_plt.update_and_save( + epoch, + [ + test_res["avg_prec"], + test_res["rec_at_x"], + test_res["avg_prec_class"], + test_res["file_acc"], + test_res["top_class"]["avg_prec"], + ], + ) + test_plt_class.update_and_save( + epoch, [rs["avg_prec"] for rs in test_res["class_pr"]] + ) + pu.plot_pr_curve_class( + params["experiment"], "test_pr", "test_pr", test_res + ) # save trained model - print('saving model to: ' + params['model_file_name']) - op_state = {'epoch': epoch + 1, - 'state_dict': model.state_dict(), - #'optimizer' : optimizer.state_dict(), - 'params' : params} - torch.save(op_state, params['model_file_name']) - + print("saving model to: " + params["model_file_name"]) + op_state = { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + #'optimizer' : optimizer.state_dict(), + "params": params, + } + torch.save(op_state, params["model_file_name"]) # save an image with associated prediction for each batch in the test set - if not args['do_not_save_images']: + if not args["do_not_save_images"]: save_images_batch(model, test_loader, params) diff --git a/bat_detect/train/train_split.py b/bat_detect/train/train_split.py index 20972bd..01b5c03 100644 --- a/bat_detect/train/train_split.py +++ b/bat_detect/train/train_split.py @@ -2,13 +2,14 @@ Run scripts/extract_anns.py to generate these json files. """ + def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True): - if split_name == 'diff': + if split_name == "diff": train_sets, test_sets = split_diff(ann_dir, wav_dir, load_extra) - elif split_name == 'same': + elif split_name == "same": train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra) else: - print('Split not defined') + print("Split not defined") assert False return train_sets, test_sets @@ -18,73 +19,126 @@ def split_diff(ann_dir, wav_dir, load_extra=True): train_sets = [] if load_extra: - train_sets.append({'dataset_name': 'BatDetective', - 'is_test': False, - 'is_binary': True, # just a bat / not bat dataset ie no classes - 'ann_path': ann_dir + 'train_set_bulgaria_batdetective_with_bbs.json', - 'wav_path': wav_dir + 'bat_detective/audio/'}) - train_sets.append({'dataset_name': 'bat_logger_qeop_empty', - 'is_test': False, - 'is_binary': True, - 'ann_path': ann_dir + 'bat_logger_qeop_empty.json', - 'wav_path': wav_dir + 'bat_logger_qeop_empty/audio/'}) - train_sets.append({'dataset_name': 'bat_logger_2016_empty', - 'is_test': False, - 'is_binary': True, - 'ann_path': ann_dir + 'train_set_bat_logger_2016_empty.json', - 'wav_path': wav_dir + 'bat_logger_2016/audio/'}) + train_sets.append( + { + "dataset_name": "BatDetective", + "is_test": False, + "is_binary": True, # just a bat / not bat dataset ie no classes + "ann_path": ann_dir + + "train_set_bulgaria_batdetective_with_bbs.json", + "wav_path": wav_dir + "bat_detective/audio/", + } + ) + train_sets.append( + { + "dataset_name": "bat_logger_qeop_empty", + "is_test": False, + "is_binary": True, + "ann_path": ann_dir + "bat_logger_qeop_empty.json", + "wav_path": wav_dir + "bat_logger_qeop_empty/audio/", + } + ) + train_sets.append( + { + "dataset_name": "bat_logger_2016_empty", + "is_test": False, + "is_binary": True, + "ann_path": ann_dir + "train_set_bat_logger_2016_empty.json", + "wav_path": wav_dir + "bat_logger_2016/audio/", + } + ) # train_sets.append({'dataset_name': 'brazil_data_binary', # 'is_test': False, # 'ann_path': ann_dir + 'brazil_data_binary.json', # 'wav_path': wav_dir + 'brazil_data/audio/'}) - train_sets.append({'dataset_name': 'echobank', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'Echobank_train_expert.json', - 'wav_path': wav_dir + 'echobank/audio/'}) - train_sets.append({'dataset_name': 'sn_scot_nor', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'sn_scot_nor_0.5_expert.json', - 'wav_path': wav_dir + 'sn_scot_nor/audio/'}) - train_sets.append({'dataset_name': 'BCT_1_sec', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'BCT_1_sec_train_expert.json', - 'wav_path': wav_dir + 'BCT_1_sec/audio/'}) - train_sets.append({'dataset_name': 'bcireland', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'bcireland_expert.json', - 'wav_path': wav_dir + 'bcireland/audio/'}) - train_sets.append({'dataset_name': 'rhinolophus_steve_BCT', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert.json', - 'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'}) + train_sets.append( + { + "dataset_name": "echobank", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + "Echobank_train_expert.json", + "wav_path": wav_dir + "echobank/audio/", + } + ) + train_sets.append( + { + "dataset_name": "sn_scot_nor", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + "sn_scot_nor_0.5_expert.json", + "wav_path": wav_dir + "sn_scot_nor/audio/", + } + ) + train_sets.append( + { + "dataset_name": "BCT_1_sec", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + "BCT_1_sec_train_expert.json", + "wav_path": wav_dir + "BCT_1_sec/audio/", + } + ) + train_sets.append( + { + "dataset_name": "bcireland", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + "bcireland_expert.json", + "wav_path": wav_dir + "bcireland/audio/", + } + ) + train_sets.append( + { + "dataset_name": "rhinolophus_steve_BCT", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + "rhinolophus_steve_BCT_expert.json", + "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/", + } + ) test_sets = [] - test_sets.append({'dataset_name': 'bat_data_martyn_2018', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert.json', - 'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'}) - test_sets.append({'dataset_name': 'bat_data_martyn_2018_test', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert.json', - 'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'}) - test_sets.append({'dataset_name': 'bat_data_martyn_2019', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert.json', - 'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'}) - test_sets.append({'dataset_name': 'bat_data_martyn_2019_test', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert.json', - 'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'}) + test_sets.append( + { + "dataset_name": "bat_data_martyn_2018", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2018_1_sec_train_expert.json", + "wav_path": wav_dir + "bat_data_martyn_2018/audio/", + } + ) + test_sets.append( + { + "dataset_name": "bat_data_martyn_2018_test", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2018_1_sec_test_expert.json", + "wav_path": wav_dir + "bat_data_martyn_2018_test/audio/", + } + ) + test_sets.append( + { + "dataset_name": "bat_data_martyn_2019", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2019_1_sec_train_expert.json", + "wav_path": wav_dir + "bat_data_martyn_2019/audio/", + } + ) + test_sets.append( + { + "dataset_name": "bat_data_martyn_2019_test", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2019_1_sec_test_expert.json", + "wav_path": wav_dir + "bat_data_martyn_2019_test/audio/", + } + ) return train_sets, test_sets @@ -93,71 +147,124 @@ def split_same(ann_dir, wav_dir, load_extra=True): train_sets = [] if load_extra: - train_sets.append({'dataset_name': 'BatDetective', - 'is_test': False, - 'is_binary': True, - 'ann_path': ann_dir + 'train_set_bulgaria_batdetective_with_bbs.json', - 'wav_path': wav_dir + 'bat_detective/audio/'}) - train_sets.append({'dataset_name': 'bat_logger_qeop_empty', - 'is_test': False, - 'is_binary': True, - 'ann_path': ann_dir + 'bat_logger_qeop_empty.json', - 'wav_path': wav_dir + 'bat_logger_qeop_empty/audio/'}) - train_sets.append({'dataset_name': 'bat_logger_2016_empty', - 'is_test': False, - 'is_binary': True, - 'ann_path': ann_dir + 'train_set_bat_logger_2016_empty.json', - 'wav_path': wav_dir + 'bat_logger_2016/audio/'}) + train_sets.append( + { + "dataset_name": "BatDetective", + "is_test": False, + "is_binary": True, + "ann_path": ann_dir + + "train_set_bulgaria_batdetective_with_bbs.json", + "wav_path": wav_dir + "bat_detective/audio/", + } + ) + train_sets.append( + { + "dataset_name": "bat_logger_qeop_empty", + "is_test": False, + "is_binary": True, + "ann_path": ann_dir + "bat_logger_qeop_empty.json", + "wav_path": wav_dir + "bat_logger_qeop_empty/audio/", + } + ) + train_sets.append( + { + "dataset_name": "bat_logger_2016_empty", + "is_test": False, + "is_binary": True, + "ann_path": ann_dir + "train_set_bat_logger_2016_empty.json", + "wav_path": wav_dir + "bat_logger_2016/audio/", + } + ) # train_sets.append({'dataset_name': 'brazil_data_binary', # 'is_test': False, # 'ann_path': ann_dir + 'brazil_data_binary.json', # 'wav_path': wav_dir + 'brazil_data/audio/'}) - train_sets.append({'dataset_name': 'echobank', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'Echobank_train_expert_TRAIN.json', - 'wav_path': wav_dir + 'echobank/audio/'}) - train_sets.append({'dataset_name': 'sn_scot_nor', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'sn_scot_nor_0.5_expert_TRAIN.json', - 'wav_path': wav_dir + 'sn_scot_nor/audio/'}) - train_sets.append({'dataset_name': 'BCT_1_sec', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'BCT_1_sec_train_expert_TRAIN.json', - 'wav_path': wav_dir + 'BCT_1_sec/audio/'}) - train_sets.append({'dataset_name': 'bcireland', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'bcireland_expert_TRAIN.json', - 'wav_path': wav_dir + 'bcireland/audio/'}) - train_sets.append({'dataset_name': 'rhinolophus_steve_BCT', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert_TRAIN.json', - 'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'}) - train_sets.append({'dataset_name': 'bat_data_martyn_2018', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TRAIN.json', - 'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'}) - train_sets.append({'dataset_name': 'bat_data_martyn_2018_test', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TRAIN.json', - 'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'}) - train_sets.append({'dataset_name': 'bat_data_martyn_2019', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TRAIN.json', - 'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'}) - train_sets.append({'dataset_name': 'bat_data_martyn_2019_test', - 'is_test': False, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TRAIN.json', - 'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'}) + train_sets.append( + { + "dataset_name": "echobank", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + "Echobank_train_expert_TRAIN.json", + "wav_path": wav_dir + "echobank/audio/", + } + ) + train_sets.append( + { + "dataset_name": "sn_scot_nor", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + "sn_scot_nor_0.5_expert_TRAIN.json", + "wav_path": wav_dir + "sn_scot_nor/audio/", + } + ) + train_sets.append( + { + "dataset_name": "BCT_1_sec", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + "BCT_1_sec_train_expert_TRAIN.json", + "wav_path": wav_dir + "BCT_1_sec/audio/", + } + ) + train_sets.append( + { + "dataset_name": "bcireland", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + "bcireland_expert_TRAIN.json", + "wav_path": wav_dir + "bcireland/audio/", + } + ) + train_sets.append( + { + "dataset_name": "rhinolophus_steve_BCT", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + "rhinolophus_steve_BCT_expert_TRAIN.json", + "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/", + } + ) + train_sets.append( + { + "dataset_name": "bat_data_martyn_2018", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TRAIN.json", + "wav_path": wav_dir + "bat_data_martyn_2018/audio/", + } + ) + train_sets.append( + { + "dataset_name": "bat_data_martyn_2018_test", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TRAIN.json", + "wav_path": wav_dir + "bat_data_martyn_2018_test/audio/", + } + ) + train_sets.append( + { + "dataset_name": "bat_data_martyn_2019", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TRAIN.json", + "wav_path": wav_dir + "bat_data_martyn_2019/audio/", + } + ) + train_sets.append( + { + "dataset_name": "bat_data_martyn_2019_test", + "is_test": False, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TRAIN.json", + "wav_path": wav_dir + "bat_data_martyn_2019_test/audio/", + } + ) # train_sets.append({'dataset_name': 'bat_data_martyn_2021_train', # 'is_test': False, @@ -171,51 +278,91 @@ def split_same(ann_dir, wav_dir, load_extra=True): # 'wav_path': wav_dir + 'volunteers_2021/audio/'}) test_sets = [] - test_sets.append({'dataset_name': 'echobank', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'Echobank_train_expert_TEST.json', - 'wav_path': wav_dir + 'echobank/audio/'}) - test_sets.append({'dataset_name': 'sn_scot_nor', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'sn_scot_nor_0.5_expert_TEST.json', - 'wav_path': wav_dir + 'sn_scot_nor/audio/'}) - test_sets.append({'dataset_name': 'BCT_1_sec', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'BCT_1_sec_train_expert_TEST.json', - 'wav_path': wav_dir + 'BCT_1_sec/audio/'}) - test_sets.append({'dataset_name': 'bcireland', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'bcireland_expert_TEST.json', - 'wav_path': wav_dir + 'bcireland/audio/'}) - test_sets.append({'dataset_name': 'rhinolophus_steve_BCT', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert_TEST.json', - 'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'}) - test_sets.append({'dataset_name': 'bat_data_martyn_2018', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TEST.json', - 'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'}) - test_sets.append({'dataset_name': 'bat_data_martyn_2018_test', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TEST.json', - 'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'}) - test_sets.append({'dataset_name': 'bat_data_martyn_2019', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TEST.json', - 'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'}) - test_sets.append({'dataset_name': 'bat_data_martyn_2019_test', - 'is_test': True, - 'is_binary': False, - 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TEST.json', - 'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'}) + test_sets.append( + { + "dataset_name": "echobank", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + "Echobank_train_expert_TEST.json", + "wav_path": wav_dir + "echobank/audio/", + } + ) + test_sets.append( + { + "dataset_name": "sn_scot_nor", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + "sn_scot_nor_0.5_expert_TEST.json", + "wav_path": wav_dir + "sn_scot_nor/audio/", + } + ) + test_sets.append( + { + "dataset_name": "BCT_1_sec", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + "BCT_1_sec_train_expert_TEST.json", + "wav_path": wav_dir + "BCT_1_sec/audio/", + } + ) + test_sets.append( + { + "dataset_name": "bcireland", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + "bcireland_expert_TEST.json", + "wav_path": wav_dir + "bcireland/audio/", + } + ) + test_sets.append( + { + "dataset_name": "rhinolophus_steve_BCT", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + "rhinolophus_steve_BCT_expert_TEST.json", + "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/", + } + ) + test_sets.append( + { + "dataset_name": "bat_data_martyn_2018", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TEST.json", + "wav_path": wav_dir + "bat_data_martyn_2018/audio/", + } + ) + test_sets.append( + { + "dataset_name": "bat_data_martyn_2018_test", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TEST.json", + "wav_path": wav_dir + "bat_data_martyn_2018_test/audio/", + } + ) + test_sets.append( + { + "dataset_name": "bat_data_martyn_2019", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TEST.json", + "wav_path": wav_dir + "bat_data_martyn_2019/audio/", + } + ) + test_sets.append( + { + "dataset_name": "bat_data_martyn_2019_test", + "is_test": True, + "is_binary": False, + "ann_path": ann_dir + + "BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TEST.json", + "wav_path": wav_dir + "bat_data_martyn_2019_test/audio/", + } + ) # test_sets.append({'dataset_name': 'bat_data_martyn_2021_test', # 'is_test': True, diff --git a/bat_detect/train/train_utils.py b/bat_detect/train/train_utils.py index cff92e4..62441a7 100644 --- a/bat_detect/train/train_utils.py +++ b/bat_detect/train/train_utils.py @@ -1,42 +1,52 @@ -import numpy as np -import random -import os import glob import json +import os +import random + +import numpy as np def write_notes_file(file_name, text): - with open(file_name, 'a') as da: - da.write(text + '\n') + with open(file_name, "a") as da: + da.write(text + "\n") def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path): - ddict = {'dataset_name': dataset_name, 'is_test': is_test, 'is_binary': False, - 'ann_path': ann_path, 'wav_path': wav_path} + ddict = { + "dataset_name": dataset_name, + "is_test": is_test, + "is_binary": False, + "ann_path": ann_path, + "wav_path": wav_path, + } return ddict def get_short_class_names(class_names, str_len=3): class_names_short = [] for cc in class_names: - class_names_short.append(' '.join([sp[:str_len] for sp in cc.split(' ')])) + class_names_short.append( + " ".join([sp[:str_len] for sp in cc.split(" ")]) + ) return class_names_short def remove_dupes(data_train, data_test): - test_ids = [dd['id'] for dd in data_test] + test_ids = [dd["id"] for dd in data_test] data_train_prune = [] for aa in data_train: - if aa['id'] not in test_ids: + if aa["id"] not in test_ids: data_train_prune.append(aa) diff = len(data_train) - len(data_train_prune) if diff != 0: - print(diff, 'items removed from train set') + print(diff, "items removed from train set") return data_train_prune def get_genus_mapping(class_names): - genus_names, genus_mapping = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True) + genus_names, genus_mapping = np.unique( + [cc.split(" ")[0] for cc in class_names], return_inverse=True + ) return genus_names.tolist(), genus_mapping.tolist() @@ -47,97 +57,110 @@ def standardize_low_freq(data, class_of_interest): low_freqs = [] high_freqs = [] for dd in data: - for aa in dd['annotation']: - if aa['class'] == class_of_interest: - low_freqs.append(aa['low_freq']) - high_freqs.append(aa['high_freq']) + for aa in dd["annotation"]: + if aa["class"] == class_of_interest: + low_freqs.append(aa["low_freq"]) + high_freqs.append(aa["high_freq"]) low_mean = np.mean(low_freqs) high_mean = np.mean(high_freqs) - assert(low_mean < high_mean) + assert low_mean < high_mean - print('\nStandardizing low and high frequency for:') + print("\nStandardizing low and high frequency for:") print(class_of_interest) - print('low: ', round(low_mean, 2)) - print('high: ', round(high_mean, 2)) + print("low: ", round(low_mean, 2)) + print("high: ", round(high_mean, 2)) # only set the low freq, high stays the same # assumes that low_mean < high_mean for dd in data: - for aa in dd['annotation']: - if aa['class'] == class_of_interest: - aa['low_freq'] = low_mean - if aa['high_freq'] < low_mean: - aa['high_freq'] = high_mean + for aa in dd["annotation"]: + if aa["class"] == class_of_interest: + aa["low_freq"] = low_mean + if aa["high_freq"] < low_mean: + aa["high_freq"] = high_mean return data -def load_set_of_anns(data, classes_to_ignore=[], events_of_interest=None, - convert_to_genus=False, verbose=True, list_of_anns=False, - filter_issues=False, name_replace=False): +def load_set_of_anns( + data, + classes_to_ignore=[], + events_of_interest=None, + convert_to_genus=False, + verbose=True, + list_of_anns=False, + filter_issues=False, + name_replace=False, +): # load the annotations anns = [] if list_of_anns: # path to list of individual json files - anns.extend(load_anns_from_path(data['ann_path'], data['wav_path'])) + anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"])) else: # dictionary of datasets for dd in data: - anns.extend(load_anns(dd['ann_path'], dd['wav_path'])) + anns.extend(load_anns(dd["ann_path"], dd["wav_path"])) # discarding unannoated files - anns = [aa for aa in anns if aa['annotated'] is True] + anns = [aa for aa in anns if aa["annotated"] is True] # filter files that have annotation issues - is the input is a dictionary of # datasets, this will lilely have already been done if filter_issues: - anns = [aa for aa in anns if aa['issues'] is False] + anns = [aa for aa in anns if aa["issues"] is False] # check for some basic formatting errors with class names for ann in anns: - for aa in ann['annotation']: - aa['class'] = aa['class'].strip() + for aa in ann["annotation"]: + aa["class"] = aa["class"].strip() # only load specified events - i.e. types of calls if events_of_interest is not None: for ann in anns: filtered_events = [] - for aa in ann['annotation']: - if aa['event'] in events_of_interest: + for aa in ann["annotation"]: + if aa["event"] in events_of_interest: filtered_events.append(aa) - ann['annotation'] = filtered_events + ann["annotation"] = filtered_events # change class names # replace_names will be a dictionary mapping input name to output if type(name_replace) is dict: for ann in anns: - for aa in ann['annotation']: - if aa['class'] in name_replace: - aa['class'] = name_replace[aa['class']] + for aa in ann["annotation"]: + if aa["class"] in name_replace: + aa["class"] = name_replace[aa["class"]] # convert everything to genus name if convert_to_genus: for ann in anns: - for aa in ann['annotation']: - aa['class'] = aa['class'].split(' ')[0] + for aa in ann["annotation"]: + aa["class"] = aa["class"].split(" ")[0] # get unique class names class_names_all = [] for ann in anns: - for aa in ann['annotation']: - if aa['class'] not in classes_to_ignore: - class_names_all.append(aa['class']) + for aa in ann["annotation"]: + if aa["class"] not in classes_to_ignore: + class_names_all.append(aa["class"]) class_names, class_cnts = np.unique(class_names_all, return_counts=True) - class_inv_freq = (class_cnts.sum() / (len(class_names) * class_cnts.astype(np.float32))) + class_inv_freq = class_cnts.sum() / ( + len(class_names) * class_cnts.astype(np.float32) + ) if verbose: - print('Class count:') + print("Class count:") str_len = np.max([len(cc) for cc in class_names]) + 5 for cc in range(len(class_names)): - print(str(cc).ljust(5) + class_names[cc].ljust(str_len) + str(class_cnts[cc])) + print( + str(cc).ljust(5) + + class_names[cc].ljust(str_len) + + str(class_cnts[cc]) + ) if len(classes_to_ignore) == 0: return anns @@ -150,36 +173,37 @@ def load_anns(ann_file_name, raw_audio_dir): anns = json.load(da) for aa in anns: - aa['file_path'] = raw_audio_dir + aa['id'] + aa["file_path"] = raw_audio_dir + aa["id"] return anns def load_anns_from_path(ann_file_dir, raw_audio_dir): - files = glob.glob(ann_file_dir + '*.json') + files = glob.glob(ann_file_dir + "*.json") anns = [] for ff in files: with open(ff) as da: ann = json.load(da) - ann['file_path'] = raw_audio_dir + ann['id'] + ann["file_path"] = raw_audio_dir + ann["id"] anns.append(ann) return anns class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self): - self.reset() + """Computes and stores the average and current value""" - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 + def __init__(self): + self.reset() - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/bat_detect/utils/audio_utils.py b/bat_detect/utils/audio_utils.py index 4a18d74..3ad648b 100644 --- a/bat_detect/utils/audio_utils.py +++ b/bat_detect/utils/audio_utils.py @@ -1,89 +1,142 @@ -import numpy as np -from . import wavfile import warnings -import torch + import librosa +import numpy as np +import torch + +from . import wavfile def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): - nfft = np.floor(fft_win_length*sampling_rate) # int() uses floor - noverlap = np.floor(fft_overlap*nfft) - return (time_in_file*sampling_rate-noverlap) / (nfft - noverlap) + nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor + noverlap = np.floor(fft_overlap * nfft) + return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap) # NOTE this is also defined in post_process def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): - nfft = np.floor(fft_win_length*sampling_rate) - noverlap = np.floor(fft_overlap*nfft) - return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate - #return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window + nfft = np.floor(fft_win_length * sampling_rate) + noverlap = np.floor(fft_overlap * nfft) + return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate + # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window -def generate_spectrogram(audio, sampling_rate, params, return_spec_for_viz=False, check_spec_size=True): +def generate_spectrogram( + audio, + sampling_rate, + params, + return_spec_for_viz=False, + check_spec_size=True, +): # generate spectrogram - spec = gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']) + spec = gen_mag_spectrogram( + audio, sampling_rate, params["fft_win_length"], params["fft_overlap"] + ) # crop to min/max freq - max_freq = round(params['max_freq']*params['fft_win_length']) - min_freq = round(params['min_freq']*params['fft_win_length']) + max_freq = round(params["max_freq"] * params["fft_win_length"]) + min_freq = round(params["min_freq"] * params["fft_win_length"]) if spec.shape[0] < max_freq: freq_pad = max_freq - spec.shape[0] - spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)) - spec_cropped = spec[-max_freq:spec.shape[0]-min_freq, :] + spec = np.vstack( + (np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec) + ) + spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :] - if params['spec_scale'] == 'log': - log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum()) - #log_scaling = (1.0 / sampling_rate)*0.1 - #log_scaling = (1.0 / sampling_rate)*10e4 - spec = np.log1p(log_scaling*spec_cropped) - elif params['spec_scale'] == 'pcen': + if params["spec_scale"] == "log": + log_scaling = ( + 2.0 + * (1.0 / sampling_rate) + * ( + 1.0 + / ( + np.abs( + np.hanning( + int(params["fft_win_length"] * sampling_rate) + ) + ) + ** 2 + ).sum() + ) + ) + # log_scaling = (1.0 / sampling_rate)*0.1 + # log_scaling = (1.0 / sampling_rate)*10e4 + spec = np.log1p(log_scaling * spec_cropped) + elif params["spec_scale"] == "pcen": spec = pcen(spec_cropped, sampling_rate) - elif params['spec_scale'] == 'none': + elif params["spec_scale"] == "none": pass - if params['denoise_spec_avg']: + if params["denoise_spec_avg"]: spec = spec - np.mean(spec, 1)[:, np.newaxis] spec.clip(min=0, out=spec) - if params['max_scale_spec']: + if params["max_scale_spec"]: spec = spec / (spec.max() + 10e-6) # needs to be divisible by specific factor - if not it should have been padded - #if check_spec_size: - #assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0) - #assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0) + # if check_spec_size: + # assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0) + # assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0) # for visualization purposes - use log scaled spectrogram if return_spec_for_viz: - log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum()) - spec_for_viz = np.log1p(log_scaling*spec_cropped).astype(np.float32) + log_scaling = ( + 2.0 + * (1.0 / sampling_rate) + * ( + 1.0 + / ( + np.abs( + np.hanning( + int(params["fft_win_length"] * sampling_rate) + ) + ) + ** 2 + ).sum() + ) + ) + spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32) else: spec_for_viz = None return spec, spec_for_viz -def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False, max_duration=False): +def load_audio_file( + audio_file, + time_exp_fact, + target_samp_rate, + scale=False, + max_duration=False, +): with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=wavfile.WavFileWarning) - #sampling_rate, audio_raw = wavfile.read(audio_file) + warnings.filterwarnings("ignore", category=wavfile.WavFileWarning) + # sampling_rate, audio_raw = wavfile.read(audio_file) audio_raw, sampling_rate = librosa.load(audio_file, sr=None) if len(audio_raw.shape) > 1: - raise Exception('Currently does not handle stereo files') + raise Exception("Currently does not handle stereo files") sampling_rate = sampling_rate * time_exp_fact # resample - need to do this after correcting for time expansion sampling_rate_old = sampling_rate sampling_rate = target_samp_rate - audio_raw = librosa.resample(audio_raw, orig_sr=sampling_rate_old, target_sr=sampling_rate, res_type='polyphase') + audio_raw = librosa.resample( + audio_raw, + orig_sr=sampling_rate_old, + target_sr=sampling_rate, + res_type="polyphase", + ) # clipping maximum duration if max_duration is not False: - max_duration = np.minimum(int(sampling_rate*max_duration), audio_raw.shape[0]) + max_duration = np.minimum( + int(sampling_rate * max_duration), audio_raw.shape[0] + ) audio_raw = audio_raw[:max_duration] - + # convert to float32 and scale audio_raw = audio_raw.astype(np.float32) if scale: @@ -93,38 +146,53 @@ def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False, ma return sampling_rate, audio_raw -def pad_audio(audio_raw, fs, ms, overlap_perc, resize_factor, divide_factor, fixed_width=None): +def pad_audio( + audio_raw, + fs, + ms, + overlap_perc, + resize_factor, + divide_factor, + fixed_width=None, +): # Adds zeros to the end of the raw data so that the generated sepctrogram # will be evenly divisible by `divide_factor` # Also deals with very short audio clips and fixed_width during training # This code could be clearer, clean up - nfft = int(ms*fs) - noverlap = int(overlap_perc*nfft) + nfft = int(ms * fs) + noverlap = int(overlap_perc * nfft) step = nfft - noverlap - min_size = int(divide_factor*(1.0/resize_factor)) - spec_width = ((audio_raw.shape[0]-noverlap)//step) + min_size = int(divide_factor * (1.0 / resize_factor)) + spec_width = (audio_raw.shape[0] - noverlap) // step spec_width_rs = spec_width * resize_factor if fixed_width is not None and spec_width < fixed_width: # too small # used during training to ensure all the batches are the same size - diff = fixed_width*step + noverlap - audio_raw.shape[0] - audio_raw = np.hstack((audio_raw, np.zeros(diff, dtype=audio_raw.dtype))) + diff = fixed_width * step + noverlap - audio_raw.shape[0] + audio_raw = np.hstack( + (audio_raw, np.zeros(diff, dtype=audio_raw.dtype)) + ) elif fixed_width is not None and spec_width > fixed_width: # too big # used during training to ensure all the batches are the same size - diff = fixed_width*step + noverlap - audio_raw.shape[0] + diff = fixed_width * step + noverlap - audio_raw.shape[0] audio_raw = audio_raw[:diff] - elif spec_width_rs < min_size or (np.floor(spec_width_rs) % divide_factor) != 0: + elif ( + spec_width_rs < min_size + or (np.floor(spec_width_rs) % divide_factor) != 0 + ): # need to be at least min_size div_amt = np.ceil(spec_width_rs / float(divide_factor)) div_amt = np.maximum(1, div_amt) - target_size = int(div_amt*divide_factor*(1.0/resize_factor)) - diff = target_size*step + noverlap - audio_raw.shape[0] - audio_raw = np.hstack((audio_raw, np.zeros(diff, dtype=audio_raw.dtype))) + target_size = int(div_amt * divide_factor * (1.0 / resize_factor)) + diff = target_size * step + noverlap - audio_raw.shape[0] + audio_raw = np.hstack( + (audio_raw, np.zeros(diff, dtype=audio_raw.dtype)) + ) return audio_raw @@ -133,14 +201,16 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc): # Computes magnitude spectrogram by specifying time. x = x.astype(np.float32) - nfft = int(ms*fs) - noverlap = int(overlap_perc*nfft) + nfft = int(ms * fs) + noverlap = int(overlap_perc * nfft) # window data step = nfft - noverlap # compute spec - spec, _ = librosa.core.spectrum._spectrogram(y=x, power=1, n_fft=nfft, hop_length=step, center=False) + spec, _ = librosa.core.spectrum._spectrogram( + y=x, power=1, n_fft=nfft, hop_length=step, center=False + ) # remove DC component and flip vertical orientation spec = np.flipud(spec[1:, :]) @@ -149,8 +219,8 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc): def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc): - nfft = int(ms*fs) - nstep = round((1.0-overlap_perc)*nfft) + nfft = int(ms * fs) + nstep = round((1.0 - overlap_perc) * nfft) han_win = torch.hann_window(nfft, periodic=False).to(x.device) @@ -158,12 +228,14 @@ def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc): spec = complex_spec.pow(2.0).sum(-1) # remove DC component and flip vertically - spec = torch.flipud(spec[0, 1:,:]) + spec = torch.flipud(spec[0, 1:, :]) return spec def pcen(spec_cropped, sampling_rate): # TODO should be passing hop_length too i.e. step - spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate/10).astype(np.float32) + spec = librosa.pcen( + spec_cropped * (2**31), sr=sampling_rate / 10 + ).astype(np.float32) return spec diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py index fef9828..7d2470f 100644 --- a/bat_detect/utils/detector_utils.py +++ b/bat_detect/utils/detector_utils.py @@ -1,39 +1,40 @@ -import torch -import torch.nn.functional as F -import os -import numpy as np -import pandas as pd import json +import os import sys -from bat_detect.detector import models +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F + import bat_detect.detector.compute_features as feats import bat_detect.detector.post_process as pp import bat_detect.utils.audio_utils as au +from bat_detect.detector import models def get_default_bd_args(): args = {} - args['detection_threshold'] = 0.001 - args['time_expansion_factor'] = 1 - args['audio_dir'] = '' - args['ann_dir'] = '' - args['spec_slices'] = False - args['chunk_size'] = 3 - args['spec_features'] = False - args['cnn_features'] = False - args['quiet'] = True - args['save_preds_if_empty'] = True - args['ann_dir'] = os.path.join(args['ann_dir'], '') + args["detection_threshold"] = 0.001 + args["time_expansion_factor"] = 1 + args["audio_dir"] = "" + args["ann_dir"] = "" + args["spec_slices"] = False + args["chunk_size"] = 3 + args["spec_features"] = False + args["cnn_features"] = False + args["quiet"] = True + args["save_preds_if_empty"] = True + args["ann_dir"] = os.path.join(args["ann_dir"], "") return args - - + + def get_audio_files(ip_dir): matches = [] for root, dirnames, filenames in os.walk(ip_dir): for filename in filenames: - if filename.lower().endswith('.wav'): + if filename.lower().endswith(".wav"): matches.append(os.path.join(root, filename)) return matches @@ -41,35 +42,47 @@ def get_audio_files(ip_dir): def load_model(model_path, load_weights=True): # load model - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if os.path.isfile(model_path): net_params = torch.load(model_path, map_location=device) else: - print('Error: model not found.') + print("Error: model not found.") sys.exit(1) - params = net_params['params'] - params['device'] = device + params = net_params["params"] + params["device"] = device - if params['model_name'] == 'Net2DFast': - model = models.Net2DFast(params['num_filters'], num_classes=len(params['class_names']), - emb_dim=params['emb_dim'], ip_height=params['ip_height'], - resize_factor=params['resize_factor']) - elif params['model_name'] == 'Net2DFastNoAttn': - model = models.Net2DFastNoAttn(params['num_filters'], num_classes=len(params['class_names']), - emb_dim=params['emb_dim'], ip_height=params['ip_height'], - resize_factor=params['resize_factor']) - elif params['model_name'] == 'Net2DFastNoCoordConv': - model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=len(params['class_names']), - emb_dim=params['emb_dim'], ip_height=params['ip_height'], - resize_factor=params['resize_factor']) + if params["model_name"] == "Net2DFast": + model = models.Net2DFast( + params["num_filters"], + num_classes=len(params["class_names"]), + emb_dim=params["emb_dim"], + ip_height=params["ip_height"], + resize_factor=params["resize_factor"], + ) + elif params["model_name"] == "Net2DFastNoAttn": + model = models.Net2DFastNoAttn( + params["num_filters"], + num_classes=len(params["class_names"]), + emb_dim=params["emb_dim"], + ip_height=params["ip_height"], + resize_factor=params["resize_factor"], + ) + elif params["model_name"] == "Net2DFastNoCoordConv": + model = models.Net2DFastNoCoordConv( + params["num_filters"], + num_classes=len(params["class_names"]), + emb_dim=params["emb_dim"], + ip_height=params["ip_height"], + resize_factor=params["resize_factor"], + ) else: - print('Error: unknown model.') + print("Error: unknown model.") if load_weights: - model.load_state_dict(net_params['state_dict']) + model.load_state_dict(net_params["state_dict"]) - model = model.to(params['device']) + model = model.to(params["device"]) model.eval() return model, params @@ -78,11 +91,13 @@ def load_model(model_path, load_weights=True): def merge_results(predictions, spec_feats, cnn_feats, spec_slices): predictions_m = {} - num_preds = np.sum([len(pp['det_probs']) for pp in predictions]) + num_preds = np.sum([len(pp["det_probs"]) for pp in predictions]) if num_preds > 0: for kk in predictions[0].keys(): - predictions_m[kk] = np.hstack([pp[kk] for pp in predictions if pp['det_probs'].shape[0] > 0]) + predictions_m[kk] = np.hstack( + [pp[kk] for pp in predictions if pp["det_probs"].shape[0] > 0] + ) else: # hack in case where no detected calls as we need some of the key names in dict predictions_m = predictions[0] @@ -94,47 +109,60 @@ def merge_results(predictions, spec_feats, cnn_feats, spec_slices): return predictions_m, spec_feats, cnn_feats, spec_slices -def convert_results(file_id, time_exp, duration, params, predictions, spec_feats, cnn_feats, spec_slices): +def convert_results( + file_id, + time_exp, + duration, + params, + predictions, + spec_feats, + cnn_feats, + spec_slices, +): # create a single dictionary - this is the format used by the annotation tool pred_dict = {} - pred_dict['id'] = file_id - pred_dict['annotated'] = False - pred_dict['issues'] = False - pred_dict['notes'] = 'Automatically generated.' - pred_dict['time_exp'] = time_exp - pred_dict['duration'] = round(duration, 4) - pred_dict['annotation'] = [] + pred_dict["id"] = file_id + pred_dict["annotated"] = False + pred_dict["issues"] = False + pred_dict["notes"] = "Automatically generated." + pred_dict["time_exp"] = time_exp + pred_dict["duration"] = round(duration, 4) + pred_dict["annotation"] = [] - class_prob_best = predictions['class_probs'].max(0) - class_ind_best = predictions['class_probs'].argmax(0) - class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs']) - pred_dict['class_name'] = params['class_names'][np.argmax(class_overall)] + class_prob_best = predictions["class_probs"].max(0) + class_ind_best = predictions["class_probs"].argmax(0) + class_overall = pp.overall_class_pred( + predictions["det_probs"], predictions["class_probs"] + ) + pred_dict["class_name"] = params["class_names"][np.argmax(class_overall)] - for ii in range(predictions['det_probs'].shape[0]): + for ii in range(predictions["det_probs"].shape[0]): res = {} - res['start_time'] = round(float(predictions['start_times'][ii]), 4) - res['end_time'] = round(float(predictions['end_times'][ii]), 4) - res['low_freq'] = int(predictions['low_freqs'][ii]) - res['high_freq'] = int(predictions['high_freqs'][ii]) - res['class'] = str(params['class_names'][int(class_ind_best[ii])]) - res['class_prob'] = round(float(class_prob_best[ii]), 3) - res['det_prob'] = round(float(predictions['det_probs'][ii]), 3) - res['individual'] = '-1' - res['event'] = 'Echolocation' - pred_dict['annotation'].append(res) + res["start_time"] = round(float(predictions["start_times"][ii]), 4) + res["end_time"] = round(float(predictions["end_times"][ii]), 4) + res["low_freq"] = int(predictions["low_freqs"][ii]) + res["high_freq"] = int(predictions["high_freqs"][ii]) + res["class"] = str(params["class_names"][int(class_ind_best[ii])]) + res["class_prob"] = round(float(class_prob_best[ii]), 3) + res["det_prob"] = round(float(predictions["det_probs"][ii]), 3) + res["individual"] = "-1" + res["event"] = "Echolocation" + pred_dict["annotation"].append(res) # combine into final results dictionary results = {} - results['pred_dict'] = pred_dict + results["pred_dict"] = pred_dict if len(spec_feats) > 0: - results['spec_feats'] = spec_feats - results['spec_feat_names'] = feats.get_feature_names() + results["spec_feats"] = spec_feats + results["spec_feat_names"] = feats.get_feature_names() if len(cnn_feats) > 0: - results['cnn_feats'] = cnn_feats - results['cnn_feat_names'] = [str(ii) for ii in range(cnn_feats.shape[1])] + results["cnn_feats"] = cnn_feats + results["cnn_feat_names"] = [ + str(ii) for ii in range(cnn_feats.shape[1]) + ] if len(spec_slices) > 0: - results['spec_slices'] = spec_slices + results["spec_slices"] = spec_slices return results @@ -146,144 +174,214 @@ def save_results_to_file(results, op_path): os.makedirs(os.path.dirname(op_path)) # save csv file - if there are predictions - result_list = [res for res in results['pred_dict']['annotation']] + result_list = [res for res in results["pred_dict"]["annotation"]] df = pd.DataFrame(result_list) - df['file_name'] = [results['pred_dict']['id']]*len(result_list) - df.index.name = 'id' - if 'class_prob' in df.columns: - df = df[['det_prob', 'start_time', 'end_time', 'high_freq', - 'low_freq', 'class', 'class_prob']] - df.to_csv(op_path + '.csv', sep=',') + df["file_name"] = [results["pred_dict"]["id"]] * len(result_list) + df.index.name = "id" + if "class_prob" in df.columns: + df = df[ + [ + "det_prob", + "start_time", + "end_time", + "high_freq", + "low_freq", + "class", + "class_prob", + ] + ] + df.to_csv(op_path + ".csv", sep=",") # save features - if 'spec_feats' in results.keys(): - df = pd.DataFrame(results['spec_feats'], columns=results['spec_feat_names']) - df.to_csv(op_path + '_spec_features.csv', sep=',', index=False, float_format='%.5f') + if "spec_feats" in results.keys(): + df = pd.DataFrame( + results["spec_feats"], columns=results["spec_feat_names"] + ) + df.to_csv( + op_path + "_spec_features.csv", + sep=",", + index=False, + float_format="%.5f", + ) - if 'cnn_feats' in results.keys(): - df = pd.DataFrame(results['cnn_feats'], columns=results['cnn_feat_names']) - df.to_csv(op_path + '_cnn_features.csv', sep=',', index=False, float_format='%.5f') + if "cnn_feats" in results.keys(): + df = pd.DataFrame( + results["cnn_feats"], columns=results["cnn_feat_names"] + ) + df.to_csv( + op_path + "_cnn_features.csv", + sep=",", + index=False, + float_format="%.5f", + ) # save json file - with open(op_path + '.json', 'w') as da: - json.dump(results['pred_dict'], da, indent=2, sort_keys=True) + with open(op_path + ".json", "w") as da: + json.dump(results["pred_dict"], da, indent=2, sort_keys=True) def compute_spectrogram(audio, sampling_rate, params, return_np=False): # pad audio so it is evenly divisible by downsampling factors duration = audio.shape[0] / float(sampling_rate) - audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], - params['fft_overlap'], params['resize_factor'], - params['spec_divide_factor']) + audio = au.pad_audio( + audio, + sampling_rate, + params["fft_win_length"], + params["fft_overlap"], + params["resize_factor"], + params["spec_divide_factor"], + ) # generate spectrogram spec, _ = au.generate_spectrogram(audio, sampling_rate, params) # convert to pytorch - spec = torch.from_numpy(spec).to(params['device']) + spec = torch.from_numpy(spec).to(params["device"]) spec = spec.unsqueeze(0).unsqueeze(0) # resize the spec - rs = params['resize_factor'] - spec_op_shape = (int(params['spec_height']*rs), int(spec.shape[-1]*rs)) - spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False) + rs = params["resize_factor"] + spec_op_shape = (int(params["spec_height"] * rs), int(spec.shape[-1] * rs)) + spec = F.interpolate( + spec, size=spec_op_shape, mode="bilinear", align_corners=False + ) if return_np: - spec_np = spec[0,0,:].cpu().data.numpy() + spec_np = spec[0, 0, :].cpu().data.numpy() else: spec_np = None return duration, spec, spec_np -def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False, max_duration=False): +def process_file( + audio_file, + model, + params, + args, + time_exp=None, + top_n=5, + return_raw_preds=False, + max_duration=False, +): # store temporary results here predictions = [] - spec_feats = [] - cnn_feats = [] + spec_feats = [] + cnn_feats = [] spec_slices = [] # get time expansion factor if time_exp is None: - time_exp = args['time_expansion_factor'] + time_exp = args["time_expansion_factor"] - params['detection_threshold'] = args['detection_threshold'] + params["detection_threshold"] = args["detection_threshold"] # load audio file - sampling_rate, audio_full = au.load_audio_file(audio_file, time_exp, - params['target_samp_rate'], params['scale_raw_audio']) + sampling_rate, audio_full = au.load_audio_file( + audio_file, + time_exp, + params["target_samp_rate"], + params["scale_raw_audio"], + ) # clipping maximum duration if max_duration is not False: - max_duration = np.minimum(int(sampling_rate*max_duration), audio_full.shape[0]) + max_duration = np.minimum( + int(sampling_rate * max_duration), audio_full.shape[0] + ) audio_full = audio_full[:max_duration] - + duration_full = audio_full.shape[0] / float(sampling_rate) - return_np_spec = args['spec_features'] or args['spec_slices'] + return_np_spec = args["spec_features"] or args["spec_slices"] # loop through larger file and split into chunks # TODO fix so that it overlaps correctly and takes care of duplicate detections at borders - num_chunks = int(np.ceil(duration_full/args['chunk_size'])) + num_chunks = int(np.ceil(duration_full / args["chunk_size"])) for chunk_id in range(num_chunks): # chunk - chunk_time = args['chunk_size']*chunk_id - chunk_length = int(sampling_rate*args['chunk_size']) - start_sample = chunk_id*chunk_length - end_sample = np.minimum((chunk_id+1)*chunk_length, audio_full.shape[0]) + chunk_time = args["chunk_size"] * chunk_id + chunk_length = int(sampling_rate * args["chunk_size"]) + start_sample = chunk_id * chunk_length + end_sample = np.minimum( + (chunk_id + 1) * chunk_length, audio_full.shape[0] + ) audio = audio_full[start_sample:end_sample] # load audio file and compute spectrogram - duration, spec, spec_np = compute_spectrogram(audio, sampling_rate, params, return_np_spec) + duration, spec, spec_np = compute_spectrogram( + audio, sampling_rate, params, return_np_spec + ) # evaluate model with torch.no_grad(): - outputs = model(spec, return_feats=args['cnn_features']) + outputs = model(spec, return_feats=args["cnn_features"]) # run non-max suppression - pred_nms, features = pp.run_nms(outputs, params, np.array([float(sampling_rate)])) + pred_nms, features = pp.run_nms( + outputs, params, np.array([float(sampling_rate)]) + ) pred_nms = pred_nms[0] - pred_nms['start_times'] += chunk_time - pred_nms['end_times'] += chunk_time + pred_nms["start_times"] += chunk_time + pred_nms["end_times"] += chunk_time # if we have a background class - if pred_nms['class_probs'].shape[0] > len(params['class_names']): - pred_nms['class_probs'] = pred_nms['class_probs'][:-1, :] + if pred_nms["class_probs"].shape[0] > len(params["class_names"]): + pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :] predictions.append(pred_nms) # extract features - if there are any calls detected - if (pred_nms['det_probs'].shape[0] > 0): - if args['spec_features']: + if pred_nms["det_probs"].shape[0] > 0: + if args["spec_features"]: spec_feats.append(feats.get_feats(spec_np, pred_nms, params)) - if args['cnn_features']: + if args["cnn_features"]: cnn_feats.append(features[0]) - if args['spec_slices']: - spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms, params)) + if args["spec_slices"]: + spec_slices.extend( + feats.extract_spec_slices(spec_np, pred_nms, params) + ) # convert the predictions into output dictionary file_id = os.path.basename(audio_file) - predictions, spec_feats, cnn_feats, spec_slices =\ - merge_results(predictions, spec_feats, cnn_feats, spec_slices) - results = convert_results(file_id, time_exp, duration_full, params, - predictions, spec_feats, cnn_feats, spec_slices) + predictions, spec_feats, cnn_feats, spec_slices = merge_results( + predictions, spec_feats, cnn_feats, spec_slices + ) + results = convert_results( + file_id, + time_exp, + duration_full, + params, + predictions, + spec_feats, + cnn_feats, + spec_slices, + ) # summarize results - if not args['quiet']: - num_detections = len(results['pred_dict']['annotation']) - print('{}'.format(num_detections) + ' call(s) detected above the threshold.') + if not args["quiet"]: + num_detections = len(results["pred_dict"]["annotation"]) + print( + "{}".format(num_detections) + + " call(s) detected above the threshold." + ) # print results for top n classes - if not args['quiet'] and (num_detections > 0): - class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs']) - print('species name'.ljust(30) + 'probablity present') + if not args["quiet"] and (num_detections > 0): + class_overall = pp.overall_class_pred( + predictions["det_probs"], predictions["class_probs"] + ) + print("species name".ljust(30) + "probablity present") for cc in np.argsort(class_overall)[::-1][:top_n]: - print(params['class_names'][cc].ljust(30) + str(round(class_overall[cc], 3))) + print( + params["class_names"][cc].ljust(30) + + str(round(class_overall[cc], 3)) + ) if return_raw_preds: return predictions diff --git a/bat_detect/utils/plot_utils.py b/bat_detect/utils/plot_utils.py index 5b38f65..ce88375 100644 --- a/bat_detect/utils/plot_utils.py +++ b/bat_detect/utils/plot_utils.py @@ -1,63 +1,109 @@ -import numpy as np -import matplotlib.pyplot as plt import json -from sklearn.metrics import confusion_matrix + +import matplotlib.pyplot as plt +import numpy as np from matplotlib import patches from matplotlib.collections import PatchCollection +from sklearn.metrics import confusion_matrix from . import audio_utils as au -def create_box_image(spec, fig, detections_ip, start_time, end_time, duration, params, max_val, hide_axis=True, plot_class_names=False): +def create_box_image( + spec, + fig, + detections_ip, + start_time, + end_time, + duration, + params, + max_val, + hide_axis=True, + plot_class_names=False, +): # filter detections stop_time = start_time + duration detections = [] for bb in detections_ip: - if (bb['start_time'] >= start_time) and (bb['start_time'] < stop_time-0.02): #(bb['end_time'] < end_time): + if (bb["start_time"] >= start_time) and ( + bb["start_time"] < stop_time - 0.02 + ): # (bb['end_time'] < end_time): detections.append(bb) # create figure freq_scale = 1000 # turn Hz to kHz - min_freq = params['min_freq']//freq_scale - max_freq = params['max_freq']//freq_scale + min_freq = params["min_freq"] // freq_scale + max_freq = params["max_freq"] // freq_scale y_extent = [0, duration, min_freq, max_freq] if hide_axis: - ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() fig.add_axes(ax) else: ax = plt.gca() - plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=max_val) + plt.imshow( + spec, + aspect="auto", + cmap="plasma", + extent=y_extent, + vmin=0, + vmax=max_val, + ) boxes = plot_bounding_box_patch_ann(detections, freq_scale, start_time) ax.add_collection(PatchCollection(boxes, match_original=True)) plt.grid(False) if plot_class_names: for ii, bb in enumerate(boxes): - txt = ' '.join([sp[:3] for sp in detections_ip[ii]['class'].split(' ')]) - font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} + txt = " ".join( + [sp[:3] for sp in detections_ip[ii]["class"].split(" ")] + ) + font_info = { + "color": "white", + "size": 10, + "weight": "bold", + "alpha": bb.get_alpha(), + } y_pos = bb.get_xy()[1] + bb.get_height() if y_pos > (max_freq - 10): y_pos = max_freq - 10 plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) -def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title_text='', anns=None): +def save_ann_spec( + op_path, + spec, + min_freq, + max_freq, + duration, + start_time, + title_text="", + anns=None, +): # create figure and plot boxes freq_scale = 1000 # turn Hz to kHz - min_freq = min_freq//freq_scale - max_freq = max_freq//freq_scale + min_freq = min_freq // freq_scale + max_freq = max_freq // freq_scale y_extent = [0, duration, min_freq, max_freq] - plt.close('all') - fig = plt.figure(0, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100) - plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=spec.max()*1.1) + plt.close("all") + fig = plt.figure( + 0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100 + ) + plt.imshow( + spec, + aspect="auto", + cmap="plasma", + extent=y_extent, + vmin=0, + vmax=spec.max() * 1.1, + ) - plt.ylabel('Freq - kHz') - plt.xlabel('Time - secs') - if title_text != '': + plt.ylabel("Freq - kHz") + plt.xlabel("Time - secs") + if title_text != "": plt.title(title_text) plt.tight_layout() @@ -66,122 +112,185 @@ def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title boxes = plot_bounding_box_patch_ann(anns, freq_scale, start_time) plt.gca().add_collection(PatchCollection(boxes, match_original=True)) for ii, bb in enumerate(boxes): - txt = ' '.join([sp[:3] for sp in anns[ii]['class'].split(' ')]) - font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} + txt = " ".join([sp[:3] for sp in anns[ii]["class"].split(" ")]) + font_info = { + "color": "white", + "size": 10, + "weight": "bold", + "alpha": bb.get_alpha(), + } y_pos = bb.get_xy()[1] + bb.get_height() if y_pos > (max_freq - 10): y_pos = max_freq - 10 plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) - print('Saving figure to:', op_path) + print("Saving figure to:", op_path) plt.savefig(op_path) -def plot_pts(fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False): +def plot_pts( + fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False +): plt.figure(fig_id) un_class, labels = np.unique(class_names, return_inverse=True) un_labels = np.unique(labels) if un_labels.shape[0] > len(colors): - colors = [plt.cm.jet(float(ii)/un_labels.shape[0]) for ii in un_labels] + colors = [ + plt.cm.jet(float(ii) / un_labels.shape[0]) for ii in un_labels + ] for ii, u in enumerate(un_labels): - inds = np.where(labels==u)[0] - plt.scatter(feats[inds, 0], feats[inds, 1], c=colors[ii], label=str(un_class[ii]), s=marker_size) + inds = np.where(labels == u)[0] + plt.scatter( + feats[inds, 0], + feats[inds, 1], + c=colors[ii], + label=str(un_class[ii]), + s=marker_size, + ) if plot_legend: plt.legend() plt.xticks([]) plt.yticks([]) - plt.title('downsampled features') + plt.title("downsampled features") -def plot_bounding_box_patch(pred, freq_scale, ecolor='w'): +def plot_bounding_box_patch(pred, freq_scale, ecolor="w"): patch_collect = [] - for bb in range(len(pred['start_times'])): - xx = pred['start_times'][bb] - ww = pred['end_times'][bb] - pred['start_times'][bb] - yy = pred['low_freqs'][bb] / freq_scale - hh = (pred['high_freqs'][bb] - pred['low_freqs'][bb]) / freq_scale + for bb in range(len(pred["start_times"])): + xx = pred["start_times"][bb] + ww = pred["end_times"][bb] - pred["start_times"][bb] + yy = pred["low_freqs"][bb] / freq_scale + hh = (pred["high_freqs"][bb] - pred["low_freqs"][bb]) / freq_scale - if 'det_probs' in pred.keys(): - alpha_val = pred['det_probs'][bb] + if "det_probs" in pred.keys(): + alpha_val = pred["det_probs"][bb] else: alpha_val = 1.0 - patch_collect.append(patches.Rectangle((xx, yy), ww, hh, linewidth=1, - edgecolor=ecolor, facecolor='none', alpha=alpha_val)) + patch_collect.append( + patches.Rectangle( + (xx, yy), + ww, + hh, + linewidth=1, + edgecolor=ecolor, + facecolor="none", + alpha=alpha_val, + ) + ) return patch_collect def plot_bounding_box_patch_ann(anns, freq_scale, start_time): patch_collect = [] for aa in range(len(anns)): - xx = anns[aa]['start_time'] - start_time - ww = anns[aa]['end_time'] - anns[aa]['start_time'] - yy = anns[aa]['low_freq'] / freq_scale - hh = (anns[aa]['high_freq'] - anns[aa]['low_freq']) / freq_scale - if 'det_prob' in anns[aa]: - alpha = anns[aa]['det_prob'] + xx = anns[aa]["start_time"] - start_time + ww = anns[aa]["end_time"] - anns[aa]["start_time"] + yy = anns[aa]["low_freq"] / freq_scale + hh = (anns[aa]["high_freq"] - anns[aa]["low_freq"]) / freq_scale + if "det_prob" in anns[aa]: + alpha = anns[aa]["det_prob"] else: alpha = 1.0 - patch_collect.append(patches.Rectangle((xx,yy), ww, hh, linewidth=1, - edgecolor='w', facecolor='none', alpha=alpha)) + patch_collect.append( + patches.Rectangle( + (xx, yy), + ww, + hh, + linewidth=1, + edgecolor="w", + facecolor="none", + alpha=alpha, + ) + ) return patch_collect -def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title, - op_file_name, pred_2d_hm, plot_boxes=True, fixed_aspect=True): +def plot_spec( + spec, + sampling_rate, + duration, + gt, + pred, + params, + plot_title, + op_file_name, + pred_2d_hm, + plot_boxes=True, + fixed_aspect=True, +): if fixed_aspect: # ouptut image will be this width irrespective of the duration of the audio file width = 12 else: - width = 12*duration + width = 12 * duration fig = plt.figure(1, figsize=(width, 8)) - ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h + ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h ax1 = plt.axes([0.05, 0.33, 0.9, 0.30]) ax2 = plt.axes([0.05, 0.01, 0.9, 0.30]) freq_scale = 1000 # turn Hz in kHz - #duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) - y_extent = [0, duration, params['min_freq']//freq_scale, params['max_freq']//freq_scale] + # duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) + y_extent = [ + 0, + duration, + params["min_freq"] // freq_scale, + params["max_freq"] // freq_scale, + ] # plot gt boxes - ax0.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent) + ax0.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent) ax0.xaxis.set_ticklabels([]) - font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} - ax0.text(0, params['min_freq']//freq_scale, 'Ground Truth', fontdict=font_info) + font_info = {"color": "white", "size": 12, "weight": "bold"} + ax0.text( + 0, params["min_freq"] // freq_scale, "Ground Truth", fontdict=font_info + ) plt.grid(False) if plot_boxes: boxes = plot_bounding_box_patch(gt, freq_scale) ax0.add_collection(PatchCollection(boxes, match_original=True)) for ii, bb in enumerate(boxes): - class_id = int(gt['class_ids'][ii]) + class_id = int(gt["class_ids"][ii]) if class_id < 0: - txt = params['generic_class'][0] + txt = params["generic_class"][0] else: - txt = params['class_names_short'][class_id] - font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} + txt = params["class_names_short"][class_id] + font_info = { + "color": "white", + "size": 10, + "weight": "bold", + "alpha": bb.get_alpha(), + } y_pos = bb.get_xy()[1] + bb.get_height() ax0.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) # plot predicted boxes - ax1.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent) + ax1.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent) ax1.xaxis.set_ticklabels([]) - font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} - ax1.text(0, params['min_freq']//freq_scale, 'Prediction', fontdict=font_info) + font_info = {"color": "white", "size": 12, "weight": "bold"} + ax1.text( + 0, params["min_freq"] // freq_scale, "Prediction", fontdict=font_info + ) plt.grid(False) if plot_boxes: boxes = plot_bounding_box_patch(pred, freq_scale) ax1.add_collection(PatchCollection(boxes, match_original=True)) for ii, bb in enumerate(boxes): - if pred['class_probs'].shape[0] > len(params['class_names_short']): - class_id = pred['class_probs'][:-1, ii].argmax() + if pred["class_probs"].shape[0] > len(params["class_names_short"]): + class_id = pred["class_probs"][:-1, ii].argmax() else: - class_id = pred['class_probs'][:, ii].argmax() - txt = params['class_names_short'][class_id] - font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()} + class_id = pred["class_probs"][:, ii].argmax() + txt = params["class_names_short"][class_id] + font_info = { + "color": "white", + "size": 10, + "weight": "bold", + "alpha": bb.get_alpha(), + } y_pos = bb.get_xy()[1] + bb.get_height() ax1.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) @@ -190,10 +299,18 @@ def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title, min_val = 0.0 if pred_2d_hm.min() > 0.0 else pred_2d_hm.min() max_val = 1.0 if pred_2d_hm.max() < 1.0 else pred_2d_hm.max() - ax2.imshow(pred_2d_hm, aspect='auto', cmap='plasma', extent=y_extent, clim=[min_val, max_val]) - #ax2.xaxis.set_ticklabels([]) - font_info = {'color': 'white', 'size': 12, 'weight': 'bold'} - ax2.text(0, params['min_freq']//freq_scale, 'Heatmap', fontdict=font_info) + ax2.imshow( + pred_2d_hm, + aspect="auto", + cmap="plasma", + extent=y_extent, + clim=[min_val, max_val], + ) + # ax2.xaxis.set_ticklabels([]) + font_info = {"color": "white", "size": 12, "weight": "bold"} + ax2.text( + 0, params["min_freq"] // freq_scale, "Heatmap", fontdict=font_info + ) plt.grid(False) @@ -204,107 +321,151 @@ def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title, plt.close(1) -def plot_pr_curve(op_dir, plt_title, file_name, results, file_type='png', title_text=''): - precision = results['precision'] - recall = results['recall'] - avg_prec = results['avg_prec'] +def plot_pr_curve( + op_dir, plt_title, file_name, results, file_type="png", title_text="" +): + precision = results["precision"] + recall = results["recall"] + avg_prec = results["avg_prec"] - plt.figure(0, figsize=(10,8)) + plt.figure(0, figsize=(10, 8)) plt.plot(recall, precision) - plt.ylabel('Precision', fontsize=20) - plt.xlabel('Recall', fontsize=20) - if title_text != '': - plt.title(title_text, fontdict={'fontsize': 28}) + plt.ylabel("Precision", fontsize=20) + plt.xlabel("Recall", fontsize=20) + if title_text != "": + plt.title(title_text, fontdict={"fontsize": 28}) else: - plt.title(plt_title + ' {:.3f}\n'.format(avg_prec)) - plt.xlim(0,1.02) - plt.ylim(0,1.02) + plt.title(plt_title + " {:.3f}\n".format(avg_prec)) + plt.xlim(0, 1.02) + plt.ylim(0, 1.02) plt.grid(True) plt.tight_layout() - plt.savefig(op_dir + file_name + '.' + file_type) + plt.savefig(op_dir + file_name + "." + file_type) plt.close(0) -def plot_pr_curve_class(op_dir, plt_title, file_name, results, file_type='png', title_text=''): - plt.figure(0, figsize=(10,8)) - plt.ylabel('Precision', fontsize=20) - plt.xlabel('Recall', fontsize=20) - plt.xlim(0,1.02) - plt.ylim(0,1.02) +def plot_pr_curve_class( + op_dir, plt_title, file_name, results, file_type="png", title_text="" +): + plt.figure(0, figsize=(10, 8)) + plt.ylabel("Precision", fontsize=20) + plt.xlabel("Recall", fontsize=20) + plt.xlim(0, 1.02) + plt.ylim(0, 1.02) plt.grid(True) - linestyles = ['-', ':', '--'] - markers = ['o', 'v', '>', '^', '<', 's', 'P', 'X', '*'] - colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + linestyles = ["-", ":", "--"] + markers = ["o", "v", ">", "^", "<", "s", "P", "X", "*"] + colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] # plot the PR curves - for ii, rr in enumerate(results['class_pr']): - class_name = ' '.join([sp[:3] for sp in rr['name'].split(' ')]) - cur_color = colors[int(ii%10)] - plt.plot(rr['recall'], rr['precision'], label=class_name, color=cur_color, - linestyle=linestyles[int(ii//10)], lw=2.5) + for ii, rr in enumerate(results["class_pr"]): + class_name = " ".join([sp[:3] for sp in rr["name"].split(" ")]) + cur_color = colors[int(ii % 10)] + plt.plot( + rr["recall"], + rr["precision"], + label=class_name, + color=cur_color, + linestyle=linestyles[int(ii // 10)], + lw=2.5, + ) - #print(class_name) + # print(class_name) # plot the location of the confidence threshold values - for jj, tt in enumerate(rr['thresholds']): - ind = rr['thresholds_inds'][jj] + for jj, tt in enumerate(rr["thresholds"]): + ind = rr["thresholds_inds"][jj] if ind > -1: - plt.plot(rr['recall'][ind], rr['precision'][ind], markers[jj], - color=cur_color, ms=10) - #print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3)) + plt.plot( + rr["recall"][ind], + rr["precision"][ind], + markers[jj], + color=cur_color, + ms=10, + ) + # print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3)) - if title_text != '': - plt.title(title_text, fontdict={'fontsize': 28}) + if title_text != "": + plt.title(title_text, fontdict={"fontsize": 28}) else: - plt.title(plt_title + ' {:.3f}\n'.format(results['avg_prec_class'])) - plt.legend(loc='lower left', prop={'size': 14}) + plt.title(plt_title + " {:.3f}\n".format(results["avg_prec_class"])) + plt.legend(loc="lower left", prop={"size": 14}) plt.tight_layout() - plt.savefig(op_dir + file_name + '.' + file_type) + plt.savefig(op_dir + file_name + "." + file_type) plt.close(0) -def plot_confusion_matrix(op_dir, op_file, gt, pred, file_acc, class_names_long, verbose=False, file_type='png', title_text=''): +def plot_confusion_matrix( + op_dir, + op_file, + gt, + pred, + file_acc, + class_names_long, + verbose=False, + file_type="png", + title_text="", +): # shorten the class names for plotting class_names = [] for cc in class_names_long: - class_name_sm = ''.join([cc_sm[:3] + ' ' for cc_sm in cc.split(' ')])[:-1] + class_name_sm = "".join([cc_sm[:3] + " " for cc_sm in cc.split(" ")])[ + :-1 + ] class_names.append(class_name_sm) num_classes = len(class_names) - cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(np.float32) + cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype( + np.float32 + ) cm_norm = cm.sum(1) valid_inds = np.where(cm_norm > 0)[0] - cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] - cm[np.where(cm_norm ==- 0)[0], :] = np.nan + cm[valid_inds, :] = ( + cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] + ) + cm[np.where(cm_norm == -0)[0], :] = np.nan if verbose: - print('Per class accuracy:') + print("Per class accuracy:") str_len = np.max([len(cc) for cc in class_names_long]) + 5 accs = np.diag(cm) for ii, cc in enumerate(class_names_long): if np.isnan(accs[ii]): print(str(ii).ljust(5) + cc.ljust(str_len)) else: - print(str(ii).ljust(5) + cc.ljust(str_len) + '{:.2f}'.format(accs[ii]*100)) + print( + str(ii).ljust(5) + + cc.ljust(str_len) + + "{:.2f}".format(accs[ii] * 100) + ) - plt.figure(0, figsize=(10,8)) - plt.imshow(cm, vmin=0, vmax=1, cmap='plasma') + plt.figure(0, figsize=(10, 8)) + plt.imshow(cm, vmin=0, vmax=1, cmap="plasma") plt.colorbar() - plt.xticks(np.arange(cm.shape[1]), class_names, rotation='vertical') + plt.xticks(np.arange(cm.shape[1]), class_names, rotation="vertical") plt.yticks(np.arange(cm.shape[0]), class_names) - plt.xlabel('Predicted', fontsize=20) - plt.ylabel('Ground Truth', fontsize=20) - if title_text != '': - plt.title(title_text, fontdict={'fontsize': 28}) + plt.xlabel("Predicted", fontsize=20) + plt.ylabel("Ground Truth", fontsize=20) + if title_text != "": + plt.title(title_text, fontdict={"fontsize": 28}) else: - plt.title(op_file + ' {:.3f}\n'.format(file_acc)) + plt.title(op_file + " {:.3f}\n".format(file_acc)) plt.tight_layout() - plt.savefig(op_dir + op_file + '.' + file_type) - plt.close('all') + plt.savefig(op_dir + op_file + "." + file_type) + plt.close("all") class LossPlotter(object): - def __init__(self, op_file_name, duration, labels, ylim, class_names, axis_labels=None, logy=False): + def __init__( + self, + op_file_name, + duration, + labels, + ylim, + class_names, + axis_labels=None, + logy=False, + ): self.reset() self.op_file_name = op_file_name self.duration = duration # length of x axis @@ -327,11 +488,16 @@ class LossPlotter(object): self.save_confusion_matrix(gt, pred) def save_plot(self): - linestyles = ['-', ':', '--'] - plt.figure(0, figsize=(8,5)) + linestyles = ["-", ":", "--"] + plt.figure(0, figsize=(8, 5)) for ii in range(len(self.vals[0])): l_vals = [vv[ii] for vv in self.vals] - plt.plot(self.epochs, l_vals, label=self.labels[ii], linestyle=linestyles[int(ii//10)]) + plt.plot( + self.epochs, + l_vals, + label=self.labels[ii], + linestyle=linestyles[int(ii // 10)], + ) plt.xlim(0, np.maximum(self.duration, len(self.vals))) if self.ylim is not None: plt.ylim(self.ylim[0], self.ylim[1]) @@ -339,33 +505,41 @@ class LossPlotter(object): plt.xlabel(self.axis_labels[0]) plt.ylabel(self.axis_labels[1]) if self.logy: - plt.gca().set_yscale('log') + plt.gca().set_yscale("log") plt.grid(True) - plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0.0) + plt.legend( + bbox_to_anchor=(1.01, 1), loc="upper left", borderaxespad=0.0 + ) plt.tight_layout() plt.savefig(self.op_file_name) plt.close(0) def save_json(self): data = {} - data['epochs'] = self.epochs + data["epochs"] = self.epochs for ii in range(len(self.vals[0])): - data[self.labels[ii]] = [round(vv[ii],4) for vv in self.vals] - with open(self.op_file_name[:-4] + '.json', 'w') as da: + data[self.labels[ii]] = [round(vv[ii], 4) for vv in self.vals] + with open(self.op_file_name[:-4] + ".json", "w") as da: json.dump(data, da, indent=2) def save_confusion_matrix(self, gt, pred): plt.figure(0) - cm = confusion_matrix(gt, pred, np.arange(len(self.class_names))).astype(np.float32) + cm = confusion_matrix( + gt, pred, np.arange(len(self.class_names)) + ).astype(np.float32) cm_norm = cm.sum(1) valid_inds = np.where(cm_norm > 0)[0] - cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] - plt.imshow(cm, vmin=0, vmax=1, cmap='plasma') + cm[valid_inds, :] = ( + cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] + ) + plt.imshow(cm, vmin=0, vmax=1, cmap="plasma") plt.colorbar() - plt.xticks(np.arange(cm.shape[1]), self.class_names, rotation='vertical') + plt.xticks( + np.arange(cm.shape[1]), self.class_names, rotation="vertical" + ) plt.yticks(np.arange(cm.shape[0]), self.class_names) - plt.xlabel('Predicted') - plt.ylabel('Ground Truth') + plt.xlabel("Predicted") + plt.ylabel("Ground Truth") plt.tight_layout() - plt.savefig(self.op_file_name[:-4] + '_cm.png') + plt.savefig(self.op_file_name[:-4] + "_cm.png") plt.close(0) diff --git a/bat_detect/utils/visualize.py b/bat_detect/utils/visualize.py index bea7f6b..54be1df 100644 --- a/bat_detect/utils/visualize.py +++ b/bat_detect/utils/visualize.py @@ -1,19 +1,46 @@ -import numpy as np import matplotlib.pyplot as plt +import numpy as np from matplotlib import patches -from sklearn.svm import LinearSVC from matplotlib.axes._axes import _log as matplotlib_axes_logger -matplotlib_axes_logger.setLevel('ERROR') +from sklearn.svm import LinearSVC + +matplotlib_axes_logger.setLevel("ERROR") -colors = ['#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', - '#42d4f4', '#f032e6', '#bfef45', '#fabebe', '#469990', '#e6beff', - '#9A6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1', - '#000075', '#a9a9a9'] +colors = [ + "#e6194B", + "#3cb44b", + "#ffe119", + "#4363d8", + "#f58231", + "#911eb4", + "#42d4f4", + "#f032e6", + "#bfef45", + "#fabebe", + "#469990", + "#e6beff", + "#9A6324", + "#fffac8", + "#800000", + "#aaffc3", + "#808000", + "#ffd8b1", + "#000075", + "#a9a9a9", +] class InteractivePlotter: - def __init__(self, feats_ds, feats, spec_slices, call_info, freq_lims, allow_training): + def __init__( + self, + feats_ds, + feats, + spec_slices, + call_info, + freq_lims, + allow_training, + ): """ Plots 2D low dimensional features on left and corresponding spectgrams on the right. @@ -24,78 +51,123 @@ class InteractivePlotter: self.spec_slices = spec_slices self.call_info = call_info - #_, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True) + # _, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True) self.labels = np.zeros(len(call_info), dtype=np.int) - self.annotated = np.zeros(self.labels.shape[0], dtype=np.int) # can populate this with 1's where we have labels - self.labels_cols = [colors[self.labels[ii]] for ii in range(len(self.labels))] + self.annotated = np.zeros( + self.labels.shape[0], dtype=np.int + ) # can populate this with 1's where we have labels + self.labels_cols = [ + colors[self.labels[ii]] for ii in range(len(self.labels)) + ] self.freq_lims = freq_lims self.allow_training = allow_training self.pt_size = 5.0 - self.spec_pad = 0.2 # this much padding has been applied to the spec slices + self.spec_pad = ( + 0.2 # this much padding has been applied to the spec slices + ) self.fig_width = 12 self.fig_height = 8 self.current_id = 0 max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices]) self.max_width = self.spec_slices[max_ind].shape[1] - self.blank_spec = np.zeros((self.spec_slices[0].shape[0], self.max_width)) - + self.blank_spec = np.zeros( + (self.spec_slices[0].shape[0], self.max_width) + ) def plot(self, fig_id): - self.fig, self.ax = plt.subplots(nrows=1, ncols=2, num=fig_id, figsize=(self.fig_width, self.fig_height), - gridspec_kw={'width_ratios': [2, 1]}) + self.fig, self.ax = plt.subplots( + nrows=1, + ncols=2, + num=fig_id, + figsize=(self.fig_width, self.fig_height), + gridspec_kw={"width_ratios": [2, 1]}, + ) plt.tight_layout() # plot 2D TNSE features - self.low_dim_plt = self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1], - c=self.labels_cols, s=self.pt_size, picker=5) - self.ax[0].set_title('TSNE of Call Features') + self.low_dim_plt = self.ax[0].scatter( + self.feats_ds[:, 0], + self.feats_ds[:, 1], + c=self.labels_cols, + s=self.pt_size, + picker=5, + ) + self.ax[0].set_title("TSNE of Call Features") self.ax[0].set_xticks([]) self.ax[0].set_yticks([]) # plot clip from spectrogram - spec_min_max = (0, self.blank_spec.shape[1], self.freq_lims[0], self.freq_lims[1]) - self.ax[1].imshow(self.blank_spec, extent=spec_min_max, cmap='plasma', aspect='auto') + spec_min_max = ( + 0, + self.blank_spec.shape[1], + self.freq_lims[0], + self.freq_lims[1], + ) + self.ax[1].imshow( + self.blank_spec, extent=spec_min_max, cmap="plasma", aspect="auto" + ) self.spec_im = self.ax[1].get_images()[0] - self.ax[1].set_title('Spectrogram') - self.ax[1].grid(color='w', linewidth=0.5) + self.ax[1].set_title("Spectrogram") + self.ax[1].grid(color="w", linewidth=0.5) self.ax[1].set_xticks([]) - self.ax[1].set_ylabel('kHz') + self.ax[1].set_ylabel("kHz") - bbox_orig = patches.Rectangle((0,0),0,0, edgecolor='w', linewidth=0, fill=False) + bbox_orig = patches.Rectangle( + (0, 0), 0, 0, edgecolor="w", linewidth=0, fill=False + ) self.ax[1].add_patch(bbox_orig) - self.annot = self.ax[0].annotate('', xy=(0,0), xytext=(20,20),textcoords='offset points', - bbox=dict(boxstyle='round', fc='w'), arrowprops=dict(arrowstyle='->')) + self.annot = self.ax[0].annotate( + "", + xy=(0, 0), + xytext=(20, 20), + textcoords="offset points", + bbox=dict(boxstyle="round", fc="w"), + arrowprops=dict(arrowstyle="->"), + ) self.annot.set_visible(False) - self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_hover) - self.fig.canvas.mpl_connect('key_press_event', self.key_press) - + self.fig.canvas.mpl_connect("motion_notify_event", self.mouse_hover) + self.fig.canvas.mpl_connect("key_press_event", self.key_press) def mouse_hover(self, event): vis = self.annot.get_visible() if event.inaxes == self.ax[0]: cont, ind = self.low_dim_plt.contains(event) if cont: - self.current_id = ind['ind'][0] + self.current_id = ind["ind"][0] # copy spec into full window - probably a better way of doing this new_spec = self.blank_spec.copy() - w_diff = (self.blank_spec.shape[1] - self.spec_slices[self.current_id].shape[1])//2 - new_spec[:, w_diff:self.spec_slices[self.current_id].shape[1]+w_diff] = self.spec_slices[self.current_id] + w_diff = ( + self.blank_spec.shape[1] + - self.spec_slices[self.current_id].shape[1] + ) // 2 + new_spec[ + :, + w_diff : self.spec_slices[self.current_id].shape[1] + + w_diff, + ] = self.spec_slices[self.current_id] self.spec_im.set_data(new_spec) self.spec_im.set_clim(vmin=0, vmax=new_spec.max()) # draw bounding box around call self.ax[1].patches[0].remove() - spec_width_orig = self.spec_slices[self.current_id].shape[1]/(1.0+2.0*self.spec_pad) - xx = w_diff + self.spec_pad*spec_width_orig + spec_width_orig = self.spec_slices[self.current_id].shape[ + 1 + ] / (1.0 + 2.0 * self.spec_pad) + xx = w_diff + self.spec_pad * spec_width_orig ww = spec_width_orig - yy = self.call_info[self.current_id]['low_freq']/1000 - hh = (self.call_info[self.current_id]['high_freq']-self.call_info[self.current_id]['low_freq'])/1000 - bbox = patches.Rectangle((xx,yy),ww,hh, edgecolor='r', linewidth=0.5, fill=False) + yy = self.call_info[self.current_id]["low_freq"] / 1000 + hh = ( + self.call_info[self.current_id]["high_freq"] + - self.call_info[self.current_id]["low_freq"] + ) / 1000 + bbox = patches.Rectangle( + (xx, yy), ww, hh, edgecolor="r", linewidth=0.5, fill=False + ) self.ax[1].add_patch(bbox) # update annotation arrow @@ -104,38 +176,54 @@ class InteractivePlotter: self.annot.set_visible(True) # write call info - info_str = self.call_info[self.current_id]['file_name'] + ', time=' \ - + str(round(self.call_info[self.current_id]['start_time'],3)) \ - + ', prob=' + str(round(self.call_info[self.current_id]['det_prob'],3)) + info_str = ( + self.call_info[self.current_id]["file_name"] + + ", time=" + + str( + round(self.call_info[self.current_id]["start_time"], 3) + ) + + ", prob=" + + str( + round(self.call_info[self.current_id]["det_prob"], 3) + ) + ) self.ax[0].set_xlabel(info_str) # redraw self.fig.canvas.draw_idle() - def key_press(self, event): if event.key.isdigit(): self.labels_cols[self.current_id] = colors[int(event.key)] self.labels[self.current_id] = int(event.key) self.annotated[self.current_id] = 1 - elif event.key == 'enter' and self.allow_training: + elif event.key == "enter" and self.allow_training: self.train_classifier() - elif event.key == 'x' and self.allow_training: + elif event.key == "x" and self.allow_training: self.get_classifier_params() - self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1], - c=self.labels_cols, s=self.pt_size) + self.ax[0].scatter( + self.feats_ds[:, 0], + self.feats_ds[:, 1], + c=self.labels_cols, + s=self.pt_size, + ) self.fig.canvas.draw_idle() - def train_classifier(self): # TODO maybe it's better to classify in 2D space - but then can't be linear ... inds = np.where(self.annotated == 1)[0] labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True) if labs_un.shape[0] > 1: # needs at least 2 classes - self.clf = LinearSVC(C=1.0, penalty='l2', loss='squared_hinge', tol=0.0001, - intercept_scaling=1.0, max_iter=2000) + self.clf = LinearSVC( + C=1.0, + penalty="l2", + loss="squared_hinge", + tol=0.0001, + intercept_scaling=1.0, + max_iter=2000, + ) self.clf.fit(self.feats[inds, :], self.labels[inds]) @@ -145,14 +233,13 @@ class InteractivePlotter: for ii in inds_unlab: self.labels_cols[ii] = colors[self.labels[ii]] else: - print('Not enough data - please label more classes.') - + print("Not enough data - please label more classes.") def get_classifier_params(self): res = {} if self.clf is None: - print('Model not trained!') + print("Model not trained!") else: - res['weights'] = self.clf.coef_.astype(np.float32) - res['biases'] = self.clf.intercept_.astype(np.float32) + res["weights"] = self.clf.coef_.astype(np.float32) + res["biases"] = self.clf.intercept_.astype(np.float32) return res diff --git a/bat_detect/utils/wavfile.py b/bat_detect/utils/wavfile.py index a6715b0..7fee660 100644 --- a/bat_detect/utils/wavfile.py +++ b/bat_detect/utils/wavfile.py @@ -8,23 +8,25 @@ Functions `write`: Write a numpy array as a WAV file. """ -from __future__ import division, print_function, absolute_import +from __future__ import absolute_import, division, print_function -import sys -import numpy -import struct -import warnings import os +import struct +import sys +import warnings + +import numpy class WavFileWarning(UserWarning): pass + _big_endian = False WAVE_FORMAT_PCM = 0x0001 WAVE_FORMAT_IEEE_FLOAT = 0x0003 -WAVE_FORMAT_EXTENSIBLE = 0xfffe +WAVE_FORMAT_EXTENSIBLE = 0xFFFE KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT) # assumes file pointer is immediately @@ -33,10 +35,10 @@ KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT) def _read_fmt_chunk(fid): if _big_endian: - fmt = '>' + fmt = ">" else: - fmt = '<' - res = struct.unpack(fmt+'iHHIIHH',fid.read(20)) + fmt = "<" + res = struct.unpack(fmt + "iHHIIHH", fid.read(20)) size, comp, noc, rate, sbytes, ba, bits = res if comp not in KNOWN_WAVE_FORMATS or size > 16: comp = WAVE_FORMAT_PCM @@ -51,41 +53,42 @@ def _read_fmt_chunk(fid): # after the 'data' id def _read_data_chunk(fid, comp, noc, bits, mmap=False): if _big_endian: - fmt = '>i' + fmt = ">i" else: - fmt = ' 1: - data = data.reshape(-1,noc) + data = data.reshape(-1, noc) return data def _skip_unknown_chunk(fid): if _big_endian: - fmt = '>i' + fmt = ">i" else: - fmt = '' or (data.dtype.byteorder == '=' and sys.byteorder == 'big'): + fid.write(b"data") + fid.write(struct.pack("" or ( + data.dtype.byteorder == "=" and sys.byteorder == "big" + ): data = data.byteswap() _array_tofile(fid, data) @@ -273,19 +286,22 @@ def write(filename, rate, data): # position at start of the file (replacing the 4 bytes of zeros) size = fid.tell() fid.seek(4) - fid.write(struct.pack('= 3: + def _array_tofile(fid, data): # ravel gives a c-contiguous buffer - fid.write(data.ravel().view('b').data) + fid.write(data.ravel().view("b").data) + else: + def _array_tofile(fid, data): fid.write(data.tostring()) diff --git a/run_batdetect.py b/run_batdetect.py index 9655d45..f9d96ab 100644 --- a/run_batdetect.py +++ b/run_batdetect.py @@ -1,67 +1,115 @@ -import os import argparse +import os + import bat_detect.utils.detector_utils as du def main(args): - print('Loading model: ' + args['model_path']) - model, params = du.load_model(args['model_path']) + print("Loading model: " + args["model_path"]) + model, params = du.load_model(args["model_path"]) - print('\nInput directory: ' + args['audio_dir']) - files = du.get_audio_files(args['audio_dir']) - print('Number of audio files: {}'.format(len(files))) - print('\nSaving results to: ' + args['ann_dir']) + print("\nInput directory: " + args["audio_dir"]) + files = du.get_audio_files(args["audio_dir"]) + print("Number of audio files: {}".format(len(files))) + print("\nSaving results to: " + args["ann_dir"]) # process files error_files = [] for ii, audio_file in enumerate(files): - print('\n' + str(ii).ljust(6) + os.path.basename(audio_file)) + print("\n" + str(ii).ljust(6) + os.path.basename(audio_file)) try: results = du.process_file(audio_file, model, params, args) - if args['save_preds_if_empty'] or (len(results['pred_dict']['annotation']) > 0): - results_path = audio_file.replace(args['audio_dir'], args['ann_dir']) + if args["save_preds_if_empty"] or ( + len(results["pred_dict"]["annotation"]) > 0 + ): + results_path = audio_file.replace( + args["audio_dir"], args["ann_dir"] + ) du.save_results_to_file(results, results_path) except: error_files.append(audio_file) print("Error processing file!") - print('\nResults saved to: ' + args['ann_dir']) + print("\nResults saved to: " + args["ann_dir"]) if len(error_files) > 0: - print('\nUnable to process the follow files:') + print("\nUnable to process the follow files:") for err in error_files: - print(' ' + err) + print(" " + err) if __name__ == "__main__": - info_str = '\nBatDetect2 - Detection and Classification\n' + \ - ' Assumes audio files are mono, not stereo.\n' + \ - ' Spaces in the input paths will throw an error. Wrap in quotes "".\n' + \ - ' Input files should be short in duration e.g. < 30 seconds.\n' + info_str = ( + "\nBatDetect2 - Detection and Classification\n" + + " Assumes audio files are mono, not stereo.\n" + + ' Spaces in the input paths will throw an error. Wrap in quotes "".\n' + + " Input files should be short in duration e.g. < 30 seconds.\n" + ) print(info_str) parser = argparse.ArgumentParser() - parser.add_argument('audio_dir', type=str, help='Input directory for audio') - parser.add_argument('ann_dir', type=str, help='Output directory for where the predictions will be stored') - parser.add_argument('detection_threshold', type=float, help='Cut-off probability for detector e.g. 0.1') - parser.add_argument('--cnn_features', action='store_true', default=False, dest='cnn_features', - help='Extracts CNN call features') - parser.add_argument('--spec_features', action='store_true', default=False, dest='spec_features', - help='Extracts low level call features') - parser.add_argument('--time_expansion_factor', type=int, default=1, dest='time_expansion_factor', - help='The time expansion factor used for all files (default is 1)') - parser.add_argument('--quiet', action='store_true', default=False, dest='quiet', - help='Minimize output printing') - parser.add_argument('--save_preds_if_empty', action='store_true', default=False, dest='save_preds_if_empty', - help='Save empty annotation file if no detections made.') - parser.add_argument('--model_path', type=str, default='models/Net2DFast_UK_same.pth.tar', - help='Path to trained BatDetect2 model') + parser.add_argument( + "audio_dir", type=str, help="Input directory for audio" + ) + parser.add_argument( + "ann_dir", + type=str, + help="Output directory for where the predictions will be stored", + ) + parser.add_argument( + "detection_threshold", + type=float, + help="Cut-off probability for detector e.g. 0.1", + ) + parser.add_argument( + "--cnn_features", + action="store_true", + default=False, + dest="cnn_features", + help="Extracts CNN call features", + ) + parser.add_argument( + "--spec_features", + action="store_true", + default=False, + dest="spec_features", + help="Extracts low level call features", + ) + parser.add_argument( + "--time_expansion_factor", + type=int, + default=1, + dest="time_expansion_factor", + help="The time expansion factor used for all files (default is 1)", + ) + parser.add_argument( + "--quiet", + action="store_true", + default=False, + dest="quiet", + help="Minimize output printing", + ) + parser.add_argument( + "--save_preds_if_empty", + action="store_true", + default=False, + dest="save_preds_if_empty", + help="Save empty annotation file if no detections made.", + ) + parser.add_argument( + "--model_path", + type=str, + default="models/Net2DFast_UK_same.pth.tar", + help="Path to trained BatDetect2 model", + ) args = vars(parser.parse_args()) - args['spec_slices'] = False # used for visualization - args['chunk_size'] = 2 # if files greater than this amount (seconds) they will be broken down into small chunks - args['ann_dir'] = os.path.join(args['ann_dir'], '') + args["spec_slices"] = False # used for visualization + args[ + "chunk_size" + ] = 2 # if files greater than this amount (seconds) they will be broken down into small chunks + args["ann_dir"] = os.path.join(args["ann_dir"], "") main(args) diff --git a/scripts/gen_dataset_summary_image.py b/scripts/gen_dataset_summary_image.py index b789584..cb823d6 100644 --- a/scripts/gen_dataset_summary_image.py +++ b/scripts/gen_dataset_summary_image.py @@ -3,62 +3,97 @@ Loads a set of annotations corresponding to a dataset and saves an image which is the mean spectrogram for each class. """ +import argparse +import os +import sys + import matplotlib.pyplot as plt import numpy as np -import os -import argparse -import sys import viz_helpers as vz -sys.path.append(os.path.join('..')) -import bat_detect.train.train_utils as tu +sys.path.append(os.path.join("..")) import bat_detect.detector.parameters as parameters -import bat_detect.utils.audio_utils as au import bat_detect.train.train_split as ts - +import bat_detect.train.train_utils as tu +import bat_detect.utils.audio_utils as au if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('audio_path', type=str, help='Input directory for audio') - parser.add_argument('op_dir', type=str, - help='Path to where single annotation json file is stored') - parser.add_argument('--ann_file', type=str, - help='Path to where single annotation json file is stored') - parser.add_argument('--uk_split', type=str, default='', - help='Set as: diff or same') - parser.add_argument('--file_type', type=str, default='png', - help='Type of image to save png or pdf') + parser.add_argument( + "audio_path", type=str, help="Input directory for audio" + ) + parser.add_argument( + "op_dir", + type=str, + help="Path to where single annotation json file is stored", + ) + parser.add_argument( + "--ann_file", + type=str, + help="Path to where single annotation json file is stored", + ) + parser.add_argument( + "--uk_split", type=str, default="", help="Set as: diff or same" + ) + parser.add_argument( + "--file_type", + type=str, + default="png", + help="Type of image to save png or pdf", + ) args = vars(parser.parse_args()) - if not os.path.isdir(args['op_dir']): - os.makedirs(args['op_dir']) + if not os.path.isdir(args["op_dir"]): + os.makedirs(args["op_dir"]) params = parameters.get_params(False) - params['smooth_spec'] = False - params['spec_width'] = 48 - params['norm_type'] = 'log' # log, pcen - params['aud_pad'] = 0.005 - classes_to_ignore = params['classes_to_ignore'] + params['generic_class'] - + params["smooth_spec"] = False + params["spec_width"] = 48 + params["norm_type"] = "log" # log, pcen + params["aud_pad"] = 0.005 + classes_to_ignore = params["classes_to_ignore"] + params["generic_class"] # load train annotations - if args['uk_split'] == '': - print('\nLoading:', args['ann_file'], '\n') - dataset_name = os.path.basename(args['ann_file']).replace('.json', '') + if args["uk_split"] == "": + print("\nLoading:", args["ann_file"], "\n") + dataset_name = os.path.basename(args["ann_file"]).replace(".json", "") datasets = [] - datasets.append(tu.get_blank_dataset_dict(dataset_name, False, args['ann_file'], args['audio_path'])) + datasets.append( + tu.get_blank_dataset_dict( + dataset_name, False, args["ann_file"], args["audio_path"] + ) + ) else: # load uk data - special case - print('\nLoading:', args['uk_split'], '\n') - dataset_name = 'uk_' + args['uk_split'] # should be uk_diff, or uk_same - datasets, _ = ts.get_train_test_data(args['ann_file'], args['audio_path'], args['uk_split'], load_extra=False) + print("\nLoading:", args["uk_split"], "\n") + dataset_name = ( + "uk_" + args["uk_split"] + ) # should be uk_diff, or uk_same + datasets, _ = ts.get_train_test_data( + args["ann_file"], + args["audio_path"], + args["uk_split"], + load_extra=False, + ) - anns, class_names, _ = tu.load_set_of_anns(datasets, classes_to_ignore, params['events_of_interest']) + anns, class_names, _ = tu.load_set_of_anns( + datasets, classes_to_ignore, params["events_of_interest"] + ) class_names_order = range(len(class_names)) - x_train, y_train = vz.load_data(anns, params, class_names, smooth_spec=params['smooth_spec'], norm_type=params['norm_type']) + x_train, y_train = vz.load_data( + anns, + params, + class_names, + smooth_spec=params["smooth_spec"], + norm_type=params["norm_type"], + ) - op_file_name = os.path.join(args['op_dir'], dataset_name + '.' + args['file_type']) - vz.save_summary_image(x_train, y_train, class_names, params, op_file_name, class_names_order) - print('\nImage saved to:', op_file_name) + op_file_name = os.path.join( + args["op_dir"], dataset_name + "." + args["file_type"] + ) + vz.save_summary_image( + x_train, y_train, class_names, params, op_file_name, class_names_order + ) + print("\nImage saved to:", op_file_name) diff --git a/scripts/gen_spec_image.py b/scripts/gen_spec_image.py index 11f76de..3d4cffa 100644 --- a/scripts/gen_spec_image.py +++ b/scripts/gen_spec_image.py @@ -7,24 +7,27 @@ Will save images with: 3) spectrogram with predicted boxes """ -import numpy as np -import sys -import os import argparse -import matplotlib.pyplot as plt import json +import os +import sys -sys.path.append(os.path.join('..')) +import matplotlib.pyplot as plt +import numpy as np + +sys.path.append(os.path.join("..")) import bat_detect.evaluate.evaluate_models as evlm +import bat_detect.utils.audio_utils as au import bat_detect.utils.detector_utils as du import bat_detect.utils.plot_utils as viz -import bat_detect.utils.audio_utils as au def filter_anns(anns, start_time, stop_time): anns_op = [] for aa in anns: - if (aa['start_time'] >= start_time) and (aa['start_time'] < stop_time-0.02): + if (aa["start_time"] >= start_time) and ( + aa["start_time"] < stop_time - 0.02 + ): anns_op.append(aa) return anns_op @@ -32,85 +35,172 @@ def filter_anns(anns, start_time, stop_time): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('audio_file', type=str, help='Path to audio file') - parser.add_argument('model_path', type=str, help='Path to BatDetect model') - parser.add_argument('--ann_file', type=str, default='', help='Path to annotation file') - parser.add_argument('--op_dir', type=str, default='plots/', - help='Output directory for plots') - parser.add_argument('--file_type', type=str, default='png', - help='Type of image to save png or pdf') - parser.add_argument('--title_text', type=str, default='', - help='Text to add as title of plots') - parser.add_argument('--detection_threshold', type=float, default=0.2, - help='Threshold for output detections') - parser.add_argument('--start_time', type=float, default=0.0, - help='Start time for cropped file') - parser.add_argument('--stop_time', type=float, default=0.5, - help='End time for cropped file') - parser.add_argument('--time_expansion_factor', type=int, default=1, - help='Time expansion factor') - + parser.add_argument("audio_file", type=str, help="Path to audio file") + parser.add_argument("model_path", type=str, help="Path to BatDetect model") + parser.add_argument( + "--ann_file", type=str, default="", help="Path to annotation file" + ) + parser.add_argument( + "--op_dir", + type=str, + default="plots/", + help="Output directory for plots", + ) + parser.add_argument( + "--file_type", + type=str, + default="png", + help="Type of image to save png or pdf", + ) + parser.add_argument( + "--title_text", + type=str, + default="", + help="Text to add as title of plots", + ) + parser.add_argument( + "--detection_threshold", + type=float, + default=0.2, + help="Threshold for output detections", + ) + parser.add_argument( + "--start_time", + type=float, + default=0.0, + help="Start time for cropped file", + ) + parser.add_argument( + "--stop_time", + type=float, + default=0.5, + help="End time for cropped file", + ) + parser.add_argument( + "--time_expansion_factor", + type=int, + default=1, + help="Time expansion factor", + ) + args_cmd = vars(parser.parse_args()) - - # load the model + + # load the model bd_args = du.get_default_bd_args() - model, params_bd = du.load_model(args_cmd['model_path']) - bd_args['detection_threshold'] = args_cmd['detection_threshold'] - bd_args['time_expansion_factor'] = args_cmd['time_expansion_factor'] - + model, params_bd = du.load_model(args_cmd["model_path"]) + bd_args["detection_threshold"] = args_cmd["detection_threshold"] + bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"] + # load the annotation if it exists gt_present = False - if args_cmd['ann_file'] != '': - if os.path.isfile(args_cmd['ann_file']): - with open(args_cmd['ann_file']) as da: + if args_cmd["ann_file"] != "": + if os.path.isfile(args_cmd["ann_file"]): + with open(args_cmd["ann_file"]) as da: gt_anns = json.load(da) - gt_anns = filter_anns(gt_anns['annotation'], args_cmd['start_time'], args_cmd['stop_time']) + gt_anns = filter_anns( + gt_anns["annotation"], + args_cmd["start_time"], + args_cmd["stop_time"], + ) gt_present = True else: - print('Annotation file not found: ', args_cmd['ann_file']) + print("Annotation file not found: ", args_cmd["ann_file"]) # load the audio file - if not os.path.isfile(args_cmd['audio_file']): - print('Audio file not found: ', args_cmd['audio_file']) + if not os.path.isfile(args_cmd["audio_file"]): + print("Audio file not found: ", args_cmd["audio_file"]) sys.exit() - + # load audio and crop - print('\nProcessing: ' + os.path.basename(args_cmd['audio_file'])) - print('\nOutput directory: ' + args_cmd['op_dir']) - sampling_rate, audio = au.load_audio_file(args_cmd['audio_file'], args_cmd['time_exp'], - params_bd['target_samp_rate'], params_bd['scale_raw_audio']) - st_samp = int(sampling_rate*args_cmd['start_time']) - en_samp = int(sampling_rate*args_cmd['stop_time']) + print("\nProcessing: " + os.path.basename(args_cmd["audio_file"])) + print("\nOutput directory: " + args_cmd["op_dir"]) + sampling_rate, audio = au.load_audio_file( + args_cmd["audio_file"], + args_cmd["time_exp"], + params_bd["target_samp_rate"], + params_bd["scale_raw_audio"], + ) + st_samp = int(sampling_rate * args_cmd["start_time"]) + en_samp = int(sampling_rate * args_cmd["stop_time"]) if en_samp > audio.shape[0]: - audio = np.hstack((audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype))) + audio = np.hstack( + (audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype)) + ) audio = audio[st_samp:en_samp] duration = audio.shape[0] / sampling_rate - print('File duration: {} seconds'.format(duration)) + print("File duration: {} seconds".format(duration)) # create spec for viz - spec, _ = au.generate_spectrogram(audio, sampling_rate, params_bd, True, False) + spec, _ = au.generate_spectrogram( + audio, sampling_rate, params_bd, True, False + ) # run model and filter detections so only keep ones in relevant time range - results = du.process_file(args_cmd['audio_file'], model, params_bd, bd_args) - pred_anns = filter_anns(results['pred_dict']['annotation'], args_cmd['start_time'], args_cmd['stop_time']) - print(len(pred_anns), 'Detections') + results = du.process_file( + args_cmd["audio_file"], model, params_bd, bd_args + ) + pred_anns = filter_anns( + results["pred_dict"]["annotation"], + args_cmd["start_time"], + args_cmd["stop_time"], + ) + print(len(pred_anns), "Detections") # save output - if not os.path.isdir(args_cmd['op_dir']): - os.makedirs(args_cmd['op_dir']) - + if not os.path.isdir(args_cmd["op_dir"]): + os.makedirs(args_cmd["op_dir"]) + # create output file names - op_path_clean = os.path.basename(args_cmd['audio_file'])[:-4] + '_clean.' + args_cmd['file_type'] - op_path_clean = os.path.join(args_cmd['op_dir'], op_path_clean) - op_path_pred = os.path.basename(args_cmd['audio_file'])[:-4] + '_pred.' + args_cmd['file_type'] - op_path_pred = os.path.join(args_cmd['op_dir'], op_path_pred) + op_path_clean = ( + os.path.basename(args_cmd["audio_file"])[:-4] + + "_clean." + + args_cmd["file_type"] + ) + op_path_clean = os.path.join(args_cmd["op_dir"], op_path_clean) + op_path_pred = ( + os.path.basename(args_cmd["audio_file"])[:-4] + + "_pred." + + args_cmd["file_type"] + ) + op_path_pred = os.path.join(args_cmd["op_dir"], op_path_pred) # create and save iamges - viz.save_ann_spec(op_path_clean, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', None) - viz.save_ann_spec(op_path_pred, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', pred_anns) + viz.save_ann_spec( + op_path_clean, + spec, + params_bd["min_freq"], + params_bd["max_freq"], + duration, + args_cmd["start_time"], + "", + None, + ) + viz.save_ann_spec( + op_path_pred, + spec, + params_bd["min_freq"], + params_bd["max_freq"], + duration, + args_cmd["start_time"], + "", + pred_anns, + ) if gt_present: - op_path_gt = os.path.basename(args_cmd['audio_file'])[:-4] + '_gt.' + args_cmd['file_type'] - op_path_gt = os.path.join(args_cmd['op_dir'], op_path_gt) - viz.save_ann_spec(op_path_gt, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', gt_anns) + op_path_gt = ( + os.path.basename(args_cmd["audio_file"])[:-4] + + "_gt." + + args_cmd["file_type"] + ) + op_path_gt = os.path.join(args_cmd["op_dir"], op_path_gt) + viz.save_ann_spec( + op_path_gt, + spec, + params_bd["min_freq"], + params_bd["max_freq"], + duration, + args_cmd["start_time"], + "", + gt_anns, + ) diff --git a/scripts/gen_spec_video.py b/scripts/gen_spec_video.py index cccfcf8..3c055ec 100644 --- a/scripts/gen_spec_video.py +++ b/scripts/gen_spec_video.py @@ -8,57 +8,83 @@ Notes: Best to use system one - see ffmpeg_path. """ -from scipy.io import wavfile +import argparse import os import shutil +import sys + import matplotlib.pyplot as plt import numpy as np -import argparse +from scipy.io import wavfile -import sys -sys.path.append(os.path.join('..')) +sys.path.append(os.path.join("..")) import bat_detect.detector.parameters as parameters import bat_detect.utils.audio_utils as au -import bat_detect.utils.plot_utils as viz import bat_detect.utils.detector_utils as du - +import bat_detect.utils.plot_utils as viz if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('audio_file', type=str, help='Path to input audio file') - parser.add_argument('model_path', type=str, help='Path to trained BatDetect model') - parser.add_argument('--op_dir', type=str, default='generated_vids/', help='Path to output directory') - parser.add_argument('--no_detector', action='store_true', help='Do not run detector') - parser.add_argument('--plot_class_names_off', action='store_true', help='Do not plot class names') - parser.add_argument('--disable_axis', action='store_true', help='Do not plot axis') - parser.add_argument('--detection_threshold', type=float, default=0.2, help='Cut-off probability for detector') - parser.add_argument('--time_expansion_factor', type=int, default=1, dest='time_expansion_factor', - help='The time expansion factor used for all files (default is 1)') + parser.add_argument( + "audio_file", type=str, help="Path to input audio file" + ) + parser.add_argument( + "model_path", type=str, help="Path to trained BatDetect model" + ) + parser.add_argument( + "--op_dir", + type=str, + default="generated_vids/", + help="Path to output directory", + ) + parser.add_argument( + "--no_detector", action="store_true", help="Do not run detector" + ) + parser.add_argument( + "--plot_class_names_off", + action="store_true", + help="Do not plot class names", + ) + parser.add_argument( + "--disable_axis", action="store_true", help="Do not plot axis" + ) + parser.add_argument( + "--detection_threshold", + type=float, + default=0.2, + help="Cut-off probability for detector", + ) + parser.add_argument( + "--time_expansion_factor", + type=int, + default=1, + dest="time_expansion_factor", + help="The time expansion factor used for all files (default is 1)", + ) args_cmd = vars(parser.parse_args()) # file of interest - audio_file = args_cmd['audio_file'] - op_dir = args_cmd['op_dir'] - op_str = '_output' - ffmpeg_path = '/usr/bin/' + audio_file = args_cmd["audio_file"] + op_dir = args_cmd["op_dir"] + op_str = "_output" + ffmpeg_path = "/usr/bin/" if not os.path.isfile(audio_file): - print('Audio file not found: ', audio_file) + print("Audio file not found: ", audio_file) sys.exit() - if not os.path.isfile(args_cmd['model_path']): - print('Model not found: ', model_path) + if not os.path.isfile(args_cmd["model_path"]): + print("Model not found: ", model_path) sys.exit() - start_time = 0.0 duration = 0.5 reveal_boxes = True # makes the boxes appear one at a time fps = 24 dpi = 100 - op_dir_tmp = os.path.join(op_dir, 'op_tmp_vids', '') + op_dir_tmp = os.path.join(op_dir, "op_tmp_vids", "") if not os.path.isdir(op_dir_tmp): os.makedirs(op_dir_tmp) if not os.path.isdir(op_dir): @@ -66,105 +92,176 @@ if __name__ == "__main__": params = parameters.get_params(False) args = du.get_default_bd_args() - args['time_expansion_factor'] = args_cmd['time_expansion_factor'] - args['detection_threshold'] = args_cmd['detection_threshold'] - + args["time_expansion_factor"] = args_cmd["time_expansion_factor"] + args["detection_threshold"] = args_cmd["detection_threshold"] # load audio file - print('\nProcessing: ' + os.path.basename(audio_file)) - print('\nOutput directory: ' + op_dir) - sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate']) - audio = audio[int(sampling_rate*start_time):int(sampling_rate*start_time + sampling_rate*duration)] + print("\nProcessing: " + os.path.basename(audio_file)) + print("\nOutput directory: " + op_dir) + sampling_rate, audio = au.load_audio_file( + audio_file, args["time_expansion_factor"], params["target_samp_rate"] + ) + audio = audio[ + int(sampling_rate * start_time) : int( + sampling_rate * start_time + sampling_rate * duration + ) + ] audio_orig = audio.copy() - audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], - params['fft_overlap'], params['resize_factor'], - params['spec_divide_factor']) + audio = au.pad_audio( + audio, + sampling_rate, + params["fft_win_length"], + params["fft_overlap"], + params["resize_factor"], + params["spec_divide_factor"], + ) # generate spectrogram spec, _ = au.generate_spectrogram(audio, sampling_rate, params, True) - max_val = spec.max()*1.1 + max_val = spec.max() * 1.1 - - if not args_cmd['no_detector']: - print(' Loading model and running detector on entire file ...') - model, det_params = du.load_model(args_cmd['model_path']) - det_params['detection_threshold'] = args['detection_threshold'] + if not args_cmd["no_detector"]: + print(" Loading model and running detector on entire file ...") + model, det_params = du.load_model(args_cmd["model_path"]) + det_params["detection_threshold"] = args["detection_threshold"] results = du.process_file(audio_file, model, det_params, args) - print(' Processing detections and plotting ...') + print(" Processing detections and plotting ...") detections = [] - for bb in results['pred_dict']['annotation']: - if (bb['start_time'] >= start_time) and (bb['end_time'] < start_time+duration): + for bb in results["pred_dict"]["annotation"]: + if (bb["start_time"] >= start_time) and ( + bb["end_time"] < start_time + duration + ): detections.append(bb) # plot boxes - fig = plt.figure(1, figsize=(spec.shape[1]/dpi, spec.shape[0]/dpi), dpi=dpi) - duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) - viz.create_box_image(spec, fig, detections, start_time, start_time+duration, duration, params, max_val, - plot_class_names=not args_cmd['plot_class_names_off']) - op_im_file_boxes = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '_boxes.png') + fig = plt.figure( + 1, figsize=(spec.shape[1] / dpi, spec.shape[0] / dpi), dpi=dpi + ) + duration = au.x_coords_to_time( + spec.shape[1], + sampling_rate, + params["fft_win_length"], + params["fft_overlap"], + ) + viz.create_box_image( + spec, + fig, + detections, + start_time, + start_time + duration, + duration, + params, + max_val, + plot_class_names=not args_cmd["plot_class_names_off"], + ) + op_im_file_boxes = os.path.join( + op_dir, os.path.basename(audio_file)[:-4] + op_str + "_boxes.png" + ) fig.savefig(op_im_file_boxes, dpi=dpi) plt.close(1) spec_with_boxes = plt.imread(op_im_file_boxes) - - print(' Saving audio file ...') - if args['time_expansion_factor']==1: - sampling_rate_op = int(sampling_rate/10.0) + print(" Saving audio file ...") + if args["time_expansion_factor"] == 1: + sampling_rate_op = int(sampling_rate / 10.0) else: sampling_rate_op = sampling_rate - op_audio_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.wav') + op_audio_file = os.path.join( + op_dir, os.path.basename(audio_file)[:-4] + op_str + ".wav" + ) wavfile.write(op_audio_file, sampling_rate_op, audio_orig) - - print(' Saving image ...') - op_im_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.png') - plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap='plasma') + print(" Saving image ...") + op_im_file = os.path.join( + op_dir, os.path.basename(audio_file)[:-4] + op_str + ".png" + ) + plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap="plasma") spec_blank = plt.imread(op_im_file) # create figure freq_scale = 1000 # turn Hz to kHz - min_freq = params['min_freq']//freq_scale - max_freq = params['max_freq']//freq_scale + min_freq = params["min_freq"] // freq_scale + max_freq = params["max_freq"] // freq_scale y_extent = [0, duration, min_freq, max_freq] - print(' Saving video frames ...') + print(" Saving video frames ...") # save images that will be combined into video # will either plot with or without boxes - for ii, col in enumerate(np.linspace(0, spec.shape[1]-1, int(fps*duration*10))): - if not args_cmd['no_detector']: + for ii, col in enumerate( + np.linspace(0, spec.shape[1] - 1, int(fps * duration * 10)) + ): + if not args_cmd["no_detector"]: spec_op = spec_with_boxes.copy() if ii > 0: spec_op[:, int(col), :] = 1.0 if reveal_boxes: - spec_op[:, int(col)+1:, :] = spec_blank[:, int(col)+1:, :] + spec_op[:, int(col) + 1 :, :] = spec_blank[ + :, int(col) + 1 :, : + ] elif ii == 0 and reveal_boxes: spec_op = spec_blank - if not args_cmd['disable_axis']: - plt.close('all') - fig = plt.figure(ii, figsize=(1.2*(spec_op.shape[1]/dpi), 1.5*(spec_op.shape[0]/dpi)), dpi=dpi) - plt.xlabel('Time - seconds') - plt.ylabel('Frequency - kHz') - plt.imshow(spec_op, vmin=0, vmax=1.0, cmap='plasma', extent=y_extent, aspect='auto') + if not args_cmd["disable_axis"]: + plt.close("all") + fig = plt.figure( + ii, + figsize=( + 1.2 * (spec_op.shape[1] / dpi), + 1.5 * (spec_op.shape[0] / dpi), + ), + dpi=dpi, + ) + plt.xlabel("Time - seconds") + plt.ylabel("Frequency - kHz") + plt.imshow( + spec_op, + vmin=0, + vmax=1.0, + cmap="plasma", + extent=y_extent, + aspect="auto", + ) plt.tight_layout() - fig.savefig(op_dir_tmp + str(ii).zfill(4) + '.png', dpi=dpi) + fig.savefig(op_dir_tmp + str(ii).zfill(4) + ".png", dpi=dpi) else: - plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=1.0, cmap='plasma') + plt.imsave( + op_dir_tmp + str(ii).zfill(4) + ".png", + spec_op, + vmin=0, + vmax=1.0, + cmap="plasma", + ) else: spec_op = spec.copy() if ii > 0: spec_op[:, int(col)] = max_val - plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=max_val, cmap='plasma') + plt.imsave( + op_dir_tmp + str(ii).zfill(4) + ".png", + spec_op, + vmin=0, + vmax=max_val, + cmap="plasma", + ) - - print(' Creating video ...') - op_vid_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.avi') - ffmpeg_cmd = 'ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 ' \ - '-crf 25 -pix_fmt yuv420p -acodec copy {}'.format(fps, spec.shape[1], spec.shape[0], op_dir_tmp, op_audio_file, op_vid_file) + print(" Creating video ...") + op_vid_file = os.path.join( + op_dir, os.path.basename(audio_file)[:-4] + op_str + ".avi" + ) + ffmpeg_cmd = ( + "ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 " + "-crf 25 -pix_fmt yuv420p -acodec copy {}".format( + fps, + spec.shape[1], + spec.shape[0], + op_dir_tmp, + op_audio_file, + op_vid_file, + ) + ) ffmpeg_cmd = ffmpeg_path + ffmpeg_cmd os.system(ffmpeg_cmd) - print(' Deleting temporary files ...') + print(" Deleting temporary files ...") if os.path.isdir(op_dir_tmp): - shutil.rmtree(op_dir_tmp) + shutil.rmtree(op_dir_tmp) diff --git a/scripts/viz_helpers.py b/scripts/viz_helpers.py index 2f55836..667bb9c 100644 --- a/scripts/viz_helpers.py +++ b/scripts/viz_helpers.py @@ -1,41 +1,70 @@ -import numpy as np -import matplotlib.pyplot as plt -from scipy import ndimage import os import sys -sys.path.append(os.path.join('..')) + +import matplotlib.pyplot as plt +import numpy as np +from scipy import ndimage + +sys.path.append(os.path.join("..")) import bat_detect.utils.audio_utils as au -def generate_spectrogram_data(audio, sampling_rate, params, norm_type='log', smooth_spec=False): - max_freq = round(params['max_freq']*params['fft_win_length']) - min_freq = round(params['min_freq']*params['fft_win_length']) +def generate_spectrogram_data( + audio, sampling_rate, params, norm_type="log", smooth_spec=False +): + max_freq = round(params["max_freq"] * params["fft_win_length"]) + min_freq = round(params["min_freq"] * params["fft_win_length"]) # create spectrogram - numpy - spec = au.gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']) - #spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy() + spec = au.gen_mag_spectrogram( + audio, sampling_rate, params["fft_win_length"], params["fft_overlap"] + ) + # spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy() if spec.shape[0] < max_freq: freq_pad = max_freq - spec.shape[0] - spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec)) - spec = spec[-max_freq:spec.shape[0]-min_freq, :] + spec = np.vstack( + (np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec) + ) + spec = spec[-max_freq : spec.shape[0] - min_freq, :] - if norm_type == 'log': - log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum()) + if norm_type == "log": + log_scaling = ( + 2.0 + * (1.0 / sampling_rate) + * ( + 1.0 + / ( + np.abs( + np.hanning( + int(params["fft_win_length"] * sampling_rate) + ) + ) + ** 2 + ).sum() + ) + ) ##log_scaling = 0.01 - spec = np.log(1.0 + log_scaling*spec).astype(np.float32) - elif norm_type == 'pcen': + spec = np.log(1.0 + log_scaling * spec).astype(np.float32) + elif norm_type == "pcen": spec = au.pcen(spec, sampling_rate) else: pass if smooth_spec: - spec = ndimage.gaussian_filter(spec, 1) + spec = ndimage.gaussian_filter(spec, 1) return spec -def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', extract_bg=False): +def load_data( + anns, + params, + class_names, + smooth_spec=False, + norm_type="log", + extract_bg=False, +): specs = [] labels = [] coords = [] @@ -43,67 +72,106 @@ def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', ext sampling_rates = [] file_names = [] for cur_file in anns: - sampling_rate, audio_orig = au.load_audio_file(cur_file['file_path'], cur_file['time_exp'], - params['target_samp_rate'], params['scale_raw_audio']) + sampling_rate, audio_orig = au.load_audio_file( + cur_file["file_path"], + cur_file["time_exp"], + params["target_samp_rate"], + params["scale_raw_audio"], + ) - for ann in cur_file['annotation']: - if ann['class'] not in params['classes_to_ignore'] and ann['class'] in class_names: + for ann in cur_file["annotation"]: + if ( + ann["class"] not in params["classes_to_ignore"] + and ann["class"] in class_names + ): # clip out of bounds - if ann['low_freq'] < params['min_freq']: - ann['low_freq'] = params['min_freq'] - if ann['high_freq'] > params['max_freq']: - ann['high_freq'] = params['max_freq'] + if ann["low_freq"] < params["min_freq"]: + ann["low_freq"] = params["min_freq"] + if ann["high_freq"] > params["max_freq"]: + ann["high_freq"] = params["max_freq"] # load cropped audio - start_samp_diff = int(sampling_rate*ann['start_time']) - int(sampling_rate*params['aud_pad']) + start_samp_diff = int(sampling_rate * ann["start_time"]) - int( + sampling_rate * params["aud_pad"] + ) start_samp = np.maximum(0, start_samp_diff) - end_samp = np.minimum(audio_orig.shape[0], int(sampling_rate*ann['end_time'])*2 + int(sampling_rate*params['aud_pad'])) + end_samp = np.minimum( + audio_orig.shape[0], + int(sampling_rate * ann["end_time"]) * 2 + + int(sampling_rate * params["aud_pad"]), + ) audio = audio_orig[start_samp:end_samp] if start_samp_diff < 0: # need to pad at start if the call is at the very begining - audio = np.hstack((np.zeros(-start_samp_diff, dtype=np.float32), audio)) + audio = np.hstack( + (np.zeros(-start_samp_diff, dtype=np.float32), audio) + ) - nfft = int(params['fft_win_length']*sampling_rate) - noverlap = int(params['fft_overlap']*nfft) - max_samps = params['spec_width']*(nfft - noverlap) + noverlap + nfft = int(params["fft_win_length"] * sampling_rate) + noverlap = int(params["fft_overlap"] * nfft) + max_samps = params["spec_width"] * (nfft - noverlap) + noverlap if max_samps > audio.shape[0]: - audio = np.hstack((audio, np.zeros(max_samps - audio.shape[0]))) + audio = np.hstack( + (audio, np.zeros(max_samps - audio.shape[0])) + ) audio = audio[:max_samps].astype(np.float32) - audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], - params['fft_overlap'], params['resize_factor'], - params['spec_divide_factor']) + audio = au.pad_audio( + audio, + sampling_rate, + params["fft_win_length"], + params["fft_overlap"], + params["resize_factor"], + params["spec_divide_factor"], + ) # generate spectrogram - spec = generate_spectrogram_data(audio, sampling_rate, params, norm_type, smooth_spec)[:, :params['spec_width']] + spec = generate_spectrogram_data( + audio, sampling_rate, params, norm_type, smooth_spec + )[:, : params["spec_width"]] specs.append(spec[np.newaxis, ...]) - labels.append(ann['class']) + labels.append(ann["class"]) audios.append(audio) sampling_rates.append(sampling_rate) - file_names.append(cur_file['file_path']) + file_names.append(cur_file["file_path"]) # position in crop - x1 = int(au.time_to_x_coords(np.array(params['aud_pad']), sampling_rate, params['fft_win_length'], params['fft_overlap'])) - y1 = (ann['low_freq'] - params['min_freq']) * params['fft_win_length'] + x1 = int( + au.time_to_x_coords( + np.array(params["aud_pad"]), + sampling_rate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + y1 = (ann["low_freq"] - params["min_freq"]) * params[ + "fft_win_length" + ] coords.append((y1, x1)) - _, file_ids = np.unique(file_names, return_inverse=True) labels = np.array([class_names.index(ll) for ll in labels]) - #return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names + # return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names return np.vstack(specs), labels -def save_summary_image(specs, labels, species_names, params, op_file_name='plots/all_species.png', order=None): +def save_summary_image( + specs, + labels, + species_names, + params, + op_file_name="plots/all_species.png", + order=None, +): # takes the mean for each class and plots it on a grid mean_specs = [] max_band = [] for ii in range(len(species_names)): - inds = np.where(labels==ii)[0] + inds = np.where(labels == ii)[0] mu = specs[inds, :].mean(0) max_band.append(np.argmax(mu.sum(1))) mean_specs.append(mu) @@ -113,11 +181,21 @@ def save_summary_image(specs, labels, species_names, params, op_file_name='plots order = np.arange(len(species_names)) max_cols = 6 - nrows = int(np.ceil(len(species_names)/max_cols)) + nrows = int(np.ceil(len(species_names) / max_cols)) ncols = np.minimum(len(species_names), max_cols) - fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3.3, nrows*6), gridspec_kw = {'wspace':0, 'hspace':0.2}) - spec_min_max = (0, mean_specs[0].shape[1], params['min_freq']/1000, params['max_freq']/1000) + fig, ax = plt.subplots( + nrows=nrows, + ncols=ncols, + figsize=(ncols * 3.3, nrows * 6), + gridspec_kw={"wspace": 0, "hspace": 0.2}, + ) + spec_min_max = ( + 0, + mean_specs[0].shape[1], + params["min_freq"] / 1000, + params["max_freq"] / 1000, + ) ii = 0 for row in ax: @@ -126,17 +204,24 @@ def save_summary_image(specs, labels, species_names, params, op_file_name='plots for col in row: if ii >= len(species_names): - col.axis('off') + col.axis("off") else: - inds = np.where(labels==order[ii])[0] - col.imshow(mean_specs[order[ii]], extent=spec_min_max, cmap='plasma', aspect='equal') - col.grid(color='w', alpha=0.3, linewidth=0.3) + inds = np.where(labels == order[ii])[0] + col.imshow( + mean_specs[order[ii]], + extent=spec_min_max, + cmap="plasma", + aspect="equal", + ) + col.grid(color="w", alpha=0.3, linewidth=0.3) col.set_xticks([]) - col.title.set_text(str(ii+1) + ' ' + species_names[order[ii]]) - col.tick_params(axis='both', which='major', labelsize=7) + col.title.set_text( + str(ii + 1) + " " + species_names[order[ii]] + ) + col.tick_params(axis="both", which="major", labelsize=7) ii += 1 - #plt.tight_layout() - #plt.show() + # plt.tight_layout() + # plt.show() plt.savefig(op_file_name) - plt.close('all') + plt.close("all")