diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 0000000..0fd091c
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,2 @@
+# Format code with Black and isort
+3c17a2337166245de8df778fe174aad997e14e8f
diff --git a/.ropeproject/autoimport.db b/.ropeproject/autoimport.db
new file mode 100644
index 0000000..585d383
Binary files /dev/null and b/.ropeproject/autoimport.db differ
diff --git a/app.py b/app.py
index ae44690..1c884f0 100644
--- a/app.py
+++ b/app.py
@@ -1,84 +1,121 @@
-import gradio as gr
import os
+
+import gradio as gr
import matplotlib.pyplot as plt
-import pandas as pd
import numpy as np
+import pandas as pd
-import bat_detect.utils.detector_utils as du
import bat_detect.utils.audio_utils as au
+import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as viz
-
# setup the arguments
args = {}
args = du.get_default_bd_args()
-args['detection_threshold'] = 0.3
-args['time_expansion_factor'] = 1
-args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
+args["detection_threshold"] = 0.3
+args["time_expansion_factor"] = 1
+args["model_path"] = "models/Net2DFast_UK_same.pth.tar"
max_duration = 2.0
# load the model
-model, params = du.load_model(args['model_path'])
+model, params = du.load_model(args["model_path"])
-df = gr.Dataframe(
- headers=["species", "time", "detection_prob", "species_prob"],
- datatype=["str", "str", "str", "str"],
- row_count=1,
- col_count=(4, "fixed"),
- label='Predictions'
- )
-
-examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
- ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
- ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
+df = gr.Dataframe(
+ headers=["species", "time", "detection_prob", "species_prob"],
+ datatype=["str", "str", "str", "str"],
+ row_count=1,
+ col_count=(4, "fixed"),
+ label="Predictions",
+)
+
+examples = [
+ ["example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav", 0.3],
+ ["example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav", 0.3],
+ ["example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav", 0.3],
+]
def make_prediction(file_name=None, detection_threshold=0.3):
-
+
if file_name is not None:
audio_file = file_name
else:
return "You must provide an input audio file."
-
- if detection_threshold is not None and detection_threshold != '':
- args['detection_threshold'] = float(detection_threshold)
-
+
+ if detection_threshold is not None and detection_threshold != "":
+ args["detection_threshold"] = float(detection_threshold)
+
# process the file to generate predictions
- results = du.process_file(audio_file, model, params, args, max_duration=max_duration)
-
- anns = [ann for ann in results['pred_dict']['annotation']]
- clss = [aa['class'] for aa in anns]
- st_time = [aa['start_time'] for aa in anns]
- cls_prob = [aa['class_prob'] for aa in anns]
- det_prob = [aa['det_prob'] for aa in anns]
- data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob}
-
+ results = du.process_file(
+ audio_file, model, params, args, max_duration=max_duration
+ )
+
+ anns = [ann for ann in results["pred_dict"]["annotation"]]
+ clss = [aa["class"] for aa in anns]
+ st_time = [aa["start_time"] for aa in anns]
+ cls_prob = [aa["class_prob"] for aa in anns]
+ det_prob = [aa["det_prob"] for aa in anns]
+ data = {
+ "species": clss,
+ "time": st_time,
+ "detection_prob": det_prob,
+ "species_prob": cls_prob,
+ }
+
df = pd.DataFrame(data=data)
im = generate_results_image(audio_file, anns)
-
+
return [df, im]
-def generate_results_image(audio_file, anns):
-
+def generate_results_image(audio_file, anns):
+
# load audio
- sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'],
- params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)
+ sampling_rate, audio = au.load_audio_file(
+ audio_file,
+ args["time_expansion_factor"],
+ params["target_samp_rate"],
+ params["scale_raw_audio"],
+ max_duration=max_duration,
+ )
duration = audio.shape[0] / sampling_rate
-
+
# generate spec
- spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)
+ spec, spec_viz = au.generate_spectrogram(
+ audio, sampling_rate, params, True, False
+ )
# create fig
- plt.close('all')
- fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)
- spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
- viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True)
- plt.ylabel('Freq - kHz')
- plt.xlabel('Time - secs')
+ plt.close("all")
+ fig = plt.figure(
+ 1,
+ figsize=(spec.shape[1] / 100, spec.shape[0] / 100),
+ dpi=100,
+ frameon=False,
+ )
+ spec_duration = au.x_coords_to_time(
+ spec.shape[1],
+ sampling_rate,
+ params["fft_win_length"],
+ params["fft_overlap"],
+ )
+ viz.create_box_image(
+ spec,
+ fig,
+ anns,
+ 0,
+ spec_duration,
+ spec_duration,
+ params,
+ spec.max() * 1.1,
+ False,
+ True,
+ )
+ plt.ylabel("Freq - kHz")
+ plt.xlabel("Time - secs")
plt.tight_layout()
-
+
# convert fig to image
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
@@ -88,21 +125,23 @@ def generate_results_image(audio_file, anns):
return im
-descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
- "
This model is only trained on bat species from the UK. If the input " \
- "file is longer than 2 seconds, only the first 2 seconds will be processed." \
- "
Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
+descr_txt = (
+ "Demo of BatDetect2 deep learning-based bat echolocation call detection. "
+ "
This model is only trained on bat species from the UK. If the input "
+ "file is longer than 2 seconds, only the first 2 seconds will be processed."
+ "
Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
+)
gr.Interface(
- fn = make_prediction,
- inputs = [gr.Audio(source="upload", type="filepath", optional=True),
- gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
- outputs = [df, gr.Image(label="Visualisation")],
- theme = "huggingface",
- title = "BatDetect2 Demo",
- description = descr_txt,
- examples = examples,
- allow_flagging = 'never',
+ fn=make_prediction,
+ inputs=[
+ gr.Audio(source="upload", type="filepath", optional=True),
+ gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
+ ],
+ outputs=[df, gr.Image(label="Visualisation")],
+ theme="huggingface",
+ title="BatDetect2 Demo",
+ description=descr_txt,
+ examples=examples,
+ allow_flagging="never",
).launch()
-
-
diff --git a/bat_detect/detector/compute_features.py b/bat_detect/detector/compute_features.py
index b6f4b8b..368c2db 100644
--- a/bat_detect/detector/compute_features.py
+++ b/bat_detect/detector/compute_features.py
@@ -2,8 +2,10 @@ import numpy as np
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
- spec_ind = spec_height-spec_ind
- return round((spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2)
+ spec_ind = spec_height - spec_ind
+ return round(
+ (spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
+ )
def extract_spec_slices(spec, pred_nms, params):
@@ -11,28 +13,40 @@ def extract_spec_slices(spec, pred_nms, params):
Extracts spectrogram slices from spectrogram based on detected call locations.
"""
- x_pos = pred_nms['x_pos']
- y_pos = pred_nms['y_pos']
- bb_width = pred_nms['bb_width']
- bb_height = pred_nms['bb_height']
- slices = []
+ x_pos = pred_nms["x_pos"]
+ y_pos = pred_nms["y_pos"]
+ bb_width = pred_nms["bb_width"]
+ bb_height = pred_nms["bb_height"]
+ slices = []
# add 20% padding either side of call
- pad = bb_width*0.2
- x_pos_pad = x_pos - pad
- bb_width_pad = bb_width + 2*pad
+ pad = bb_width * 0.2
+ x_pos_pad = x_pos - pad
+ bb_width_pad = bb_width + 2 * pad
- for ff in range(len(pred_nms['det_probs'])):
+ for ff in range(len(pred_nms["det_probs"])):
x_start = int(np.maximum(0, x_pos_pad[ff]))
- x_end = int(np.minimum(spec.shape[1]-1, np.round(x_pos_pad[ff] + bb_width_pad[ff])))
+ x_end = int(
+ np.minimum(
+ spec.shape[1] - 1, np.round(x_pos_pad[ff] + bb_width_pad[ff])
+ )
+ )
slices.append(spec[:, x_start:x_end].astype(np.float16))
return slices
def get_feature_names():
- feature_names = ['duration', 'low_freq_bb', 'high_freq_bb', 'bandwidth',
- 'max_power_bb', 'max_power', 'max_power_first',
- 'max_power_second', 'call_interval']
+ feature_names = [
+ "duration",
+ "low_freq_bb",
+ "high_freq_bb",
+ "bandwidth",
+ "max_power_bb",
+ "max_power",
+ "max_power_first",
+ "max_power_second",
+ "call_interval",
+ ]
return feature_names
@@ -45,40 +59,76 @@ def get_feats(spec, pred_nms, params):
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
"""
- x_pos = pred_nms['x_pos']
- y_pos = pred_nms['y_pos']
- bb_width = pred_nms['bb_width']
- bb_height = pred_nms['bb_height']
+ x_pos = pred_nms["x_pos"]
+ y_pos = pred_nms["y_pos"]
+ bb_width = pred_nms["bb_width"]
+ bb_height = pred_nms["bb_height"]
- feature_names = get_feature_names()
- num_detections = len(pred_nms['det_probs'])
- features = np.ones((num_detections, len(feature_names)), dtype=np.float32)*-1
+ feature_names = get_feature_names()
+ num_detections = len(pred_nms["det_probs"])
+ features = (
+ np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1
+ )
for ff in range(num_detections):
x_start = int(np.maximum(0, x_pos[ff]))
- x_end = int(np.minimum(spec.shape[1]-1, np.round(x_pos[ff] + bb_width[ff])))
+ x_end = int(
+ np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff]))
+ )
# y low is the lowest freq but it will have a higher value due to array starting at 0 at top
- y_low = int(np.minimum(spec.shape[0]-1, y_pos[ff]))
- y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff])))
+ y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff]))
+ y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff])))
spec_slice = spec[:, x_start:x_end]
if spec_slice.shape[1] > 1:
- features[ff, 0] = round(pred_nms['end_times'][ff] - pred_nms['start_times'][ff], 5)
- features[ff, 1] = int(pred_nms['low_freqs'][ff])
- features[ff, 2] = int(pred_nms['high_freqs'][ff])
- features[ff, 3] = int(pred_nms['high_freqs'][ff] - pred_nms['low_freqs'][ff])
- features[ff, 4] = int(convert_int_to_freq(y_high+spec_slice[y_high:y_low, :].sum(1).argmax(),
- spec.shape[0], params['min_freq'], params['max_freq']))
- features[ff, 5] = int(convert_int_to_freq(spec_slice.sum(1).argmax(),
- spec.shape[0], params['min_freq'], params['max_freq']))
- hlf_val = spec_slice.shape[1]//2
+ features[ff, 0] = round(
+ pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5
+ )
+ features[ff, 1] = int(pred_nms["low_freqs"][ff])
+ features[ff, 2] = int(pred_nms["high_freqs"][ff])
+ features[ff, 3] = int(
+ pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff]
+ )
+ features[ff, 4] = int(
+ convert_int_to_freq(
+ y_high + spec_slice[y_high:y_low, :].sum(1).argmax(),
+ spec.shape[0],
+ params["min_freq"],
+ params["max_freq"],
+ )
+ )
+ features[ff, 5] = int(
+ convert_int_to_freq(
+ spec_slice.sum(1).argmax(),
+ spec.shape[0],
+ params["min_freq"],
+ params["max_freq"],
+ )
+ )
+ hlf_val = spec_slice.shape[1] // 2
- features[ff, 6] = int(convert_int_to_freq(spec_slice[:, :hlf_val].sum(1).argmax(),
- spec.shape[0], params['min_freq'], params['max_freq']))
- features[ff, 7] = int(convert_int_to_freq(spec_slice[:, hlf_val:].sum(1).argmax(),
- spec.shape[0], params['min_freq'], params['max_freq']))
+ features[ff, 6] = int(
+ convert_int_to_freq(
+ spec_slice[:, :hlf_val].sum(1).argmax(),
+ spec.shape[0],
+ params["min_freq"],
+ params["max_freq"],
+ )
+ )
+ features[ff, 7] = int(
+ convert_int_to_freq(
+ spec_slice[:, hlf_val:].sum(1).argmax(),
+ spec.shape[0],
+ params["min_freq"],
+ params["max_freq"],
+ )
+ )
if ff > 0:
- features[ff, 8] = round(pred_nms['start_times'][ff] - pred_nms['start_times'][ff-1], 5)
+ features[ff, 8] = round(
+ pred_nms["start_times"][ff]
+ - pred_nms["start_times"][ff - 1],
+ 5,
+ )
return features
diff --git a/bat_detect/detector/model_helpers.py b/bat_detect/detector/model_helpers.py
index c91ef04..e237f7c 100644
--- a/bat_detect/detector/model_helpers.py
+++ b/bat_detect/detector/model_helpers.py
@@ -1,47 +1,71 @@
-import torch.nn as nn
-import torch
-import torch.nn.functional as F
-import numpy as np
import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
class SelfAttention(nn.Module):
def __init__(self, ip_dim, att_dim):
super(SelfAttention, self).__init__()
# Note, does not encode position information (absolute or realtive)
self.temperature = 1.0
- self.att_dim = att_dim
+ self.att_dim = att_dim
self.key_fun = nn.Linear(ip_dim, att_dim)
self.val_fun = nn.Linear(ip_dim, att_dim)
self.que_fun = nn.Linear(ip_dim, att_dim)
self.pro_fun = nn.Linear(att_dim, ip_dim)
def forward(self, x):
- x = x.squeeze(2).permute(0,2,1)
+ x = x.squeeze(2).permute(0, 2, 1)
- kk = torch.matmul(x, self.key_fun.weight.T) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
- qq = torch.matmul(x, self.que_fun.weight.T) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
- vv = torch.matmul(x, self.val_fun.weight.T) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
+ kk = torch.matmul(
+ x, self.key_fun.weight.T
+ ) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
+ qq = torch.matmul(
+ x, self.que_fun.weight.T
+ ) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
+ vv = torch.matmul(
+ x, self.val_fun.weight.T
+ ) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
- kk_qq = torch.bmm(kk, qq.permute(0,2,1)) / (self.temperature*self.att_dim)
- att_weights = F.softmax(kk_qq, 1) # each col of each attention matrix sums to 1
- att = torch.bmm(vv.permute(0,2,1), att_weights)
+ kk_qq = torch.bmm(kk, qq.permute(0, 2, 1)) / (
+ self.temperature * self.att_dim
+ )
+ att_weights = F.softmax(
+ kk_qq, 1
+ ) # each col of each attention matrix sums to 1
+ att = torch.bmm(vv.permute(0, 2, 1), att_weights)
- op = torch.matmul(att.permute(0,2,1), self.pro_fun.weight.T) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0)
- op = op.permute(0,2,1).unsqueeze(2)
+ op = torch.matmul(
+ att.permute(0, 2, 1), self.pro_fun.weight.T
+ ) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0)
+ op = op.permute(0, 2, 1).unsqueeze(2)
return op
class ConvBlockDownCoordF(nn.Module):
- def __init__(self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1):
+ def __init__(
+ self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1
+ ):
super(ConvBlockDownCoordF, self).__init__()
- self.coords = nn.Parameter(torch.linspace(-1, 1, ip_height)[None, None, ..., None], requires_grad=False)
- self.conv = nn.Conv2d(in_chn+1, out_chn, kernel_size=k_size, padding=pad_size, stride=stride)
+ self.coords = nn.Parameter(
+ torch.linspace(-1, 1, ip_height)[None, None, ..., None],
+ requires_grad=False,
+ )
+ self.conv = nn.Conv2d(
+ in_chn + 1,
+ out_chn,
+ kernel_size=k_size,
+ padding=pad_size,
+ stride=stride,
+ )
self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x):
- freq_info = self.coords.repeat(x.shape[0],1,1,x.shape[3])
+ freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
x = torch.cat((x, freq_info), 1)
x = F.max_pool2d(self.conv(x), 2, 2)
x = F.relu(self.conv_bn(x), inplace=True)
@@ -49,9 +73,17 @@ class ConvBlockDownCoordF(nn.Module):
class ConvBlockDownStandard(nn.Module):
- def __init__(self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1):
+ def __init__(
+ self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1
+ ):
super(ConvBlockDownStandard, self).__init__()
- self.conv = nn.Conv2d(in_chn, out_chn, kernel_size=k_size, padding=pad_size, stride=stride)
+ self.conv = nn.Conv2d(
+ in_chn,
+ out_chn,
+ kernel_size=k_size,
+ padding=pad_size,
+ stride=stride,
+ )
self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x):
@@ -61,17 +93,41 @@ class ConvBlockDownStandard(nn.Module):
class ConvBlockUpF(nn.Module):
- def __init__(self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, up_mode='bilinear', up_scale=(2,2)):
+ def __init__(
+ self,
+ in_chn,
+ out_chn,
+ ip_height,
+ k_size=3,
+ pad_size=1,
+ up_mode="bilinear",
+ up_scale=(2, 2),
+ ):
super(ConvBlockUpF, self).__init__()
self.up_scale = up_scale
self.up_mode = up_mode
- self.coords = nn.Parameter(torch.linspace(-1, 1, ip_height*up_scale[0])[None, None, ..., None], requires_grad=False)
- self.conv = nn.Conv2d(in_chn+1, out_chn, kernel_size=k_size, padding=pad_size)
+ self.coords = nn.Parameter(
+ torch.linspace(-1, 1, ip_height * up_scale[0])[
+ None, None, ..., None
+ ],
+ requires_grad=False,
+ )
+ self.conv = nn.Conv2d(
+ in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size
+ )
self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x):
- op = F.interpolate(x, size=(x.shape[-2]*self.up_scale[0], x.shape[-1]*self.up_scale[1]), mode=self.up_mode, align_corners=False)
- freq_info = self.coords.repeat(op.shape[0],1,1,op.shape[3])
+ op = F.interpolate(
+ x,
+ size=(
+ x.shape[-2] * self.up_scale[0],
+ x.shape[-1] * self.up_scale[1],
+ ),
+ mode=self.up_mode,
+ align_corners=False,
+ )
+ freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
op = torch.cat((op, freq_info), 1)
op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True)
@@ -79,15 +135,34 @@ class ConvBlockUpF(nn.Module):
class ConvBlockUpStandard(nn.Module):
- def __init__(self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, up_mode='bilinear', up_scale=(2,2)):
+ def __init__(
+ self,
+ in_chn,
+ out_chn,
+ ip_height=None,
+ k_size=3,
+ pad_size=1,
+ up_mode="bilinear",
+ up_scale=(2, 2),
+ ):
super(ConvBlockUpStandard, self).__init__()
self.up_scale = up_scale
self.up_mode = up_mode
- self.conv = nn.Conv2d(in_chn, out_chn, kernel_size=k_size, padding=pad_size)
+ self.conv = nn.Conv2d(
+ in_chn, out_chn, kernel_size=k_size, padding=pad_size
+ )
self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x):
- op = F.interpolate(x, size=(x.shape[-2]*self.up_scale[0], x.shape[-1]*self.up_scale[1]), mode=self.up_mode, align_corners=False)
+ op = F.interpolate(
+ x,
+ size=(
+ x.shape[-2] * self.up_scale[0],
+ x.shape[-1] * self.up_scale[1],
+ ),
+ mode=self.up_mode,
+ align_corners=False,
+ )
op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True)
return op
diff --git a/bat_detect/detector/models.py b/bat_detect/detector/models.py
index fc7b5b4..b39cbf4 100644
--- a/bat_detect/detector/models.py
+++ b/bat_detect/detector/models.py
@@ -1,52 +1,97 @@
-import torch.nn as nn
-import torch
-import torch.nn.functional as F
import numpy as np
-from .model_helpers import *
-
-import torchvision
-
+import torch
import torch.fft
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
from torch import nn
+from .model_helpers import *
+
class Net2DFast(nn.Module):
- def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
+ def __init__(
+ self,
+ num_filts,
+ num_classes=0,
+ emb_dim=0,
+ ip_height=128,
+ resize_factor=0.5,
+ ):
super(Net2DFast, self).__init__()
self.num_classes = num_classes
self.emb_dim = emb_dim
self.num_filts = num_filts
self.resize_factor = resize_factor
self.ip_height_rs = ip_height
- self.bneck_height = self.ip_height_rs//32
+ self.bneck_height = self.ip_height_rs // 32
# encoder
- self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
- self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
- self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
- self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
- self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
+ self.conv_dn_0 = ConvBlockDownCoordF(
+ 1,
+ num_filts // 4,
+ self.ip_height_rs,
+ k_size=3,
+ pad_size=1,
+ stride=1,
+ )
+ self.conv_dn_1 = ConvBlockDownCoordF(
+ num_filts // 4,
+ num_filts // 2,
+ self.ip_height_rs // 2,
+ k_size=3,
+ pad_size=1,
+ stride=1,
+ )
+ self.conv_dn_2 = ConvBlockDownCoordF(
+ num_filts // 2,
+ num_filts,
+ self.ip_height_rs // 4,
+ k_size=3,
+ pad_size=1,
+ stride=1,
+ )
+ self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1)
+ self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2)
# bottleneck
- self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
- self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
- self.att = SelfAttention(num_filts*2, num_filts*2)
+ self.conv_1d = nn.Conv2d(
+ num_filts * 2,
+ num_filts * 2,
+ (self.ip_height_rs // 8, 1),
+ padding=0,
+ )
+ self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2)
+ self.att = SelfAttention(num_filts * 2, num_filts * 2)
# decoder
- self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8)
- self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4)
- self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2)
+ self.conv_up_2 = ConvBlockUpF(
+ num_filts * 2, num_filts // 2, self.ip_height_rs // 8
+ )
+ self.conv_up_3 = ConvBlockUpF(
+ num_filts // 2, num_filts // 4, self.ip_height_rs // 4
+ )
+ self.conv_up_4 = ConvBlockUpF(
+ num_filts // 4, num_filts // 4, self.ip_height_rs // 2
+ )
# output
# +1 to include background class for class output
- self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
- self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
- self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
- self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
+ self.conv_op = nn.Conv2d(
+ num_filts // 4, num_filts // 4, kernel_size=3, padding=1
+ )
+ self.conv_op_bn = nn.BatchNorm2d(num_filts // 4)
+ self.conv_size_op = nn.Conv2d(
+ num_filts // 4, 2, kernel_size=1, padding=0
+ )
+ self.conv_classes_op = nn.Conv2d(
+ num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
+ )
if self.emb_dim > 0:
- self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
-
+ self.conv_emb = nn.Conv2d(
+ num_filts, self.emb_dim, kernel_size=1, padding=0
+ )
def forward(self, ip, return_feats=False):
@@ -59,33 +104,40 @@ class Net2DFast(nn.Module):
# bottleneck
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = self.att(x)
- x = x.repeat([1,1,self.bneck_height*4,1])
+ x = x.repeat([1, 1, self.bneck_height * 4, 1])
# decoder
- x = self.conv_up_2(x+x3)
- x = self.conv_up_3(x+x2)
- x = self.conv_up_4(x+x1)
+ x = self.conv_up_2(x + x3)
+ x = self.conv_up_3(x + x2)
+ x = self.conv_up_4(x + x1)
# output
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
- cls = self.conv_classes_op(x)
+ cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
op = {}
- op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
- op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
- op['pred_class'] = comb
- op['pred_class_un_norm'] = cls
+ op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
+ op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
+ op["pred_class"] = comb
+ op["pred_class_un_norm"] = cls
if self.emb_dim > 0:
- op['pred_emb'] = self.conv_emb(x)
+ op["pred_emb"] = self.conv_emb(x)
if return_feats:
- op['features'] = x
+ op["features"] = x
return op
class Net2DFastNoAttn(nn.Module):
- def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
+ def __init__(
+ self,
+ num_filts,
+ num_classes=0,
+ emb_dim=0,
+ ip_height=128,
+ resize_factor=0.5,
+ ):
super(Net2DFastNoAttn, self).__init__()
self.num_classes = num_classes
@@ -93,31 +145,70 @@ class Net2DFastNoAttn(nn.Module):
self.num_filts = num_filts
self.resize_factor = resize_factor
self.ip_height_rs = ip_height
- self.bneck_height = self.ip_height_rs//32
+ self.bneck_height = self.ip_height_rs // 32
- self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
- self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
- self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
- self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
- self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
+ self.conv_dn_0 = ConvBlockDownCoordF(
+ 1,
+ num_filts // 4,
+ self.ip_height_rs,
+ k_size=3,
+ pad_size=1,
+ stride=1,
+ )
+ self.conv_dn_1 = ConvBlockDownCoordF(
+ num_filts // 4,
+ num_filts // 2,
+ self.ip_height_rs // 2,
+ k_size=3,
+ pad_size=1,
+ stride=1,
+ )
+ self.conv_dn_2 = ConvBlockDownCoordF(
+ num_filts // 2,
+ num_filts,
+ self.ip_height_rs // 4,
+ k_size=3,
+ pad_size=1,
+ stride=1,
+ )
+ self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1)
+ self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2)
- self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
- self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
+ self.conv_1d = nn.Conv2d(
+ num_filts * 2,
+ num_filts * 2,
+ (self.ip_height_rs // 8, 1),
+ padding=0,
+ )
+ self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2)
-
- self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8)
- self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4)
- self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2)
+ self.conv_up_2 = ConvBlockUpF(
+ num_filts * 2, num_filts // 2, self.ip_height_rs // 8
+ )
+ self.conv_up_3 = ConvBlockUpF(
+ num_filts // 2, num_filts // 4, self.ip_height_rs // 4
+ )
+ self.conv_up_4 = ConvBlockUpF(
+ num_filts // 4, num_filts // 4, self.ip_height_rs // 2
+ )
# output
# +1 to include background class for class output
- self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
- self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
- self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
- self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
+ self.conv_op = nn.Conv2d(
+ num_filts // 4, num_filts // 4, kernel_size=3, padding=1
+ )
+ self.conv_op_bn = nn.BatchNorm2d(num_filts // 4)
+ self.conv_size_op = nn.Conv2d(
+ num_filts // 4, 2, kernel_size=1, padding=0
+ )
+ self.conv_classes_op = nn.Conv2d(
+ num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
+ )
if self.emb_dim > 0:
- self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
+ self.conv_emb = nn.Conv2d(
+ num_filts, self.emb_dim, kernel_size=1, padding=0
+ )
def forward(self, ip, return_feats=False):
@@ -127,31 +218,38 @@ class Net2DFastNoAttn(nn.Module):
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
- x = x.repeat([1,1,self.bneck_height*4,1])
+ x = x.repeat([1, 1, self.bneck_height * 4, 1])
- x = self.conv_up_2(x+x3)
- x = self.conv_up_3(x+x2)
- x = self.conv_up_4(x+x1)
+ x = self.conv_up_2(x + x3)
+ x = self.conv_up_3(x + x2)
+ x = self.conv_up_4(x + x1)
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
- cls = self.conv_classes_op(x)
+ cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
op = {}
- op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
- op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
- op['pred_class'] = comb
- op['pred_class_un_norm'] = cls
+ op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
+ op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
+ op["pred_class"] = comb
+ op["pred_class_un_norm"] = cls
if self.emb_dim > 0:
- op['pred_emb'] = self.conv_emb(x)
+ op["pred_emb"] = self.conv_emb(x)
if return_feats:
- op['features'] = x
+ op["features"] = x
return op
class Net2DFastNoCoordConv(nn.Module):
- def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
+ def __init__(
+ self,
+ num_filts,
+ num_classes=0,
+ emb_dim=0,
+ ip_height=128,
+ resize_factor=0.5,
+ ):
super(Net2DFastNoCoordConv, self).__init__()
self.num_classes = num_classes
@@ -159,32 +257,72 @@ class Net2DFastNoCoordConv(nn.Module):
self.num_filts = num_filts
self.resize_factor = resize_factor
self.ip_height_rs = ip_height
- self.bneck_height = self.ip_height_rs//32
+ self.bneck_height = self.ip_height_rs // 32
- self.conv_dn_0 = ConvBlockDownStandard(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
- self.conv_dn_1 = ConvBlockDownStandard(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
- self.conv_dn_2 = ConvBlockDownStandard(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
- self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
- self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
+ self.conv_dn_0 = ConvBlockDownStandard(
+ 1,
+ num_filts // 4,
+ self.ip_height_rs,
+ k_size=3,
+ pad_size=1,
+ stride=1,
+ )
+ self.conv_dn_1 = ConvBlockDownStandard(
+ num_filts // 4,
+ num_filts // 2,
+ self.ip_height_rs // 2,
+ k_size=3,
+ pad_size=1,
+ stride=1,
+ )
+ self.conv_dn_2 = ConvBlockDownStandard(
+ num_filts // 2,
+ num_filts,
+ self.ip_height_rs // 4,
+ k_size=3,
+ pad_size=1,
+ stride=1,
+ )
+ self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1)
+ self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2)
- self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
- self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
+ self.conv_1d = nn.Conv2d(
+ num_filts * 2,
+ num_filts * 2,
+ (self.ip_height_rs // 8, 1),
+ padding=0,
+ )
+ self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2)
- self.att = SelfAttention(num_filts*2, num_filts*2)
+ self.att = SelfAttention(num_filts * 2, num_filts * 2)
- self.conv_up_2 = ConvBlockUpStandard(num_filts*2, num_filts//2, self.ip_height_rs//8)
- self.conv_up_3 = ConvBlockUpStandard(num_filts//2, num_filts//4, self.ip_height_rs//4)
- self.conv_up_4 = ConvBlockUpStandard(num_filts//4, num_filts//4, self.ip_height_rs//2)
+ self.conv_up_2 = ConvBlockUpStandard(
+ num_filts * 2, num_filts // 2, self.ip_height_rs // 8
+ )
+ self.conv_up_3 = ConvBlockUpStandard(
+ num_filts // 2, num_filts // 4, self.ip_height_rs // 4
+ )
+ self.conv_up_4 = ConvBlockUpStandard(
+ num_filts // 4, num_filts // 4, self.ip_height_rs // 2
+ )
# output
# +1 to include background class for class output
- self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
- self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
- self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
- self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
+ self.conv_op = nn.Conv2d(
+ num_filts // 4, num_filts // 4, kernel_size=3, padding=1
+ )
+ self.conv_op_bn = nn.BatchNorm2d(num_filts // 4)
+ self.conv_size_op = nn.Conv2d(
+ num_filts // 4, 2, kernel_size=1, padding=0
+ )
+ self.conv_classes_op = nn.Conv2d(
+ num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
+ )
if self.emb_dim > 0:
- self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
+ self.conv_emb = nn.Conv2d(
+ num_filts, self.emb_dim, kernel_size=1, padding=0
+ )
def forward(self, ip, return_feats=False):
@@ -195,24 +333,24 @@ class Net2DFastNoCoordConv(nn.Module):
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = self.att(x)
- x = x.repeat([1,1,self.bneck_height*4,1])
+ x = x.repeat([1, 1, self.bneck_height * 4, 1])
- x = self.conv_up_2(x+x3)
- x = self.conv_up_3(x+x2)
- x = self.conv_up_4(x+x1)
+ x = self.conv_up_2(x + x3)
+ x = self.conv_up_3(x + x2)
+ x = self.conv_up_4(x + x1)
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
- cls = self.conv_classes_op(x)
+ cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1)
op = {}
- op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
- op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
- op['pred_class'] = comb
- op['pred_class_un_norm'] = cls
+ op["pred_det"] = comb[:, :-1, :, :].sum(1).unsqueeze(1)
+ op["pred_size"] = F.relu(self.conv_size_op(x), inplace=True)
+ op["pred_class"] = comb
+ op["pred_class_un_norm"] = cls
if self.emb_dim > 0:
- op['pred_emb'] = self.conv_emb(x)
+ op["pred_emb"] = self.conv_emb(x)
if return_feats:
- op['features'] = x
+ op["features"] = x
return op
diff --git a/bat_detect/detector/parameters.py b/bat_detect/detector/parameters.py
index 10276eb..d93ac8c 100644
--- a/bat_detect/detector/parameters.py
+++ b/bat_detect/detector/parameters.py
@@ -1,108 +1,164 @@
-import numpy as np
-import os
import datetime
+import os
+
+import numpy as np
def mk_dir(path):
if not os.path.isdir(path):
os.makedirs(path)
-
-
-def get_params(make_dirs=False, exps_dir='../../experiments/'):
+
+
+def get_params(make_dirs=False, exps_dir="../../experiments/"):
params = {}
- params['model_name'] = 'Net2DFast' # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
- params['num_filters'] = 128
+ params[
+ "model_name"
+ ] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
+ params["num_filters"] = 128
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
- model_name = now_str + '.pth.tar'
- params['experiment'] = os.path.join(exps_dir, now_str, '')
- params['model_file_name'] = os.path.join(params['experiment'], model_name)
- params['op_im_dir'] = os.path.join(params['experiment'], 'op_ims', '')
- params['op_im_dir_test'] = os.path.join(params['experiment'], 'op_ims_test', '')
- #params['notes'] = '' # can save notes about an experiment here
-
+ model_name = now_str + ".pth.tar"
+ params["experiment"] = os.path.join(exps_dir, now_str, "")
+ params["model_file_name"] = os.path.join(params["experiment"], model_name)
+ params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "")
+ params["op_im_dir_test"] = os.path.join(
+ params["experiment"], "op_ims_test", ""
+ )
+ # params['notes'] = '' # can save notes about an experiment here
# spec parameters
- params['target_samp_rate'] = 256000 # resamples all audio so that it is at this rate
- params['fft_win_length'] = 512 / 256000.0 # in milliseconds, amount of time per stft time step
- params['fft_overlap'] = 0.75 # stft window overlap
+ params[
+ "target_samp_rate"
+ ] = 256000 # resamples all audio so that it is at this rate
+ params["fft_win_length"] = (
+ 512 / 256000.0
+ ) # in milliseconds, amount of time per stft time step
+ params["fft_overlap"] = 0.75 # stft window overlap
- params['max_freq'] = 120000 # in Hz, everything above this will be discarded
- params['min_freq'] = 10000 # in Hz, everything below this will be discarded
+ params[
+ "max_freq"
+ ] = 120000 # in Hz, everything above this will be discarded
+ params[
+ "min_freq"
+ ] = 10000 # in Hz, everything below this will be discarded
- params['resize_factor'] = 0.5 # resize so the spectrogram at the input of the network
- params['spec_height'] = 256 # units are number of frequency bins (before resizing is performed)
- params['spec_train_width'] = 512 # units are number of time steps (before resizing is performed)
- params['spec_divide_factor'] = 32 # spectrogram should be divisible by this amount in width and height
+ params[
+ "resize_factor"
+ ] = 0.5 # resize so the spectrogram at the input of the network
+ params[
+ "spec_height"
+ ] = 256 # units are number of frequency bins (before resizing is performed)
+ params[
+ "spec_train_width"
+ ] = 512 # units are number of time steps (before resizing is performed)
+ params[
+ "spec_divide_factor"
+ ] = 32 # spectrogram should be divisible by this amount in width and height
# spec processing params
- params['denoise_spec_avg'] = True # removes the mean for each frequency band
- params['scale_raw_audio'] = False # scales the raw audio to [-1, 1]
- params['max_scale_spec'] = False # scales the spectrogram so that it is max 1
- params['spec_scale'] = 'pcen' # 'log', 'pcen', 'none'
+ params[
+ "denoise_spec_avg"
+ ] = True # removes the mean for each frequency band
+ params["scale_raw_audio"] = False # scales the raw audio to [-1, 1]
+ params[
+ "max_scale_spec"
+ ] = False # scales the spectrogram so that it is max 1
+ params["spec_scale"] = "pcen" # 'log', 'pcen', 'none'
# detection params
- params['detection_overlap'] = 0.01 # has to be within this number of ms to count as detection
- params['ignore_start_end'] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
- params['detection_threshold'] = 0.01 # the smaller this is the better the recall will be
- params['nms_kernel_size'] = 9
- params['nms_top_k_per_sec'] = 200 # keep top K highest predictions per second of audio
- params['target_sigma'] = 2.0
+ params[
+ "detection_overlap"
+ ] = 0.01 # has to be within this number of ms to count as detection
+ params[
+ "ignore_start_end"
+ ] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
+ params[
+ "detection_threshold"
+ ] = 0.01 # the smaller this is the better the recall will be
+ params["nms_kernel_size"] = 9
+ params[
+ "nms_top_k_per_sec"
+ ] = 200 # keep top K highest predictions per second of audio
+ params["target_sigma"] = 2.0
# augmentation params
- params['aug_prob'] = 0.20 # augmentations will be performed with this probability
- params['augment_at_train'] = True
- params['augment_at_train_combine'] = True
- params['echo_max_delay'] = 0.005 # simulate echo by adding copy of raw audio
- params['stretch_squeeze_delta'] = 0.04 # stretch or squeeze spec
- params['mask_max_time_perc'] = 0.05 # max mask size - here percentage, not ideal
- params['mask_max_freq_perc'] = 0.10 # max mask size - here percentage, not ideal
- params['spec_amp_scaling'] = 2.0 # multiply the "volume" by 0:X times current amount
- params['aug_sampling_rates'] = [220500, 256000, 300000, 312500, 384000, 441000, 500000]
+ params[
+ "aug_prob"
+ ] = 0.20 # augmentations will be performed with this probability
+ params["augment_at_train"] = True
+ params["augment_at_train_combine"] = True
+ params[
+ "echo_max_delay"
+ ] = 0.005 # simulate echo by adding copy of raw audio
+ params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec
+ params[
+ "mask_max_time_perc"
+ ] = 0.05 # max mask size - here percentage, not ideal
+ params[
+ "mask_max_freq_perc"
+ ] = 0.10 # max mask size - here percentage, not ideal
+ params[
+ "spec_amp_scaling"
+ ] = 2.0 # multiply the "volume" by 0:X times current amount
+ params["aug_sampling_rates"] = [
+ 220500,
+ 256000,
+ 300000,
+ 312500,
+ 384000,
+ 441000,
+ 500000,
+ ]
# loss params
- params['train_loss'] = 'focal' # mse or focal
- params['det_loss_weight'] = 1.0 # weight for the detection part of the loss
- params['size_loss_weight'] = 0.1 # weight for the bbox size loss
- params['class_loss_weight'] = 2.0 # weight for the classification loss
- params['individual_loss_weight'] = 0.0 # not used
- if params['individual_loss_weight'] == 0.0:
- params['emb_dim'] = 0 # number of dimensions used for individual id embedding
+ params["train_loss"] = "focal" # mse or focal
+ params[
+ "det_loss_weight"
+ ] = 1.0 # weight for the detection part of the loss
+ params["size_loss_weight"] = 0.1 # weight for the bbox size loss
+ params["class_loss_weight"] = 2.0 # weight for the classification loss
+ params["individual_loss_weight"] = 0.0 # not used
+ if params["individual_loss_weight"] == 0.0:
+ params[
+ "emb_dim"
+ ] = 0 # number of dimensions used for individual id embedding
else:
- params['emb_dim'] = 3
+ params["emb_dim"] = 3
# train params
- params['lr'] = 0.001
- params['batch_size'] = 8
- params['num_workers'] = 4
- params['num_epochs'] = 200
- params['num_eval_epochs'] = 5 # run evaluation every X epochs
- params['device'] = 'cuda'
- params['save_test_image_during_train'] = False
- params['save_test_image_after_train'] = True
+ params["lr"] = 0.001
+ params["batch_size"] = 8
+ params["num_workers"] = 4
+ params["num_epochs"] = 200
+ params["num_eval_epochs"] = 5 # run evaluation every X epochs
+ params["device"] = "cuda"
+ params["save_test_image_during_train"] = False
+ params["save_test_image_after_train"] = True
- params['convert_to_genus'] = False
- params['genus_mapping'] = []
- params['class_names'] = []
- params['classes_to_ignore'] = ['', ' ', 'Unknown', 'Not Bat']
- params['generic_class'] = ['Bat']
- params['events_of_interest'] = ['Echolocation'] # will ignore all other types of events e.g. social calls
+ params["convert_to_genus"] = False
+ params["genus_mapping"] = []
+ params["class_names"] = []
+ params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"]
+ params["generic_class"] = ["Bat"]
+ params["events_of_interest"] = [
+ "Echolocation"
+ ] # will ignore all other types of events e.g. social calls
# the classes in this list are standardized during training so that the same low and high freq are used
- params['standardize_classs_names'] = []
+ params["standardize_classs_names"] = []
# create directories
if make_dirs:
- print('Model name : ' + params['model_name'])
- print('Model file : ' + params['model_file_name'])
- print('Experiment : ' + params['experiment'])
+ print("Model name : " + params["model_name"])
+ print("Model file : " + params["model_file_name"])
+ print("Experiment : " + params["experiment"])
- mk_dir(params['experiment'])
- if params['save_test_image_during_train']:
- mk_dir(params['op_im_dir'])
- if params['save_test_image_after_train']:
- mk_dir(params['op_im_dir_test'])
- mk_dir(os.path.dirname(params['model_file_name']))
+ mk_dir(params["experiment"])
+ if params["save_test_image_during_train"]:
+ mk_dir(params["op_im_dir"])
+ if params["save_test_image_after_train"]:
+ mk_dir(params["op_im_dir_test"])
+ mk_dir(os.path.dirname(params["model_file_name"]))
return params
diff --git a/bat_detect/detector/post_process.py b/bat_detect/detector/post_process.py
index 757831f..2745cdf 100644
--- a/bat_detect/detector/post_process.py
+++ b/bat_detect/detector/post_process.py
@@ -1,35 +1,42 @@
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-import numpy as np
-np.seterr(divide='ignore', invalid='ignore')
+
+np.seterr(divide="ignore", invalid="ignore")
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
- nfft = int(fft_win_length*sampling_rate)
- noverlap = int(fft_overlap*nfft)
- return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate
- #return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
+ nfft = int(fft_win_length * sampling_rate)
+ noverlap = int(fft_overlap * nfft)
+ return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
+ # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
def overall_class_pred(det_prob, class_prob):
- weighted_pred = (class_prob*det_prob).sum(1)
+ weighted_pred = (class_prob * det_prob).sum(1)
return weighted_pred / weighted_pred.sum()
def run_nms(outputs, params, sampling_rate):
- pred_det = outputs['pred_det'] # probability of box
- pred_size = outputs['pred_size'] # box size
+ pred_det = outputs["pred_det"] # probability of box
+ pred_size = outputs["pred_size"] # box size
- pred_det_nms = non_max_suppression(pred_det, params['nms_kernel_size'])
- freq_rescale = (params['max_freq'] - params['min_freq']) /pred_det.shape[-2]
+ pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
+ freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
+ -2
+ ]
# NOTE there will be small differences depending on which sampling rate is chosen
# as we are choosing the same sampling rate for the entire batch
- duration = x_coords_to_time(pred_det.shape[-1], sampling_rate[0].item(),
- params['fft_win_length'], params['fft_overlap'])
- top_k = int(duration * params['nms_top_k_per_sec'])
+ duration = x_coords_to_time(
+ pred_det.shape[-1],
+ sampling_rate[0].item(),
+ params["fft_win_length"],
+ params["fft_overlap"],
+ )
+ top_k = int(duration * params["nms_top_k_per_sec"])
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
# loop over batch to save outputs
@@ -38,30 +45,47 @@ def run_nms(outputs, params, sampling_rate):
for ii in range(pred_det_nms.shape[0]):
# get valid indices
inds_ord = torch.argsort(x_pos[ii, :])
- valid_inds = scores[ii, inds_ord] > params['detection_threshold']
+ valid_inds = scores[ii, inds_ord] > params["detection_threshold"]
valid_inds = inds_ord[valid_inds]
# create result dictionary
pred = {}
- pred['det_probs'] = scores[ii, valid_inds]
- pred['x_pos'] = x_pos[ii, valid_inds]
- pred['y_pos'] = y_pos[ii, valid_inds]
- pred['bb_width'] = pred_size[ii, 0, pred['y_pos'], pred['x_pos']]
- pred['bb_height'] = pred_size[ii, 1, pred['y_pos'], pred['x_pos']]
- pred['start_times'] = x_coords_to_time(pred['x_pos'].float() / params['resize_factor'],
- sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap'])
- pred['end_times'] = x_coords_to_time((pred['x_pos'].float()+pred['bb_width']) / params['resize_factor'],
- sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap'])
- pred['low_freqs'] = (pred_size[ii].shape[1] - pred['y_pos'].float())*freq_rescale + params['min_freq']
- pred['high_freqs'] = pred['low_freqs'] + pred['bb_height']*freq_rescale
+ pred["det_probs"] = scores[ii, valid_inds]
+ pred["x_pos"] = x_pos[ii, valid_inds]
+ pred["y_pos"] = y_pos[ii, valid_inds]
+ pred["bb_width"] = pred_size[ii, 0, pred["y_pos"], pred["x_pos"]]
+ pred["bb_height"] = pred_size[ii, 1, pred["y_pos"], pred["x_pos"]]
+ pred["start_times"] = x_coords_to_time(
+ pred["x_pos"].float() / params["resize_factor"],
+ sampling_rate[ii].item(),
+ params["fft_win_length"],
+ params["fft_overlap"],
+ )
+ pred["end_times"] = x_coords_to_time(
+ (pred["x_pos"].float() + pred["bb_width"])
+ / params["resize_factor"],
+ sampling_rate[ii].item(),
+ params["fft_win_length"],
+ params["fft_overlap"],
+ )
+ pred["low_freqs"] = (
+ pred_size[ii].shape[1] - pred["y_pos"].float()
+ ) * freq_rescale + params["min_freq"]
+ pred["high_freqs"] = (
+ pred["low_freqs"] + pred["bb_height"] * freq_rescale
+ )
# extract the per class votes
- if 'pred_class' in outputs:
- pred['class_probs'] = outputs['pred_class'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]]
+ if "pred_class" in outputs:
+ pred["class_probs"] = outputs["pred_class"][
+ ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]
+ ]
# extract the model features
- if 'features' in outputs:
- feat = outputs['features'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]].transpose(0, 1)
+ if "features" in outputs:
+ feat = outputs["features"][
+ ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]
+ ].transpose(0, 1)
feat = feat.cpu().numpy().astype(np.float32)
feats.append(feat)
@@ -82,7 +106,9 @@ def non_max_suppression(heat, kernel_size):
pad_h = (kernel_size_h - 1) // 2
pad_w = (kernel_size_w - 1) // 2
- hmax = nn.functional.max_pool2d(heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w))
+ hmax = nn.functional.max_pool2d(
+ heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w)
+ )
keep = (hmax == heat).float()
return heat * keep
@@ -94,7 +120,7 @@ def get_topk_scores(scores, K):
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
topk_inds = topk_inds % (height * width)
- topk_ys = torch.div(topk_inds, width, rounding_mode='floor').long()
- topk_xs = (topk_inds % width).long()
+ topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
+ topk_xs = (topk_inds % width).long()
return topk_scores, topk_ys, topk_xs
diff --git a/bat_detect/evaluate/evaluate_models.py b/bat_detect/evaluate/evaluate_models.py
index 0fc8ae9..6b7c460 100644
--- a/bat_detect/evaluate/evaluate_models.py
+++ b/bat_detect/evaluate/evaluate_models.py
@@ -2,67 +2,76 @@
Evaluates trained model on test set and generates plots.
"""
-import numpy as np
-import sys
-import os
+import argparse
import copy
import json
+import os
+import sys
+
+import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
-import argparse
-sys.path.append('../../')
-import bat_detect.utils.detector_utils as du
-import bat_detect.train.train_utils as tu
+sys.path.append("../../")
import bat_detect.detector.parameters as parameters
import bat_detect.train.evaluate as evl
+import bat_detect.train.train_utils as tu
+import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as pu
def get_blank_annotation(ip_str):
res = {}
- res['class_name'] = ''
- res['duration'] = -1
- res['id'] = ''# fileName
- res['issues'] = False
- res['notes'] = ip_str
- res['time_exp'] = 1
- res['annotated'] = False
- res['annotation'] = []
+ res["class_name"] = ""
+ res["duration"] = -1
+ res["id"] = "" # fileName
+ res["issues"] = False
+ res["notes"] = ip_str
+ res["time_exp"] = 1
+ res["annotated"] = False
+ res["annotation"] = []
ann = {}
- ann['class'] = ''
- ann['event'] = 'Echolocation'
- ann['individual'] = -1
- ann['start_time'] = -1
- ann['end_time'] = -1
- ann['low_freq'] = -1
- ann['high_freq'] = -1
- ann['confidence'] = -1
+ ann["class"] = ""
+ ann["event"] = "Echolocation"
+ ann["individual"] = -1
+ ann["start_time"] = -1
+ ann["end_time"] = -1
+ ann["low_freq"] = -1
+ ann["high_freq"] = -1
+ ann["confidence"] = -1
return copy.deepcopy(res), copy.deepcopy(ann)
def create_genus_mapping(gt_test, preds, class_names):
# rolls the per class predictions and ground truth back up to genus level
- class_names_genus, cls_to_genus = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True)
- genus_to_cls_map = [np.where(np.array(cls_to_genus) == cc)[0] for cc in range(len(class_names_genus))]
+ class_names_genus, cls_to_genus = np.unique(
+ [cc.split(" ")[0] for cc in class_names], return_inverse=True
+ )
+ genus_to_cls_map = [
+ np.where(np.array(cls_to_genus) == cc)[0]
+ for cc in range(len(class_names_genus))
+ ]
gt_test_g = []
for gg in gt_test:
gg_g = copy.deepcopy(gg)
- inds = np.where(gg_g['class_ids']!=-1)[0]
- gg_g['class_ids'][inds] = cls_to_genus[gg_g['class_ids'][inds]]
+ inds = np.where(gg_g["class_ids"] != -1)[0]
+ gg_g["class_ids"][inds] = cls_to_genus[gg_g["class_ids"][inds]]
gt_test_g.append(gg_g)
# note, will have entries geater than one as we are summing across the respective classes
preds_g = []
for pp in preds:
pp_g = copy.deepcopy(pp)
- pp_g['class_probs'] = np.zeros((len(class_names_genus), pp_g['class_probs'].shape[1]), dtype=np.float32)
+ pp_g["class_probs"] = np.zeros(
+ (len(class_names_genus), pp_g["class_probs"].shape[1]),
+ dtype=np.float32,
+ )
for cc, inds in enumerate(genus_to_cls_map):
- pp_g['class_probs'][cc, :] = pp['class_probs'][inds, :].sum(0)
+ pp_g["class_probs"][cc, :] = pp["class_probs"][inds, :].sum(0)
preds_g.append(pp_g)
return class_names_genus, preds_g, gt_test_g
@@ -70,56 +79,70 @@ def create_genus_mapping(gt_test, preds, class_names):
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
- res, ann = get_blank_annotation('Generated by Tadarida')
+ res, ann = get_blank_annotation("Generated by Tadarida")
# create the annotations in the correct format
- da_c = pd.read_csv(ip_dir + dataset + '/' + file_of_interest.replace('.wav', '.ta').replace('.WAV', '.ta'), sep='\t')
+ da_c = pd.read_csv(
+ ip_dir
+ + dataset
+ + "/"
+ + file_of_interest.replace(".wav", ".ta").replace(".WAV", ".ta"),
+ sep="\t",
+ )
res_c = copy.deepcopy(res)
- res_c['id'] = file_of_interest
- res_c['dataset'] = dataset
- res_c['feats'] = da_c.iloc[:, 6:].values.astype(np.float32)
+ res_c["id"] = file_of_interest
+ res_c["dataset"] = dataset
+ res_c["feats"] = da_c.iloc[:, 6:].values.astype(np.float32)
if da_c.shape[0] > 0:
- res_c['class_name'] = ''
- res_c['class_prob'] = 0.0
+ res_c["class_name"] = ""
+ res_c["class_prob"] = 0.0
for aa in range(da_c.shape[0]):
ann_c = copy.deepcopy(ann)
- ann_c['class'] = 'Not Bat' # will assign to class later
- ann_c['start_time'] = np.round(da_c.iloc[aa]['StTime']/1000.0 ,5)
- ann_c['end_time'] = np.round((da_c.iloc[aa]['StTime'] + da_c.iloc[aa]['Dur'])/1000.0, 5)
- ann_c['low_freq'] = np.round(da_c.iloc[aa]['Fmin'] * 1000.0, 2)
- ann_c['high_freq'] = np.round(da_c.iloc[aa]['Fmax'] * 1000.0, 2)
- ann_c['det_prob'] = 0.0
- res_c['annotation'].append(ann_c)
+ ann_c["class"] = "Not Bat" # will assign to class later
+ ann_c["start_time"] = np.round(da_c.iloc[aa]["StTime"] / 1000.0, 5)
+ ann_c["end_time"] = np.round(
+ (da_c.iloc[aa]["StTime"] + da_c.iloc[aa]["Dur"]) / 1000.0, 5
+ )
+ ann_c["low_freq"] = np.round(da_c.iloc[aa]["Fmin"] * 1000.0, 2)
+ ann_c["high_freq"] = np.round(da_c.iloc[aa]["Fmax"] * 1000.0, 2)
+ ann_c["det_prob"] = 0.0
+ res_c["annotation"].append(ann_c)
return res_c
-def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_accepted_species=True):
+def load_sonobat_meta(
+ ip_dir,
+ datasets,
+ region_classifier,
+ class_names,
+ only_accepted_species=True,
+):
sp_dict = {}
for ss in class_names:
- sp_key = ss.split(' ')[0][:3] + ss.split(' ')[1][:3]
+ sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
sp_dict[sp_key] = ss
- sp_dict['x'] = '' # not bat
- sp_dict['Bat'] = 'Bat'
+ sp_dict["x"] = "" # not bat
+ sp_dict["Bat"] = "Bat"
sonobat_meta = {}
for tt in datasets:
- dataset = tt['dataset_name']
- sb_ip_dir = ip_dir + dataset + '/' + region_classifier + '/'
+ dataset = tt["dataset_name"]
+ sb_ip_dir = ip_dir + dataset + "/" + region_classifier + "/"
# load the call level predictions
- ip_file_p = sb_ip_dir + dataset + '_Parameters_v4.5.0.txt'
- #ip_file_p = sb_ip_dir + 'audio_SonoBatch_v30.0 beta.txt'
- da = pd.read_csv(ip_file_p, sep='\t')
+ ip_file_p = sb_ip_dir + dataset + "_Parameters_v4.5.0.txt"
+ # ip_file_p = sb_ip_dir + 'audio_SonoBatch_v30.0 beta.txt'
+ da = pd.read_csv(ip_file_p, sep="\t")
# load the file level predictions
- ip_file_b = sb_ip_dir + dataset + '_SonoBatch_v4.5.0.txt'
- #ip_file_b = sb_ip_dir + 'audio_CumulativeParameters_v30.0 beta.txt'
+ ip_file_b = sb_ip_dir + dataset + "_SonoBatch_v4.5.0.txt"
+ # ip_file_b = sb_ip_dir + 'audio_CumulativeParameters_v30.0 beta.txt'
with open(ip_file_b) as f:
lines = f.readlines()
@@ -129,7 +152,7 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc
file_res = {}
for ll in lines:
# note this does not seem to parse the file very well
- ll_data = ll.split('\t')
+ ll_data = ll.split("\t")
# there are sometimes many different species names per file
if only_accepted_species:
@@ -137,20 +160,24 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc
ind = 4
else:
# choosing ""~Spp" if "SppAccp" does not exist
- if ll_data[4] != 'x':
- ind = 4 # choosing "SppAccp", along with "Prob" here
+ if ll_data[4] != "x":
+ ind = 4 # choosing "SppAccp", along with "Prob" here
else:
ind = 8 # choosing "~Spp", along with "~Prob" here
sp_name_1 = sp_dict[ll_data[ind]]
- prob_1 = ll_data[ind+1]
- if prob_1 == 'x':
+ prob_1 = ll_data[ind + 1]
+ if prob_1 == "x":
prob_1 = 0.0
- file_res[ll_data[1]] = {'id':ll_data[1], 'species_1':sp_name_1, 'prob_1':prob_1}
+ file_res[ll_data[1]] = {
+ "id": ll_data[1],
+ "species_1": sp_name_1,
+ "prob_1": prob_1,
+ }
sonobat_meta[dataset] = {}
- sonobat_meta[dataset]['file_res'] = file_res
- sonobat_meta[dataset]['call_info'] = da
+ sonobat_meta[dataset]["file_res"] = file_res
+ sonobat_meta[dataset]["call_info"] = da
return sonobat_meta
@@ -158,34 +185,38 @@ def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_acc
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
# create the annotations in the correct format
- res, ann = get_blank_annotation('Generated by Sonobat')
+ res, ann = get_blank_annotation("Generated by Sonobat")
res_c = copy.deepcopy(res)
- res_c['id'] = id
- res_c['dataset'] = dataset
+ res_c["id"] = id
+ res_c["dataset"] = dataset
- da = sb_meta[dataset]['call_info']
- da_c = da[da['Filename'] == id]
+ da = sb_meta[dataset]["call_info"]
+ da_c = da[da["Filename"] == id]
- file_res = sb_meta[dataset]['file_res']
- res_c['feats'] = np.zeros((0,0))
+ file_res = sb_meta[dataset]["file_res"]
+ res_c["feats"] = np.zeros((0, 0))
if da_c.shape[0] > 0:
- res_c['class_name'] = file_res[id]['species_1']
- res_c['class_prob'] = file_res[id]['prob_1']
- res_c['feats'] = da_c.iloc[:, 3:105].values.astype(np.float32)
+ res_c["class_name"] = file_res[id]["species_1"]
+ res_c["class_prob"] = file_res[id]["prob_1"]
+ res_c["feats"] = da_c.iloc[:, 3:105].values.astype(np.float32)
for aa in range(da_c.shape[0]):
ann_c = copy.deepcopy(ann)
if set_class_name is None:
- ann_c['class'] = file_res[id]['species_1']
+ ann_c["class"] = file_res[id]["species_1"]
else:
- ann_c['class'] = set_class_name
- ann_c['start_time'] = np.round(da_c.iloc[aa]['TimeInFile'] / 1000.0 ,5)
- ann_c['end_time'] = np.round(ann_c['start_time'] + da_c.iloc[aa]['CallDuration']/1000.0, 5)
- ann_c['low_freq'] = np.round(da_c.iloc[aa]['LowFreq'] * 1000.0, 2)
- ann_c['high_freq'] = np.round(da_c.iloc[aa]['HiFreq'] * 1000.0, 2)
- ann_c['det_prob'] = np.round(da_c.iloc[aa]['Quality'], 3)
- res_c['annotation'].append(ann_c)
+ ann_c["class"] = set_class_name
+ ann_c["start_time"] = np.round(
+ da_c.iloc[aa]["TimeInFile"] / 1000.0, 5
+ )
+ ann_c["end_time"] = np.round(
+ ann_c["start_time"] + da_c.iloc[aa]["CallDuration"] / 1000.0, 5
+ )
+ ann_c["low_freq"] = np.round(da_c.iloc[aa]["LowFreq"] * 1000.0, 2)
+ ann_c["high_freq"] = np.round(da_c.iloc[aa]["HiFreq"] * 1000.0, 2)
+ ann_c["det_prob"] = np.round(da_c.iloc[aa]["Quality"], 3)
+ res_c["annotation"].append(ann_c)
return res_c
@@ -193,8 +224,18 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
def bb_overlap(bb_g_in, bb_p_in):
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
- bb_g = [bb_g_in['start_time'], bb_g_in['low_freq']/freq_scale, bb_g_in['end_time'], bb_g_in['high_freq']/freq_scale]
- bb_p = [bb_p_in['start_time'], bb_p_in['low_freq']/freq_scale, bb_p_in['end_time'], bb_p_in['high_freq']/freq_scale]
+ bb_g = [
+ bb_g_in["start_time"],
+ bb_g_in["low_freq"] / freq_scale,
+ bb_g_in["end_time"],
+ bb_g_in["high_freq"] / freq_scale,
+ ]
+ bb_p = [
+ bb_p_in["start_time"],
+ bb_p_in["low_freq"] / freq_scale,
+ bb_p_in["end_time"],
+ bb_p_in["high_freq"] / freq_scale,
+ ]
xA = max(bb_g[0], bb_p[0])
yA = max(bb_g[1], bb_p[1])
@@ -220,13 +261,15 @@ def bb_overlap(bb_g_in, bb_p_in):
def assign_to_gt(gt, pred, iou_thresh):
# this will edit pred in place
- num_preds = len(pred['annotation'])
- num_gts = len(gt['annotation'])
+ num_preds = len(pred["annotation"])
+ num_gts = len(gt["annotation"])
if num_preds > 0 and num_gts > 0:
iou_m = np.zeros((num_preds, num_gts))
for ii in range(num_preds):
for jj in range(num_gts):
- iou_m[ii, jj] = bb_overlap(gt['annotation'][jj], pred['annotation'][ii])
+ iou_m[ii, jj] = bb_overlap(
+ gt["annotation"][jj], pred["annotation"][ii]
+ )
# greedily assign detections to ground truths
# needs to be greater than some threshold and we cannot assign GT
@@ -235,7 +278,9 @@ def assign_to_gt(gt, pred, iou_thresh):
for jj in range(num_gts):
max_iou = np.argmax(iou_m[:, jj])
if iou_m[max_iou, jj] > iou_thresh:
- pred['annotation'][max_iou]['class'] = gt['annotation'][jj]['class']
+ pred["annotation"][max_iou]["class"] = gt["annotation"][jj][
+ "class"
+ ]
iou_m[max_iou, :] = -1.0
return pred
@@ -244,27 +289,39 @@ def assign_to_gt(gt, pred, iou_thresh):
def parse_data(data, class_names, non_event_classes, is_pred=False):
class_names_all = class_names + non_event_classes
- data['class_names'] = np.array([aa['class'] for aa in data['annotation']])
- data['start_times'] = np.array([aa['start_time'] for aa in data['annotation']])
- data['end_times'] = np.array([aa['end_time'] for aa in data['annotation']])
- data['high_freqs'] = np.array([float(aa['high_freq']) for aa in data['annotation']])
- data['low_freqs'] = np.array([float(aa['low_freq']) for aa in data['annotation']])
+ data["class_names"] = np.array([aa["class"] for aa in data["annotation"]])
+ data["start_times"] = np.array(
+ [aa["start_time"] for aa in data["annotation"]]
+ )
+ data["end_times"] = np.array([aa["end_time"] for aa in data["annotation"]])
+ data["high_freqs"] = np.array(
+ [float(aa["high_freq"]) for aa in data["annotation"]]
+ )
+ data["low_freqs"] = np.array(
+ [float(aa["low_freq"]) for aa in data["annotation"]]
+ )
if is_pred:
# when loading predictions
- data['det_probs'] = np.array([float(aa['det_prob']) for aa in data['annotation']])
- data['class_probs'] = np.zeros((len(class_names)+1, len(data['annotation'])))
- data['class_ids'] = np.array([class_names_all.index(aa['class']) for aa in data['annotation']]).astype(np.int32)
+ data["det_probs"] = np.array(
+ [float(aa["det_prob"]) for aa in data["annotation"]]
+ )
+ data["class_probs"] = np.zeros(
+ (len(class_names) + 1, len(data["annotation"]))
+ )
+ data["class_ids"] = np.array(
+ [class_names_all.index(aa["class"]) for aa in data["annotation"]]
+ ).astype(np.int32)
else:
# when loading ground truth
# if the class label is not in the set of interest then set to -1
labels = []
- for aa in data['annotation']:
- if aa['class'] in class_names:
- labels.append(class_names_all.index(aa['class']))
+ for aa in data["annotation"]:
+ if aa["class"] in class_names:
+ labels.append(class_names_all.index(aa["class"]))
else:
labels.append(-1)
- data['class_ids'] = np.array(labels).astype(np.int32)
+ data["class_ids"] = np.array(labels).astype(np.int32)
return data
@@ -272,12 +329,17 @@ def parse_data(data, class_names, non_event_classes, is_pred=False):
def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore):
gt_data = []
for dd in datasets:
- print('\n' + dd['dataset_name'])
- gt_dataset = tu.load_set_of_anns([dd], events_of_interest=events_of_interest, verbose=True)
- gt_dataset = [parse_data(gg, class_names, classes_to_ignore, False) for gg in gt_dataset]
+ print("\n" + dd["dataset_name"])
+ gt_dataset = tu.load_set_of_anns(
+ [dd], events_of_interest=events_of_interest, verbose=True
+ )
+ gt_dataset = [
+ parse_data(gg, class_names, classes_to_ignore, False)
+ for gg in gt_dataset
+ ]
for gt in gt_dataset:
- gt['dataset_name'] = dd['dataset_name']
+ gt["dataset_name"] = dd["dataset_name"]
gt_data.extend(gt_dataset)
@@ -300,69 +362,103 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
clf.fit(x_train, y_train)
y_pred = clf.predict(x_train)
- tr_acc = (y_pred==y_train).mean()
- #print('Train acc', round(tr_acc*100, 2))
+ tr_acc = (y_pred == y_train).mean()
+ # print('Train acc', round(tr_acc*100, 2))
return clf, un_train_class
def eval_rf_model(clf, pred, un_train_class, num_classes):
# stores the prediction in place
- if pred['feats'].shape[0] > 0:
- pred['class_probs'] = np.zeros((num_classes, pred['feats'].shape[0]))
- pred['class_probs'][un_train_class, :] = clf.predict_proba(pred['feats']).T
- pred['det_probs'] = pred['class_probs'][:-1, :].sum(0)
+ if pred["feats"].shape[0] > 0:
+ pred["class_probs"] = np.zeros((num_classes, pred["feats"].shape[0]))
+ pred["class_probs"][un_train_class, :] = clf.predict_proba(
+ pred["feats"]
+ ).T
+ pred["det_probs"] = pred["class_probs"][:-1, :].sum(0)
else:
- pred['class_probs'] = np.zeros((num_classes, 0))
- pred['det_probs'] = np.zeros(0)
+ pred["class_probs"] = np.zeros((num_classes, 0))
+ pred["det_probs"] = np.zeros(0)
return pred
def save_summary_to_json(op_dir, mod_name, results):
op = {}
- op['avg_prec'] = round(results['avg_prec'], 3)
- op['avg_prec_class'] = round(results['avg_prec_class'], 3)
- op['top_class'] = round(results['top_class']['avg_prec'], 3)
- op['file_acc'] = round(results['file_acc'], 3)
- op['model'] = mod_name
+ op["avg_prec"] = round(results["avg_prec"], 3)
+ op["avg_prec_class"] = round(results["avg_prec_class"], 3)
+ op["top_class"] = round(results["top_class"]["avg_prec"], 3)
+ op["file_acc"] = round(results["file_acc"], 3)
+ op["model"] = mod_name
- op['per_class'] = {}
- for cc in results['class_pr']:
- op['per_class'][cc['name']] = cc['avg_prec']
+ op["per_class"] = {}
+ for cc in results["class_pr"]:
+ op["per_class"][cc["name"]] = cc["avg_prec"]
- op_file_name = os.path.join(op_dir, mod_name + '_results.json')
- with open(op_file_name, 'w') as da:
+ op_file_name = os.path.join(op_dir, mod_name + "_results.json")
+ with open(op_file_name, "w") as da:
json.dump(op, da, indent=2)
-def print_results(model_name, mod_str, results, op_dir, class_names, file_type, title_text=''):
- print('\nResults - ' + model_name)
- print('avg_prec ', round(results['avg_prec'], 3))
- print('avg_prec_class', round(results['avg_prec_class'], 3))
- print('top_class ', round(results['top_class']['avg_prec'], 3))
- print('file_acc ', round(results['file_acc'], 3))
+def print_results(
+ model_name, mod_str, results, op_dir, class_names, file_type, title_text=""
+):
+ print("\nResults - " + model_name)
+ print("avg_prec ", round(results["avg_prec"], 3))
+ print("avg_prec_class", round(results["avg_prec_class"], 3))
+ print("top_class ", round(results["top_class"]["avg_prec"], 3))
+ print("file_acc ", round(results["file_acc"], 3))
- print('\nSaving ' + model_name + ' results to: ' + op_dir)
+ print("\nSaving " + model_name + " results to: " + op_dir)
save_summary_to_json(op_dir, mod_str, results)
- pu.plot_pr_curve(op_dir, mod_str+'_test_all_det', mod_str+'_test_all_det', results, file_type, title_text + 'Detection PR')
- pu.plot_pr_curve(op_dir, mod_str+'_test_all_top_class', mod_str+'_test_all_top_class', results['top_class'], file_type, title_text + 'Top Class')
- pu.plot_pr_curve_class(op_dir, mod_str+'_test_all_class', mod_str+'_test_all_class', results, file_type, title_text + 'Per-Class PR')
- pu.plot_confusion_matrix(op_dir, mod_str+'_confusion', results['gt_valid_file'], results['pred_valid_file'],
- results['file_acc'], class_names, True, file_type, title_text + 'Confusion Matrix')
+ pu.plot_pr_curve(
+ op_dir,
+ mod_str + "_test_all_det",
+ mod_str + "_test_all_det",
+ results,
+ file_type,
+ title_text + "Detection PR",
+ )
+ pu.plot_pr_curve(
+ op_dir,
+ mod_str + "_test_all_top_class",
+ mod_str + "_test_all_top_class",
+ results["top_class"],
+ file_type,
+ title_text + "Top Class",
+ )
+ pu.plot_pr_curve_class(
+ op_dir,
+ mod_str + "_test_all_class",
+ mod_str + "_test_all_class",
+ results,
+ file_type,
+ title_text + "Per-Class PR",
+ )
+ pu.plot_confusion_matrix(
+ op_dir,
+ mod_str + "_confusion",
+ results["gt_valid_file"],
+ results["pred_valid_file"],
+ results["file_acc"],
+ class_names,
+ True,
+ file_type,
+ title_text + "Confusion Matrix",
+ )
def add_root_path_back(data_sets, ann_path, wav_path):
for dd in data_sets:
- dd['ann_path'] = os.path.join(ann_path, dd['ann_path'])
- dd['wav_path'] = os.path.join(wav_path, dd['wav_path'])
+ dd["ann_path"] = os.path.join(ann_path, dd["ann_path"])
+ dd["wav_path"] = os.path.join(wav_path, dd["wav_path"])
return data_sets
def check_classes_in_train(gt_list, class_names):
- num_gt_total = np.sum([gg['start_times'].shape[0] for gg in gt_list])
+ num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
num_with_no_class = 0
for gt in gt_list:
- for cc in gt['class_names']:
+ for cc in gt["class_names"]:
if cc not in class_names:
num_with_no_class += 1
return num_with_no_class
@@ -371,195 +467,335 @@ def check_classes_in_train(gt_list, class_names):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('op_dir', type=str, default='plots/results_compare/',
- help='Output directory for plots')
- parser.add_argument('data_dir', type=str,
- help='Path to root of datasets')
- parser.add_argument('ann_dir', type=str,
- help='Path to extracted annotations')
- parser.add_argument('bd_model_path', type=str,
- help='Path to BatDetect model')
- parser.add_argument('--test_file', type=str, default='',
- help='Path to json file used for evaluation.')
- parser.add_argument('--sb_ip_dir', type=str, default='',
- help='Path to sonobat predictions')
- parser.add_argument('--sb_region_classifier', type=str, default='south',
- help='Path to sonobat predictions')
- parser.add_argument('--td_ip_dir', type=str, default='',
- help='Path to tadarida_D predictions')
- parser.add_argument('--iou_thresh', type=float, default=0.01,
- help='IOU threshold for assigning predictions to ground truth')
- parser.add_argument('--file_type', type=str, default='png',
- help='Type of image to save - png or pdf')
- parser.add_argument('--title_text', type=str, default='',
- help='Text to add as title of plots')
- parser.add_argument('--rand_seed', type=int, default=2001,
- help='Random seed')
+ parser.add_argument(
+ "op_dir",
+ type=str,
+ default="plots/results_compare/",
+ help="Output directory for plots",
+ )
+ parser.add_argument("data_dir", type=str, help="Path to root of datasets")
+ parser.add_argument(
+ "ann_dir", type=str, help="Path to extracted annotations"
+ )
+ parser.add_argument(
+ "bd_model_path", type=str, help="Path to BatDetect model"
+ )
+ parser.add_argument(
+ "--test_file",
+ type=str,
+ default="",
+ help="Path to json file used for evaluation.",
+ )
+ parser.add_argument(
+ "--sb_ip_dir", type=str, default="", help="Path to sonobat predictions"
+ )
+ parser.add_argument(
+ "--sb_region_classifier",
+ type=str,
+ default="south",
+ help="Path to sonobat predictions",
+ )
+ parser.add_argument(
+ "--td_ip_dir",
+ type=str,
+ default="",
+ help="Path to tadarida_D predictions",
+ )
+ parser.add_argument(
+ "--iou_thresh",
+ type=float,
+ default=0.01,
+ help="IOU threshold for assigning predictions to ground truth",
+ )
+ parser.add_argument(
+ "--file_type",
+ type=str,
+ default="png",
+ help="Type of image to save - png or pdf",
+ )
+ parser.add_argument(
+ "--title_text",
+ type=str,
+ default="",
+ help="Text to add as title of plots",
+ )
+ parser.add_argument(
+ "--rand_seed", type=int, default=2001, help="Random seed"
+ )
args = vars(parser.parse_args())
- np.random.seed(args['rand_seed'])
-
- if not os.path.isdir(args['op_dir']):
- os.makedirs(args['op_dir'])
+ np.random.seed(args["rand_seed"])
+ if not os.path.isdir(args["op_dir"]):
+ os.makedirs(args["op_dir"])
# load the model
params_eval = parameters.get_params(False)
- _, params_bd = du.load_model(args['bd_model_path'])
+ _, params_bd = du.load_model(args["bd_model_path"])
- class_names = params_bd['class_names']
+ class_names = params_bd["class_names"]
num_classes = len(class_names) + 1 # num classes plus background class
- classes_to_ignore = ['Not Bat', 'Bat', 'Unknown']
- events_of_interest = ['Echolocation']
+ classes_to_ignore = ["Not Bat", "Bat", "Unknown"]
+ events_of_interest = ["Echolocation"]
# load test data
- if args['test_file'] == '':
+ if args["test_file"] == "":
# load the test files of interest from the trained model
- test_sets = add_root_path_back(params_bd['test_sets'], args['ann_dir'], args['data_dir'])
- test_sets = [dd for dd in test_sets if not dd['is_binary']] # exclude bat/not datasets
+ test_sets = add_root_path_back(
+ params_bd["test_sets"], args["ann_dir"], args["data_dir"]
+ )
+ test_sets = [
+ dd for dd in test_sets if not dd["is_binary"]
+ ] # exclude bat/not datasets
else:
# user specified annotation file to evaluate
test_dict = {}
- test_dict['dataset_name'] = args['test_file'].replace('.json', '')
- test_dict['is_test'] = True
- test_dict['is_binary'] = True
- test_dict['ann_path'] = os.path.join(args['ann_dir'], args['test_file'])
- test_dict['wav_path'] = args['data_dir']
+ test_dict["dataset_name"] = args["test_file"].replace(".json", "")
+ test_dict["is_test"] = True
+ test_dict["is_binary"] = True
+ test_dict["ann_path"] = os.path.join(
+ args["ann_dir"], args["test_file"]
+ )
+ test_dict["wav_path"] = args["data_dir"]
test_sets = [test_dict]
# load the gt for the test set
- gt_test = load_gt_data(test_sets, events_of_interest, class_names, classes_to_ignore)
- total_num_calls = np.sum([gg['start_times'].shape[0] for gg in gt_test])
- print('\nTotal number of test files:', len(gt_test))
- print('Total number of test calls:', np.sum([gg['start_times'].shape[0] for gg in gt_test]))
+ gt_test = load_gt_data(
+ test_sets, events_of_interest, class_names, classes_to_ignore
+ )
+ total_num_calls = np.sum([gg["start_times"].shape[0] for gg in gt_test])
+ print("\nTotal number of test files:", len(gt_test))
+ print(
+ "Total number of test calls:",
+ np.sum([gg["start_times"].shape[0] for gg in gt_test]),
+ )
# check if test contains classes not in the train set
num_with_no_class = check_classes_in_train(gt_test, class_names)
if total_num_calls == num_with_no_class:
- print('Classes from the test set are not in the train set.')
+ print("Classes from the test set are not in the train set.")
assert False
# only need the train data if evaluating Sonobat or Tadarida
- if args['sb_ip_dir'] != '' or args['td_ip_dir'] != '':
- train_sets = add_root_path_back(params_bd['train_sets'], args['ann_dir'], args['data_dir'])
- train_sets = [dd for dd in train_sets if not dd['is_binary']] # exclude bat/not datasets
- gt_train = load_gt_data(train_sets, events_of_interest, class_names, classes_to_ignore)
-
+ if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
+ train_sets = add_root_path_back(
+ params_bd["train_sets"], args["ann_dir"], args["data_dir"]
+ )
+ train_sets = [
+ dd for dd in train_sets if not dd["is_binary"]
+ ] # exclude bat/not datasets
+ gt_train = load_gt_data(
+ train_sets, events_of_interest, class_names, classes_to_ignore
+ )
#
# evaluate Sonobat by training random forest classifier
#
# NOTE: Sonobat may only make predictions for a subset of the files
#
- if args['sb_ip_dir'] != '':
- sb_meta = load_sonobat_meta(args['sb_ip_dir'], train_sets + test_sets, args['sb_region_classifier'], class_names)
+ if args["sb_ip_dir"] != "":
+ sb_meta = load_sonobat_meta(
+ args["sb_ip_dir"],
+ train_sets + test_sets,
+ args["sb_region_classifier"],
+ class_names,
+ )
preds_sb = []
keep_inds_sb = []
for ii, gt in enumerate(gt_test):
- sb_pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta)
- if sb_pred['class_name'] != '':
- sb_pred = parse_data(sb_pred, class_names, classes_to_ignore, True)
- sb_pred['class_probs'][sb_pred['class_ids'], np.arange(sb_pred['class_probs'].shape[1])] = sb_pred['det_probs']
+ sb_pred = load_sonobat_preds(gt["dataset_name"], gt["id"], sb_meta)
+ if sb_pred["class_name"] != "":
+ sb_pred = parse_data(
+ sb_pred, class_names, classes_to_ignore, True
+ )
+ sb_pred["class_probs"][
+ sb_pred["class_ids"],
+ np.arange(sb_pred["class_probs"].shape[1]),
+ ] = sb_pred["det_probs"]
preds_sb.append(sb_pred)
keep_inds_sb.append(ii)
- results_sb = evl.evaluate_predictions([gt_test[ii] for ii in keep_inds_sb], preds_sb, class_names,
- params_eval['detection_overlap'], params_eval['ignore_start_end'])
- print_results('Sonobat', 'sb', results_sb, args['op_dir'], class_names,
- args['file_type'], args['title_text'] + ' - Species - ')
- print('Only reporting results for', len(keep_inds_sb), 'files, out of', len(gt_test))
-
+ results_sb = evl.evaluate_predictions(
+ [gt_test[ii] for ii in keep_inds_sb],
+ preds_sb,
+ class_names,
+ params_eval["detection_overlap"],
+ params_eval["ignore_start_end"],
+ )
+ print_results(
+ "Sonobat",
+ "sb",
+ results_sb,
+ args["op_dir"],
+ class_names,
+ args["file_type"],
+ args["title_text"] + " - Species - ",
+ )
+ print(
+ "Only reporting results for",
+ len(keep_inds_sb),
+ "files, out of",
+ len(gt_test),
+ )
# train our own random forest on sonobat features
x_train = []
y_train = []
for gt in gt_train:
- pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta, 'Not Bat')
+ pred = load_sonobat_preds(
+ gt["dataset_name"], gt["id"], sb_meta, "Not Bat"
+ )
- if len(pred['annotation']) > 0:
+ if len(pred["annotation"]) > 0:
# compute detection overlap with ground truth to determine which are the TP detections
- assign_to_gt(gt, pred, args['iou_thresh'])
+ assign_to_gt(gt, pred, args["iou_thresh"])
pred = parse_data(pred, class_names, classes_to_ignore, True)
- x_train.append(pred['feats'])
- y_train.append(pred['class_ids'])
+ x_train.append(pred["feats"])
+ y_train.append(pred["class_ids"])
# train random forest on tadarida predictions
- clf_sb, un_train_class = train_rf_model(x_train, y_train, num_classes, args['rand_seed'])
+ clf_sb, un_train_class = train_rf_model(
+ x_train, y_train, num_classes, args["rand_seed"]
+ )
# run the model on the test set
preds_sb_rf = []
for gt in gt_test:
- pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta, 'Not Bat')
+ pred = load_sonobat_preds(
+ gt["dataset_name"], gt["id"], sb_meta, "Not Bat"
+ )
pred = parse_data(pred, class_names, classes_to_ignore, True)
pred = eval_rf_model(clf_sb, pred, un_train_class, num_classes)
preds_sb_rf.append(pred)
- results_sb_rf = evl.evaluate_predictions(gt_test, preds_sb_rf, class_names,
- params_eval['detection_overlap'], params_eval['ignore_start_end'])
- print_results('Sonobat RF', 'sb_rf', results_sb_rf, args['op_dir'], class_names,
- args['file_type'], args['title_text'] + ' - Species - ')
- print('\n\nWARNING\nThis is evaluating on the full test set, but there is only dections for a subset of files\n\n')
-
+ results_sb_rf = evl.evaluate_predictions(
+ gt_test,
+ preds_sb_rf,
+ class_names,
+ params_eval["detection_overlap"],
+ params_eval["ignore_start_end"],
+ )
+ print_results(
+ "Sonobat RF",
+ "sb_rf",
+ results_sb_rf,
+ args["op_dir"],
+ class_names,
+ args["file_type"],
+ args["title_text"] + " - Species - ",
+ )
+ print(
+ "\n\nWARNING\nThis is evaluating on the full test set, but there is only dections for a subset of files\n\n"
+ )
#
# evaluate Tadarida-D by training random forest classifier
#
- if args['td_ip_dir'] != '':
+ if args["td_ip_dir"] != "":
x_train = []
y_train = []
for gt in gt_train:
- pred = load_tadarida_pred(args['td_ip_dir'], gt['dataset_name'], gt['id'])
+ pred = load_tadarida_pred(
+ args["td_ip_dir"], gt["dataset_name"], gt["id"]
+ )
# compute detection overlap with ground truth to determine which are the TP detections
- assign_to_gt(gt, pred, args['iou_thresh'])
+ assign_to_gt(gt, pred, args["iou_thresh"])
pred = parse_data(pred, class_names, classes_to_ignore, True)
- x_train.append(pred['feats'])
- y_train.append(pred['class_ids'])
+ x_train.append(pred["feats"])
+ y_train.append(pred["class_ids"])
# train random forest on Tadarida-D predictions
- clf_td, un_train_class = train_rf_model(x_train, y_train, num_classes, args['rand_seed'])
+ clf_td, un_train_class = train_rf_model(
+ x_train, y_train, num_classes, args["rand_seed"]
+ )
# run the model on the test set
preds_td = []
for gt in gt_test:
- pred = load_tadarida_pred(args['td_ip_dir'], gt['dataset_name'], gt['id'])
+ pred = load_tadarida_pred(
+ args["td_ip_dir"], gt["dataset_name"], gt["id"]
+ )
pred = parse_data(pred, class_names, classes_to_ignore, True)
pred = eval_rf_model(clf_td, pred, un_train_class, num_classes)
preds_td.append(pred)
- results_td = evl.evaluate_predictions(gt_test, preds_td, class_names,
- params_eval['detection_overlap'], params_eval['ignore_start_end'])
- print_results('Tadarida', 'td_rf', results_td, args['op_dir'], class_names,
- args['file_type'], args['title_text'] + ' - Species - ')
-
+ results_td = evl.evaluate_predictions(
+ gt_test,
+ preds_td,
+ class_names,
+ params_eval["detection_overlap"],
+ params_eval["ignore_start_end"],
+ )
+ print_results(
+ "Tadarida",
+ "td_rf",
+ results_td,
+ args["op_dir"],
+ class_names,
+ args["file_type"],
+ args["title_text"] + " - Species - ",
+ )
#
# evaluate BatDetect
#
- if args['bd_model_path'] != '':
+ if args["bd_model_path"] != "":
# load model
bd_args = du.get_default_bd_args()
- model, params_bd = du.load_model(args['bd_model_path'])
+ model, params_bd = du.load_model(args["bd_model_path"])
# check if the class names are the same
- if params_bd['class_names'] != class_names:
- print('Warning: Class names are not the same as the trained model')
+ if params_bd["class_names"] != class_names:
+ print("Warning: Class names are not the same as the trained model")
assert False
preds_bd = []
for ii, gg in enumerate(gt_test):
- pred = du.process_file(gg['file_path'], model, params_bd, bd_args, return_raw_preds=True)
+ pred = du.process_file(
+ gg["file_path"],
+ model,
+ params_bd,
+ bd_args,
+ return_raw_preds=True,
+ )
preds_bd.append(pred)
- results_bd = evl.evaluate_predictions(gt_test, preds_bd, class_names,
- params_eval['detection_overlap'], params_eval['ignore_start_end'])
- print_results('BatDetect', 'bd', results_bd, args['op_dir'],
- class_names, args['file_type'], args['title_text'] + ' - Species - ')
+ results_bd = evl.evaluate_predictions(
+ gt_test,
+ preds_bd,
+ class_names,
+ params_eval["detection_overlap"],
+ params_eval["ignore_start_end"],
+ )
+ print_results(
+ "BatDetect",
+ "bd",
+ results_bd,
+ args["op_dir"],
+ class_names,
+ args["file_type"],
+ args["title_text"] + " - Species - ",
+ )
# evaluate genus level
- class_names_genus, preds_bd_g, gt_test_g = create_genus_mapping(gt_test, preds_bd, class_names)
- results_bd_genus = evl.evaluate_predictions(gt_test_g, preds_bd_g, class_names_genus,
- params_eval['detection_overlap'], params_eval['ignore_start_end'])
- print_results('BatDetect Genus', 'bd_genus', results_bd_genus, args['op_dir'],
- class_names_genus, args['file_type'], args['title_text'] + ' - Genus - ')
+ class_names_genus, preds_bd_g, gt_test_g = create_genus_mapping(
+ gt_test, preds_bd, class_names
+ )
+ results_bd_genus = evl.evaluate_predictions(
+ gt_test_g,
+ preds_bd_g,
+ class_names_genus,
+ params_eval["detection_overlap"],
+ params_eval["ignore_start_end"],
+ )
+ print_results(
+ "BatDetect Genus",
+ "bd_genus",
+ results_bd_genus,
+ args["op_dir"],
+ class_names_genus,
+ args["file_type"],
+ args["title_text"] + " - Genus - ",
+ )
diff --git a/bat_detect/finetune/finetune_model.py b/bat_detect/finetune/finetune_model.py
index 4fecc48..8c20e22 100644
--- a/bat_detect/finetune/finetune_model.py
+++ b/bat_detect/finetune/finetune_model.py
@@ -1,183 +1,325 @@
-import numpy as np
-import matplotlib.pyplot as plt
+import argparse
+import glob
+import json
import os
+import sys
+
+import matplotlib.pyplot as plt
+import numpy as np
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
-import json
-import argparse
-import glob
-import sys
-sys.path.append(os.path.join('..', '..'))
-import bat_detect.train.train_model as tm
+sys.path.append(os.path.join("..", ".."))
+import bat_detect.detector.models as models
+import bat_detect.detector.parameters as parameters
+import bat_detect.detector.post_process as pp
import bat_detect.train.audio_dataloader as adl
import bat_detect.train.evaluate as evl
-import bat_detect.train.train_utils as tu
import bat_detect.train.losses as losses
-
-import bat_detect.detector.parameters as parameters
-import bat_detect.detector.models as models
-import bat_detect.detector.post_process as pp
-import bat_detect.utils.plot_utils as pu
+import bat_detect.train.train_model as tm
+import bat_detect.train.train_utils as tu
import bat_detect.utils.detector_utils as du
-
+import bat_detect.utils.plot_utils as pu
if __name__ == "__main__":
- info_str = '\nBatDetect - Finetune Model\n'
+ info_str = "\nBatDetect - Finetune Model\n"
print(info_str)
parser = argparse.ArgumentParser()
- parser.add_argument('audio_path', type=str, help='Input directory for audio')
- parser.add_argument('train_ann_path', type=str,
- help='Path to where train annotation file is stored')
- parser.add_argument('test_ann_path', type=str,
- help='Path to where test annotation file is stored')
- parser.add_argument('model_path', type=str,
- help='Path to pretrained model')
- parser.add_argument('--op_model_name', type=str, default='',
- help='Path and name for finetuned model')
- parser.add_argument('--num_epochs', type=int, default=200, dest='num_epochs',
- help='Number of finetuning epochs')
- parser.add_argument('--finetune_only_last_layer', action='store_true',
- help='Only train final layers')
- parser.add_argument('--train_from_scratch', action='store_true',
- help='Do not use pretrained weights')
- parser.add_argument('--do_not_save_images', action='store_false',
- help='Do not save images at the end of training')
- parser.add_argument('--notes', type=str, default='',
- help='Notes to save in text file')
+ parser.add_argument(
+ "audio_path", type=str, help="Input directory for audio"
+ )
+ parser.add_argument(
+ "train_ann_path",
+ type=str,
+ help="Path to where train annotation file is stored",
+ )
+ parser.add_argument(
+ "test_ann_path",
+ type=str,
+ help="Path to where test annotation file is stored",
+ )
+ parser.add_argument(
+ "model_path", type=str, help="Path to pretrained model"
+ )
+ parser.add_argument(
+ "--op_model_name",
+ type=str,
+ default="",
+ help="Path and name for finetuned model",
+ )
+ parser.add_argument(
+ "--num_epochs",
+ type=int,
+ default=200,
+ dest="num_epochs",
+ help="Number of finetuning epochs",
+ )
+ parser.add_argument(
+ "--finetune_only_last_layer",
+ action="store_true",
+ help="Only train final layers",
+ )
+ parser.add_argument(
+ "--train_from_scratch",
+ action="store_true",
+ help="Do not use pretrained weights",
+ )
+ parser.add_argument(
+ "--do_not_save_images",
+ action="store_false",
+ help="Do not save images at the end of training",
+ )
+ parser.add_argument(
+ "--notes", type=str, default="", help="Notes to save in text file"
+ )
args = vars(parser.parse_args())
- params = parameters.get_params(True, '../../experiments/')
+ params = parameters.get_params(True, "../../experiments/")
if torch.cuda.is_available():
- params['device'] = 'cuda'
+ params["device"] = "cuda"
else:
- params['device'] = 'cpu'
- print('\nNote, this will be a lot faster if you use computer with a GPU.\n')
+ params["device"] = "cpu"
+ print(
+ "\nNote, this will be a lot faster if you use computer with a GPU.\n"
+ )
- print('\nAudio directory: ' + args['audio_path'])
- print('Train file: ' + args['train_ann_path'])
- print('Test file: ' + args['test_ann_path'])
- print('Loading model: ' + args['model_path'])
+ print("\nAudio directory: " + args["audio_path"])
+ print("Train file: " + args["train_ann_path"])
+ print("Test file: " + args["test_ann_path"])
+ print("Loading model: " + args["model_path"])
- dataset_name = os.path.basename(args['train_ann_path']).replace('.json', '').replace('_TRAIN', '')
+ dataset_name = (
+ os.path.basename(args["train_ann_path"])
+ .replace(".json", "")
+ .replace("_TRAIN", "")
+ )
- if args['train_from_scratch']:
- print('\nTraining model from scratch i.e. not using pretrained weights')
- model, params_train = du.load_model(args['model_path'], False)
+ if args["train_from_scratch"]:
+ print(
+ "\nTraining model from scratch i.e. not using pretrained weights"
+ )
+ model, params_train = du.load_model(args["model_path"], False)
else:
- model, params_train = du.load_model(args['model_path'], True)
- model.to(params['device'])
+ model, params_train = du.load_model(args["model_path"], True)
+ model.to(params["device"])
- params['num_epochs'] = args['num_epochs']
- if args['op_model_name'] != '':
- params['model_file_name'] = args['op_model_name']
- classes_to_ignore = params['classes_to_ignore']+params['generic_class']
+ params["num_epochs"] = args["num_epochs"]
+ if args["op_model_name"] != "":
+ params["model_file_name"] = args["op_model_name"]
+ classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
# save notes file
- params['notes'] = args['notes']
- if args['notes'] != '':
- tu.write_notes_file(params['experiment'] + 'notes.txt', args['notes'])
-
+ params["notes"] = args["notes"]
+ if args["notes"] != "":
+ tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"])
# load train annotations
train_sets = []
- train_sets.append(tu.get_blank_dataset_dict(dataset_name, False, args['train_ann_path'], args['audio_path']))
- params['train_sets'] = [tu.get_blank_dataset_dict(dataset_name, False, os.path.basename(args['train_ann_path']), args['audio_path'])]
+ train_sets.append(
+ tu.get_blank_dataset_dict(
+ dataset_name, False, args["train_ann_path"], args["audio_path"]
+ )
+ )
+ params["train_sets"] = [
+ tu.get_blank_dataset_dict(
+ dataset_name,
+ False,
+ os.path.basename(args["train_ann_path"]),
+ args["audio_path"],
+ )
+ ]
- print('\nTrain set:')
- data_train, params['class_names'], params['class_inv_freq'] = \
- tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'])
- print('Number of files', len(data_train))
+ print("\nTrain set:")
+ (
+ data_train,
+ params["class_names"],
+ params["class_inv_freq"],
+ ) = tu.load_set_of_anns(
+ train_sets, classes_to_ignore, params["events_of_interest"]
+ )
+ print("Number of files", len(data_train))
- params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names'])
- params['class_names_short'] = tu.get_short_class_names(params['class_names'])
+ params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
+ params["class_names"]
+ )
+ params["class_names_short"] = tu.get_short_class_names(
+ params["class_names"]
+ )
# load test annotations
test_sets = []
- test_sets.append(tu.get_blank_dataset_dict(dataset_name, True, args['test_ann_path'], args['audio_path']))
- params['test_sets'] = [tu.get_blank_dataset_dict(dataset_name, True, os.path.basename(args['test_ann_path']), args['audio_path'])]
+ test_sets.append(
+ tu.get_blank_dataset_dict(
+ dataset_name, True, args["test_ann_path"], args["audio_path"]
+ )
+ )
+ params["test_sets"] = [
+ tu.get_blank_dataset_dict(
+ dataset_name,
+ True,
+ os.path.basename(args["test_ann_path"]),
+ args["audio_path"],
+ )
+ ]
- print('\nTest set:')
- data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'])
- print('Number of files', len(data_test))
+ print("\nTest set:")
+ data_test, _, _ = tu.load_set_of_anns(
+ test_sets, classes_to_ignore, params["events_of_interest"]
+ )
+ print("Number of files", len(data_test))
# train loader
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'],
- shuffle=True, num_workers=params['num_workers'], pin_memory=True)
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=params["batch_size"],
+ shuffle=True,
+ num_workers=params["num_workers"],
+ pin_memory=True,
+ )
# test loader - batch size of one because of variable file length
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
- test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
- shuffle=False, num_workers=params['num_workers'], pin_memory=True)
+ test_loader = torch.utils.data.DataLoader(
+ test_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=params["num_workers"],
+ pin_memory=True,
+ )
inputs_train = next(iter(train_loader))
- params['ip_height'] = inputs_train['spec'].shape[2]
- print('\ntrain batch size :', inputs_train['spec'].shape)
+ params["ip_height"] = inputs_train["spec"].shape[2]
+ print("\ntrain batch size :", inputs_train["spec"].shape)
- assert(params_train['model_name'] == 'Net2DFast')
- print('\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n')
+ assert params_train["model_name"] == "Net2DFast"
+ print(
+ "\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n"
+ )
# set the number of output classes
num_filts = model.conv_classes_op.in_channels
k_size = model.conv_classes_op.kernel_size
pad = model.conv_classes_op.padding
- model.conv_classes_op = torch.nn.Conv2d(num_filts, len(params['class_names'])+1, kernel_size=k_size, padding=pad)
- model.conv_classes_op.to(params['device'])
+ model.conv_classes_op = torch.nn.Conv2d(
+ num_filts,
+ len(params["class_names"]) + 1,
+ kernel_size=k_size,
+ padding=pad,
+ )
+ model.conv_classes_op.to(params["device"])
- if args['finetune_only_last_layer']:
- print('\nOnly finetuning the final layers.\n')
- train_layers_i = ['conv_classes', 'conv_classes_op', 'conv_size', 'conv_size_op']
- train_layers = [tt + '.weight' for tt in train_layers_i] + [tt + '.bias' for tt in train_layers_i]
+ if args["finetune_only_last_layer"]:
+ print("\nOnly finetuning the final layers.\n")
+ train_layers_i = [
+ "conv_classes",
+ "conv_classes_op",
+ "conv_size",
+ "conv_size_op",
+ ]
+ train_layers = [tt + ".weight" for tt in train_layers_i] + [
+ tt + ".bias" for tt in train_layers_i
+ ]
for name, param in model.named_parameters():
if name in train_layers:
param.requires_grad = True
else:
param.requires_grad = False
- optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
- scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader))
- if params['train_loss'] == 'mse':
+ optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
+ scheduler = CosineAnnealingLR(
+ optimizer, params["num_epochs"] * len(train_loader)
+ )
+ if params["train_loss"] == "mse":
det_criterion = losses.mse_loss
- elif params['train_loss'] == 'focal':
+ elif params["train_loss"] == "focal":
det_criterion = losses.focal_loss
# plotting
- train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1,
- ['train_loss'], None, None, ['epoch', 'train_loss'], logy=True)
- test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1,
- ['test_loss'], None, None, ['epoch', 'test_loss'], logy=True)
- test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1,
- ['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', ''])
- test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1,
- params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec'])
+ train_plt_ls = pu.LossPlotter(
+ params["experiment"] + "train_loss.png",
+ params["num_epochs"] + 1,
+ ["train_loss"],
+ None,
+ None,
+ ["epoch", "train_loss"],
+ logy=True,
+ )
+ test_plt_ls = pu.LossPlotter(
+ params["experiment"] + "test_loss.png",
+ params["num_epochs"] + 1,
+ ["test_loss"],
+ None,
+ None,
+ ["epoch", "test_loss"],
+ logy=True,
+ )
+ test_plt = pu.LossPlotter(
+ params["experiment"] + "test.png",
+ params["num_epochs"] + 1,
+ ["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
+ [0, 1],
+ None,
+ ["epoch", ""],
+ )
+ test_plt_class = pu.LossPlotter(
+ params["experiment"] + "test_avg_prec.png",
+ params["num_epochs"] + 1,
+ params["class_names_short"],
+ [0, 1],
+ params["class_names_short"],
+ ["epoch", "avg_prec"],
+ )
# main train loop
- for epoch in range(0, params['num_epochs']+1):
+ for epoch in range(0, params["num_epochs"] + 1):
- train_loss = tm.train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params)
- train_plt_ls.update_and_save(epoch, [train_loss['train_loss']])
+ train_loss = tm.train(
+ model,
+ epoch,
+ train_loader,
+ det_criterion,
+ optimizer,
+ scheduler,
+ params,
+ )
+ train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
- if epoch % params['num_eval_epochs'] == 0:
+ if epoch % params["num_eval_epochs"] == 0:
# detection accuracy on test set
- test_res, test_loss = tm.test(model, epoch, test_loader, det_criterion, params)
- test_plt_ls.update_and_save(epoch, [test_loss['test_loss']])
- test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'],
- test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']])
- test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']])
- pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res)
+ test_res, test_loss = tm.test(
+ model, epoch, test_loader, det_criterion, params
+ )
+ test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
+ test_plt.update_and_save(
+ epoch,
+ [
+ test_res["avg_prec"],
+ test_res["rec_at_x"],
+ test_res["avg_prec_class"],
+ test_res["file_acc"],
+ test_res["top_class"]["avg_prec"],
+ ],
+ )
+ test_plt_class.update_and_save(
+ epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
+ )
+ pu.plot_pr_curve_class(
+ params["experiment"], "test_pr", "test_pr", test_res
+ )
# save finetuned model
- print('saving model to: ' + params['model_file_name'])
- op_state = {'epoch': epoch + 1,
- 'state_dict': model.state_dict(),
- 'params' : params}
- torch.save(op_state, params['model_file_name'])
-
+ print("saving model to: " + params["model_file_name"])
+ op_state = {
+ "epoch": epoch + 1,
+ "state_dict": model.state_dict(),
+ "params": params,
+ }
+ torch.save(op_state, params["model_file_name"])
# save an image with associated prediction for each batch in the test set
- if not args['do_not_save_images']:
+ if not args["do_not_save_images"]:
tm.save_images_batch(model, test_loader, params)
diff --git a/bat_detect/finetune/prep_data_finetune.py b/bat_detect/finetune/prep_data_finetune.py
index 3e86cd4..bf86e97 100644
--- a/bat_detect/finetune/prep_data_finetune.py
+++ b/bat_detect/finetune/prep_data_finetune.py
@@ -1,32 +1,33 @@
-import numpy as np
import argparse
-import os
import json
-
+import os
import sys
-sys.path.append(os.path.join('..', '..'))
+
+import numpy as np
+
+sys.path.append(os.path.join("..", ".."))
import bat_detect.train.train_utils as tu
def print_dataset_stats(data, split_name, classes_to_ignore):
- print('\nSplit:', split_name)
- print('Num files:', len(data))
+ print("\nSplit:", split_name)
+ print("Num files:", len(data))
class_cnts = {}
for dd in data:
- for aa in dd['annotation']:
- if aa['class'] not in classes_to_ignore:
- if aa['class'] in class_cnts:
- class_cnts[aa['class']] += 1
+ for aa in dd["annotation"]:
+ if aa["class"] not in classes_to_ignore:
+ if aa["class"] in class_cnts:
+ class_cnts[aa["class"]] += 1
else:
- class_cnts[aa['class']] = 1
+ class_cnts[aa["class"]] = 1
if len(class_cnts) == 0:
class_names = []
else:
class_names = np.sort([*class_cnts]).tolist()
- print('Class count:')
+ print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5
for ii, cc in enumerate(class_names):
@@ -41,111 +42,169 @@ def load_file_names(file_name):
with open(file_name) as da:
files = [line.rstrip() for line in da.readlines()]
for ff in files:
- if ff.lower()[-3:] != 'wav':
- print('Error: Filenames need to end in .wav - ', ff)
- assert(False)
+ if ff.lower()[-3:] != "wav":
+ print("Error: Filenames need to end in .wav - ", ff)
+ assert False
else:
- print('Error: Input file not found - ', file_name)
- assert(False)
+ print("Error: Input file not found - ", file_name)
+ assert False
return files
if __name__ == "__main__":
- info_str = '\nBatDetect - Prepare Data for Finetuning\n'
+ info_str = "\nBatDetect - Prepare Data for Finetuning\n"
print(info_str)
parser = argparse.ArgumentParser()
- parser.add_argument('dataset_name', type=str, help='Name to call your dataset')
- parser.add_argument('audio_dir', type=str, help='Input directory for audio')
- parser.add_argument('ann_dir', type=str, help='Input directory for where the audio annotations are stored')
- parser.add_argument('op_dir', type=str, help='Path where the train and test splits will be stored')
- parser.add_argument('--percent_val', type=float, default=0.20,
- help='Hold out this much data for validation. Should be number between 0 and 1')
- parser.add_argument('--rand_seed', type=int, default=2001,
- help='Random seed used for creating the validation split')
- parser.add_argument('--train_file', type=str, default='',
- help='Text file where each line is a wav file in train split')
- parser.add_argument('--test_file', type=str, default='',
- help='Text file where each line is a wav file in test split')
- parser.add_argument('--input_class_names', type=str, default='',
- help='Specify names of classes that you want to change. Separate with ";"')
- parser.add_argument('--output_class_names', type=str, default='',
- help='New class names to use instead. One to one mapping with "--input_class_names". \
- Separate with ";"')
+ parser.add_argument(
+ "dataset_name", type=str, help="Name to call your dataset"
+ )
+ parser.add_argument(
+ "audio_dir", type=str, help="Input directory for audio"
+ )
+ parser.add_argument(
+ "ann_dir",
+ type=str,
+ help="Input directory for where the audio annotations are stored",
+ )
+ parser.add_argument(
+ "op_dir",
+ type=str,
+ help="Path where the train and test splits will be stored",
+ )
+ parser.add_argument(
+ "--percent_val",
+ type=float,
+ default=0.20,
+ help="Hold out this much data for validation. Should be number between 0 and 1",
+ )
+ parser.add_argument(
+ "--rand_seed",
+ type=int,
+ default=2001,
+ help="Random seed used for creating the validation split",
+ )
+ parser.add_argument(
+ "--train_file",
+ type=str,
+ default="",
+ help="Text file where each line is a wav file in train split",
+ )
+ parser.add_argument(
+ "--test_file",
+ type=str,
+ default="",
+ help="Text file where each line is a wav file in test split",
+ )
+ parser.add_argument(
+ "--input_class_names",
+ type=str,
+ default="",
+ help='Specify names of classes that you want to change. Separate with ";"',
+ )
+ parser.add_argument(
+ "--output_class_names",
+ type=str,
+ default="",
+ help='New class names to use instead. One to one mapping with "--input_class_names". \
+ Separate with ";"',
+ )
args = vars(parser.parse_args())
+ np.random.seed(args["rand_seed"])
- np.random.seed(args['rand_seed'])
+ classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
+ generic_class = ["Bat"]
+ events_of_interest = ["Echolocation"]
- classes_to_ignore = ['', ' ', 'Unknown', 'Not Bat']
- generic_class = ['Bat']
- events_of_interest = ['Echolocation']
-
- if args['input_class_names'] != '' and args['output_class_names'] != '':
+ if args["input_class_names"] != "" and args["output_class_names"] != "":
# change the names of the classes
- ip_names = args['input_class_names'].split(';')
- op_names = args['output_class_names'].split(';')
+ ip_names = args["input_class_names"].split(";")
+ op_names = args["output_class_names"].split(";")
name_dict = dict(zip(ip_names, op_names))
else:
name_dict = False
# load annotations
- data_all, _, _ = tu.load_set_of_anns({'ann_path': args['ann_dir'], 'wav_path': args['audio_dir']},
- classes_to_ignore, events_of_interest, False, False,
- list_of_anns=True, filter_issues=True, name_replace=name_dict)
+ data_all, _, _ = tu.load_set_of_anns(
+ {"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]},
+ classes_to_ignore,
+ events_of_interest,
+ False,
+ False,
+ list_of_anns=True,
+ filter_issues=True,
+ name_replace=name_dict,
+ )
- print('Dataset name: ' + args['dataset_name'])
- print('Audio directory: ' + args['audio_dir'])
- print('Annotation directory: ' + args['ann_dir'])
- print('Ouput directory: ' + args['op_dir'])
- print('Num annotated files: ' + str(len(data_all)))
+ print("Dataset name: " + args["dataset_name"])
+ print("Audio directory: " + args["audio_dir"])
+ print("Annotation directory: " + args["ann_dir"])
+ print("Ouput directory: " + args["op_dir"])
+ print("Num annotated files: " + str(len(data_all)))
- if args['train_file'] != '' and args['test_file'] != '':
+ if args["train_file"] != "" and args["test_file"] != "":
# user has specifed the train / test split
- train_files = load_file_names(args['train_file'])
- test_files = load_file_names(args['test_file'])
- file_names_all = [dd['id'] for dd in data_all]
- train_inds = [file_names_all.index(ff) for ff in train_files if ff in file_names_all]
- test_inds = [file_names_all.index(ff) for ff in test_files if ff in file_names_all]
+ train_files = load_file_names(args["train_file"])
+ test_files = load_file_names(args["test_file"])
+ file_names_all = [dd["id"] for dd in data_all]
+ train_inds = [
+ file_names_all.index(ff)
+ for ff in train_files
+ if ff in file_names_all
+ ]
+ test_inds = [
+ file_names_all.index(ff)
+ for ff in test_files
+ if ff in file_names_all
+ ]
else:
# split the data into train and test at the file level
num_exs = len(data_all)
- test_inds = np.random.choice(np.arange(num_exs), int(num_exs*args['percent_val']), replace=False)
+ test_inds = np.random.choice(
+ np.arange(num_exs),
+ int(num_exs * args["percent_val"]),
+ replace=False,
+ )
test_inds = np.sort(test_inds)
train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
data_train = [data_all[ii] for ii in train_inds]
data_test = [data_all[ii] for ii in test_inds]
- if not os.path.isdir(args['op_dir']):
- os.makedirs(args['op_dir'])
- op_name = os.path.join(args['op_dir'], args['dataset_name'])
- op_name_train = op_name + '_TRAIN.json'
- op_name_test = op_name + '_TEST.json'
+ if not os.path.isdir(args["op_dir"]):
+ os.makedirs(args["op_dir"])
+ op_name = os.path.join(args["op_dir"], args["dataset_name"])
+ op_name_train = op_name + "_TRAIN.json"
+ op_name_test = op_name + "_TEST.json"
- class_un_train = print_dataset_stats(data_train, 'Train', classes_to_ignore)
- class_un_test = print_dataset_stats(data_test, 'Test', classes_to_ignore)
+ class_un_train = print_dataset_stats(
+ data_train, "Train", classes_to_ignore
+ )
+ class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore)
if len(data_train) > 0 and len(data_test) > 0:
if class_un_train != class_un_test:
- print('\nError: some classes are not in both the training and test sets.\
- \nTry a different random seed "--rand_seed".')
+ print(
+ '\nError: some classes are not in both the training and test sets.\
+ \nTry a different random seed "--rand_seed".'
+ )
assert False
- print('\n')
+ print("\n")
if len(data_train) == 0:
- print('No train annotations to save')
+ print("No train annotations to save")
else:
- print('Saving: ', op_name_train)
- with open(op_name_train, 'w') as da:
+ print("Saving: ", op_name_train)
+ with open(op_name_train, "w") as da:
json.dump(data_train, da, indent=2)
if len(data_test) == 0:
- print('No test annotations to save')
+ print("No test annotations to save")
else:
- print('Saving: ', op_name_test)
- with open(op_name_test, 'w') as da:
+ print("Saving: ", op_name_test)
+ with open(op_name_test, "w") as da:
json.dump(data_test, da, indent=2)
diff --git a/bat_detect/train/audio_dataloader.py b/bat_detect/train/audio_dataloader.py
index a36ec0b..ffd8086 100644
--- a/bat_detect/train/audio_dataloader.py
+++ b/bat_detect/train/audio_dataloader.py
@@ -1,71 +1,101 @@
-import torch
-import random
-import numpy as np
import copy
+import os
+import random
+import sys
+
import librosa
+import numpy as np
+import torch
import torch.nn.functional as F
import torchaudio
-import os
-import sys
-sys.path.append(os.path.join('..', '..'))
+sys.path.append(os.path.join("..", ".."))
import bat_detect.utils.audio_utils as au
def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
# spec may be resized on input into the network
- num_classes = len(params['class_names'])
- op_height = spec_op_shape[0]
- op_width = spec_op_shape[1]
- freq_per_bin = (params['max_freq'] - params['min_freq']) / op_height
+ num_classes = len(params["class_names"])
+ op_height = spec_op_shape[0]
+ op_width = spec_op_shape[1]
+ freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height
# start and end times
- x_pos_start = au.time_to_x_coords(ann['start_times'], sampling_rate,
- params['fft_win_length'], params['fft_overlap'])
- x_pos_start = (params['resize_factor']*x_pos_start).astype(np.int)
- x_pos_end = au.time_to_x_coords(ann['end_times'], sampling_rate,
- params['fft_win_length'], params['fft_overlap'])
- x_pos_end = (params['resize_factor']*x_pos_end).astype(np.int)
+ x_pos_start = au.time_to_x_coords(
+ ann["start_times"],
+ sampling_rate,
+ params["fft_win_length"],
+ params["fft_overlap"],
+ )
+ x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int)
+ x_pos_end = au.time_to_x_coords(
+ ann["end_times"],
+ sampling_rate,
+ params["fft_win_length"],
+ params["fft_overlap"],
+ )
+ x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int)
# location on y axis i.e. frequency
- y_pos_low = (ann['low_freqs'] - params['min_freq']) / freq_per_bin
- y_pos_low = (op_height - y_pos_low).astype(np.int)
- y_pos_high = (ann['high_freqs'] - params['min_freq']) / freq_per_bin
+ y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
+ y_pos_low = (op_height - y_pos_low).astype(np.int)
+ y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
y_pos_high = (op_height - y_pos_high).astype(np.int)
- bb_widths = x_pos_end - x_pos_start
- bb_heights = (y_pos_low - y_pos_high)
+ bb_widths = x_pos_end - x_pos_start
+ bb_heights = y_pos_low - y_pos_high
- valid_inds = np.where((x_pos_start >= 0) & (x_pos_start < op_width) &
- (y_pos_low >= 0) & (y_pos_low < (op_height-1)))[0]
+ valid_inds = np.where(
+ (x_pos_start >= 0)
+ & (x_pos_start < op_width)
+ & (y_pos_low >= 0)
+ & (y_pos_low < (op_height - 1))
+ )[0]
ann_aug = {}
- ann_aug['x_inds'] = x_pos_start[valid_inds]
- ann_aug['y_inds'] = y_pos_low[valid_inds]
- keys = ['start_times', 'end_times', 'high_freqs', 'low_freqs', 'class_ids', 'individual_ids']
+ ann_aug["x_inds"] = x_pos_start[valid_inds]
+ ann_aug["y_inds"] = y_pos_low[valid_inds]
+ keys = [
+ "start_times",
+ "end_times",
+ "high_freqs",
+ "low_freqs",
+ "class_ids",
+ "individual_ids",
+ ]
for kk in keys:
ann_aug[kk] = ann[kk][valid_inds]
# if the number of calls is only 1, then it is unique
# TODO would be better if we found these unique calls at the merging stage
- if len(ann_aug['individual_ids']) == 1:
- ann_aug['individual_ids'][0] = 0
+ if len(ann_aug["individual_ids"]) == 1:
+ ann_aug["individual_ids"][0] = 0
- y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
+ y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
# num classes and "background" class
- y_2d_classes = np.zeros((num_classes+1, op_height, op_width), dtype=np.float32)
+ y_2d_classes = np.zeros(
+ (num_classes + 1, op_height, op_width), dtype=np.float32
+ )
# create 2D ground truth heatmaps
for ii in valid_inds:
- draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'])
- #draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
+ draw_gaussian(
+ y_2d_det[0, :],
+ (x_pos_start[ii], y_pos_low[ii]),
+ params["target_sigma"],
+ )
+ # draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
- cls_id = ann['class_ids'][ii]
+ cls_id = ann["class_ids"][ii]
if cls_id > -1:
- draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'])
- #draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
+ draw_gaussian(
+ y_2d_classes[cls_id, :],
+ (x_pos_start[ii], y_pos_low[ii]),
+ params["target_sigma"],
+ )
+ # draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
# be careful as this will have a 1.0 places where we have event but dont know gt class
# this will be masked in training anyway
@@ -96,20 +126,24 @@ def draw_gaussian(heatmap, center, sigmax, sigmay=None):
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
- #g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
- g = np.exp(- ((x - x0) ** 2)/(2 * sigmax ** 2) - ((y - y0) ** 2)/(2 * sigmay ** 2))
+ # g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
+ g = np.exp(
+ -((x - x0) ** 2) / (2 * sigmax**2)
+ - ((y - y0) ** 2) / (2 * sigmay**2)
+ )
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
img_x = max(0, ul[0]), min(br[0], h)
img_y = max(0, ul[1]), min(br[1], w)
- heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum(
- heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]],
- g[g_y[0]:g_y[1], g_x[0]:g_x[1]])
+ heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
+ heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
+ g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
+ )
return True
def pad_aray(ip_array, pad_size):
- return np.hstack((ip_array, np.ones(pad_size, dtype=np.int)*-1))
+ return np.hstack((ip_array, np.ones(pad_size, dtype=np.int) * -1))
def warp_spec_aug(spec, ann, return_spec_for_viz, params):
@@ -121,24 +155,37 @@ def warp_spec_aug(spec, ann, return_spec_for_viz, params):
if return_spec_for_viz:
assert False
- delta = params['stretch_squeeze_delta']
+ delta = params["stretch_squeeze_delta"]
op_size = (spec.shape[1], spec.shape[2])
- resize_fract_r = np.random.rand()*delta*2 - delta + 1.0
- resize_amt = int(spec.shape[2]*resize_fract_r)
+ resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
+ resize_amt = int(spec.shape[2] * resize_fract_r)
if resize_amt >= spec.shape[2]:
- spec_r = torch.cat((spec, torch.zeros((1, spec.shape[1], resize_amt-spec.shape[2]), dtype=spec.dtype)), 2)
+ spec_r = torch.cat(
+ (
+ spec,
+ torch.zeros(
+ (1, spec.shape[1], resize_amt - spec.shape[2]),
+ dtype=spec.dtype,
+ ),
+ ),
+ 2,
+ )
else:
spec_r = spec[:, :, :resize_amt]
- spec = F.interpolate(spec_r.unsqueeze(0), size=op_size, mode='bilinear', align_corners=False).squeeze(0)
- ann['start_times'] *= (1.0/resize_fract_r)
- ann['end_times'] *= (1.0/resize_fract_r)
+ spec = F.interpolate(
+ spec_r.unsqueeze(0), size=op_size, mode="bilinear", align_corners=False
+ ).squeeze(0)
+ ann["start_times"] *= 1.0 / resize_fract_r
+ ann["end_times"] *= 1.0 / resize_fract_r
return spec
def mask_time_aug(spec, params):
# Mask out a random block of time - repeat up to 3 times
# SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition
- fm = torchaudio.transforms.TimeMasking(int(spec.shape[1]*params['mask_max_time_perc']))
+ fm = torchaudio.transforms.TimeMasking(
+ int(spec.shape[1] * params["mask_max_time_perc"])
+ )
for ii in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
@@ -147,40 +194,59 @@ def mask_time_aug(spec, params):
def mask_freq_aug(spec, params):
# Mask out a random frequncy range - repeat up to 3 times
# SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
- fm = torchaudio.transforms.FrequencyMasking(int(spec.shape[1]*params['mask_max_freq_perc']))
+ fm = torchaudio.transforms.FrequencyMasking(
+ int(spec.shape[1] * params["mask_max_freq_perc"])
+ )
for ii in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def scale_vol_aug(spec, params):
- return spec * np.random.random()*params['spec_amp_scaling']
+ return spec * np.random.random() * params["spec_amp_scaling"]
def echo_aug(audio, sampling_rate, params):
- sample_offset = int(params['echo_max_delay']*np.random.random()*sampling_rate) + 1
- audio[:-sample_offset] += np.random.random()*audio[sample_offset:]
+ sample_offset = (
+ int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
+ )
+ audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
return audio
def resample_aug(audio, sampling_rate, params):
sampling_rate_old = sampling_rate
- sampling_rate = np.random.choice(params['aug_sampling_rates'])
- audio = librosa.resample(audio, sampling_rate_old, sampling_rate, res_type='polyphase')
+ sampling_rate = np.random.choice(params["aug_sampling_rates"])
+ audio = librosa.resample(
+ audio, sampling_rate_old, sampling_rate, res_type="polyphase"
+ )
- audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
- params['fft_overlap'], params['resize_factor'],
- params['spec_divide_factor'], params['spec_train_width'])
+ audio = au.pad_audio(
+ audio,
+ sampling_rate,
+ params["fft_win_length"],
+ params["fft_overlap"],
+ params["resize_factor"],
+ params["spec_divide_factor"],
+ params["spec_train_width"],
+ )
duration = audio.shape[0] / float(sampling_rate)
return audio, sampling_rate, duration
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
if sampling_rate != sampling_rate2:
- audio2 = librosa.resample(audio2, sampling_rate2, sampling_rate, res_type='polyphase')
+ audio2 = librosa.resample(
+ audio2, sampling_rate2, sampling_rate, res_type="polyphase"
+ )
sampling_rate2 = sampling_rate
if audio2.shape[0] < num_samples:
- audio2 = np.hstack((audio2, np.zeros((num_samples-audio2.shape[0]), dtype=audio2.dtype)))
+ audio2 = np.hstack(
+ (
+ audio2,
+ np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
+ )
+ )
elif audio2.shape[0] > num_samples:
audio2 = audio2[:num_samples]
return audio2, sampling_rate2
@@ -189,33 +255,43 @@ def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
# resample so they are the same
- audio2, sampling_rate2 = resample_audio(audio.shape[0], sampling_rate, audio2, sampling_rate2)
+ audio2, sampling_rate2 = resample_audio(
+ audio.shape[0], sampling_rate, audio2, sampling_rate2
+ )
# # set mean and std to be the same
# audio2 = (audio2 - audio2.mean())
# audio2 = (audio2/audio2.std())*audio.std()
# audio2 = audio2 + audio.mean()
- if ann['annotated'] and (ann2['annotated']) and \
- (sampling_rate2 == sampling_rate) and (audio.shape[0] == audio2.shape[0]):
- comb_weight = 0.3 + np.random.random()*0.4
- audio = comb_weight*audio + (1-comb_weight)*audio2
- inds = np.argsort(np.hstack((ann['start_times'], ann2['start_times'])))
+ if (
+ ann["annotated"]
+ and (ann2["annotated"])
+ and (sampling_rate2 == sampling_rate)
+ and (audio.shape[0] == audio2.shape[0])
+ ):
+ comb_weight = 0.3 + np.random.random() * 0.4
+ audio = comb_weight * audio + (1 - comb_weight) * audio2
+ inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
for kk in ann.keys():
# when combining calls from different files, assume they come from different individuals
- if kk == 'individual_ids':
- if (ann[kk]>-1).sum() > 0:
- ann2[kk][ann2[kk]>-1] += np.max(ann[kk][ann[kk]>-1]) + 1
+ if kk == "individual_ids":
+ if (ann[kk] > -1).sum() > 0:
+ ann2[kk][ann2[kk] > -1] += (
+ np.max(ann[kk][ann[kk] > -1]) + 1
+ )
- if (kk != 'class_id_file') and (kk != 'annotated'):
+ if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
return audio, ann
class AudioLoader(torch.utils.data.Dataset):
- def __init__(self, data_anns_ip, params, dataset_name=None, is_train=False):
+ def __init__(
+ self, data_anns_ip, params, dataset_name=None, is_train=False
+ ):
self.data_anns = []
self.is_train = is_train
@@ -227,53 +303,70 @@ class AudioLoader(torch.utils.data.Dataset):
# filter out unused annotation here
filtered_annotations = []
- for ii, aa in enumerate(dd['annotation']):
+ for ii, aa in enumerate(dd["annotation"]):
- if 'individual' in aa.keys():
- aa['individual'] = int(aa['individual'])
+ if "individual" in aa.keys():
+ aa["individual"] = int(aa["individual"])
# if only one call labeled it has to be from the same individual
- if len(dd['annotation']) == 1:
- aa['individual'] = 0
+ if len(dd["annotation"]) == 1:
+ aa["individual"] = 0
# convert class name into class label
- if aa['class'] in self.params['class_names']:
- aa['class_id'] = self.params['class_names'].index(aa['class'])
+ if aa["class"] in self.params["class_names"]:
+ aa["class_id"] = self.params["class_names"].index(
+ aa["class"]
+ )
else:
- aa['class_id'] = -1
+ aa["class_id"] = -1
- if aa['class'] not in self.params['classes_to_ignore']:
+ if aa["class"] not in self.params["classes_to_ignore"]:
filtered_annotations.append(aa)
- dd['annotation'] = filtered_annotations
- dd['start_times'] = np.array([aa['start_time'] for aa in dd['annotation']])
- dd['end_times'] = np.array([aa['end_time'] for aa in dd['annotation']])
- dd['high_freqs'] = np.array([float(aa['high_freq']) for aa in dd['annotation']])
- dd['low_freqs'] = np.array([float(aa['low_freq']) for aa in dd['annotation']])
- dd['class_ids'] = np.array([aa['class_id'] for aa in dd['annotation']]).astype(np.int)
- dd['individual_ids'] = np.array([aa['individual'] for aa in dd['annotation']]).astype(np.int)
+ dd["annotation"] = filtered_annotations
+ dd["start_times"] = np.array(
+ [aa["start_time"] for aa in dd["annotation"]]
+ )
+ dd["end_times"] = np.array(
+ [aa["end_time"] for aa in dd["annotation"]]
+ )
+ dd["high_freqs"] = np.array(
+ [float(aa["high_freq"]) for aa in dd["annotation"]]
+ )
+ dd["low_freqs"] = np.array(
+ [float(aa["low_freq"]) for aa in dd["annotation"]]
+ )
+ dd["class_ids"] = np.array(
+ [aa["class_id"] for aa in dd["annotation"]]
+ ).astype(np.int)
+ dd["individual_ids"] = np.array(
+ [aa["individual"] for aa in dd["annotation"]]
+ ).astype(np.int)
# file level class name
- dd['class_id_file'] = -1
- if 'class_name' in dd.keys():
- if dd['class_name'] in self.params['class_names']:
- dd['class_id_file'] = self.params['class_names'].index(dd['class_name'])
+ dd["class_id_file"] = -1
+ if "class_name" in dd.keys():
+ if dd["class_name"] in self.params["class_names"]:
+ dd["class_id_file"] = self.params["class_names"].index(
+ dd["class_name"]
+ )
self.data_anns.append(dd)
- ann_cnt = [len(aa['annotation']) for aa in self.data_anns]
- self.max_num_anns = 2*np.max(ann_cnt) # x2 because we may be combining files during training
+ ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
+ self.max_num_anns = 2 * np.max(
+ ann_cnt
+ ) # x2 because we may be combining files during training
- print('\n')
+ print("\n")
if dataset_name is not None:
- print('Dataset : ' + dataset_name)
+ print("Dataset : " + dataset_name)
if self.is_train:
- print('Split type : train')
+ print("Split type : train")
else:
- print('Split type : test')
- print('Num files : ' + str(len(self.data_anns)))
- print('Num calls : ' + str(np.sum(ann_cnt)))
-
+ print("Split type : test")
+ print("Num files : " + str(len(self.data_anns)))
+ print("Num calls : " + str(np.sum(ann_cnt)))
def get_file_and_anns(self, index=None):
@@ -281,110 +374,171 @@ class AudioLoader(torch.utils.data.Dataset):
if index == None:
index = np.random.randint(0, len(self.data_anns))
- audio_file = self.data_anns[index]['file_path']
- sampling_rate, audio_raw = au.load_audio_file(audio_file, self.data_anns[index]['time_exp'],
- self.params['target_samp_rate'], self.params['scale_raw_audio'])
+ audio_file = self.data_anns[index]["file_path"]
+ sampling_rate, audio_raw = au.load_audio_file(
+ audio_file,
+ self.data_anns[index]["time_exp"],
+ self.params["target_samp_rate"],
+ self.params["scale_raw_audio"],
+ )
# copy annotation
ann = {}
- ann['annotated'] = self.data_anns[index]['annotated']
- ann['class_id_file'] = self.data_anns[index]['class_id_file']
- keys = ['start_times', 'end_times', 'high_freqs', 'low_freqs', 'class_ids', 'individual_ids']
+ ann["annotated"] = self.data_anns[index]["annotated"]
+ ann["class_id_file"] = self.data_anns[index]["class_id_file"]
+ keys = [
+ "start_times",
+ "end_times",
+ "high_freqs",
+ "low_freqs",
+ "class_ids",
+ "individual_ids",
+ ]
for kk in keys:
ann[kk] = self.data_anns[index][kk].copy()
# if train then grab a random crop
if self.is_train:
- nfft = int(self.params['fft_win_length']*sampling_rate)
- noverlap = int(self.params['fft_overlap']*nfft)
- length_samples = self.params['spec_train_width']*(nfft - noverlap) + noverlap
+ nfft = int(self.params["fft_win_length"] * sampling_rate)
+ noverlap = int(self.params["fft_overlap"] * nfft)
+ length_samples = (
+ self.params["spec_train_width"] * (nfft - noverlap) + noverlap
+ )
if audio_raw.shape[0] - length_samples > 0:
- sample_crop = np.random.randint(audio_raw.shape[0] - length_samples)
+ sample_crop = np.random.randint(
+ audio_raw.shape[0] - length_samples
+ )
else:
sample_crop = 0
- audio_raw = audio_raw[sample_crop:sample_crop+length_samples]
- ann['start_times'] = ann['start_times'] - sample_crop/float(sampling_rate)
- ann['end_times'] = ann['end_times'] - sample_crop/float(sampling_rate)
+ audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
+ ann["start_times"] = ann["start_times"] - sample_crop / float(
+ sampling_rate
+ )
+ ann["end_times"] = ann["end_times"] - sample_crop / float(
+ sampling_rate
+ )
# pad audio
if self.is_train:
- op_spec_target_size = self.params['spec_train_width']
+ op_spec_target_size = self.params["spec_train_width"]
else:
op_spec_target_size = None
- audio_raw = au.pad_audio(audio_raw, sampling_rate, self.params['fft_win_length'],
- self.params['fft_overlap'], self.params['resize_factor'],
- self.params['spec_divide_factor'], op_spec_target_size)
+ audio_raw = au.pad_audio(
+ audio_raw,
+ sampling_rate,
+ self.params["fft_win_length"],
+ self.params["fft_overlap"],
+ self.params["resize_factor"],
+ self.params["spec_divide_factor"],
+ op_spec_target_size,
+ )
duration = audio_raw.shape[0] / float(sampling_rate)
# sort based on time
- inds = np.argsort(ann['start_times'])
+ inds = np.argsort(ann["start_times"])
for kk in ann.keys():
- if (kk != 'class_id_file') and (kk != 'annotated'):
+ if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = ann[kk][inds]
return audio_raw, sampling_rate, duration, ann
-
def __getitem__(self, index):
# load audio file
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
# augment on raw audio
- if self.is_train and self.params['augment_at_train']:
+ if self.is_train and self.params["augment_at_train"]:
# augment - combine with random audio file
- if self.params['augment_at_train_combine'] and np.random.random() < self.params['aug_prob']:
- audio2, sampling_rate2, duration2, ann2 = self.get_file_and_anns()
- audio, ann = combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2)
+ if (
+ self.params["augment_at_train_combine"]
+ and np.random.random() < self.params["aug_prob"]
+ ):
+ (
+ audio2,
+ sampling_rate2,
+ duration2,
+ ann2,
+ ) = self.get_file_and_anns()
+ audio, ann = combine_audio_aug(
+ audio, sampling_rate, ann, audio2, sampling_rate2, ann2
+ )
# simulate echo by adding delayed copy of the file
- if np.random.random() < self.params['aug_prob']:
+ if np.random.random() < self.params["aug_prob"]:
audio = echo_aug(audio, sampling_rate, self.params)
# resample the audio
- #if np.random.random() < self.params['aug_prob']:
+ # if np.random.random() < self.params['aug_prob']:
# audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params)
# create spectrogram
- spec, spec_for_viz = au.generate_spectrogram(audio, sampling_rate, self.params, self.return_spec_for_viz)
- rsf = self.params['resize_factor']
- spec_op_shape = (int(self.params['spec_height']*rsf), int(spec.shape[1]*rsf))
+ spec, spec_for_viz = au.generate_spectrogram(
+ audio, sampling_rate, self.params, self.return_spec_for_viz
+ )
+ rsf = self.params["resize_factor"]
+ spec_op_shape = (
+ int(self.params["spec_height"] * rsf),
+ int(spec.shape[1] * rsf),
+ )
# resize the spec
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
- spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False).squeeze(0)
+ spec = F.interpolate(
+ spec, size=spec_op_shape, mode="bilinear", align_corners=False
+ ).squeeze(0)
# augment spectrogram
- if self.is_train and self.params['augment_at_train']:
+ if self.is_train and self.params["augment_at_train"]:
- if np.random.random() < self.params['aug_prob']:
+ if np.random.random() < self.params["aug_prob"]:
spec = scale_vol_aug(spec, self.params)
- if np.random.random() < self.params['aug_prob']:
- spec = warp_spec_aug(spec, ann, self.return_spec_for_viz, self.params)
+ if np.random.random() < self.params["aug_prob"]:
+ spec = warp_spec_aug(
+ spec, ann, self.return_spec_for_viz, self.params
+ )
- if np.random.random() < self.params['aug_prob']:
+ if np.random.random() < self.params["aug_prob"]:
spec = mask_time_aug(spec, self.params)
- if np.random.random() < self.params['aug_prob']:
+ if np.random.random() < self.params["aug_prob"]:
spec = mask_freq_aug(spec, self.params)
outputs = {}
- outputs['spec'] = spec
+ outputs["spec"] = spec
if self.return_spec_for_viz:
- outputs['spec_for_viz'] = torch.from_numpy(spec_for_viz).unsqueeze(0)
+ outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
+ 0
+ )
# create ground truth heatmaps
- outputs['y_2d_det'], outputs['y_2d_size'], outputs['y_2d_classes'], ann_aug =\
- generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params)
+ (
+ outputs["y_2d_det"],
+ outputs["y_2d_size"],
+ outputs["y_2d_classes"],
+ ann_aug,
+ ) = generate_gt_heatmaps(
+ spec_op_shape, sampling_rate, ann, self.params
+ )
# hack to get around requirement that all vectors are the same length in
# the output batch
- pad_size = self.max_num_anns-len(ann_aug['individual_ids'])
- outputs['is_valid'] = pad_aray(np.ones(len(ann_aug['individual_ids'])), pad_size)
- keys = ['class_ids', 'individual_ids', 'x_inds', 'y_inds',
- 'start_times', 'end_times', 'low_freqs', 'high_freqs']
+ pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
+ outputs["is_valid"] = pad_aray(
+ np.ones(len(ann_aug["individual_ids"])), pad_size
+ )
+ keys = [
+ "class_ids",
+ "individual_ids",
+ "x_inds",
+ "y_inds",
+ "start_times",
+ "end_times",
+ "low_freqs",
+ "high_freqs",
+ ]
for kk in keys:
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
@@ -394,14 +548,13 @@ class AudioLoader(torch.utils.data.Dataset):
outputs[kk] = torch.from_numpy(outputs[kk])
# scalars
- outputs['class_id_file'] = ann['class_id_file']
- outputs['annotated'] = ann['annotated']
- outputs['duration'] = duration
- outputs['sampling_rate'] = sampling_rate
- outputs['file_id'] = index
+ outputs["class_id_file"] = ann["class_id_file"]
+ outputs["annotated"] = ann["annotated"]
+ outputs["duration"] = duration
+ outputs["sampling_rate"] = sampling_rate
+ outputs["file_id"] = index
return outputs
-
def __len__(self):
return len(self.data_anns)
diff --git a/bat_detect/train/evaluate.py b/bat_detect/train/evaluate.py
index b88719f..a926fbb 100755
--- a/bat_detect/train/evaluate.py
+++ b/bat_detect/train/evaluate.py
@@ -1,6 +1,10 @@
import numpy as np
-from sklearn.metrics import roc_curve, auc
-from sklearn.metrics import accuracy_score, balanced_accuracy_score
+from sklearn.metrics import (
+ accuracy_score,
+ auc,
+ balanced_accuracy_score,
+ roc_curve,
+)
def compute_error_auc(op_str, gt, pred, prob):
@@ -13,8 +17,11 @@ def compute_error_auc(op_str, gt, pred, prob):
fpr, tpr, thresholds = roc_curve(gt, pred)
roc_auc = auc(fpr, tpr)
- print(op_str + ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc))
- #return class_acc, roc_auc
+ print(
+ op_str
+ + ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc)
+ )
+ # return class_acc, roc_auc
def calc_average_precision(recall, precision):
@@ -25,10 +32,10 @@ def calc_average_precision(recall, precision):
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
- for ii in range(mprec.shape[0]-2, -1,-1):
- mprec[ii] = np.maximum(mprec[ii], mprec[ii+1])
- inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0]+1
- ave_prec = ((mrec[inds] - mrec[inds-1])*mprec[inds]).sum()
+ for ii in range(mprec.shape[0] - 2, -1, -1):
+ mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
+ inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
+ ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return float(ave_prec)
@@ -37,7 +44,7 @@ def calc_recall_at_x(recall, precision, x=0.95):
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
- inds = np.where(precision[::-1]>x)[0]
+ inds = np.where(precision[::-1] > x)[0]
if len(inds) > 0:
return float(recall[::-1][inds[0]])
else:
@@ -51,7 +58,15 @@ def compute_affinity_1d(pred_box, gt_boxes, threshold):
return valid_detection, np.argmin(score)
-def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, threshold, ignore_start_end):
+def compute_pre_rec(
+ gts,
+ preds,
+ eval_mode,
+ class_of_interest,
+ num_classes,
+ threshold,
+ ignore_start_end,
+):
"""
Computes precision and recall. Assumes that each file has been exhaustively
annotated. Will not count predicted detection with a start time that is within
@@ -78,26 +93,40 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
for pid, pp in enumerate(preds):
# filter predicted calls that are too near the start or end of the file
- file_dur = gts[pid]['duration']
- valid_inds = (pp['start_times'] >= ignore_start_end) & (pp['start_times'] <= (file_dur - ignore_start_end))
+ file_dur = gts[pid]["duration"]
+ valid_inds = (pp["start_times"] >= ignore_start_end) & (
+ pp["start_times"] <= (file_dur - ignore_start_end)
+ )
- pred_boxes.append(np.vstack((pp['start_times'][valid_inds], pp['end_times'][valid_inds],
- pp['low_freqs'][valid_inds], pp['high_freqs'][valid_inds])).T)
+ pred_boxes.append(
+ np.vstack(
+ (
+ pp["start_times"][valid_inds],
+ pp["end_times"][valid_inds],
+ pp["low_freqs"][valid_inds],
+ pp["high_freqs"][valid_inds],
+ )
+ ).T
+ )
- if eval_mode == 'detection':
+ if eval_mode == "detection":
# overall detection
- confidence.append(pp['det_probs'][valid_inds])
- elif eval_mode == 'per_class':
+ confidence.append(pp["det_probs"][valid_inds])
+ elif eval_mode == "per_class":
# per class
- confidence.append(pp['class_probs'].T[valid_inds, class_of_interest])
- elif eval_mode == 'top_class':
+ confidence.append(
+ pp["class_probs"].T[valid_inds, class_of_interest]
+ )
+ elif eval_mode == "top_class":
# per class - note that sometimes 'class_probs' can be num_classes+1 in size
- top_class = np.argmax(pp['class_probs'].T[valid_inds, :num_classes], 1)
- confidence.append(pp['class_probs'].T[valid_inds, top_class])
+ top_class = np.argmax(
+ pp["class_probs"].T[valid_inds, :num_classes], 1
+ )
+ confidence.append(pp["class_probs"].T[valid_inds, top_class])
pred_class.append(top_class)
# be careful, assuming the order in the list is same as GT
- file_ids.append([pid]*valid_inds.sum())
+ file_ids.append([pid] * valid_inds.sum())
confidence = np.hstack(confidence)
file_ids = np.hstack(file_ids).astype(np.int)
@@ -105,7 +134,6 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
if len(pred_class) > 0:
pred_class = np.hstack(pred_class)
-
# extract relevant ground truth boxes
gt_boxes = []
gt_assigned = []
@@ -115,32 +143,42 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
for gg in gts:
# filter ground truth calls that are too near the start or end of the file
- file_dur = gg['duration']
- valid_inds = (gg['start_times'] >= ignore_start_end) & (gg['start_times'] <= (file_dur - ignore_start_end))
+ file_dur = gg["duration"]
+ valid_inds = (gg["start_times"] >= ignore_start_end) & (
+ gg["start_times"] <= (file_dur - ignore_start_end)
+ )
# note, files with the incorrect duration will cause a problem
- if (gg['start_times'] > file_dur).sum() > 0:
- print('Error: file duration incorrect for', gg['id'])
- assert(False)
+ if (gg["start_times"] > file_dur).sum() > 0:
+ print("Error: file duration incorrect for", gg["id"])
+ assert False
- boxes = np.vstack((gg['start_times'][valid_inds], gg['end_times'][valid_inds],
- gg['low_freqs'][valid_inds], gg['high_freqs'][valid_inds])).T
- gen_class = gg['class_ids'][valid_inds] == -1
- class_ids = gg['class_ids'][valid_inds]
+ boxes = np.vstack(
+ (
+ gg["start_times"][valid_inds],
+ gg["end_times"][valid_inds],
+ gg["low_freqs"][valid_inds],
+ gg["high_freqs"][valid_inds],
+ )
+ ).T
+ gen_class = gg["class_ids"][valid_inds] == -1
+ class_ids = gg["class_ids"][valid_inds]
# keep track of the number of relevant ground truth calls
- if eval_mode == 'detection':
+ if eval_mode == "detection":
# all valid ones
- num_positives += len(gg['start_times'][valid_inds])
- elif eval_mode == 'per_class':
+ num_positives += len(gg["start_times"][valid_inds])
+ elif eval_mode == "per_class":
# all valid ones with class of interest
- num_positives += (gg['class_ids'][valid_inds] == class_of_interest).sum()
- elif eval_mode == 'top_class':
+ num_positives += (
+ gg["class_ids"][valid_inds] == class_of_interest
+ ).sum()
+ elif eval_mode == "top_class":
# all valid ones with non generic class
- num_positives += (gg['class_ids'][valid_inds] > -1).sum()
+ num_positives += (gg["class_ids"][valid_inds] > -1).sum()
# find relevant classes (i.e. class_of_interest) and events without known class (i.e. generic class, -1)
- if eval_mode == 'per_class':
+ if eval_mode == "per_class":
class_inds = (class_ids == class_of_interest) | (class_ids == -1)
boxes = boxes[class_inds, :]
gen_class = gen_class[class_inds]
@@ -151,25 +189,27 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
gt_generic_class.append(gen_class)
gt_class.append(class_ids)
-
# loop through detections and keep track of those that have been assigned
- true_pos = np.zeros(confidence.shape[0])
- valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True
- sorted_inds = np.argsort(confidence)[::-1] # sort high to low
+ true_pos = np.zeros(confidence.shape[0])
+ valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True
+ sorted_inds = np.argsort(confidence)[::-1] # sort high to low
for ii, ind in enumerate(sorted_inds):
gt_id = file_ids[ind]
valid_det = False
if gt_boxes[gt_id].shape[0] > 0:
# compute overlap
- valid_det, det_ind = compute_affinity_1d(pred_boxes[ind], gt_boxes[gt_id],
- threshold)
+ valid_det, det_ind = compute_affinity_1d(
+ pred_boxes[ind], gt_boxes[gt_id], threshold
+ )
# valid detection that has not already been assigned
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
count_as_true_pos = True
- if eval_mode == 'top_class' and (gt_class[gt_id][det_ind] != pred_class[ind]):
+ if eval_mode == "top_class" and (
+ gt_class[gt_id][det_ind] != pred_class[ind]
+ ):
# needs to be the same class
count_as_true_pos = False
@@ -181,40 +221,43 @@ def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, thres
# if event is generic class (i.e. gt_generic_class[gt_id][det_ind] is True)
# and eval_mode != 'detection', then ignore it
if gt_generic_class[gt_id][det_ind]:
- if eval_mode == 'per_class' or eval_mode == 'top_class':
+ if eval_mode == "per_class" or eval_mode == "top_class":
valid_inds[ii] = False
-
# store threshold values - used for plotting
conf_sorted = np.sort(confidence)[::-1][valid_inds]
thresholds = np.linspace(0.1, 0.9, 9)
thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
for ii, tt in enumerate(thresholds):
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
- thresholds_inds[thresholds_inds==0] = -1
+ thresholds_inds[thresholds_inds == 0] = -1
# compute precision and recall
- true_pos = true_pos[valid_inds]
- false_pos_c = np.cumsum(1-true_pos)
- true_pos_c = np.cumsum(true_pos)
+ true_pos = true_pos[valid_inds]
+ false_pos_c = np.cumsum(1 - true_pos)
+ true_pos_c = np.cumsum(true_pos)
recall = true_pos_c / num_positives
- precision = true_pos_c / np.maximum(true_pos_c + false_pos_c, np.finfo(np.float64).eps)
+ precision = true_pos_c / np.maximum(
+ true_pos_c + false_pos_c, np.finfo(np.float64).eps
+ )
results = {}
- results['recall'] = recall
- results['precision'] = precision
- results['num_gt'] = num_positives
+ results["recall"] = recall
+ results["precision"] = precision
+ results["num_gt"] = num_positives
- results['thresholds'] = thresholds
- results['thresholds_inds'] = thresholds_inds
+ results["thresholds"] = thresholds
+ results["thresholds_inds"] = thresholds_inds
if num_positives == 0:
- results['avg_prec'] = np.nan
- results['rec_at_x'] = np.nan
+ results["avg_prec"] = np.nan
+ results["rec_at_x"] = np.nan
else:
- results['avg_prec'] = np.round(calc_average_precision(recall, precision), 5)
- results['rec_at_x'] = np.round(calc_recall_at_x(recall, precision), 5)
+ results["avg_prec"] = np.round(
+ calc_average_precision(recall, precision), 5
+ )
+ results["rec_at_x"] = np.round(calc_recall_at_x(recall, precision), 5)
return results
@@ -230,19 +273,19 @@ def compute_file_accuracy_simple(gts, preds, num_classes):
gt_valid = []
pred_valid = []
for ii in range(len(gts)):
- gt_class = np.unique(gts[ii]['class_ids'])
+ gt_class = np.unique(gts[ii]["class_ids"])
if len(gt_class) == 1 and gt_class[0] != -1:
gt_valid.append(gt_class[0])
- pred = preds[ii]['class_probs'][:num_classes, :].T
+ pred = preds[ii]["class_probs"][:num_classes, :].T
pred_valid.append(np.argmax(pred.mean(0)))
acc = (np.array(gt_valid) == np.array(pred_valid)).mean()
res = {}
- res['num_valid_files'] = len(gt_valid)
- res['num_total_files'] = len(gts)
- res['gt_valid_file'] = gt_valid
- res['pred_valid_file'] = pred_valid
- res['file_acc'] = np.round(acc, 5)
+ res["num_valid_files"] = len(gt_valid)
+ res["num_total_files"] = len(gts)
+ res["gt_valid_file"] = gt_valid
+ res["pred_valid_file"] = pred_valid
+ res["file_acc"] = np.round(acc, 5)
return res
@@ -256,12 +299,20 @@ def compute_file_accuracy(gts, preds, num_classes):
# compute min and max scoring range - then threshold
min_val = 0
- mins = [pp['class_probs'].min() for pp in preds if pp['class_probs'].shape[1] > 0]
+ mins = [
+ pp["class_probs"].min()
+ for pp in preds
+ if pp["class_probs"].shape[1] > 0
+ ]
if len(mins) > 0:
min_val = np.min(mins)
max_val = 1.0
- maxes = [pp['class_probs'].max() for pp in preds if pp['class_probs'].shape[1] > 0]
+ maxes = [
+ pp["class_probs"].max()
+ for pp in preds
+ if pp["class_probs"].shape[1] > 0
+ ]
if len(maxes) > 0:
max_val = np.max(maxes)
@@ -272,33 +323,37 @@ def compute_file_accuracy(gts, preds, num_classes):
gt_valid = []
pred_valid_all = []
for ii in range(len(gts)):
- gt_class = np.unique(gts[ii]['class_ids'])
+ gt_class = np.unique(gts[ii]["class_ids"])
if len(gt_class) == 1 and gt_class[0] != -1:
gt_valid.append(gt_class[0])
- pred = preds[ii]['class_probs'][:num_classes, :].T
+ pred = preds[ii]["class_probs"][:num_classes, :].T
p_class = np.zeros(len(thresh))
for tt in range(len(thresh)):
- p_class[tt] = (pred*(pred>=thresh[tt])).sum(0).argmax()
+ p_class[tt] = (pred * (pred >= thresh[tt])).sum(0).argmax()
pred_valid_all.append(p_class)
# pick the result corresponding to the overall best threshold
pred_valid_all = np.vstack(pred_valid_all)
- acc_per_thresh = (np.array(gt_valid)[..., np.newaxis] == pred_valid_all).mean(0)
+ acc_per_thresh = (
+ np.array(gt_valid)[..., np.newaxis] == pred_valid_all
+ ).mean(0)
best_thresh = np.argmax(acc_per_thresh)
best_acc = acc_per_thresh[best_thresh]
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
res = {}
- res['num_valid_files'] = len(gt_valid)
- res['num_total_files'] = len(gts)
- res['gt_valid_file'] = gt_valid
- res['pred_valid_file'] = pred_valid
- res['file_acc'] = np.round(best_acc, 5)
+ res["num_valid_files"] = len(gt_valid)
+ res["num_total_files"] = len(gts)
+ res["gt_valid_file"] = gt_valid
+ res["pred_valid_file"] = pred_valid
+ res["file_acc"] = np.round(best_acc, 5)
return res
-def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_start_end=0.0):
+def evaluate_predictions(
+ gts, preds, class_names, detection_overlap, ignore_start_end=0.0
+):
"""
Computes metrics derived from the precision and recall.
Assumes that gts and preds are both lists of the same lengths, with ground
@@ -307,24 +362,50 @@ def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_star
Returns the overall detection results, and per class results
"""
- assert(len(gts) == len(preds))
+ assert len(gts) == len(preds)
num_classes = len(class_names)
# evaluate detection on its own i.e. ignoring class
- det_results = compute_pre_rec(gts, preds, 'detection', None, num_classes, detection_overlap, ignore_start_end)
- top_class = compute_pre_rec(gts, preds, 'top_class', None, num_classes, detection_overlap, ignore_start_end)
- det_results['top_class'] = top_class
+ det_results = compute_pre_rec(
+ gts,
+ preds,
+ "detection",
+ None,
+ num_classes,
+ detection_overlap,
+ ignore_start_end,
+ )
+ top_class = compute_pre_rec(
+ gts,
+ preds,
+ "top_class",
+ None,
+ num_classes,
+ detection_overlap,
+ ignore_start_end,
+ )
+ det_results["top_class"] = top_class
# per class evaluation
- det_results['class_pr'] = []
+ det_results["class_pr"] = []
for cc in range(num_classes):
- res = compute_pre_rec(gts, preds, 'per_class', cc, num_classes, detection_overlap, ignore_start_end)
- res['name'] = class_names[cc]
- det_results['class_pr'].append(res)
+ res = compute_pre_rec(
+ gts,
+ preds,
+ "per_class",
+ cc,
+ num_classes,
+ detection_overlap,
+ ignore_start_end,
+ )
+ res["name"] = class_names[cc]
+ det_results["class_pr"].append(res)
# ignores classes that are not present in the test set
- det_results['avg_prec_class'] = np.mean([rs['avg_prec'] for rs in det_results['class_pr'] if rs['num_gt'] > 0])
- det_results['avg_prec_class'] = np.round(det_results['avg_prec_class'], 5)
+ det_results["avg_prec_class"] = np.mean(
+ [rs["avg_prec"] for rs in det_results["class_pr"] if rs["num_gt"] > 0]
+ )
+ det_results["avg_prec_class"] = np.round(det_results["avg_prec_class"], 5)
# file level evaluation
res_file = compute_file_accuracy(gts, preds, num_classes)
diff --git a/bat_detect/train/losses.py b/bat_detect/train/losses.py
index aaef2c4..02bfdd6 100644
--- a/bat_detect/train/losses.py
+++ b/bat_detect/train/losses.py
@@ -7,7 +7,9 @@ def bbox_size_loss(pred_size, gt_size):
Bounding box size loss. Only compute loss where there is a bounding box.
"""
gt_size_mask = (gt_size > 0).float()
- return (F.l1_loss(pred_size*gt_size_mask, gt_size, reduction='sum') / (gt_size_mask.sum() + 1e-5))
+ return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / (
+ gt_size_mask.sum() + 1e-5
+ )
def focal_loss(pred, gt, weights=None, valid_mask=None):
@@ -24,20 +26,25 @@ def focal_loss(pred, gt, weights=None, valid_mask=None):
neg_inds = gt.lt(1).float()
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
- neg_loss = torch.log(1 - pred + eps) * torch.pow(pred, alpha) * torch.pow(1 - gt, beta) * neg_inds
+ neg_loss = (
+ torch.log(1 - pred + eps)
+ * torch.pow(pred, alpha)
+ * torch.pow(1 - gt, beta)
+ * neg_inds
+ )
if weights is not None:
- pos_loss = pos_loss*weights
- #neg_loss = neg_loss*weights
+ pos_loss = pos_loss * weights
+ # neg_loss = neg_loss*weights
if valid_mask is not None:
- pos_loss = pos_loss*valid_mask
- neg_loss = neg_loss*valid_mask
+ pos_loss = pos_loss * valid_mask
+ neg_loss = neg_loss * valid_mask
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
- num_pos = pos_inds.float().sum()
+ num_pos = pos_inds.float().sum()
if num_pos == 0:
loss = -neg_loss
else:
@@ -47,10 +54,10 @@ def focal_loss(pred, gt, weights=None, valid_mask=None):
def mse_loss(pred, gt, weights=None, valid_mask=None):
"""
- Mean squared error loss.
+ Mean squared error loss.
"""
if valid_mask is None:
- op = ((gt-pred)**2).mean()
+ op = ((gt - pred) ** 2).mean()
else:
- op = (valid_mask*((gt-pred)**2)).sum() / valid_mask.sum()
+ op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
return op
diff --git a/bat_detect/train/train_model.py b/bat_detect/train/train_model.py
index d955216..2fd33fe 100644
--- a/bat_detect/train/train_model.py
+++ b/bat_detect/train/train_model.py
@@ -1,32 +1,33 @@
-import numpy as np
-import matplotlib.pyplot as plt
+import argparse
+import json
import os
+import sys
+
+import matplotlib.pyplot as plt
+import numpy as np
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
-import json
-import argparse
-import sys
-sys.path.append(os.path.join('..', '..'))
-
-import bat_detect.detector.parameters as parameters
-import bat_detect.detector.models as models
-import bat_detect.detector.post_process as pp
-import bat_detect.utils.plot_utils as pu
-
-import bat_detect.train.audio_dataloader as adl
-import bat_detect.train.evaluate as evl
-import bat_detect.train.train_utils as tu
-import bat_detect.train.train_split as ts
-import bat_detect.train.losses as losses
+sys.path.append(os.path.join("..", ".."))
import warnings
+
+import bat_detect.detector.models as models
+import bat_detect.detector.parameters as parameters
+import bat_detect.detector.post_process as pp
+import bat_detect.train.audio_dataloader as adl
+import bat_detect.train.evaluate as evl
+import bat_detect.train.losses as losses
+import bat_detect.train.train_split as ts
+import bat_detect.train.train_utils as tu
+import bat_detect.utils.plot_utils as pu
+
warnings.filterwarnings("ignore", category=UserWarning)
def save_images_batch(model, data_loader, params):
- print('\nsaving images ...')
+ print("\nsaving images ...")
is_train_state = data_loader.dataset.is_train
data_loader.dataset.is_train = False
@@ -36,67 +37,112 @@ def save_images_batch(model, data_loader, params):
ind = 0 # first image in each batch
with torch.no_grad():
for batch_idx, inputs in enumerate(data_loader):
- data = inputs['spec'].to(params['device'])
+ data = inputs["spec"].to(params["device"])
outputs = model(data)
- spec_viz = inputs['spec_for_viz'].data.cpu().numpy()
- orig_index = inputs['file_id'][ind]
- plot_title = data_loader.dataset.data_anns[orig_index]['id']
- op_file_name = params['op_im_dir_test'] + data_loader.dataset.data_anns[orig_index]['id'] + '.jpg'
- save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title)
+ spec_viz = inputs["spec_for_viz"].data.cpu().numpy()
+ orig_index = inputs["file_id"][ind]
+ plot_title = data_loader.dataset.data_anns[orig_index]["id"]
+ op_file_name = (
+ params["op_im_dir_test"]
+ + data_loader.dataset.data_anns[orig_index]["id"]
+ + ".jpg"
+ )
+ save_image(
+ spec_viz,
+ outputs,
+ ind,
+ inputs,
+ params,
+ op_file_name,
+ plot_title,
+ )
data_loader.dataset.is_train = is_train_state
data_loader.dataset.return_spec_for_viz = False
-def save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title):
- pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float())
- pred_hm = outputs['pred_det'][ind, 0, :].data.cpu().numpy()
+def save_image(
+ spec_viz, outputs, ind, inputs, params, op_file_name, plot_title
+):
+ pred_nms, _ = pp.run_nms(outputs, params, inputs["sampling_rate"].float())
+ pred_hm = outputs["pred_det"][ind, 0, :].data.cpu().numpy()
spec_viz = spec_viz[ind, 0, :]
- gt = parse_gt_data(inputs)[ind]
- sampling_rate = inputs['sampling_rate'][ind].item()
- duration = inputs['duration'][ind].item()
+ gt = parse_gt_data(inputs)[ind]
+ sampling_rate = inputs["sampling_rate"][ind].item()
+ duration = inputs["duration"][ind].item()
- pu.plot_spec(spec_viz, sampling_rate, duration, gt, pred_nms[ind],
- params, plot_title, op_file_name, pred_hm, plot_boxes=True, fixed_aspect=False)
+ pu.plot_spec(
+ spec_viz,
+ sampling_rate,
+ duration,
+ gt,
+ pred_nms[ind],
+ params,
+ plot_title,
+ op_file_name,
+ pred_hm,
+ plot_boxes=True,
+ fixed_aspect=False,
+ )
-def loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq):
+def loss_fun(
+ outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
+):
# detection loss
- loss = params['det_loss_weight']*det_criterion(outputs['pred_det'], gt_det)
+ loss = params["det_loss_weight"] * det_criterion(
+ outputs["pred_det"], gt_det
+ )
# bounding box size loss
- loss += params['size_loss_weight']*losses.bbox_size_loss(outputs['pred_size'], gt_size)
+ loss += params["size_loss_weight"] * losses.bbox_size_loss(
+ outputs["pred_size"], gt_size
+ )
# classification loss
valid_mask = (gt_class[:, :-1, :, :].sum(1) > 0).float().unsqueeze(1)
- p_class = outputs['pred_class'][:, :-1, :]
- loss += params['class_loss_weight']*det_criterion(p_class, gt_class[:, :-1, :], valid_mask=valid_mask)
+ p_class = outputs["pred_class"][:, :-1, :]
+ loss += params["class_loss_weight"] * det_criterion(
+ p_class, gt_class[:, :-1, :], valid_mask=valid_mask
+ )
return loss
-def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params):
+def train(
+ model, epoch, data_loader, det_criterion, optimizer, scheduler, params
+):
model.train()
train_loss = tu.AverageMeter()
- class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device'])
+ class_inv_freq = torch.from_numpy(
+ np.array(params["class_inv_freq"], dtype=np.float32)
+ ).to(params["device"])
class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2)
- print('\nEpoch', epoch)
+ print("\nEpoch", epoch)
for batch_idx, inputs in enumerate(data_loader):
- data = inputs['spec'].to(params['device'])
- gt_det = inputs['y_2d_det'].to(params['device'])
- gt_size = inputs['y_2d_size'].to(params['device'])
- gt_class = inputs['y_2d_classes'].to(params['device'])
+ data = inputs["spec"].to(params["device"])
+ gt_det = inputs["y_2d_det"].to(params["device"])
+ gt_size = inputs["y_2d_size"].to(params["device"])
+ gt_class = inputs["y_2d_classes"].to(params["device"])
optimizer.zero_grad()
outputs = model(data)
- loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq)
+ loss = loss_fun(
+ outputs,
+ gt_det,
+ gt_size,
+ gt_class,
+ det_criterion,
+ params,
+ class_inv_freq,
+ )
train_loss.update(loss.item(), data.shape[0])
loss.backward()
@@ -104,13 +150,18 @@ def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params
scheduler.step()
if batch_idx % 50 == 0 and batch_idx != 0:
- print('[{}/{}]\tLoss: {:.4f}'.format(
- batch_idx * len(data), len(data_loader.dataset), train_loss.avg))
+ print(
+ "[{}/{}]\tLoss: {:.4f}".format(
+ batch_idx * len(data),
+ len(data_loader.dataset),
+ train_loss.avg,
+ )
+ )
- print('Train loss : {:.4f}'.format(train_loss.avg))
+ print("Train loss : {:.4f}".format(train_loss.avg))
res = {}
- res['train_loss'] = float(train_loss.avg)
+ res["train_loss"] = float(train_loss.avg)
return res
@@ -120,16 +171,18 @@ def test(model, epoch, data_loader, det_criterion, params):
ground_truths = []
test_loss = tu.AverageMeter()
- class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device'])
+ class_inv_freq = torch.from_numpy(
+ np.array(params["class_inv_freq"], dtype=np.float32)
+ ).to(params["device"])
class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2)
with torch.no_grad():
for batch_idx, inputs in enumerate(data_loader):
- data = inputs['spec'].to(params['device'])
- gt_det = inputs['y_2d_det'].to(params['device'])
- gt_size = inputs['y_2d_size'].to(params['device'])
- gt_class = inputs['y_2d_classes'].to(params['device'])
+ data = inputs["spec"].to(params["device"])
+ gt_det = inputs["y_2d_det"].to(params["device"])
+ gt_size = inputs["y_2d_size"].to(params["device"])
+ gt_class = inputs["y_2d_classes"].to(params["device"])
outputs = model(data)
@@ -139,41 +192,79 @@ def test(model, epoch, data_loader, det_criterion, params):
# for kk in ['pred_det', 'pred_size', 'pred_class']:
# outputs[kk] = torch.cat([oo for oo in outputs[kk]], 2).unsqueeze(0)
- if params['save_test_image_during_train'] and batch_idx == 0:
+ if params["save_test_image_during_train"] and batch_idx == 0:
# for visualization - save the first prediction
ind = 0
- orig_index = inputs['file_id'][ind]
- plot_title = data_loader.dataset.data_anns[orig_index]['id']
- op_file_name = params['op_im_dir'] + str(orig_index.item()).zfill(4) + '_' + str(epoch).zfill(4) + '_pred.jpg'
- save_image(data, outputs, ind, inputs, params, op_file_name, plot_title)
+ orig_index = inputs["file_id"][ind]
+ plot_title = data_loader.dataset.data_anns[orig_index]["id"]
+ op_file_name = (
+ params["op_im_dir"]
+ + str(orig_index.item()).zfill(4)
+ + "_"
+ + str(epoch).zfill(4)
+ + "_pred.jpg"
+ )
+ save_image(
+ data,
+ outputs,
+ ind,
+ inputs,
+ params,
+ op_file_name,
+ plot_title,
+ )
- loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq)
+ loss = loss_fun(
+ outputs,
+ gt_det,
+ gt_size,
+ gt_class,
+ det_criterion,
+ params,
+ class_inv_freq,
+ )
test_loss.update(loss.item(), data.shape[0])
# do NMS
- pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float())
+ pred_nms, _ = pp.run_nms(
+ outputs, params, inputs["sampling_rate"].float()
+ )
predictions.extend(pred_nms)
ground_truths.extend(parse_gt_data(inputs))
- res_det = evl.evaluate_predictions(ground_truths, predictions, params['class_names'],
- params['detection_overlap'], params['ignore_start_end'])
+ res_det = evl.evaluate_predictions(
+ ground_truths,
+ predictions,
+ params["class_names"],
+ params["detection_overlap"],
+ params["ignore_start_end"],
+ )
- print('\nTest loss : {:.4f}'.format(test_loss.avg))
- print('Rec at 0.95 (det) : {:.4f}'.format(res_det['rec_at_x']))
- print('Avg prec (cls) : {:.4f}'.format(res_det['avg_prec']))
- print('File acc (cls) : {:.2f} - for {} out of {}'.format(res_det['file_acc'],
- res_det['num_valid_files'], res_det['num_total_files']))
- print('Cls Avg prec (cls) : {:.4f}'.format(res_det['avg_prec_class']))
+ print("\nTest loss : {:.4f}".format(test_loss.avg))
+ print("Rec at 0.95 (det) : {:.4f}".format(res_det["rec_at_x"]))
+ print("Avg prec (cls) : {:.4f}".format(res_det["avg_prec"]))
+ print(
+ "File acc (cls) : {:.2f} - for {} out of {}".format(
+ res_det["file_acc"],
+ res_det["num_valid_files"],
+ res_det["num_total_files"],
+ )
+ )
+ print("Cls Avg prec (cls) : {:.4f}".format(res_det["avg_prec_class"]))
- print('\nPer class average precision')
- str_len = np.max([len(rs['name']) for rs in res_det['class_pr']]) + 5
- for cc, rs in enumerate(res_det['class_pr']):
- if rs['num_gt'] > 0:
- print(str(cc).ljust(5) + rs['name'].ljust(str_len) + '{:.4f}'.format(rs['avg_prec']))
+ print("\nPer class average precision")
+ str_len = np.max([len(rs["name"]) for rs in res_det["class_pr"]]) + 5
+ for cc, rs in enumerate(res_det["class_pr"]):
+ if rs["num_gt"] > 0:
+ print(
+ str(cc).ljust(5)
+ + rs["name"].ljust(str_len)
+ + "{:.4f}".format(rs["avg_prec"])
+ )
res = {}
- res['test_loss'] = float(test_loss.avg)
+ res["test_loss"] = float(test_loss.avg)
return res_det, res
@@ -181,176 +272,288 @@ def test(model, epoch, data_loader, det_criterion, params):
def parse_gt_data(inputs):
# reads the torch arrays into a dictionary of numpy arrays, taking care to
# remove padding data i.e. not valid ones
- keys = ['start_times', 'end_times', 'low_freqs', 'high_freqs', 'class_ids', 'individual_ids']
+ keys = [
+ "start_times",
+ "end_times",
+ "low_freqs",
+ "high_freqs",
+ "class_ids",
+ "individual_ids",
+ ]
batch_data = []
- for ind in range(inputs['start_times'].shape[0]):
- is_valid = inputs['is_valid'][ind]==1
+ for ind in range(inputs["start_times"].shape[0]):
+ is_valid = inputs["is_valid"][ind] == 1
gt = {}
for kk in keys:
gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32)
- gt['duration'] = inputs['duration'][ind].item()
- gt['file_id'] = inputs['file_id'][ind].item()
- gt['class_id_file'] = inputs['class_id_file'][ind].item()
+ gt["duration"] = inputs["duration"][ind].item()
+ gt["file_id"] = inputs["file_id"][ind].item()
+ gt["class_id_file"] = inputs["class_id_file"][ind].item()
batch_data.append(gt)
return batch_data
def select_model(params):
- num_classes = len(params['class_names'])
- if params['model_name'] == 'Net2DFast':
- model = models.Net2DFast(params['num_filters'], num_classes=num_classes,
- emb_dim=params['emb_dim'], ip_height=params['ip_height'],
- resize_factor=params['resize_factor'])
- elif params['model_name'] == 'Net2DFastNoAttn':
- model = models.Net2DFastNoAttn(params['num_filters'], num_classes=num_classes,
- emb_dim=params['emb_dim'], ip_height=params['ip_height'],
- resize_factor=params['resize_factor'])
- elif params['model_name'] == 'Net2DFastNoCoordConv':
- model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=num_classes,
- emb_dim=params['emb_dim'], ip_height=params['ip_height'],
- resize_factor=params['resize_factor'])
+ num_classes = len(params["class_names"])
+ if params["model_name"] == "Net2DFast":
+ model = models.Net2DFast(
+ params["num_filters"],
+ num_classes=num_classes,
+ emb_dim=params["emb_dim"],
+ ip_height=params["ip_height"],
+ resize_factor=params["resize_factor"],
+ )
+ elif params["model_name"] == "Net2DFastNoAttn":
+ model = models.Net2DFastNoAttn(
+ params["num_filters"],
+ num_classes=num_classes,
+ emb_dim=params["emb_dim"],
+ ip_height=params["ip_height"],
+ resize_factor=params["resize_factor"],
+ )
+ elif params["model_name"] == "Net2DFastNoCoordConv":
+ model = models.Net2DFastNoCoordConv(
+ params["num_filters"],
+ num_classes=num_classes,
+ emb_dim=params["emb_dim"],
+ ip_height=params["ip_height"],
+ resize_factor=params["resize_factor"],
+ )
else:
- print('No valid network specified')
+ print("No valid network specified")
return model
if __name__ == "__main__":
- plt.close('all')
+ plt.close("all")
params = parameters.get_params(True)
if torch.cuda.is_available():
- params['device'] = 'cuda'
+ params["device"] = "cuda"
else:
- params['device'] = 'cpu'
+ params["device"] = "cpu"
# setup arg parser and populate it with exiting parameters - will not work with lists
parser = argparse.ArgumentParser()
- parser.add_argument('data_dir', type=str,
- help='Path to root of datasets')
- parser.add_argument('ann_dir', type=str,
- help='Path to extracted annotations')
- parser.add_argument('--train_split', type=str, default='diff', # diff, same
- help='Which train split to use')
- parser.add_argument('--notes', type=str, default='',
- help='Notes to save in text file')
- parser.add_argument('--do_not_save_images', action='store_false',
- help='Do not save images at the end of training')
- parser.add_argument('--standardize_classs_names_ip', type=str,
- default='Rhinolophus ferrumequinum;Rhinolophus hipposideros',
- help='Will set low and high frequency the same for these classes. Separate names with ";"')
+ parser.add_argument("data_dir", type=str, help="Path to root of datasets")
+ parser.add_argument(
+ "ann_dir", type=str, help="Path to extracted annotations"
+ )
+ parser.add_argument(
+ "--train_split",
+ type=str,
+ default="diff", # diff, same
+ help="Which train split to use",
+ )
+ parser.add_argument(
+ "--notes", type=str, default="", help="Notes to save in text file"
+ )
+ parser.add_argument(
+ "--do_not_save_images",
+ action="store_false",
+ help="Do not save images at the end of training",
+ )
+ parser.add_argument(
+ "--standardize_classs_names_ip",
+ type=str,
+ default="Rhinolophus ferrumequinum;Rhinolophus hipposideros",
+ help='Will set low and high frequency the same for these classes. Separate names with ";"',
+ )
for key, val in params.items():
- parser.add_argument('--'+key, type=type(val), default=val)
+ parser.add_argument("--" + key, type=type(val), default=val)
params = vars(parser.parse_args())
# save notes file
- if params['notes'] != '':
- tu.write_notes_file(params['experiment'] + 'notes.txt', params['notes'])
+ if params["notes"] != "":
+ tu.write_notes_file(
+ params["experiment"] + "notes.txt", params["notes"]
+ )
# load the training and test meta data - there are different splits defined
- train_sets, test_sets = ts.get_train_test_data(params['ann_dir'], params['data_dir'], params['train_split'])
- train_sets_no_path, test_sets_no_path = ts.get_train_test_data('', '', params['train_split'])
+ train_sets, test_sets = ts.get_train_test_data(
+ params["ann_dir"], params["data_dir"], params["train_split"]
+ )
+ train_sets_no_path, test_sets_no_path = ts.get_train_test_data(
+ "", "", params["train_split"]
+ )
# keep track of what we have trained on
- params['train_sets'] = train_sets_no_path
- params['test_sets'] = test_sets_no_path
+ params["train_sets"] = train_sets_no_path
+ params["test_sets"] = test_sets_no_path
# load train annotations - merge them all together
- print('\nTraining on:')
+ print("\nTraining on:")
for tt in train_sets:
- print(tt['ann_path'])
- classes_to_ignore = params['classes_to_ignore']+params['generic_class']
- data_train, params['class_names'], params['class_inv_freq'] = \
- tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus'])
- params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names'])
- params['class_names_short'] = tu.get_short_class_names(params['class_names'])
+ print(tt["ann_path"])
+ classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
+ (
+ data_train,
+ params["class_names"],
+ params["class_inv_freq"],
+ ) = tu.load_set_of_anns(
+ train_sets,
+ classes_to_ignore,
+ params["events_of_interest"],
+ params["convert_to_genus"],
+ )
+ params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
+ params["class_names"]
+ )
+ params["class_names_short"] = tu.get_short_class_names(
+ params["class_names"]
+ )
# standardize the low and high frequency value for specified classes
- params['standardize_classs_names'] = params['standardize_classs_names_ip'].split(';')
- for cc in params['standardize_classs_names']:
- if cc in params['class_names']:
+ params["standardize_classs_names"] = params[
+ "standardize_classs_names_ip"
+ ].split(";")
+ for cc in params["standardize_classs_names"]:
+ if cc in params["class_names"]:
data_train = tu.standardize_low_freq(data_train, cc)
else:
- print(cc, 'not found')
+ print(cc, "not found")
# train loader
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
- train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'],
- shuffle=True, num_workers=params['num_workers'], pin_memory=True)
-
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=params["batch_size"],
+ shuffle=True,
+ num_workers=params["num_workers"],
+ pin_memory=True,
+ )
# test set
- print('\nTesting on:')
+ print("\nTesting on:")
for tt in test_sets:
- print(tt['ann_path'])
- data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus'])
+ print(tt["ann_path"])
+ data_test, _, _ = tu.load_set_of_anns(
+ test_sets,
+ classes_to_ignore,
+ params["events_of_interest"],
+ params["convert_to_genus"],
+ )
data_train = tu.remove_dupes(data_train, data_test)
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
# batch size of 1 because of variable file length
- test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
- shuffle=False, num_workers=params['num_workers'], pin_memory=True)
-
+ test_loader = torch.utils.data.DataLoader(
+ test_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=params["num_workers"],
+ pin_memory=True,
+ )
inputs_train = next(iter(train_loader))
# TODO remove params['ip_height'], this is just legacy
- params['ip_height'] = int(params['spec_height']*params['resize_factor'])
- print('\ntrain batch spec size :', inputs_train['spec'].shape)
- print('class target size :', inputs_train['y_2d_classes'].shape)
+ params["ip_height"] = int(params["spec_height"] * params["resize_factor"])
+ print("\ntrain batch spec size :", inputs_train["spec"].shape)
+ print("class target size :", inputs_train["y_2d_classes"].shape)
# select network
model = select_model(params)
- model = model.to(params['device'])
+ model = model.to(params["device"])
- optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
- #optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9)
- scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader))
- if params['train_loss'] == 'mse':
+ optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
+ # optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9)
+ scheduler = CosineAnnealingLR(
+ optimizer, params["num_epochs"] * len(train_loader)
+ )
+ if params["train_loss"] == "mse":
det_criterion = losses.mse_loss
- elif params['train_loss'] == 'focal':
+ elif params["train_loss"] == "focal":
det_criterion = losses.focal_loss
# save parameters to file
- with open(params['experiment'] + 'params.json', 'w') as da:
+ with open(params["experiment"] + "params.json", "w") as da:
json.dump(params, da, indent=2, sort_keys=True)
# plotting
- train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1,
- ['train_loss'], None, None, ['epoch', 'train_loss'], logy=True)
- test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1,
- ['test_loss'], None, None, ['epoch', 'test_loss'], logy=True)
- test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1,
- ['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', ''])
- test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1,
- params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec'])
-
+ train_plt_ls = pu.LossPlotter(
+ params["experiment"] + "train_loss.png",
+ params["num_epochs"] + 1,
+ ["train_loss"],
+ None,
+ None,
+ ["epoch", "train_loss"],
+ logy=True,
+ )
+ test_plt_ls = pu.LossPlotter(
+ params["experiment"] + "test_loss.png",
+ params["num_epochs"] + 1,
+ ["test_loss"],
+ None,
+ None,
+ ["epoch", "test_loss"],
+ logy=True,
+ )
+ test_plt = pu.LossPlotter(
+ params["experiment"] + "test.png",
+ params["num_epochs"] + 1,
+ ["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
+ [0, 1],
+ None,
+ ["epoch", ""],
+ )
+ test_plt_class = pu.LossPlotter(
+ params["experiment"] + "test_avg_prec.png",
+ params["num_epochs"] + 1,
+ params["class_names_short"],
+ [0, 1],
+ params["class_names_short"],
+ ["epoch", "avg_prec"],
+ )
#
# main train loop
- for epoch in range(0, params['num_epochs']+1):
+ for epoch in range(0, params["num_epochs"] + 1):
- train_loss = train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params)
- train_plt_ls.update_and_save(epoch, [train_loss['train_loss']])
+ train_loss = train(
+ model,
+ epoch,
+ train_loader,
+ det_criterion,
+ optimizer,
+ scheduler,
+ params,
+ )
+ train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
- if epoch % params['num_eval_epochs'] == 0:
+ if epoch % params["num_eval_epochs"] == 0:
# detection accuracy on test set
- test_res, test_loss = test(model, epoch, test_loader, det_criterion, params)
- test_plt_ls.update_and_save(epoch, [test_loss['test_loss']])
- test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'],
- test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']])
- test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']])
- pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res)
-
+ test_res, test_loss = test(
+ model, epoch, test_loader, det_criterion, params
+ )
+ test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
+ test_plt.update_and_save(
+ epoch,
+ [
+ test_res["avg_prec"],
+ test_res["rec_at_x"],
+ test_res["avg_prec_class"],
+ test_res["file_acc"],
+ test_res["top_class"]["avg_prec"],
+ ],
+ )
+ test_plt_class.update_and_save(
+ epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
+ )
+ pu.plot_pr_curve_class(
+ params["experiment"], "test_pr", "test_pr", test_res
+ )
# save trained model
- print('saving model to: ' + params['model_file_name'])
- op_state = {'epoch': epoch + 1,
- 'state_dict': model.state_dict(),
- #'optimizer' : optimizer.state_dict(),
- 'params' : params}
- torch.save(op_state, params['model_file_name'])
-
+ print("saving model to: " + params["model_file_name"])
+ op_state = {
+ "epoch": epoch + 1,
+ "state_dict": model.state_dict(),
+ #'optimizer' : optimizer.state_dict(),
+ "params": params,
+ }
+ torch.save(op_state, params["model_file_name"])
# save an image with associated prediction for each batch in the test set
- if not args['do_not_save_images']:
+ if not args["do_not_save_images"]:
save_images_batch(model, test_loader, params)
diff --git a/bat_detect/train/train_split.py b/bat_detect/train/train_split.py
index 20972bd..01b5c03 100644
--- a/bat_detect/train/train_split.py
+++ b/bat_detect/train/train_split.py
@@ -2,13 +2,14 @@
Run scripts/extract_anns.py to generate these json files.
"""
+
def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
- if split_name == 'diff':
+ if split_name == "diff":
train_sets, test_sets = split_diff(ann_dir, wav_dir, load_extra)
- elif split_name == 'same':
+ elif split_name == "same":
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
else:
- print('Split not defined')
+ print("Split not defined")
assert False
return train_sets, test_sets
@@ -18,73 +19,126 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
train_sets = []
if load_extra:
- train_sets.append({'dataset_name': 'BatDetective',
- 'is_test': False,
- 'is_binary': True, # just a bat / not bat dataset ie no classes
- 'ann_path': ann_dir + 'train_set_bulgaria_batdetective_with_bbs.json',
- 'wav_path': wav_dir + 'bat_detective/audio/'})
- train_sets.append({'dataset_name': 'bat_logger_qeop_empty',
- 'is_test': False,
- 'is_binary': True,
- 'ann_path': ann_dir + 'bat_logger_qeop_empty.json',
- 'wav_path': wav_dir + 'bat_logger_qeop_empty/audio/'})
- train_sets.append({'dataset_name': 'bat_logger_2016_empty',
- 'is_test': False,
- 'is_binary': True,
- 'ann_path': ann_dir + 'train_set_bat_logger_2016_empty.json',
- 'wav_path': wav_dir + 'bat_logger_2016/audio/'})
+ train_sets.append(
+ {
+ "dataset_name": "BatDetective",
+ "is_test": False,
+ "is_binary": True, # just a bat / not bat dataset ie no classes
+ "ann_path": ann_dir
+ + "train_set_bulgaria_batdetective_with_bbs.json",
+ "wav_path": wav_dir + "bat_detective/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "bat_logger_qeop_empty",
+ "is_test": False,
+ "is_binary": True,
+ "ann_path": ann_dir + "bat_logger_qeop_empty.json",
+ "wav_path": wav_dir + "bat_logger_qeop_empty/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "bat_logger_2016_empty",
+ "is_test": False,
+ "is_binary": True,
+ "ann_path": ann_dir + "train_set_bat_logger_2016_empty.json",
+ "wav_path": wav_dir + "bat_logger_2016/audio/",
+ }
+ )
# train_sets.append({'dataset_name': 'brazil_data_binary',
# 'is_test': False,
# 'ann_path': ann_dir + 'brazil_data_binary.json',
# 'wav_path': wav_dir + 'brazil_data/audio/'})
- train_sets.append({'dataset_name': 'echobank',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'Echobank_train_expert.json',
- 'wav_path': wav_dir + 'echobank/audio/'})
- train_sets.append({'dataset_name': 'sn_scot_nor',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'sn_scot_nor_0.5_expert.json',
- 'wav_path': wav_dir + 'sn_scot_nor/audio/'})
- train_sets.append({'dataset_name': 'BCT_1_sec',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BCT_1_sec_train_expert.json',
- 'wav_path': wav_dir + 'BCT_1_sec/audio/'})
- train_sets.append({'dataset_name': 'bcireland',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'bcireland_expert.json',
- 'wav_path': wav_dir + 'bcireland/audio/'})
- train_sets.append({'dataset_name': 'rhinolophus_steve_BCT',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert.json',
- 'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'})
+ train_sets.append(
+ {
+ "dataset_name": "echobank",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir + "Echobank_train_expert.json",
+ "wav_path": wav_dir + "echobank/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "sn_scot_nor",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir + "sn_scot_nor_0.5_expert.json",
+ "wav_path": wav_dir + "sn_scot_nor/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "BCT_1_sec",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir + "BCT_1_sec_train_expert.json",
+ "wav_path": wav_dir + "BCT_1_sec/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "bcireland",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir + "bcireland_expert.json",
+ "wav_path": wav_dir + "bcireland/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "rhinolophus_steve_BCT",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir + "rhinolophus_steve_BCT_expert.json",
+ "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/",
+ }
+ )
test_sets = []
- test_sets.append({'dataset_name': 'bat_data_martyn_2018',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'})
- test_sets.append({'dataset_name': 'bat_data_martyn_2018_test',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'})
- test_sets.append({'dataset_name': 'bat_data_martyn_2019',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'})
- test_sets.append({'dataset_name': 'bat_data_martyn_2019_test',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'})
+ test_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2018",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2018_1_sec_train_expert.json",
+ "wav_path": wav_dir + "bat_data_martyn_2018/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2018_test",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2018_1_sec_test_expert.json",
+ "wav_path": wav_dir + "bat_data_martyn_2018_test/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2019",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2019_1_sec_train_expert.json",
+ "wav_path": wav_dir + "bat_data_martyn_2019/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2019_test",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2019_1_sec_test_expert.json",
+ "wav_path": wav_dir + "bat_data_martyn_2019_test/audio/",
+ }
+ )
return train_sets, test_sets
@@ -93,71 +147,124 @@ def split_same(ann_dir, wav_dir, load_extra=True):
train_sets = []
if load_extra:
- train_sets.append({'dataset_name': 'BatDetective',
- 'is_test': False,
- 'is_binary': True,
- 'ann_path': ann_dir + 'train_set_bulgaria_batdetective_with_bbs.json',
- 'wav_path': wav_dir + 'bat_detective/audio/'})
- train_sets.append({'dataset_name': 'bat_logger_qeop_empty',
- 'is_test': False,
- 'is_binary': True,
- 'ann_path': ann_dir + 'bat_logger_qeop_empty.json',
- 'wav_path': wav_dir + 'bat_logger_qeop_empty/audio/'})
- train_sets.append({'dataset_name': 'bat_logger_2016_empty',
- 'is_test': False,
- 'is_binary': True,
- 'ann_path': ann_dir + 'train_set_bat_logger_2016_empty.json',
- 'wav_path': wav_dir + 'bat_logger_2016/audio/'})
+ train_sets.append(
+ {
+ "dataset_name": "BatDetective",
+ "is_test": False,
+ "is_binary": True,
+ "ann_path": ann_dir
+ + "train_set_bulgaria_batdetective_with_bbs.json",
+ "wav_path": wav_dir + "bat_detective/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "bat_logger_qeop_empty",
+ "is_test": False,
+ "is_binary": True,
+ "ann_path": ann_dir + "bat_logger_qeop_empty.json",
+ "wav_path": wav_dir + "bat_logger_qeop_empty/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "bat_logger_2016_empty",
+ "is_test": False,
+ "is_binary": True,
+ "ann_path": ann_dir + "train_set_bat_logger_2016_empty.json",
+ "wav_path": wav_dir + "bat_logger_2016/audio/",
+ }
+ )
# train_sets.append({'dataset_name': 'brazil_data_binary',
# 'is_test': False,
# 'ann_path': ann_dir + 'brazil_data_binary.json',
# 'wav_path': wav_dir + 'brazil_data/audio/'})
- train_sets.append({'dataset_name': 'echobank',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'Echobank_train_expert_TRAIN.json',
- 'wav_path': wav_dir + 'echobank/audio/'})
- train_sets.append({'dataset_name': 'sn_scot_nor',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'sn_scot_nor_0.5_expert_TRAIN.json',
- 'wav_path': wav_dir + 'sn_scot_nor/audio/'})
- train_sets.append({'dataset_name': 'BCT_1_sec',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BCT_1_sec_train_expert_TRAIN.json',
- 'wav_path': wav_dir + 'BCT_1_sec/audio/'})
- train_sets.append({'dataset_name': 'bcireland',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'bcireland_expert_TRAIN.json',
- 'wav_path': wav_dir + 'bcireland/audio/'})
- train_sets.append({'dataset_name': 'rhinolophus_steve_BCT',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert_TRAIN.json',
- 'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'})
- train_sets.append({'dataset_name': 'bat_data_martyn_2018',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TRAIN.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'})
- train_sets.append({'dataset_name': 'bat_data_martyn_2018_test',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TRAIN.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'})
- train_sets.append({'dataset_name': 'bat_data_martyn_2019',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TRAIN.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'})
- train_sets.append({'dataset_name': 'bat_data_martyn_2019_test',
- 'is_test': False,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TRAIN.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'})
+ train_sets.append(
+ {
+ "dataset_name": "echobank",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir + "Echobank_train_expert_TRAIN.json",
+ "wav_path": wav_dir + "echobank/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "sn_scot_nor",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir + "sn_scot_nor_0.5_expert_TRAIN.json",
+ "wav_path": wav_dir + "sn_scot_nor/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "BCT_1_sec",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir + "BCT_1_sec_train_expert_TRAIN.json",
+ "wav_path": wav_dir + "BCT_1_sec/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "bcireland",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir + "bcireland_expert_TRAIN.json",
+ "wav_path": wav_dir + "bcireland/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "rhinolophus_steve_BCT",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir + "rhinolophus_steve_BCT_expert_TRAIN.json",
+ "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2018",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TRAIN.json",
+ "wav_path": wav_dir + "bat_data_martyn_2018/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2018_test",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TRAIN.json",
+ "wav_path": wav_dir + "bat_data_martyn_2018_test/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2019",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TRAIN.json",
+ "wav_path": wav_dir + "bat_data_martyn_2019/audio/",
+ }
+ )
+ train_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2019_test",
+ "is_test": False,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TRAIN.json",
+ "wav_path": wav_dir + "bat_data_martyn_2019_test/audio/",
+ }
+ )
# train_sets.append({'dataset_name': 'bat_data_martyn_2021_train',
# 'is_test': False,
@@ -171,51 +278,91 @@ def split_same(ann_dir, wav_dir, load_extra=True):
# 'wav_path': wav_dir + 'volunteers_2021/audio/'})
test_sets = []
- test_sets.append({'dataset_name': 'echobank',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'Echobank_train_expert_TEST.json',
- 'wav_path': wav_dir + 'echobank/audio/'})
- test_sets.append({'dataset_name': 'sn_scot_nor',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'sn_scot_nor_0.5_expert_TEST.json',
- 'wav_path': wav_dir + 'sn_scot_nor/audio/'})
- test_sets.append({'dataset_name': 'BCT_1_sec',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BCT_1_sec_train_expert_TEST.json',
- 'wav_path': wav_dir + 'BCT_1_sec/audio/'})
- test_sets.append({'dataset_name': 'bcireland',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'bcireland_expert_TEST.json',
- 'wav_path': wav_dir + 'bcireland/audio/'})
- test_sets.append({'dataset_name': 'rhinolophus_steve_BCT',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert_TEST.json',
- 'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'})
- test_sets.append({'dataset_name': 'bat_data_martyn_2018',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TEST.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'})
- test_sets.append({'dataset_name': 'bat_data_martyn_2018_test',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TEST.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'})
- test_sets.append({'dataset_name': 'bat_data_martyn_2019',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TEST.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'})
- test_sets.append({'dataset_name': 'bat_data_martyn_2019_test',
- 'is_test': True,
- 'is_binary': False,
- 'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TEST.json',
- 'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'})
+ test_sets.append(
+ {
+ "dataset_name": "echobank",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir + "Echobank_train_expert_TEST.json",
+ "wav_path": wav_dir + "echobank/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "sn_scot_nor",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir + "sn_scot_nor_0.5_expert_TEST.json",
+ "wav_path": wav_dir + "sn_scot_nor/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "BCT_1_sec",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir + "BCT_1_sec_train_expert_TEST.json",
+ "wav_path": wav_dir + "BCT_1_sec/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "bcireland",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir + "bcireland_expert_TEST.json",
+ "wav_path": wav_dir + "bcireland/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "rhinolophus_steve_BCT",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir + "rhinolophus_steve_BCT_expert_TEST.json",
+ "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2018",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TEST.json",
+ "wav_path": wav_dir + "bat_data_martyn_2018/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2018_test",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TEST.json",
+ "wav_path": wav_dir + "bat_data_martyn_2018_test/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2019",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TEST.json",
+ "wav_path": wav_dir + "bat_data_martyn_2019/audio/",
+ }
+ )
+ test_sets.append(
+ {
+ "dataset_name": "bat_data_martyn_2019_test",
+ "is_test": True,
+ "is_binary": False,
+ "ann_path": ann_dir
+ + "BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TEST.json",
+ "wav_path": wav_dir + "bat_data_martyn_2019_test/audio/",
+ }
+ )
# test_sets.append({'dataset_name': 'bat_data_martyn_2021_test',
# 'is_test': True,
diff --git a/bat_detect/train/train_utils.py b/bat_detect/train/train_utils.py
index cff92e4..62441a7 100644
--- a/bat_detect/train/train_utils.py
+++ b/bat_detect/train/train_utils.py
@@ -1,42 +1,52 @@
-import numpy as np
-import random
-import os
import glob
import json
+import os
+import random
+
+import numpy as np
def write_notes_file(file_name, text):
- with open(file_name, 'a') as da:
- da.write(text + '\n')
+ with open(file_name, "a") as da:
+ da.write(text + "\n")
def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path):
- ddict = {'dataset_name': dataset_name, 'is_test': is_test, 'is_binary': False,
- 'ann_path': ann_path, 'wav_path': wav_path}
+ ddict = {
+ "dataset_name": dataset_name,
+ "is_test": is_test,
+ "is_binary": False,
+ "ann_path": ann_path,
+ "wav_path": wav_path,
+ }
return ddict
def get_short_class_names(class_names, str_len=3):
class_names_short = []
for cc in class_names:
- class_names_short.append(' '.join([sp[:str_len] for sp in cc.split(' ')]))
+ class_names_short.append(
+ " ".join([sp[:str_len] for sp in cc.split(" ")])
+ )
return class_names_short
def remove_dupes(data_train, data_test):
- test_ids = [dd['id'] for dd in data_test]
+ test_ids = [dd["id"] for dd in data_test]
data_train_prune = []
for aa in data_train:
- if aa['id'] not in test_ids:
+ if aa["id"] not in test_ids:
data_train_prune.append(aa)
diff = len(data_train) - len(data_train_prune)
if diff != 0:
- print(diff, 'items removed from train set')
+ print(diff, "items removed from train set")
return data_train_prune
def get_genus_mapping(class_names):
- genus_names, genus_mapping = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True)
+ genus_names, genus_mapping = np.unique(
+ [cc.split(" ")[0] for cc in class_names], return_inverse=True
+ )
return genus_names.tolist(), genus_mapping.tolist()
@@ -47,97 +57,110 @@ def standardize_low_freq(data, class_of_interest):
low_freqs = []
high_freqs = []
for dd in data:
- for aa in dd['annotation']:
- if aa['class'] == class_of_interest:
- low_freqs.append(aa['low_freq'])
- high_freqs.append(aa['high_freq'])
+ for aa in dd["annotation"]:
+ if aa["class"] == class_of_interest:
+ low_freqs.append(aa["low_freq"])
+ high_freqs.append(aa["high_freq"])
low_mean = np.mean(low_freqs)
high_mean = np.mean(high_freqs)
- assert(low_mean < high_mean)
+ assert low_mean < high_mean
- print('\nStandardizing low and high frequency for:')
+ print("\nStandardizing low and high frequency for:")
print(class_of_interest)
- print('low: ', round(low_mean, 2))
- print('high: ', round(high_mean, 2))
+ print("low: ", round(low_mean, 2))
+ print("high: ", round(high_mean, 2))
# only set the low freq, high stays the same
# assumes that low_mean < high_mean
for dd in data:
- for aa in dd['annotation']:
- if aa['class'] == class_of_interest:
- aa['low_freq'] = low_mean
- if aa['high_freq'] < low_mean:
- aa['high_freq'] = high_mean
+ for aa in dd["annotation"]:
+ if aa["class"] == class_of_interest:
+ aa["low_freq"] = low_mean
+ if aa["high_freq"] < low_mean:
+ aa["high_freq"] = high_mean
return data
-def load_set_of_anns(data, classes_to_ignore=[], events_of_interest=None,
- convert_to_genus=False, verbose=True, list_of_anns=False,
- filter_issues=False, name_replace=False):
+def load_set_of_anns(
+ data,
+ classes_to_ignore=[],
+ events_of_interest=None,
+ convert_to_genus=False,
+ verbose=True,
+ list_of_anns=False,
+ filter_issues=False,
+ name_replace=False,
+):
# load the annotations
anns = []
if list_of_anns:
# path to list of individual json files
- anns.extend(load_anns_from_path(data['ann_path'], data['wav_path']))
+ anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"]))
else:
# dictionary of datasets
for dd in data:
- anns.extend(load_anns(dd['ann_path'], dd['wav_path']))
+ anns.extend(load_anns(dd["ann_path"], dd["wav_path"]))
# discarding unannoated files
- anns = [aa for aa in anns if aa['annotated'] is True]
+ anns = [aa for aa in anns if aa["annotated"] is True]
# filter files that have annotation issues - is the input is a dictionary of
# datasets, this will lilely have already been done
if filter_issues:
- anns = [aa for aa in anns if aa['issues'] is False]
+ anns = [aa for aa in anns if aa["issues"] is False]
# check for some basic formatting errors with class names
for ann in anns:
- for aa in ann['annotation']:
- aa['class'] = aa['class'].strip()
+ for aa in ann["annotation"]:
+ aa["class"] = aa["class"].strip()
# only load specified events - i.e. types of calls
if events_of_interest is not None:
for ann in anns:
filtered_events = []
- for aa in ann['annotation']:
- if aa['event'] in events_of_interest:
+ for aa in ann["annotation"]:
+ if aa["event"] in events_of_interest:
filtered_events.append(aa)
- ann['annotation'] = filtered_events
+ ann["annotation"] = filtered_events
# change class names
# replace_names will be a dictionary mapping input name to output
if type(name_replace) is dict:
for ann in anns:
- for aa in ann['annotation']:
- if aa['class'] in name_replace:
- aa['class'] = name_replace[aa['class']]
+ for aa in ann["annotation"]:
+ if aa["class"] in name_replace:
+ aa["class"] = name_replace[aa["class"]]
# convert everything to genus name
if convert_to_genus:
for ann in anns:
- for aa in ann['annotation']:
- aa['class'] = aa['class'].split(' ')[0]
+ for aa in ann["annotation"]:
+ aa["class"] = aa["class"].split(" ")[0]
# get unique class names
class_names_all = []
for ann in anns:
- for aa in ann['annotation']:
- if aa['class'] not in classes_to_ignore:
- class_names_all.append(aa['class'])
+ for aa in ann["annotation"]:
+ if aa["class"] not in classes_to_ignore:
+ class_names_all.append(aa["class"])
class_names, class_cnts = np.unique(class_names_all, return_counts=True)
- class_inv_freq = (class_cnts.sum() / (len(class_names) * class_cnts.astype(np.float32)))
+ class_inv_freq = class_cnts.sum() / (
+ len(class_names) * class_cnts.astype(np.float32)
+ )
if verbose:
- print('Class count:')
+ print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5
for cc in range(len(class_names)):
- print(str(cc).ljust(5) + class_names[cc].ljust(str_len) + str(class_cnts[cc]))
+ print(
+ str(cc).ljust(5)
+ + class_names[cc].ljust(str_len)
+ + str(class_cnts[cc])
+ )
if len(classes_to_ignore) == 0:
return anns
@@ -150,36 +173,37 @@ def load_anns(ann_file_name, raw_audio_dir):
anns = json.load(da)
for aa in anns:
- aa['file_path'] = raw_audio_dir + aa['id']
+ aa["file_path"] = raw_audio_dir + aa["id"]
return anns
def load_anns_from_path(ann_file_dir, raw_audio_dir):
- files = glob.glob(ann_file_dir + '*.json')
+ files = glob.glob(ann_file_dir + "*.json")
anns = []
for ff in files:
with open(ff) as da:
ann = json.load(da)
- ann['file_path'] = raw_audio_dir + ann['id']
+ ann["file_path"] = raw_audio_dir + ann["id"]
anns.append(ann)
return anns
class AverageMeter(object):
- """Computes and stores the average and current value"""
- def __init__(self):
- self.reset()
+ """Computes and stores the average and current value"""
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
+ def __init__(self):
+ self.reset()
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
diff --git a/bat_detect/utils/audio_utils.py b/bat_detect/utils/audio_utils.py
index 4a18d74..3ad648b 100644
--- a/bat_detect/utils/audio_utils.py
+++ b/bat_detect/utils/audio_utils.py
@@ -1,89 +1,142 @@
-import numpy as np
-from . import wavfile
import warnings
-import torch
+
import librosa
+import numpy as np
+import torch
+
+from . import wavfile
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
- nfft = np.floor(fft_win_length*sampling_rate) # int() uses floor
- noverlap = np.floor(fft_overlap*nfft)
- return (time_in_file*sampling_rate-noverlap) / (nfft - noverlap)
+ nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
+ noverlap = np.floor(fft_overlap * nfft)
+ return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
# NOTE this is also defined in post_process
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
- nfft = np.floor(fft_win_length*sampling_rate)
- noverlap = np.floor(fft_overlap*nfft)
- return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate
- #return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
+ nfft = np.floor(fft_win_length * sampling_rate)
+ noverlap = np.floor(fft_overlap * nfft)
+ return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
+ # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
-def generate_spectrogram(audio, sampling_rate, params, return_spec_for_viz=False, check_spec_size=True):
+def generate_spectrogram(
+ audio,
+ sampling_rate,
+ params,
+ return_spec_for_viz=False,
+ check_spec_size=True,
+):
# generate spectrogram
- spec = gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap'])
+ spec = gen_mag_spectrogram(
+ audio, sampling_rate, params["fft_win_length"], params["fft_overlap"]
+ )
# crop to min/max freq
- max_freq = round(params['max_freq']*params['fft_win_length'])
- min_freq = round(params['min_freq']*params['fft_win_length'])
+ max_freq = round(params["max_freq"] * params["fft_win_length"])
+ min_freq = round(params["min_freq"] * params["fft_win_length"])
if spec.shape[0] < max_freq:
freq_pad = max_freq - spec.shape[0]
- spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec))
- spec_cropped = spec[-max_freq:spec.shape[0]-min_freq, :]
+ spec = np.vstack(
+ (np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
+ )
+ spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :]
- if params['spec_scale'] == 'log':
- log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum())
- #log_scaling = (1.0 / sampling_rate)*0.1
- #log_scaling = (1.0 / sampling_rate)*10e4
- spec = np.log1p(log_scaling*spec_cropped)
- elif params['spec_scale'] == 'pcen':
+ if params["spec_scale"] == "log":
+ log_scaling = (
+ 2.0
+ * (1.0 / sampling_rate)
+ * (
+ 1.0
+ / (
+ np.abs(
+ np.hanning(
+ int(params["fft_win_length"] * sampling_rate)
+ )
+ )
+ ** 2
+ ).sum()
+ )
+ )
+ # log_scaling = (1.0 / sampling_rate)*0.1
+ # log_scaling = (1.0 / sampling_rate)*10e4
+ spec = np.log1p(log_scaling * spec_cropped)
+ elif params["spec_scale"] == "pcen":
spec = pcen(spec_cropped, sampling_rate)
- elif params['spec_scale'] == 'none':
+ elif params["spec_scale"] == "none":
pass
- if params['denoise_spec_avg']:
+ if params["denoise_spec_avg"]:
spec = spec - np.mean(spec, 1)[:, np.newaxis]
spec.clip(min=0, out=spec)
- if params['max_scale_spec']:
+ if params["max_scale_spec"]:
spec = spec / (spec.max() + 10e-6)
# needs to be divisible by specific factor - if not it should have been padded
- #if check_spec_size:
- #assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
- #assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
+ # if check_spec_size:
+ # assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
+ # assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
# for visualization purposes - use log scaled spectrogram
if return_spec_for_viz:
- log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum())
- spec_for_viz = np.log1p(log_scaling*spec_cropped).astype(np.float32)
+ log_scaling = (
+ 2.0
+ * (1.0 / sampling_rate)
+ * (
+ 1.0
+ / (
+ np.abs(
+ np.hanning(
+ int(params["fft_win_length"] * sampling_rate)
+ )
+ )
+ ** 2
+ ).sum()
+ )
+ )
+ spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32)
else:
spec_for_viz = None
return spec, spec_for_viz
-def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False, max_duration=False):
+def load_audio_file(
+ audio_file,
+ time_exp_fact,
+ target_samp_rate,
+ scale=False,
+ max_duration=False,
+):
with warnings.catch_warnings():
- warnings.filterwarnings('ignore', category=wavfile.WavFileWarning)
- #sampling_rate, audio_raw = wavfile.read(audio_file)
+ warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
+ # sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load(audio_file, sr=None)
if len(audio_raw.shape) > 1:
- raise Exception('Currently does not handle stereo files')
+ raise Exception("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion
sampling_rate_old = sampling_rate
sampling_rate = target_samp_rate
- audio_raw = librosa.resample(audio_raw, orig_sr=sampling_rate_old, target_sr=sampling_rate, res_type='polyphase')
+ audio_raw = librosa.resample(
+ audio_raw,
+ orig_sr=sampling_rate_old,
+ target_sr=sampling_rate,
+ res_type="polyphase",
+ )
# clipping maximum duration
if max_duration is not False:
- max_duration = np.minimum(int(sampling_rate*max_duration), audio_raw.shape[0])
+ max_duration = np.minimum(
+ int(sampling_rate * max_duration), audio_raw.shape[0]
+ )
audio_raw = audio_raw[:max_duration]
-
+
# convert to float32 and scale
audio_raw = audio_raw.astype(np.float32)
if scale:
@@ -93,38 +146,53 @@ def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False, ma
return sampling_rate, audio_raw
-def pad_audio(audio_raw, fs, ms, overlap_perc, resize_factor, divide_factor, fixed_width=None):
+def pad_audio(
+ audio_raw,
+ fs,
+ ms,
+ overlap_perc,
+ resize_factor,
+ divide_factor,
+ fixed_width=None,
+):
# Adds zeros to the end of the raw data so that the generated sepctrogram
# will be evenly divisible by `divide_factor`
# Also deals with very short audio clips and fixed_width during training
# This code could be clearer, clean up
- nfft = int(ms*fs)
- noverlap = int(overlap_perc*nfft)
+ nfft = int(ms * fs)
+ noverlap = int(overlap_perc * nfft)
step = nfft - noverlap
- min_size = int(divide_factor*(1.0/resize_factor))
- spec_width = ((audio_raw.shape[0]-noverlap)//step)
+ min_size = int(divide_factor * (1.0 / resize_factor))
+ spec_width = (audio_raw.shape[0] - noverlap) // step
spec_width_rs = spec_width * resize_factor
if fixed_width is not None and spec_width < fixed_width:
# too small
# used during training to ensure all the batches are the same size
- diff = fixed_width*step + noverlap - audio_raw.shape[0]
- audio_raw = np.hstack((audio_raw, np.zeros(diff, dtype=audio_raw.dtype)))
+ diff = fixed_width * step + noverlap - audio_raw.shape[0]
+ audio_raw = np.hstack(
+ (audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
+ )
elif fixed_width is not None and spec_width > fixed_width:
# too big
# used during training to ensure all the batches are the same size
- diff = fixed_width*step + noverlap - audio_raw.shape[0]
+ diff = fixed_width * step + noverlap - audio_raw.shape[0]
audio_raw = audio_raw[:diff]
- elif spec_width_rs < min_size or (np.floor(spec_width_rs) % divide_factor) != 0:
+ elif (
+ spec_width_rs < min_size
+ or (np.floor(spec_width_rs) % divide_factor) != 0
+ ):
# need to be at least min_size
div_amt = np.ceil(spec_width_rs / float(divide_factor))
div_amt = np.maximum(1, div_amt)
- target_size = int(div_amt*divide_factor*(1.0/resize_factor))
- diff = target_size*step + noverlap - audio_raw.shape[0]
- audio_raw = np.hstack((audio_raw, np.zeros(diff, dtype=audio_raw.dtype)))
+ target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
+ diff = target_size * step + noverlap - audio_raw.shape[0]
+ audio_raw = np.hstack(
+ (audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
+ )
return audio_raw
@@ -133,14 +201,16 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
# Computes magnitude spectrogram by specifying time.
x = x.astype(np.float32)
- nfft = int(ms*fs)
- noverlap = int(overlap_perc*nfft)
+ nfft = int(ms * fs)
+ noverlap = int(overlap_perc * nfft)
# window data
step = nfft - noverlap
# compute spec
- spec, _ = librosa.core.spectrum._spectrogram(y=x, power=1, n_fft=nfft, hop_length=step, center=False)
+ spec, _ = librosa.core.spectrum._spectrogram(
+ y=x, power=1, n_fft=nfft, hop_length=step, center=False
+ )
# remove DC component and flip vertical orientation
spec = np.flipud(spec[1:, :])
@@ -149,8 +219,8 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
- nfft = int(ms*fs)
- nstep = round((1.0-overlap_perc)*nfft)
+ nfft = int(ms * fs)
+ nstep = round((1.0 - overlap_perc) * nfft)
han_win = torch.hann_window(nfft, periodic=False).to(x.device)
@@ -158,12 +228,14 @@ def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
spec = complex_spec.pow(2.0).sum(-1)
# remove DC component and flip vertically
- spec = torch.flipud(spec[0, 1:,:])
+ spec = torch.flipud(spec[0, 1:, :])
return spec
def pcen(spec_cropped, sampling_rate):
# TODO should be passing hop_length too i.e. step
- spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate/10).astype(np.float32)
+ spec = librosa.pcen(
+ spec_cropped * (2**31), sr=sampling_rate / 10
+ ).astype(np.float32)
return spec
diff --git a/bat_detect/utils/detector_utils.py b/bat_detect/utils/detector_utils.py
index fef9828..7d2470f 100644
--- a/bat_detect/utils/detector_utils.py
+++ b/bat_detect/utils/detector_utils.py
@@ -1,39 +1,40 @@
-import torch
-import torch.nn.functional as F
-import os
-import numpy as np
-import pandas as pd
import json
+import os
import sys
-from bat_detect.detector import models
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn.functional as F
+
import bat_detect.detector.compute_features as feats
import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au
+from bat_detect.detector import models
def get_default_bd_args():
args = {}
- args['detection_threshold'] = 0.001
- args['time_expansion_factor'] = 1
- args['audio_dir'] = ''
- args['ann_dir'] = ''
- args['spec_slices'] = False
- args['chunk_size'] = 3
- args['spec_features'] = False
- args['cnn_features'] = False
- args['quiet'] = True
- args['save_preds_if_empty'] = True
- args['ann_dir'] = os.path.join(args['ann_dir'], '')
+ args["detection_threshold"] = 0.001
+ args["time_expansion_factor"] = 1
+ args["audio_dir"] = ""
+ args["ann_dir"] = ""
+ args["spec_slices"] = False
+ args["chunk_size"] = 3
+ args["spec_features"] = False
+ args["cnn_features"] = False
+ args["quiet"] = True
+ args["save_preds_if_empty"] = True
+ args["ann_dir"] = os.path.join(args["ann_dir"], "")
return args
-
-
+
+
def get_audio_files(ip_dir):
matches = []
for root, dirnames, filenames in os.walk(ip_dir):
for filename in filenames:
- if filename.lower().endswith('.wav'):
+ if filename.lower().endswith(".wav"):
matches.append(os.path.join(root, filename))
return matches
@@ -41,35 +42,47 @@ def get_audio_files(ip_dir):
def load_model(model_path, load_weights=True):
# load model
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if os.path.isfile(model_path):
net_params = torch.load(model_path, map_location=device)
else:
- print('Error: model not found.')
+ print("Error: model not found.")
sys.exit(1)
- params = net_params['params']
- params['device'] = device
+ params = net_params["params"]
+ params["device"] = device
- if params['model_name'] == 'Net2DFast':
- model = models.Net2DFast(params['num_filters'], num_classes=len(params['class_names']),
- emb_dim=params['emb_dim'], ip_height=params['ip_height'],
- resize_factor=params['resize_factor'])
- elif params['model_name'] == 'Net2DFastNoAttn':
- model = models.Net2DFastNoAttn(params['num_filters'], num_classes=len(params['class_names']),
- emb_dim=params['emb_dim'], ip_height=params['ip_height'],
- resize_factor=params['resize_factor'])
- elif params['model_name'] == 'Net2DFastNoCoordConv':
- model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=len(params['class_names']),
- emb_dim=params['emb_dim'], ip_height=params['ip_height'],
- resize_factor=params['resize_factor'])
+ if params["model_name"] == "Net2DFast":
+ model = models.Net2DFast(
+ params["num_filters"],
+ num_classes=len(params["class_names"]),
+ emb_dim=params["emb_dim"],
+ ip_height=params["ip_height"],
+ resize_factor=params["resize_factor"],
+ )
+ elif params["model_name"] == "Net2DFastNoAttn":
+ model = models.Net2DFastNoAttn(
+ params["num_filters"],
+ num_classes=len(params["class_names"]),
+ emb_dim=params["emb_dim"],
+ ip_height=params["ip_height"],
+ resize_factor=params["resize_factor"],
+ )
+ elif params["model_name"] == "Net2DFastNoCoordConv":
+ model = models.Net2DFastNoCoordConv(
+ params["num_filters"],
+ num_classes=len(params["class_names"]),
+ emb_dim=params["emb_dim"],
+ ip_height=params["ip_height"],
+ resize_factor=params["resize_factor"],
+ )
else:
- print('Error: unknown model.')
+ print("Error: unknown model.")
if load_weights:
- model.load_state_dict(net_params['state_dict'])
+ model.load_state_dict(net_params["state_dict"])
- model = model.to(params['device'])
+ model = model.to(params["device"])
model.eval()
return model, params
@@ -78,11 +91,13 @@ def load_model(model_path, load_weights=True):
def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m = {}
- num_preds = np.sum([len(pp['det_probs']) for pp in predictions])
+ num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
if num_preds > 0:
for kk in predictions[0].keys():
- predictions_m[kk] = np.hstack([pp[kk] for pp in predictions if pp['det_probs'].shape[0] > 0])
+ predictions_m[kk] = np.hstack(
+ [pp[kk] for pp in predictions if pp["det_probs"].shape[0] > 0]
+ )
else:
# hack in case where no detected calls as we need some of the key names in dict
predictions_m = predictions[0]
@@ -94,47 +109,60 @@ def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
return predictions_m, spec_feats, cnn_feats, spec_slices
-def convert_results(file_id, time_exp, duration, params, predictions, spec_feats, cnn_feats, spec_slices):
+def convert_results(
+ file_id,
+ time_exp,
+ duration,
+ params,
+ predictions,
+ spec_feats,
+ cnn_feats,
+ spec_slices,
+):
# create a single dictionary - this is the format used by the annotation tool
pred_dict = {}
- pred_dict['id'] = file_id
- pred_dict['annotated'] = False
- pred_dict['issues'] = False
- pred_dict['notes'] = 'Automatically generated.'
- pred_dict['time_exp'] = time_exp
- pred_dict['duration'] = round(duration, 4)
- pred_dict['annotation'] = []
+ pred_dict["id"] = file_id
+ pred_dict["annotated"] = False
+ pred_dict["issues"] = False
+ pred_dict["notes"] = "Automatically generated."
+ pred_dict["time_exp"] = time_exp
+ pred_dict["duration"] = round(duration, 4)
+ pred_dict["annotation"] = []
- class_prob_best = predictions['class_probs'].max(0)
- class_ind_best = predictions['class_probs'].argmax(0)
- class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs'])
- pred_dict['class_name'] = params['class_names'][np.argmax(class_overall)]
+ class_prob_best = predictions["class_probs"].max(0)
+ class_ind_best = predictions["class_probs"].argmax(0)
+ class_overall = pp.overall_class_pred(
+ predictions["det_probs"], predictions["class_probs"]
+ )
+ pred_dict["class_name"] = params["class_names"][np.argmax(class_overall)]
- for ii in range(predictions['det_probs'].shape[0]):
+ for ii in range(predictions["det_probs"].shape[0]):
res = {}
- res['start_time'] = round(float(predictions['start_times'][ii]), 4)
- res['end_time'] = round(float(predictions['end_times'][ii]), 4)
- res['low_freq'] = int(predictions['low_freqs'][ii])
- res['high_freq'] = int(predictions['high_freqs'][ii])
- res['class'] = str(params['class_names'][int(class_ind_best[ii])])
- res['class_prob'] = round(float(class_prob_best[ii]), 3)
- res['det_prob'] = round(float(predictions['det_probs'][ii]), 3)
- res['individual'] = '-1'
- res['event'] = 'Echolocation'
- pred_dict['annotation'].append(res)
+ res["start_time"] = round(float(predictions["start_times"][ii]), 4)
+ res["end_time"] = round(float(predictions["end_times"][ii]), 4)
+ res["low_freq"] = int(predictions["low_freqs"][ii])
+ res["high_freq"] = int(predictions["high_freqs"][ii])
+ res["class"] = str(params["class_names"][int(class_ind_best[ii])])
+ res["class_prob"] = round(float(class_prob_best[ii]), 3)
+ res["det_prob"] = round(float(predictions["det_probs"][ii]), 3)
+ res["individual"] = "-1"
+ res["event"] = "Echolocation"
+ pred_dict["annotation"].append(res)
# combine into final results dictionary
results = {}
- results['pred_dict'] = pred_dict
+ results["pred_dict"] = pred_dict
if len(spec_feats) > 0:
- results['spec_feats'] = spec_feats
- results['spec_feat_names'] = feats.get_feature_names()
+ results["spec_feats"] = spec_feats
+ results["spec_feat_names"] = feats.get_feature_names()
if len(cnn_feats) > 0:
- results['cnn_feats'] = cnn_feats
- results['cnn_feat_names'] = [str(ii) for ii in range(cnn_feats.shape[1])]
+ results["cnn_feats"] = cnn_feats
+ results["cnn_feat_names"] = [
+ str(ii) for ii in range(cnn_feats.shape[1])
+ ]
if len(spec_slices) > 0:
- results['spec_slices'] = spec_slices
+ results["spec_slices"] = spec_slices
return results
@@ -146,144 +174,214 @@ def save_results_to_file(results, op_path):
os.makedirs(os.path.dirname(op_path))
# save csv file - if there are predictions
- result_list = [res for res in results['pred_dict']['annotation']]
+ result_list = [res for res in results["pred_dict"]["annotation"]]
df = pd.DataFrame(result_list)
- df['file_name'] = [results['pred_dict']['id']]*len(result_list)
- df.index.name = 'id'
- if 'class_prob' in df.columns:
- df = df[['det_prob', 'start_time', 'end_time', 'high_freq',
- 'low_freq', 'class', 'class_prob']]
- df.to_csv(op_path + '.csv', sep=',')
+ df["file_name"] = [results["pred_dict"]["id"]] * len(result_list)
+ df.index.name = "id"
+ if "class_prob" in df.columns:
+ df = df[
+ [
+ "det_prob",
+ "start_time",
+ "end_time",
+ "high_freq",
+ "low_freq",
+ "class",
+ "class_prob",
+ ]
+ ]
+ df.to_csv(op_path + ".csv", sep=",")
# save features
- if 'spec_feats' in results.keys():
- df = pd.DataFrame(results['spec_feats'], columns=results['spec_feat_names'])
- df.to_csv(op_path + '_spec_features.csv', sep=',', index=False, float_format='%.5f')
+ if "spec_feats" in results.keys():
+ df = pd.DataFrame(
+ results["spec_feats"], columns=results["spec_feat_names"]
+ )
+ df.to_csv(
+ op_path + "_spec_features.csv",
+ sep=",",
+ index=False,
+ float_format="%.5f",
+ )
- if 'cnn_feats' in results.keys():
- df = pd.DataFrame(results['cnn_feats'], columns=results['cnn_feat_names'])
- df.to_csv(op_path + '_cnn_features.csv', sep=',', index=False, float_format='%.5f')
+ if "cnn_feats" in results.keys():
+ df = pd.DataFrame(
+ results["cnn_feats"], columns=results["cnn_feat_names"]
+ )
+ df.to_csv(
+ op_path + "_cnn_features.csv",
+ sep=",",
+ index=False,
+ float_format="%.5f",
+ )
# save json file
- with open(op_path + '.json', 'w') as da:
- json.dump(results['pred_dict'], da, indent=2, sort_keys=True)
+ with open(op_path + ".json", "w") as da:
+ json.dump(results["pred_dict"], da, indent=2, sort_keys=True)
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
# pad audio so it is evenly divisible by downsampling factors
duration = audio.shape[0] / float(sampling_rate)
- audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
- params['fft_overlap'], params['resize_factor'],
- params['spec_divide_factor'])
+ audio = au.pad_audio(
+ audio,
+ sampling_rate,
+ params["fft_win_length"],
+ params["fft_overlap"],
+ params["resize_factor"],
+ params["spec_divide_factor"],
+ )
# generate spectrogram
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
# convert to pytorch
- spec = torch.from_numpy(spec).to(params['device'])
+ spec = torch.from_numpy(spec).to(params["device"])
spec = spec.unsqueeze(0).unsqueeze(0)
# resize the spec
- rs = params['resize_factor']
- spec_op_shape = (int(params['spec_height']*rs), int(spec.shape[-1]*rs))
- spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False)
+ rs = params["resize_factor"]
+ spec_op_shape = (int(params["spec_height"] * rs), int(spec.shape[-1] * rs))
+ spec = F.interpolate(
+ spec, size=spec_op_shape, mode="bilinear", align_corners=False
+ )
if return_np:
- spec_np = spec[0,0,:].cpu().data.numpy()
+ spec_np = spec[0, 0, :].cpu().data.numpy()
else:
spec_np = None
return duration, spec, spec_np
-def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False, max_duration=False):
+def process_file(
+ audio_file,
+ model,
+ params,
+ args,
+ time_exp=None,
+ top_n=5,
+ return_raw_preds=False,
+ max_duration=False,
+):
# store temporary results here
predictions = []
- spec_feats = []
- cnn_feats = []
+ spec_feats = []
+ cnn_feats = []
spec_slices = []
# get time expansion factor
if time_exp is None:
- time_exp = args['time_expansion_factor']
+ time_exp = args["time_expansion_factor"]
- params['detection_threshold'] = args['detection_threshold']
+ params["detection_threshold"] = args["detection_threshold"]
# load audio file
- sampling_rate, audio_full = au.load_audio_file(audio_file, time_exp,
- params['target_samp_rate'], params['scale_raw_audio'])
+ sampling_rate, audio_full = au.load_audio_file(
+ audio_file,
+ time_exp,
+ params["target_samp_rate"],
+ params["scale_raw_audio"],
+ )
# clipping maximum duration
if max_duration is not False:
- max_duration = np.minimum(int(sampling_rate*max_duration), audio_full.shape[0])
+ max_duration = np.minimum(
+ int(sampling_rate * max_duration), audio_full.shape[0]
+ )
audio_full = audio_full[:max_duration]
-
+
duration_full = audio_full.shape[0] / float(sampling_rate)
- return_np_spec = args['spec_features'] or args['spec_slices']
+ return_np_spec = args["spec_features"] or args["spec_slices"]
# loop through larger file and split into chunks
# TODO fix so that it overlaps correctly and takes care of duplicate detections at borders
- num_chunks = int(np.ceil(duration_full/args['chunk_size']))
+ num_chunks = int(np.ceil(duration_full / args["chunk_size"]))
for chunk_id in range(num_chunks):
# chunk
- chunk_time = args['chunk_size']*chunk_id
- chunk_length = int(sampling_rate*args['chunk_size'])
- start_sample = chunk_id*chunk_length
- end_sample = np.minimum((chunk_id+1)*chunk_length, audio_full.shape[0])
+ chunk_time = args["chunk_size"] * chunk_id
+ chunk_length = int(sampling_rate * args["chunk_size"])
+ start_sample = chunk_id * chunk_length
+ end_sample = np.minimum(
+ (chunk_id + 1) * chunk_length, audio_full.shape[0]
+ )
audio = audio_full[start_sample:end_sample]
# load audio file and compute spectrogram
- duration, spec, spec_np = compute_spectrogram(audio, sampling_rate, params, return_np_spec)
+ duration, spec, spec_np = compute_spectrogram(
+ audio, sampling_rate, params, return_np_spec
+ )
# evaluate model
with torch.no_grad():
- outputs = model(spec, return_feats=args['cnn_features'])
+ outputs = model(spec, return_feats=args["cnn_features"])
# run non-max suppression
- pred_nms, features = pp.run_nms(outputs, params, np.array([float(sampling_rate)]))
+ pred_nms, features = pp.run_nms(
+ outputs, params, np.array([float(sampling_rate)])
+ )
pred_nms = pred_nms[0]
- pred_nms['start_times'] += chunk_time
- pred_nms['end_times'] += chunk_time
+ pred_nms["start_times"] += chunk_time
+ pred_nms["end_times"] += chunk_time
# if we have a background class
- if pred_nms['class_probs'].shape[0] > len(params['class_names']):
- pred_nms['class_probs'] = pred_nms['class_probs'][:-1, :]
+ if pred_nms["class_probs"].shape[0] > len(params["class_names"]):
+ pred_nms["class_probs"] = pred_nms["class_probs"][:-1, :]
predictions.append(pred_nms)
# extract features - if there are any calls detected
- if (pred_nms['det_probs'].shape[0] > 0):
- if args['spec_features']:
+ if pred_nms["det_probs"].shape[0] > 0:
+ if args["spec_features"]:
spec_feats.append(feats.get_feats(spec_np, pred_nms, params))
- if args['cnn_features']:
+ if args["cnn_features"]:
cnn_feats.append(features[0])
- if args['spec_slices']:
- spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms, params))
+ if args["spec_slices"]:
+ spec_slices.extend(
+ feats.extract_spec_slices(spec_np, pred_nms, params)
+ )
# convert the predictions into output dictionary
file_id = os.path.basename(audio_file)
- predictions, spec_feats, cnn_feats, spec_slices =\
- merge_results(predictions, spec_feats, cnn_feats, spec_slices)
- results = convert_results(file_id, time_exp, duration_full, params,
- predictions, spec_feats, cnn_feats, spec_slices)
+ predictions, spec_feats, cnn_feats, spec_slices = merge_results(
+ predictions, spec_feats, cnn_feats, spec_slices
+ )
+ results = convert_results(
+ file_id,
+ time_exp,
+ duration_full,
+ params,
+ predictions,
+ spec_feats,
+ cnn_feats,
+ spec_slices,
+ )
# summarize results
- if not args['quiet']:
- num_detections = len(results['pred_dict']['annotation'])
- print('{}'.format(num_detections) + ' call(s) detected above the threshold.')
+ if not args["quiet"]:
+ num_detections = len(results["pred_dict"]["annotation"])
+ print(
+ "{}".format(num_detections)
+ + " call(s) detected above the threshold."
+ )
# print results for top n classes
- if not args['quiet'] and (num_detections > 0):
- class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs'])
- print('species name'.ljust(30) + 'probablity present')
+ if not args["quiet"] and (num_detections > 0):
+ class_overall = pp.overall_class_pred(
+ predictions["det_probs"], predictions["class_probs"]
+ )
+ print("species name".ljust(30) + "probablity present")
for cc in np.argsort(class_overall)[::-1][:top_n]:
- print(params['class_names'][cc].ljust(30) + str(round(class_overall[cc], 3)))
+ print(
+ params["class_names"][cc].ljust(30)
+ + str(round(class_overall[cc], 3))
+ )
if return_raw_preds:
return predictions
diff --git a/bat_detect/utils/plot_utils.py b/bat_detect/utils/plot_utils.py
index 5b38f65..ce88375 100644
--- a/bat_detect/utils/plot_utils.py
+++ b/bat_detect/utils/plot_utils.py
@@ -1,63 +1,109 @@
-import numpy as np
-import matplotlib.pyplot as plt
import json
-from sklearn.metrics import confusion_matrix
+
+import matplotlib.pyplot as plt
+import numpy as np
from matplotlib import patches
from matplotlib.collections import PatchCollection
+from sklearn.metrics import confusion_matrix
from . import audio_utils as au
-def create_box_image(spec, fig, detections_ip, start_time, end_time, duration, params, max_val, hide_axis=True, plot_class_names=False):
+def create_box_image(
+ spec,
+ fig,
+ detections_ip,
+ start_time,
+ end_time,
+ duration,
+ params,
+ max_val,
+ hide_axis=True,
+ plot_class_names=False,
+):
# filter detections
stop_time = start_time + duration
detections = []
for bb in detections_ip:
- if (bb['start_time'] >= start_time) and (bb['start_time'] < stop_time-0.02): #(bb['end_time'] < end_time):
+ if (bb["start_time"] >= start_time) and (
+ bb["start_time"] < stop_time - 0.02
+ ): # (bb['end_time'] < end_time):
detections.append(bb)
# create figure
freq_scale = 1000 # turn Hz to kHz
- min_freq = params['min_freq']//freq_scale
- max_freq = params['max_freq']//freq_scale
+ min_freq = params["min_freq"] // freq_scale
+ max_freq = params["max_freq"] // freq_scale
y_extent = [0, duration, min_freq, max_freq]
if hide_axis:
- ax = plt.Axes(fig, [0., 0., 1., 1.])
+ ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
ax.set_axis_off()
fig.add_axes(ax)
else:
ax = plt.gca()
- plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=max_val)
+ plt.imshow(
+ spec,
+ aspect="auto",
+ cmap="plasma",
+ extent=y_extent,
+ vmin=0,
+ vmax=max_val,
+ )
boxes = plot_bounding_box_patch_ann(detections, freq_scale, start_time)
ax.add_collection(PatchCollection(boxes, match_original=True))
plt.grid(False)
if plot_class_names:
for ii, bb in enumerate(boxes):
- txt = ' '.join([sp[:3] for sp in detections_ip[ii]['class'].split(' ')])
- font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
+ txt = " ".join(
+ [sp[:3] for sp in detections_ip[ii]["class"].split(" ")]
+ )
+ font_info = {
+ "color": "white",
+ "size": 10,
+ "weight": "bold",
+ "alpha": bb.get_alpha(),
+ }
y_pos = bb.get_xy()[1] + bb.get_height()
if y_pos > (max_freq - 10):
y_pos = max_freq - 10
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
-def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title_text='', anns=None):
+def save_ann_spec(
+ op_path,
+ spec,
+ min_freq,
+ max_freq,
+ duration,
+ start_time,
+ title_text="",
+ anns=None,
+):
# create figure and plot boxes
freq_scale = 1000 # turn Hz to kHz
- min_freq = min_freq//freq_scale
- max_freq = max_freq//freq_scale
+ min_freq = min_freq // freq_scale
+ max_freq = max_freq // freq_scale
y_extent = [0, duration, min_freq, max_freq]
- plt.close('all')
- fig = plt.figure(0, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100)
- plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=spec.max()*1.1)
+ plt.close("all")
+ fig = plt.figure(
+ 0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100
+ )
+ plt.imshow(
+ spec,
+ aspect="auto",
+ cmap="plasma",
+ extent=y_extent,
+ vmin=0,
+ vmax=spec.max() * 1.1,
+ )
- plt.ylabel('Freq - kHz')
- plt.xlabel('Time - secs')
- if title_text != '':
+ plt.ylabel("Freq - kHz")
+ plt.xlabel("Time - secs")
+ if title_text != "":
plt.title(title_text)
plt.tight_layout()
@@ -66,122 +112,185 @@ def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title
boxes = plot_bounding_box_patch_ann(anns, freq_scale, start_time)
plt.gca().add_collection(PatchCollection(boxes, match_original=True))
for ii, bb in enumerate(boxes):
- txt = ' '.join([sp[:3] for sp in anns[ii]['class'].split(' ')])
- font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
+ txt = " ".join([sp[:3] for sp in anns[ii]["class"].split(" ")])
+ font_info = {
+ "color": "white",
+ "size": 10,
+ "weight": "bold",
+ "alpha": bb.get_alpha(),
+ }
y_pos = bb.get_xy()[1] + bb.get_height()
if y_pos > (max_freq - 10):
y_pos = max_freq - 10
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
- print('Saving figure to:', op_path)
+ print("Saving figure to:", op_path)
plt.savefig(op_path)
-def plot_pts(fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False):
+def plot_pts(
+ fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False
+):
plt.figure(fig_id)
un_class, labels = np.unique(class_names, return_inverse=True)
un_labels = np.unique(labels)
if un_labels.shape[0] > len(colors):
- colors = [plt.cm.jet(float(ii)/un_labels.shape[0]) for ii in un_labels]
+ colors = [
+ plt.cm.jet(float(ii) / un_labels.shape[0]) for ii in un_labels
+ ]
for ii, u in enumerate(un_labels):
- inds = np.where(labels==u)[0]
- plt.scatter(feats[inds, 0], feats[inds, 1], c=colors[ii], label=str(un_class[ii]), s=marker_size)
+ inds = np.where(labels == u)[0]
+ plt.scatter(
+ feats[inds, 0],
+ feats[inds, 1],
+ c=colors[ii],
+ label=str(un_class[ii]),
+ s=marker_size,
+ )
if plot_legend:
plt.legend()
plt.xticks([])
plt.yticks([])
- plt.title('downsampled features')
+ plt.title("downsampled features")
-def plot_bounding_box_patch(pred, freq_scale, ecolor='w'):
+def plot_bounding_box_patch(pred, freq_scale, ecolor="w"):
patch_collect = []
- for bb in range(len(pred['start_times'])):
- xx = pred['start_times'][bb]
- ww = pred['end_times'][bb] - pred['start_times'][bb]
- yy = pred['low_freqs'][bb] / freq_scale
- hh = (pred['high_freqs'][bb] - pred['low_freqs'][bb]) / freq_scale
+ for bb in range(len(pred["start_times"])):
+ xx = pred["start_times"][bb]
+ ww = pred["end_times"][bb] - pred["start_times"][bb]
+ yy = pred["low_freqs"][bb] / freq_scale
+ hh = (pred["high_freqs"][bb] - pred["low_freqs"][bb]) / freq_scale
- if 'det_probs' in pred.keys():
- alpha_val = pred['det_probs'][bb]
+ if "det_probs" in pred.keys():
+ alpha_val = pred["det_probs"][bb]
else:
alpha_val = 1.0
- patch_collect.append(patches.Rectangle((xx, yy), ww, hh, linewidth=1,
- edgecolor=ecolor, facecolor='none', alpha=alpha_val))
+ patch_collect.append(
+ patches.Rectangle(
+ (xx, yy),
+ ww,
+ hh,
+ linewidth=1,
+ edgecolor=ecolor,
+ facecolor="none",
+ alpha=alpha_val,
+ )
+ )
return patch_collect
def plot_bounding_box_patch_ann(anns, freq_scale, start_time):
patch_collect = []
for aa in range(len(anns)):
- xx = anns[aa]['start_time'] - start_time
- ww = anns[aa]['end_time'] - anns[aa]['start_time']
- yy = anns[aa]['low_freq'] / freq_scale
- hh = (anns[aa]['high_freq'] - anns[aa]['low_freq']) / freq_scale
- if 'det_prob' in anns[aa]:
- alpha = anns[aa]['det_prob']
+ xx = anns[aa]["start_time"] - start_time
+ ww = anns[aa]["end_time"] - anns[aa]["start_time"]
+ yy = anns[aa]["low_freq"] / freq_scale
+ hh = (anns[aa]["high_freq"] - anns[aa]["low_freq"]) / freq_scale
+ if "det_prob" in anns[aa]:
+ alpha = anns[aa]["det_prob"]
else:
alpha = 1.0
- patch_collect.append(patches.Rectangle((xx,yy), ww, hh, linewidth=1,
- edgecolor='w', facecolor='none', alpha=alpha))
+ patch_collect.append(
+ patches.Rectangle(
+ (xx, yy),
+ ww,
+ hh,
+ linewidth=1,
+ edgecolor="w",
+ facecolor="none",
+ alpha=alpha,
+ )
+ )
return patch_collect
-def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title,
- op_file_name, pred_2d_hm, plot_boxes=True, fixed_aspect=True):
+def plot_spec(
+ spec,
+ sampling_rate,
+ duration,
+ gt,
+ pred,
+ params,
+ plot_title,
+ op_file_name,
+ pred_2d_hm,
+ plot_boxes=True,
+ fixed_aspect=True,
+):
if fixed_aspect:
# ouptut image will be this width irrespective of the duration of the audio file
width = 12
else:
- width = 12*duration
+ width = 12 * duration
fig = plt.figure(1, figsize=(width, 8))
- ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h
+ ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h
ax1 = plt.axes([0.05, 0.33, 0.9, 0.30])
ax2 = plt.axes([0.05, 0.01, 0.9, 0.30])
freq_scale = 1000 # turn Hz in kHz
- #duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
- y_extent = [0, duration, params['min_freq']//freq_scale, params['max_freq']//freq_scale]
+ # duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
+ y_extent = [
+ 0,
+ duration,
+ params["min_freq"] // freq_scale,
+ params["max_freq"] // freq_scale,
+ ]
# plot gt boxes
- ax0.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent)
+ ax0.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent)
ax0.xaxis.set_ticklabels([])
- font_info = {'color': 'white', 'size': 12, 'weight': 'bold'}
- ax0.text(0, params['min_freq']//freq_scale, 'Ground Truth', fontdict=font_info)
+ font_info = {"color": "white", "size": 12, "weight": "bold"}
+ ax0.text(
+ 0, params["min_freq"] // freq_scale, "Ground Truth", fontdict=font_info
+ )
plt.grid(False)
if plot_boxes:
boxes = plot_bounding_box_patch(gt, freq_scale)
ax0.add_collection(PatchCollection(boxes, match_original=True))
for ii, bb in enumerate(boxes):
- class_id = int(gt['class_ids'][ii])
+ class_id = int(gt["class_ids"][ii])
if class_id < 0:
- txt = params['generic_class'][0]
+ txt = params["generic_class"][0]
else:
- txt = params['class_names_short'][class_id]
- font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
+ txt = params["class_names_short"][class_id]
+ font_info = {
+ "color": "white",
+ "size": 10,
+ "weight": "bold",
+ "alpha": bb.get_alpha(),
+ }
y_pos = bb.get_xy()[1] + bb.get_height()
ax0.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
# plot predicted boxes
- ax1.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent)
+ ax1.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent)
ax1.xaxis.set_ticklabels([])
- font_info = {'color': 'white', 'size': 12, 'weight': 'bold'}
- ax1.text(0, params['min_freq']//freq_scale, 'Prediction', fontdict=font_info)
+ font_info = {"color": "white", "size": 12, "weight": "bold"}
+ ax1.text(
+ 0, params["min_freq"] // freq_scale, "Prediction", fontdict=font_info
+ )
plt.grid(False)
if plot_boxes:
boxes = plot_bounding_box_patch(pred, freq_scale)
ax1.add_collection(PatchCollection(boxes, match_original=True))
for ii, bb in enumerate(boxes):
- if pred['class_probs'].shape[0] > len(params['class_names_short']):
- class_id = pred['class_probs'][:-1, ii].argmax()
+ if pred["class_probs"].shape[0] > len(params["class_names_short"]):
+ class_id = pred["class_probs"][:-1, ii].argmax()
else:
- class_id = pred['class_probs'][:, ii].argmax()
- txt = params['class_names_short'][class_id]
- font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
+ class_id = pred["class_probs"][:, ii].argmax()
+ txt = params["class_names_short"][class_id]
+ font_info = {
+ "color": "white",
+ "size": 10,
+ "weight": "bold",
+ "alpha": bb.get_alpha(),
+ }
y_pos = bb.get_xy()[1] + bb.get_height()
ax1.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
@@ -190,10 +299,18 @@ def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title,
min_val = 0.0 if pred_2d_hm.min() > 0.0 else pred_2d_hm.min()
max_val = 1.0 if pred_2d_hm.max() < 1.0 else pred_2d_hm.max()
- ax2.imshow(pred_2d_hm, aspect='auto', cmap='plasma', extent=y_extent, clim=[min_val, max_val])
- #ax2.xaxis.set_ticklabels([])
- font_info = {'color': 'white', 'size': 12, 'weight': 'bold'}
- ax2.text(0, params['min_freq']//freq_scale, 'Heatmap', fontdict=font_info)
+ ax2.imshow(
+ pred_2d_hm,
+ aspect="auto",
+ cmap="plasma",
+ extent=y_extent,
+ clim=[min_val, max_val],
+ )
+ # ax2.xaxis.set_ticklabels([])
+ font_info = {"color": "white", "size": 12, "weight": "bold"}
+ ax2.text(
+ 0, params["min_freq"] // freq_scale, "Heatmap", fontdict=font_info
+ )
plt.grid(False)
@@ -204,107 +321,151 @@ def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title,
plt.close(1)
-def plot_pr_curve(op_dir, plt_title, file_name, results, file_type='png', title_text=''):
- precision = results['precision']
- recall = results['recall']
- avg_prec = results['avg_prec']
+def plot_pr_curve(
+ op_dir, plt_title, file_name, results, file_type="png", title_text=""
+):
+ precision = results["precision"]
+ recall = results["recall"]
+ avg_prec = results["avg_prec"]
- plt.figure(0, figsize=(10,8))
+ plt.figure(0, figsize=(10, 8))
plt.plot(recall, precision)
- plt.ylabel('Precision', fontsize=20)
- plt.xlabel('Recall', fontsize=20)
- if title_text != '':
- plt.title(title_text, fontdict={'fontsize': 28})
+ plt.ylabel("Precision", fontsize=20)
+ plt.xlabel("Recall", fontsize=20)
+ if title_text != "":
+ plt.title(title_text, fontdict={"fontsize": 28})
else:
- plt.title(plt_title + ' {:.3f}\n'.format(avg_prec))
- plt.xlim(0,1.02)
- plt.ylim(0,1.02)
+ plt.title(plt_title + " {:.3f}\n".format(avg_prec))
+ plt.xlim(0, 1.02)
+ plt.ylim(0, 1.02)
plt.grid(True)
plt.tight_layout()
- plt.savefig(op_dir + file_name + '.' + file_type)
+ plt.savefig(op_dir + file_name + "." + file_type)
plt.close(0)
-def plot_pr_curve_class(op_dir, plt_title, file_name, results, file_type='png', title_text=''):
- plt.figure(0, figsize=(10,8))
- plt.ylabel('Precision', fontsize=20)
- plt.xlabel('Recall', fontsize=20)
- plt.xlim(0,1.02)
- plt.ylim(0,1.02)
+def plot_pr_curve_class(
+ op_dir, plt_title, file_name, results, file_type="png", title_text=""
+):
+ plt.figure(0, figsize=(10, 8))
+ plt.ylabel("Precision", fontsize=20)
+ plt.xlabel("Recall", fontsize=20)
+ plt.xlim(0, 1.02)
+ plt.ylim(0, 1.02)
plt.grid(True)
- linestyles = ['-', ':', '--']
- markers = ['o', 'v', '>', '^', '<', 's', 'P', 'X', '*']
- colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
+ linestyles = ["-", ":", "--"]
+ markers = ["o", "v", ">", "^", "<", "s", "P", "X", "*"]
+ colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
# plot the PR curves
- for ii, rr in enumerate(results['class_pr']):
- class_name = ' '.join([sp[:3] for sp in rr['name'].split(' ')])
- cur_color = colors[int(ii%10)]
- plt.plot(rr['recall'], rr['precision'], label=class_name, color=cur_color,
- linestyle=linestyles[int(ii//10)], lw=2.5)
+ for ii, rr in enumerate(results["class_pr"]):
+ class_name = " ".join([sp[:3] for sp in rr["name"].split(" ")])
+ cur_color = colors[int(ii % 10)]
+ plt.plot(
+ rr["recall"],
+ rr["precision"],
+ label=class_name,
+ color=cur_color,
+ linestyle=linestyles[int(ii // 10)],
+ lw=2.5,
+ )
- #print(class_name)
+ # print(class_name)
# plot the location of the confidence threshold values
- for jj, tt in enumerate(rr['thresholds']):
- ind = rr['thresholds_inds'][jj]
+ for jj, tt in enumerate(rr["thresholds"]):
+ ind = rr["thresholds_inds"][jj]
if ind > -1:
- plt.plot(rr['recall'][ind], rr['precision'][ind], markers[jj],
- color=cur_color, ms=10)
- #print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3))
+ plt.plot(
+ rr["recall"][ind],
+ rr["precision"][ind],
+ markers[jj],
+ color=cur_color,
+ ms=10,
+ )
+ # print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3))
- if title_text != '':
- plt.title(title_text, fontdict={'fontsize': 28})
+ if title_text != "":
+ plt.title(title_text, fontdict={"fontsize": 28})
else:
- plt.title(plt_title + ' {:.3f}\n'.format(results['avg_prec_class']))
- plt.legend(loc='lower left', prop={'size': 14})
+ plt.title(plt_title + " {:.3f}\n".format(results["avg_prec_class"]))
+ plt.legend(loc="lower left", prop={"size": 14})
plt.tight_layout()
- plt.savefig(op_dir + file_name + '.' + file_type)
+ plt.savefig(op_dir + file_name + "." + file_type)
plt.close(0)
-def plot_confusion_matrix(op_dir, op_file, gt, pred, file_acc, class_names_long, verbose=False, file_type='png', title_text=''):
+def plot_confusion_matrix(
+ op_dir,
+ op_file,
+ gt,
+ pred,
+ file_acc,
+ class_names_long,
+ verbose=False,
+ file_type="png",
+ title_text="",
+):
# shorten the class names for plotting
class_names = []
for cc in class_names_long:
- class_name_sm = ''.join([cc_sm[:3] + ' ' for cc_sm in cc.split(' ')])[:-1]
+ class_name_sm = "".join([cc_sm[:3] + " " for cc_sm in cc.split(" ")])[
+ :-1
+ ]
class_names.append(class_name_sm)
num_classes = len(class_names)
- cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(np.float32)
+ cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(
+ np.float32
+ )
cm_norm = cm.sum(1)
valid_inds = np.where(cm_norm > 0)[0]
- cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
- cm[np.where(cm_norm ==- 0)[0], :] = np.nan
+ cm[valid_inds, :] = (
+ cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
+ )
+ cm[np.where(cm_norm == -0)[0], :] = np.nan
if verbose:
- print('Per class accuracy:')
+ print("Per class accuracy:")
str_len = np.max([len(cc) for cc in class_names_long]) + 5
accs = np.diag(cm)
for ii, cc in enumerate(class_names_long):
if np.isnan(accs[ii]):
print(str(ii).ljust(5) + cc.ljust(str_len))
else:
- print(str(ii).ljust(5) + cc.ljust(str_len) + '{:.2f}'.format(accs[ii]*100))
+ print(
+ str(ii).ljust(5)
+ + cc.ljust(str_len)
+ + "{:.2f}".format(accs[ii] * 100)
+ )
- plt.figure(0, figsize=(10,8))
- plt.imshow(cm, vmin=0, vmax=1, cmap='plasma')
+ plt.figure(0, figsize=(10, 8))
+ plt.imshow(cm, vmin=0, vmax=1, cmap="plasma")
plt.colorbar()
- plt.xticks(np.arange(cm.shape[1]), class_names, rotation='vertical')
+ plt.xticks(np.arange(cm.shape[1]), class_names, rotation="vertical")
plt.yticks(np.arange(cm.shape[0]), class_names)
- plt.xlabel('Predicted', fontsize=20)
- plt.ylabel('Ground Truth', fontsize=20)
- if title_text != '':
- plt.title(title_text, fontdict={'fontsize': 28})
+ plt.xlabel("Predicted", fontsize=20)
+ plt.ylabel("Ground Truth", fontsize=20)
+ if title_text != "":
+ plt.title(title_text, fontdict={"fontsize": 28})
else:
- plt.title(op_file + ' {:.3f}\n'.format(file_acc))
+ plt.title(op_file + " {:.3f}\n".format(file_acc))
plt.tight_layout()
- plt.savefig(op_dir + op_file + '.' + file_type)
- plt.close('all')
+ plt.savefig(op_dir + op_file + "." + file_type)
+ plt.close("all")
class LossPlotter(object):
- def __init__(self, op_file_name, duration, labels, ylim, class_names, axis_labels=None, logy=False):
+ def __init__(
+ self,
+ op_file_name,
+ duration,
+ labels,
+ ylim,
+ class_names,
+ axis_labels=None,
+ logy=False,
+ ):
self.reset()
self.op_file_name = op_file_name
self.duration = duration # length of x axis
@@ -327,11 +488,16 @@ class LossPlotter(object):
self.save_confusion_matrix(gt, pred)
def save_plot(self):
- linestyles = ['-', ':', '--']
- plt.figure(0, figsize=(8,5))
+ linestyles = ["-", ":", "--"]
+ plt.figure(0, figsize=(8, 5))
for ii in range(len(self.vals[0])):
l_vals = [vv[ii] for vv in self.vals]
- plt.plot(self.epochs, l_vals, label=self.labels[ii], linestyle=linestyles[int(ii//10)])
+ plt.plot(
+ self.epochs,
+ l_vals,
+ label=self.labels[ii],
+ linestyle=linestyles[int(ii // 10)],
+ )
plt.xlim(0, np.maximum(self.duration, len(self.vals)))
if self.ylim is not None:
plt.ylim(self.ylim[0], self.ylim[1])
@@ -339,33 +505,41 @@ class LossPlotter(object):
plt.xlabel(self.axis_labels[0])
plt.ylabel(self.axis_labels[1])
if self.logy:
- plt.gca().set_yscale('log')
+ plt.gca().set_yscale("log")
plt.grid(True)
- plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0.0)
+ plt.legend(
+ bbox_to_anchor=(1.01, 1), loc="upper left", borderaxespad=0.0
+ )
plt.tight_layout()
plt.savefig(self.op_file_name)
plt.close(0)
def save_json(self):
data = {}
- data['epochs'] = self.epochs
+ data["epochs"] = self.epochs
for ii in range(len(self.vals[0])):
- data[self.labels[ii]] = [round(vv[ii],4) for vv in self.vals]
- with open(self.op_file_name[:-4] + '.json', 'w') as da:
+ data[self.labels[ii]] = [round(vv[ii], 4) for vv in self.vals]
+ with open(self.op_file_name[:-4] + ".json", "w") as da:
json.dump(data, da, indent=2)
def save_confusion_matrix(self, gt, pred):
plt.figure(0)
- cm = confusion_matrix(gt, pred, np.arange(len(self.class_names))).astype(np.float32)
+ cm = confusion_matrix(
+ gt, pred, np.arange(len(self.class_names))
+ ).astype(np.float32)
cm_norm = cm.sum(1)
valid_inds = np.where(cm_norm > 0)[0]
- cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
- plt.imshow(cm, vmin=0, vmax=1, cmap='plasma')
+ cm[valid_inds, :] = (
+ cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
+ )
+ plt.imshow(cm, vmin=0, vmax=1, cmap="plasma")
plt.colorbar()
- plt.xticks(np.arange(cm.shape[1]), self.class_names, rotation='vertical')
+ plt.xticks(
+ np.arange(cm.shape[1]), self.class_names, rotation="vertical"
+ )
plt.yticks(np.arange(cm.shape[0]), self.class_names)
- plt.xlabel('Predicted')
- plt.ylabel('Ground Truth')
+ plt.xlabel("Predicted")
+ plt.ylabel("Ground Truth")
plt.tight_layout()
- plt.savefig(self.op_file_name[:-4] + '_cm.png')
+ plt.savefig(self.op_file_name[:-4] + "_cm.png")
plt.close(0)
diff --git a/bat_detect/utils/visualize.py b/bat_detect/utils/visualize.py
index bea7f6b..54be1df 100644
--- a/bat_detect/utils/visualize.py
+++ b/bat_detect/utils/visualize.py
@@ -1,19 +1,46 @@
-import numpy as np
import matplotlib.pyplot as plt
+import numpy as np
from matplotlib import patches
-from sklearn.svm import LinearSVC
from matplotlib.axes._axes import _log as matplotlib_axes_logger
-matplotlib_axes_logger.setLevel('ERROR')
+from sklearn.svm import LinearSVC
+
+matplotlib_axes_logger.setLevel("ERROR")
-colors = ['#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4',
- '#42d4f4', '#f032e6', '#bfef45', '#fabebe', '#469990', '#e6beff',
- '#9A6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1',
- '#000075', '#a9a9a9']
+colors = [
+ "#e6194B",
+ "#3cb44b",
+ "#ffe119",
+ "#4363d8",
+ "#f58231",
+ "#911eb4",
+ "#42d4f4",
+ "#f032e6",
+ "#bfef45",
+ "#fabebe",
+ "#469990",
+ "#e6beff",
+ "#9A6324",
+ "#fffac8",
+ "#800000",
+ "#aaffc3",
+ "#808000",
+ "#ffd8b1",
+ "#000075",
+ "#a9a9a9",
+]
class InteractivePlotter:
- def __init__(self, feats_ds, feats, spec_slices, call_info, freq_lims, allow_training):
+ def __init__(
+ self,
+ feats_ds,
+ feats,
+ spec_slices,
+ call_info,
+ freq_lims,
+ allow_training,
+ ):
"""
Plots 2D low dimensional features on left and corresponding spectgrams on
the right.
@@ -24,78 +51,123 @@ class InteractivePlotter:
self.spec_slices = spec_slices
self.call_info = call_info
- #_, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True)
+ # _, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True)
self.labels = np.zeros(len(call_info), dtype=np.int)
- self.annotated = np.zeros(self.labels.shape[0], dtype=np.int) # can populate this with 1's where we have labels
- self.labels_cols = [colors[self.labels[ii]] for ii in range(len(self.labels))]
+ self.annotated = np.zeros(
+ self.labels.shape[0], dtype=np.int
+ ) # can populate this with 1's where we have labels
+ self.labels_cols = [
+ colors[self.labels[ii]] for ii in range(len(self.labels))
+ ]
self.freq_lims = freq_lims
self.allow_training = allow_training
self.pt_size = 5.0
- self.spec_pad = 0.2 # this much padding has been applied to the spec slices
+ self.spec_pad = (
+ 0.2 # this much padding has been applied to the spec slices
+ )
self.fig_width = 12
self.fig_height = 8
self.current_id = 0
max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices])
self.max_width = self.spec_slices[max_ind].shape[1]
- self.blank_spec = np.zeros((self.spec_slices[0].shape[0], self.max_width))
-
+ self.blank_spec = np.zeros(
+ (self.spec_slices[0].shape[0], self.max_width)
+ )
def plot(self, fig_id):
- self.fig, self.ax = plt.subplots(nrows=1, ncols=2, num=fig_id, figsize=(self.fig_width, self.fig_height),
- gridspec_kw={'width_ratios': [2, 1]})
+ self.fig, self.ax = plt.subplots(
+ nrows=1,
+ ncols=2,
+ num=fig_id,
+ figsize=(self.fig_width, self.fig_height),
+ gridspec_kw={"width_ratios": [2, 1]},
+ )
plt.tight_layout()
# plot 2D TNSE features
- self.low_dim_plt = self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1],
- c=self.labels_cols, s=self.pt_size, picker=5)
- self.ax[0].set_title('TSNE of Call Features')
+ self.low_dim_plt = self.ax[0].scatter(
+ self.feats_ds[:, 0],
+ self.feats_ds[:, 1],
+ c=self.labels_cols,
+ s=self.pt_size,
+ picker=5,
+ )
+ self.ax[0].set_title("TSNE of Call Features")
self.ax[0].set_xticks([])
self.ax[0].set_yticks([])
# plot clip from spectrogram
- spec_min_max = (0, self.blank_spec.shape[1], self.freq_lims[0], self.freq_lims[1])
- self.ax[1].imshow(self.blank_spec, extent=spec_min_max, cmap='plasma', aspect='auto')
+ spec_min_max = (
+ 0,
+ self.blank_spec.shape[1],
+ self.freq_lims[0],
+ self.freq_lims[1],
+ )
+ self.ax[1].imshow(
+ self.blank_spec, extent=spec_min_max, cmap="plasma", aspect="auto"
+ )
self.spec_im = self.ax[1].get_images()[0]
- self.ax[1].set_title('Spectrogram')
- self.ax[1].grid(color='w', linewidth=0.5)
+ self.ax[1].set_title("Spectrogram")
+ self.ax[1].grid(color="w", linewidth=0.5)
self.ax[1].set_xticks([])
- self.ax[1].set_ylabel('kHz')
+ self.ax[1].set_ylabel("kHz")
- bbox_orig = patches.Rectangle((0,0),0,0, edgecolor='w', linewidth=0, fill=False)
+ bbox_orig = patches.Rectangle(
+ (0, 0), 0, 0, edgecolor="w", linewidth=0, fill=False
+ )
self.ax[1].add_patch(bbox_orig)
- self.annot = self.ax[0].annotate('', xy=(0,0), xytext=(20,20),textcoords='offset points',
- bbox=dict(boxstyle='round', fc='w'), arrowprops=dict(arrowstyle='->'))
+ self.annot = self.ax[0].annotate(
+ "",
+ xy=(0, 0),
+ xytext=(20, 20),
+ textcoords="offset points",
+ bbox=dict(boxstyle="round", fc="w"),
+ arrowprops=dict(arrowstyle="->"),
+ )
self.annot.set_visible(False)
- self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_hover)
- self.fig.canvas.mpl_connect('key_press_event', self.key_press)
-
+ self.fig.canvas.mpl_connect("motion_notify_event", self.mouse_hover)
+ self.fig.canvas.mpl_connect("key_press_event", self.key_press)
def mouse_hover(self, event):
vis = self.annot.get_visible()
if event.inaxes == self.ax[0]:
cont, ind = self.low_dim_plt.contains(event)
if cont:
- self.current_id = ind['ind'][0]
+ self.current_id = ind["ind"][0]
# copy spec into full window - probably a better way of doing this
new_spec = self.blank_spec.copy()
- w_diff = (self.blank_spec.shape[1] - self.spec_slices[self.current_id].shape[1])//2
- new_spec[:, w_diff:self.spec_slices[self.current_id].shape[1]+w_diff] = self.spec_slices[self.current_id]
+ w_diff = (
+ self.blank_spec.shape[1]
+ - self.spec_slices[self.current_id].shape[1]
+ ) // 2
+ new_spec[
+ :,
+ w_diff : self.spec_slices[self.current_id].shape[1]
+ + w_diff,
+ ] = self.spec_slices[self.current_id]
self.spec_im.set_data(new_spec)
self.spec_im.set_clim(vmin=0, vmax=new_spec.max())
# draw bounding box around call
self.ax[1].patches[0].remove()
- spec_width_orig = self.spec_slices[self.current_id].shape[1]/(1.0+2.0*self.spec_pad)
- xx = w_diff + self.spec_pad*spec_width_orig
+ spec_width_orig = self.spec_slices[self.current_id].shape[
+ 1
+ ] / (1.0 + 2.0 * self.spec_pad)
+ xx = w_diff + self.spec_pad * spec_width_orig
ww = spec_width_orig
- yy = self.call_info[self.current_id]['low_freq']/1000
- hh = (self.call_info[self.current_id]['high_freq']-self.call_info[self.current_id]['low_freq'])/1000
- bbox = patches.Rectangle((xx,yy),ww,hh, edgecolor='r', linewidth=0.5, fill=False)
+ yy = self.call_info[self.current_id]["low_freq"] / 1000
+ hh = (
+ self.call_info[self.current_id]["high_freq"]
+ - self.call_info[self.current_id]["low_freq"]
+ ) / 1000
+ bbox = patches.Rectangle(
+ (xx, yy), ww, hh, edgecolor="r", linewidth=0.5, fill=False
+ )
self.ax[1].add_patch(bbox)
# update annotation arrow
@@ -104,38 +176,54 @@ class InteractivePlotter:
self.annot.set_visible(True)
# write call info
- info_str = self.call_info[self.current_id]['file_name'] + ', time=' \
- + str(round(self.call_info[self.current_id]['start_time'],3)) \
- + ', prob=' + str(round(self.call_info[self.current_id]['det_prob'],3))
+ info_str = (
+ self.call_info[self.current_id]["file_name"]
+ + ", time="
+ + str(
+ round(self.call_info[self.current_id]["start_time"], 3)
+ )
+ + ", prob="
+ + str(
+ round(self.call_info[self.current_id]["det_prob"], 3)
+ )
+ )
self.ax[0].set_xlabel(info_str)
# redraw
self.fig.canvas.draw_idle()
-
def key_press(self, event):
if event.key.isdigit():
self.labels_cols[self.current_id] = colors[int(event.key)]
self.labels[self.current_id] = int(event.key)
self.annotated[self.current_id] = 1
- elif event.key == 'enter' and self.allow_training:
+ elif event.key == "enter" and self.allow_training:
self.train_classifier()
- elif event.key == 'x' and self.allow_training:
+ elif event.key == "x" and self.allow_training:
self.get_classifier_params()
- self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1],
- c=self.labels_cols, s=self.pt_size)
+ self.ax[0].scatter(
+ self.feats_ds[:, 0],
+ self.feats_ds[:, 1],
+ c=self.labels_cols,
+ s=self.pt_size,
+ )
self.fig.canvas.draw_idle()
-
def train_classifier(self):
# TODO maybe it's better to classify in 2D space - but then can't be linear ...
inds = np.where(self.annotated == 1)[0]
labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True)
if labs_un.shape[0] > 1: # needs at least 2 classes
- self.clf = LinearSVC(C=1.0, penalty='l2', loss='squared_hinge', tol=0.0001,
- intercept_scaling=1.0, max_iter=2000)
+ self.clf = LinearSVC(
+ C=1.0,
+ penalty="l2",
+ loss="squared_hinge",
+ tol=0.0001,
+ intercept_scaling=1.0,
+ max_iter=2000,
+ )
self.clf.fit(self.feats[inds, :], self.labels[inds])
@@ -145,14 +233,13 @@ class InteractivePlotter:
for ii in inds_unlab:
self.labels_cols[ii] = colors[self.labels[ii]]
else:
- print('Not enough data - please label more classes.')
-
+ print("Not enough data - please label more classes.")
def get_classifier_params(self):
res = {}
if self.clf is None:
- print('Model not trained!')
+ print("Model not trained!")
else:
- res['weights'] = self.clf.coef_.astype(np.float32)
- res['biases'] = self.clf.intercept_.astype(np.float32)
+ res["weights"] = self.clf.coef_.astype(np.float32)
+ res["biases"] = self.clf.intercept_.astype(np.float32)
return res
diff --git a/bat_detect/utils/wavfile.py b/bat_detect/utils/wavfile.py
index a6715b0..7fee660 100644
--- a/bat_detect/utils/wavfile.py
+++ b/bat_detect/utils/wavfile.py
@@ -8,23 +8,25 @@ Functions
`write`: Write a numpy array as a WAV file.
"""
-from __future__ import division, print_function, absolute_import
+from __future__ import absolute_import, division, print_function
-import sys
-import numpy
-import struct
-import warnings
import os
+import struct
+import sys
+import warnings
+
+import numpy
class WavFileWarning(UserWarning):
pass
+
_big_endian = False
WAVE_FORMAT_PCM = 0x0001
WAVE_FORMAT_IEEE_FLOAT = 0x0003
-WAVE_FORMAT_EXTENSIBLE = 0xfffe
+WAVE_FORMAT_EXTENSIBLE = 0xFFFE
KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT)
# assumes file pointer is immediately
@@ -33,10 +35,10 @@ KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT)
def _read_fmt_chunk(fid):
if _big_endian:
- fmt = '>'
+ fmt = ">"
else:
- fmt = '<'
- res = struct.unpack(fmt+'iHHIIHH',fid.read(20))
+ fmt = "<"
+ res = struct.unpack(fmt + "iHHIIHH", fid.read(20))
size, comp, noc, rate, sbytes, ba, bits = res
if comp not in KNOWN_WAVE_FORMATS or size > 16:
comp = WAVE_FORMAT_PCM
@@ -51,41 +53,42 @@ def _read_fmt_chunk(fid):
# after the 'data' id
def _read_data_chunk(fid, comp, noc, bits, mmap=False):
if _big_endian:
- fmt = '>i'
+ fmt = ">i"
else:
- fmt = ' 1:
- data = data.reshape(-1,noc)
+ data = data.reshape(-1, noc)
return data
def _skip_unknown_chunk(fid):
if _big_endian:
- fmt = '>i'
+ fmt = ">i"
else:
- fmt = '' or (data.dtype.byteorder == '=' and sys.byteorder == 'big'):
+ fid.write(b"data")
+ fid.write(struct.pack("" or (
+ data.dtype.byteorder == "=" and sys.byteorder == "big"
+ ):
data = data.byteswap()
_array_tofile(fid, data)
@@ -273,19 +286,22 @@ def write(filename, rate, data):
# position at start of the file (replacing the 4 bytes of zeros)
size = fid.tell()
fid.seek(4)
- fid.write(struct.pack('= 3:
+
def _array_tofile(fid, data):
# ravel gives a c-contiguous buffer
- fid.write(data.ravel().view('b').data)
+ fid.write(data.ravel().view("b").data)
+
else:
+
def _array_tofile(fid, data):
fid.write(data.tostring())
diff --git a/run_batdetect.py b/run_batdetect.py
index 9655d45..f9d96ab 100644
--- a/run_batdetect.py
+++ b/run_batdetect.py
@@ -1,67 +1,115 @@
-import os
import argparse
+import os
+
import bat_detect.utils.detector_utils as du
def main(args):
- print('Loading model: ' + args['model_path'])
- model, params = du.load_model(args['model_path'])
+ print("Loading model: " + args["model_path"])
+ model, params = du.load_model(args["model_path"])
- print('\nInput directory: ' + args['audio_dir'])
- files = du.get_audio_files(args['audio_dir'])
- print('Number of audio files: {}'.format(len(files)))
- print('\nSaving results to: ' + args['ann_dir'])
+ print("\nInput directory: " + args["audio_dir"])
+ files = du.get_audio_files(args["audio_dir"])
+ print("Number of audio files: {}".format(len(files)))
+ print("\nSaving results to: " + args["ann_dir"])
# process files
error_files = []
for ii, audio_file in enumerate(files):
- print('\n' + str(ii).ljust(6) + os.path.basename(audio_file))
+ print("\n" + str(ii).ljust(6) + os.path.basename(audio_file))
try:
results = du.process_file(audio_file, model, params, args)
- if args['save_preds_if_empty'] or (len(results['pred_dict']['annotation']) > 0):
- results_path = audio_file.replace(args['audio_dir'], args['ann_dir'])
+ if args["save_preds_if_empty"] or (
+ len(results["pred_dict"]["annotation"]) > 0
+ ):
+ results_path = audio_file.replace(
+ args["audio_dir"], args["ann_dir"]
+ )
du.save_results_to_file(results, results_path)
except:
error_files.append(audio_file)
print("Error processing file!")
- print('\nResults saved to: ' + args['ann_dir'])
+ print("\nResults saved to: " + args["ann_dir"])
if len(error_files) > 0:
- print('\nUnable to process the follow files:')
+ print("\nUnable to process the follow files:")
for err in error_files:
- print(' ' + err)
+ print(" " + err)
if __name__ == "__main__":
- info_str = '\nBatDetect2 - Detection and Classification\n' + \
- ' Assumes audio files are mono, not stereo.\n' + \
- ' Spaces in the input paths will throw an error. Wrap in quotes "".\n' + \
- ' Input files should be short in duration e.g. < 30 seconds.\n'
+ info_str = (
+ "\nBatDetect2 - Detection and Classification\n"
+ + " Assumes audio files are mono, not stereo.\n"
+ + ' Spaces in the input paths will throw an error. Wrap in quotes "".\n'
+ + " Input files should be short in duration e.g. < 30 seconds.\n"
+ )
print(info_str)
parser = argparse.ArgumentParser()
- parser.add_argument('audio_dir', type=str, help='Input directory for audio')
- parser.add_argument('ann_dir', type=str, help='Output directory for where the predictions will be stored')
- parser.add_argument('detection_threshold', type=float, help='Cut-off probability for detector e.g. 0.1')
- parser.add_argument('--cnn_features', action='store_true', default=False, dest='cnn_features',
- help='Extracts CNN call features')
- parser.add_argument('--spec_features', action='store_true', default=False, dest='spec_features',
- help='Extracts low level call features')
- parser.add_argument('--time_expansion_factor', type=int, default=1, dest='time_expansion_factor',
- help='The time expansion factor used for all files (default is 1)')
- parser.add_argument('--quiet', action='store_true', default=False, dest='quiet',
- help='Minimize output printing')
- parser.add_argument('--save_preds_if_empty', action='store_true', default=False, dest='save_preds_if_empty',
- help='Save empty annotation file if no detections made.')
- parser.add_argument('--model_path', type=str, default='models/Net2DFast_UK_same.pth.tar',
- help='Path to trained BatDetect2 model')
+ parser.add_argument(
+ "audio_dir", type=str, help="Input directory for audio"
+ )
+ parser.add_argument(
+ "ann_dir",
+ type=str,
+ help="Output directory for where the predictions will be stored",
+ )
+ parser.add_argument(
+ "detection_threshold",
+ type=float,
+ help="Cut-off probability for detector e.g. 0.1",
+ )
+ parser.add_argument(
+ "--cnn_features",
+ action="store_true",
+ default=False,
+ dest="cnn_features",
+ help="Extracts CNN call features",
+ )
+ parser.add_argument(
+ "--spec_features",
+ action="store_true",
+ default=False,
+ dest="spec_features",
+ help="Extracts low level call features",
+ )
+ parser.add_argument(
+ "--time_expansion_factor",
+ type=int,
+ default=1,
+ dest="time_expansion_factor",
+ help="The time expansion factor used for all files (default is 1)",
+ )
+ parser.add_argument(
+ "--quiet",
+ action="store_true",
+ default=False,
+ dest="quiet",
+ help="Minimize output printing",
+ )
+ parser.add_argument(
+ "--save_preds_if_empty",
+ action="store_true",
+ default=False,
+ dest="save_preds_if_empty",
+ help="Save empty annotation file if no detections made.",
+ )
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default="models/Net2DFast_UK_same.pth.tar",
+ help="Path to trained BatDetect2 model",
+ )
args = vars(parser.parse_args())
- args['spec_slices'] = False # used for visualization
- args['chunk_size'] = 2 # if files greater than this amount (seconds) they will be broken down into small chunks
- args['ann_dir'] = os.path.join(args['ann_dir'], '')
+ args["spec_slices"] = False # used for visualization
+ args[
+ "chunk_size"
+ ] = 2 # if files greater than this amount (seconds) they will be broken down into small chunks
+ args["ann_dir"] = os.path.join(args["ann_dir"], "")
main(args)
diff --git a/scripts/gen_dataset_summary_image.py b/scripts/gen_dataset_summary_image.py
index b789584..cb823d6 100644
--- a/scripts/gen_dataset_summary_image.py
+++ b/scripts/gen_dataset_summary_image.py
@@ -3,62 +3,97 @@ Loads a set of annotations corresponding to a dataset and saves an image which
is the mean spectrogram for each class.
"""
+import argparse
+import os
+import sys
+
import matplotlib.pyplot as plt
import numpy as np
-import os
-import argparse
-import sys
import viz_helpers as vz
-sys.path.append(os.path.join('..'))
-import bat_detect.train.train_utils as tu
+sys.path.append(os.path.join(".."))
import bat_detect.detector.parameters as parameters
-import bat_detect.utils.audio_utils as au
import bat_detect.train.train_split as ts
-
+import bat_detect.train.train_utils as tu
+import bat_detect.utils.audio_utils as au
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('audio_path', type=str, help='Input directory for audio')
- parser.add_argument('op_dir', type=str,
- help='Path to where single annotation json file is stored')
- parser.add_argument('--ann_file', type=str,
- help='Path to where single annotation json file is stored')
- parser.add_argument('--uk_split', type=str, default='',
- help='Set as: diff or same')
- parser.add_argument('--file_type', type=str, default='png',
- help='Type of image to save png or pdf')
+ parser.add_argument(
+ "audio_path", type=str, help="Input directory for audio"
+ )
+ parser.add_argument(
+ "op_dir",
+ type=str,
+ help="Path to where single annotation json file is stored",
+ )
+ parser.add_argument(
+ "--ann_file",
+ type=str,
+ help="Path to where single annotation json file is stored",
+ )
+ parser.add_argument(
+ "--uk_split", type=str, default="", help="Set as: diff or same"
+ )
+ parser.add_argument(
+ "--file_type",
+ type=str,
+ default="png",
+ help="Type of image to save png or pdf",
+ )
args = vars(parser.parse_args())
- if not os.path.isdir(args['op_dir']):
- os.makedirs(args['op_dir'])
+ if not os.path.isdir(args["op_dir"]):
+ os.makedirs(args["op_dir"])
params = parameters.get_params(False)
- params['smooth_spec'] = False
- params['spec_width'] = 48
- params['norm_type'] = 'log' # log, pcen
- params['aud_pad'] = 0.005
- classes_to_ignore = params['classes_to_ignore'] + params['generic_class']
-
+ params["smooth_spec"] = False
+ params["spec_width"] = 48
+ params["norm_type"] = "log" # log, pcen
+ params["aud_pad"] = 0.005
+ classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
# load train annotations
- if args['uk_split'] == '':
- print('\nLoading:', args['ann_file'], '\n')
- dataset_name = os.path.basename(args['ann_file']).replace('.json', '')
+ if args["uk_split"] == "":
+ print("\nLoading:", args["ann_file"], "\n")
+ dataset_name = os.path.basename(args["ann_file"]).replace(".json", "")
datasets = []
- datasets.append(tu.get_blank_dataset_dict(dataset_name, False, args['ann_file'], args['audio_path']))
+ datasets.append(
+ tu.get_blank_dataset_dict(
+ dataset_name, False, args["ann_file"], args["audio_path"]
+ )
+ )
else:
# load uk data - special case
- print('\nLoading:', args['uk_split'], '\n')
- dataset_name = 'uk_' + args['uk_split'] # should be uk_diff, or uk_same
- datasets, _ = ts.get_train_test_data(args['ann_file'], args['audio_path'], args['uk_split'], load_extra=False)
+ print("\nLoading:", args["uk_split"], "\n")
+ dataset_name = (
+ "uk_" + args["uk_split"]
+ ) # should be uk_diff, or uk_same
+ datasets, _ = ts.get_train_test_data(
+ args["ann_file"],
+ args["audio_path"],
+ args["uk_split"],
+ load_extra=False,
+ )
- anns, class_names, _ = tu.load_set_of_anns(datasets, classes_to_ignore, params['events_of_interest'])
+ anns, class_names, _ = tu.load_set_of_anns(
+ datasets, classes_to_ignore, params["events_of_interest"]
+ )
class_names_order = range(len(class_names))
- x_train, y_train = vz.load_data(anns, params, class_names, smooth_spec=params['smooth_spec'], norm_type=params['norm_type'])
+ x_train, y_train = vz.load_data(
+ anns,
+ params,
+ class_names,
+ smooth_spec=params["smooth_spec"],
+ norm_type=params["norm_type"],
+ )
- op_file_name = os.path.join(args['op_dir'], dataset_name + '.' + args['file_type'])
- vz.save_summary_image(x_train, y_train, class_names, params, op_file_name, class_names_order)
- print('\nImage saved to:', op_file_name)
+ op_file_name = os.path.join(
+ args["op_dir"], dataset_name + "." + args["file_type"]
+ )
+ vz.save_summary_image(
+ x_train, y_train, class_names, params, op_file_name, class_names_order
+ )
+ print("\nImage saved to:", op_file_name)
diff --git a/scripts/gen_spec_image.py b/scripts/gen_spec_image.py
index 11f76de..3d4cffa 100644
--- a/scripts/gen_spec_image.py
+++ b/scripts/gen_spec_image.py
@@ -7,24 +7,27 @@ Will save images with:
3) spectrogram with predicted boxes
"""
-import numpy as np
-import sys
-import os
import argparse
-import matplotlib.pyplot as plt
import json
+import os
+import sys
-sys.path.append(os.path.join('..'))
+import matplotlib.pyplot as plt
+import numpy as np
+
+sys.path.append(os.path.join(".."))
import bat_detect.evaluate.evaluate_models as evlm
+import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du
import bat_detect.utils.plot_utils as viz
-import bat_detect.utils.audio_utils as au
def filter_anns(anns, start_time, stop_time):
anns_op = []
for aa in anns:
- if (aa['start_time'] >= start_time) and (aa['start_time'] < stop_time-0.02):
+ if (aa["start_time"] >= start_time) and (
+ aa["start_time"] < stop_time - 0.02
+ ):
anns_op.append(aa)
return anns_op
@@ -32,85 +35,172 @@ def filter_anns(anns, start_time, stop_time):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('audio_file', type=str, help='Path to audio file')
- parser.add_argument('model_path', type=str, help='Path to BatDetect model')
- parser.add_argument('--ann_file', type=str, default='', help='Path to annotation file')
- parser.add_argument('--op_dir', type=str, default='plots/',
- help='Output directory for plots')
- parser.add_argument('--file_type', type=str, default='png',
- help='Type of image to save png or pdf')
- parser.add_argument('--title_text', type=str, default='',
- help='Text to add as title of plots')
- parser.add_argument('--detection_threshold', type=float, default=0.2,
- help='Threshold for output detections')
- parser.add_argument('--start_time', type=float, default=0.0,
- help='Start time for cropped file')
- parser.add_argument('--stop_time', type=float, default=0.5,
- help='End time for cropped file')
- parser.add_argument('--time_expansion_factor', type=int, default=1,
- help='Time expansion factor')
-
+ parser.add_argument("audio_file", type=str, help="Path to audio file")
+ parser.add_argument("model_path", type=str, help="Path to BatDetect model")
+ parser.add_argument(
+ "--ann_file", type=str, default="", help="Path to annotation file"
+ )
+ parser.add_argument(
+ "--op_dir",
+ type=str,
+ default="plots/",
+ help="Output directory for plots",
+ )
+ parser.add_argument(
+ "--file_type",
+ type=str,
+ default="png",
+ help="Type of image to save png or pdf",
+ )
+ parser.add_argument(
+ "--title_text",
+ type=str,
+ default="",
+ help="Text to add as title of plots",
+ )
+ parser.add_argument(
+ "--detection_threshold",
+ type=float,
+ default=0.2,
+ help="Threshold for output detections",
+ )
+ parser.add_argument(
+ "--start_time",
+ type=float,
+ default=0.0,
+ help="Start time for cropped file",
+ )
+ parser.add_argument(
+ "--stop_time",
+ type=float,
+ default=0.5,
+ help="End time for cropped file",
+ )
+ parser.add_argument(
+ "--time_expansion_factor",
+ type=int,
+ default=1,
+ help="Time expansion factor",
+ )
+
args_cmd = vars(parser.parse_args())
-
- # load the model
+
+ # load the model
bd_args = du.get_default_bd_args()
- model, params_bd = du.load_model(args_cmd['model_path'])
- bd_args['detection_threshold'] = args_cmd['detection_threshold']
- bd_args['time_expansion_factor'] = args_cmd['time_expansion_factor']
-
+ model, params_bd = du.load_model(args_cmd["model_path"])
+ bd_args["detection_threshold"] = args_cmd["detection_threshold"]
+ bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
+
# load the annotation if it exists
gt_present = False
- if args_cmd['ann_file'] != '':
- if os.path.isfile(args_cmd['ann_file']):
- with open(args_cmd['ann_file']) as da:
+ if args_cmd["ann_file"] != "":
+ if os.path.isfile(args_cmd["ann_file"]):
+ with open(args_cmd["ann_file"]) as da:
gt_anns = json.load(da)
- gt_anns = filter_anns(gt_anns['annotation'], args_cmd['start_time'], args_cmd['stop_time'])
+ gt_anns = filter_anns(
+ gt_anns["annotation"],
+ args_cmd["start_time"],
+ args_cmd["stop_time"],
+ )
gt_present = True
else:
- print('Annotation file not found: ', args_cmd['ann_file'])
+ print("Annotation file not found: ", args_cmd["ann_file"])
# load the audio file
- if not os.path.isfile(args_cmd['audio_file']):
- print('Audio file not found: ', args_cmd['audio_file'])
+ if not os.path.isfile(args_cmd["audio_file"]):
+ print("Audio file not found: ", args_cmd["audio_file"])
sys.exit()
-
+
# load audio and crop
- print('\nProcessing: ' + os.path.basename(args_cmd['audio_file']))
- print('\nOutput directory: ' + args_cmd['op_dir'])
- sampling_rate, audio = au.load_audio_file(args_cmd['audio_file'], args_cmd['time_exp'],
- params_bd['target_samp_rate'], params_bd['scale_raw_audio'])
- st_samp = int(sampling_rate*args_cmd['start_time'])
- en_samp = int(sampling_rate*args_cmd['stop_time'])
+ print("\nProcessing: " + os.path.basename(args_cmd["audio_file"]))
+ print("\nOutput directory: " + args_cmd["op_dir"])
+ sampling_rate, audio = au.load_audio_file(
+ args_cmd["audio_file"],
+ args_cmd["time_exp"],
+ params_bd["target_samp_rate"],
+ params_bd["scale_raw_audio"],
+ )
+ st_samp = int(sampling_rate * args_cmd["start_time"])
+ en_samp = int(sampling_rate * args_cmd["stop_time"])
if en_samp > audio.shape[0]:
- audio = np.hstack((audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype)))
+ audio = np.hstack(
+ (audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype))
+ )
audio = audio[st_samp:en_samp]
duration = audio.shape[0] / sampling_rate
- print('File duration: {} seconds'.format(duration))
+ print("File duration: {} seconds".format(duration))
# create spec for viz
- spec, _ = au.generate_spectrogram(audio, sampling_rate, params_bd, True, False)
+ spec, _ = au.generate_spectrogram(
+ audio, sampling_rate, params_bd, True, False
+ )
# run model and filter detections so only keep ones in relevant time range
- results = du.process_file(args_cmd['audio_file'], model, params_bd, bd_args)
- pred_anns = filter_anns(results['pred_dict']['annotation'], args_cmd['start_time'], args_cmd['stop_time'])
- print(len(pred_anns), 'Detections')
+ results = du.process_file(
+ args_cmd["audio_file"], model, params_bd, bd_args
+ )
+ pred_anns = filter_anns(
+ results["pred_dict"]["annotation"],
+ args_cmd["start_time"],
+ args_cmd["stop_time"],
+ )
+ print(len(pred_anns), "Detections")
# save output
- if not os.path.isdir(args_cmd['op_dir']):
- os.makedirs(args_cmd['op_dir'])
-
+ if not os.path.isdir(args_cmd["op_dir"]):
+ os.makedirs(args_cmd["op_dir"])
+
# create output file names
- op_path_clean = os.path.basename(args_cmd['audio_file'])[:-4] + '_clean.' + args_cmd['file_type']
- op_path_clean = os.path.join(args_cmd['op_dir'], op_path_clean)
- op_path_pred = os.path.basename(args_cmd['audio_file'])[:-4] + '_pred.' + args_cmd['file_type']
- op_path_pred = os.path.join(args_cmd['op_dir'], op_path_pred)
+ op_path_clean = (
+ os.path.basename(args_cmd["audio_file"])[:-4]
+ + "_clean."
+ + args_cmd["file_type"]
+ )
+ op_path_clean = os.path.join(args_cmd["op_dir"], op_path_clean)
+ op_path_pred = (
+ os.path.basename(args_cmd["audio_file"])[:-4]
+ + "_pred."
+ + args_cmd["file_type"]
+ )
+ op_path_pred = os.path.join(args_cmd["op_dir"], op_path_pred)
# create and save iamges
- viz.save_ann_spec(op_path_clean, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', None)
- viz.save_ann_spec(op_path_pred, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', pred_anns)
+ viz.save_ann_spec(
+ op_path_clean,
+ spec,
+ params_bd["min_freq"],
+ params_bd["max_freq"],
+ duration,
+ args_cmd["start_time"],
+ "",
+ None,
+ )
+ viz.save_ann_spec(
+ op_path_pred,
+ spec,
+ params_bd["min_freq"],
+ params_bd["max_freq"],
+ duration,
+ args_cmd["start_time"],
+ "",
+ pred_anns,
+ )
if gt_present:
- op_path_gt = os.path.basename(args_cmd['audio_file'])[:-4] + '_gt.' + args_cmd['file_type']
- op_path_gt = os.path.join(args_cmd['op_dir'], op_path_gt)
- viz.save_ann_spec(op_path_gt, spec, params_bd['min_freq'], params_bd['max_freq'], duration, args_cmd['start_time'], '', gt_anns)
+ op_path_gt = (
+ os.path.basename(args_cmd["audio_file"])[:-4]
+ + "_gt."
+ + args_cmd["file_type"]
+ )
+ op_path_gt = os.path.join(args_cmd["op_dir"], op_path_gt)
+ viz.save_ann_spec(
+ op_path_gt,
+ spec,
+ params_bd["min_freq"],
+ params_bd["max_freq"],
+ duration,
+ args_cmd["start_time"],
+ "",
+ gt_anns,
+ )
diff --git a/scripts/gen_spec_video.py b/scripts/gen_spec_video.py
index cccfcf8..3c055ec 100644
--- a/scripts/gen_spec_video.py
+++ b/scripts/gen_spec_video.py
@@ -8,57 +8,83 @@ Notes:
Best to use system one - see ffmpeg_path.
"""
-from scipy.io import wavfile
+import argparse
import os
import shutil
+import sys
+
import matplotlib.pyplot as plt
import numpy as np
-import argparse
+from scipy.io import wavfile
-import sys
-sys.path.append(os.path.join('..'))
+sys.path.append(os.path.join(".."))
import bat_detect.detector.parameters as parameters
import bat_detect.utils.audio_utils as au
-import bat_detect.utils.plot_utils as viz
import bat_detect.utils.detector_utils as du
-
+import bat_detect.utils.plot_utils as viz
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('audio_file', type=str, help='Path to input audio file')
- parser.add_argument('model_path', type=str, help='Path to trained BatDetect model')
- parser.add_argument('--op_dir', type=str, default='generated_vids/', help='Path to output directory')
- parser.add_argument('--no_detector', action='store_true', help='Do not run detector')
- parser.add_argument('--plot_class_names_off', action='store_true', help='Do not plot class names')
- parser.add_argument('--disable_axis', action='store_true', help='Do not plot axis')
- parser.add_argument('--detection_threshold', type=float, default=0.2, help='Cut-off probability for detector')
- parser.add_argument('--time_expansion_factor', type=int, default=1, dest='time_expansion_factor',
- help='The time expansion factor used for all files (default is 1)')
+ parser.add_argument(
+ "audio_file", type=str, help="Path to input audio file"
+ )
+ parser.add_argument(
+ "model_path", type=str, help="Path to trained BatDetect model"
+ )
+ parser.add_argument(
+ "--op_dir",
+ type=str,
+ default="generated_vids/",
+ help="Path to output directory",
+ )
+ parser.add_argument(
+ "--no_detector", action="store_true", help="Do not run detector"
+ )
+ parser.add_argument(
+ "--plot_class_names_off",
+ action="store_true",
+ help="Do not plot class names",
+ )
+ parser.add_argument(
+ "--disable_axis", action="store_true", help="Do not plot axis"
+ )
+ parser.add_argument(
+ "--detection_threshold",
+ type=float,
+ default=0.2,
+ help="Cut-off probability for detector",
+ )
+ parser.add_argument(
+ "--time_expansion_factor",
+ type=int,
+ default=1,
+ dest="time_expansion_factor",
+ help="The time expansion factor used for all files (default is 1)",
+ )
args_cmd = vars(parser.parse_args())
# file of interest
- audio_file = args_cmd['audio_file']
- op_dir = args_cmd['op_dir']
- op_str = '_output'
- ffmpeg_path = '/usr/bin/'
+ audio_file = args_cmd["audio_file"]
+ op_dir = args_cmd["op_dir"]
+ op_str = "_output"
+ ffmpeg_path = "/usr/bin/"
if not os.path.isfile(audio_file):
- print('Audio file not found: ', audio_file)
+ print("Audio file not found: ", audio_file)
sys.exit()
- if not os.path.isfile(args_cmd['model_path']):
- print('Model not found: ', model_path)
+ if not os.path.isfile(args_cmd["model_path"]):
+ print("Model not found: ", model_path)
sys.exit()
-
start_time = 0.0
duration = 0.5
reveal_boxes = True # makes the boxes appear one at a time
fps = 24
dpi = 100
- op_dir_tmp = os.path.join(op_dir, 'op_tmp_vids', '')
+ op_dir_tmp = os.path.join(op_dir, "op_tmp_vids", "")
if not os.path.isdir(op_dir_tmp):
os.makedirs(op_dir_tmp)
if not os.path.isdir(op_dir):
@@ -66,105 +92,176 @@ if __name__ == "__main__":
params = parameters.get_params(False)
args = du.get_default_bd_args()
- args['time_expansion_factor'] = args_cmd['time_expansion_factor']
- args['detection_threshold'] = args_cmd['detection_threshold']
-
+ args["time_expansion_factor"] = args_cmd["time_expansion_factor"]
+ args["detection_threshold"] = args_cmd["detection_threshold"]
# load audio file
- print('\nProcessing: ' + os.path.basename(audio_file))
- print('\nOutput directory: ' + op_dir)
- sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'])
- audio = audio[int(sampling_rate*start_time):int(sampling_rate*start_time + sampling_rate*duration)]
+ print("\nProcessing: " + os.path.basename(audio_file))
+ print("\nOutput directory: " + op_dir)
+ sampling_rate, audio = au.load_audio_file(
+ audio_file, args["time_expansion_factor"], params["target_samp_rate"]
+ )
+ audio = audio[
+ int(sampling_rate * start_time) : int(
+ sampling_rate * start_time + sampling_rate * duration
+ )
+ ]
audio_orig = audio.copy()
- audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
- params['fft_overlap'], params['resize_factor'],
- params['spec_divide_factor'])
+ audio = au.pad_audio(
+ audio,
+ sampling_rate,
+ params["fft_win_length"],
+ params["fft_overlap"],
+ params["resize_factor"],
+ params["spec_divide_factor"],
+ )
# generate spectrogram
spec, _ = au.generate_spectrogram(audio, sampling_rate, params, True)
- max_val = spec.max()*1.1
+ max_val = spec.max() * 1.1
-
- if not args_cmd['no_detector']:
- print(' Loading model and running detector on entire file ...')
- model, det_params = du.load_model(args_cmd['model_path'])
- det_params['detection_threshold'] = args['detection_threshold']
+ if not args_cmd["no_detector"]:
+ print(" Loading model and running detector on entire file ...")
+ model, det_params = du.load_model(args_cmd["model_path"])
+ det_params["detection_threshold"] = args["detection_threshold"]
results = du.process_file(audio_file, model, det_params, args)
- print(' Processing detections and plotting ...')
+ print(" Processing detections and plotting ...")
detections = []
- for bb in results['pred_dict']['annotation']:
- if (bb['start_time'] >= start_time) and (bb['end_time'] < start_time+duration):
+ for bb in results["pred_dict"]["annotation"]:
+ if (bb["start_time"] >= start_time) and (
+ bb["end_time"] < start_time + duration
+ ):
detections.append(bb)
# plot boxes
- fig = plt.figure(1, figsize=(spec.shape[1]/dpi, spec.shape[0]/dpi), dpi=dpi)
- duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
- viz.create_box_image(spec, fig, detections, start_time, start_time+duration, duration, params, max_val,
- plot_class_names=not args_cmd['plot_class_names_off'])
- op_im_file_boxes = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '_boxes.png')
+ fig = plt.figure(
+ 1, figsize=(spec.shape[1] / dpi, spec.shape[0] / dpi), dpi=dpi
+ )
+ duration = au.x_coords_to_time(
+ spec.shape[1],
+ sampling_rate,
+ params["fft_win_length"],
+ params["fft_overlap"],
+ )
+ viz.create_box_image(
+ spec,
+ fig,
+ detections,
+ start_time,
+ start_time + duration,
+ duration,
+ params,
+ max_val,
+ plot_class_names=not args_cmd["plot_class_names_off"],
+ )
+ op_im_file_boxes = os.path.join(
+ op_dir, os.path.basename(audio_file)[:-4] + op_str + "_boxes.png"
+ )
fig.savefig(op_im_file_boxes, dpi=dpi)
plt.close(1)
spec_with_boxes = plt.imread(op_im_file_boxes)
-
- print(' Saving audio file ...')
- if args['time_expansion_factor']==1:
- sampling_rate_op = int(sampling_rate/10.0)
+ print(" Saving audio file ...")
+ if args["time_expansion_factor"] == 1:
+ sampling_rate_op = int(sampling_rate / 10.0)
else:
sampling_rate_op = sampling_rate
- op_audio_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.wav')
+ op_audio_file = os.path.join(
+ op_dir, os.path.basename(audio_file)[:-4] + op_str + ".wav"
+ )
wavfile.write(op_audio_file, sampling_rate_op, audio_orig)
-
- print(' Saving image ...')
- op_im_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.png')
- plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap='plasma')
+ print(" Saving image ...")
+ op_im_file = os.path.join(
+ op_dir, os.path.basename(audio_file)[:-4] + op_str + ".png"
+ )
+ plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap="plasma")
spec_blank = plt.imread(op_im_file)
# create figure
freq_scale = 1000 # turn Hz to kHz
- min_freq = params['min_freq']//freq_scale
- max_freq = params['max_freq']//freq_scale
+ min_freq = params["min_freq"] // freq_scale
+ max_freq = params["max_freq"] // freq_scale
y_extent = [0, duration, min_freq, max_freq]
- print(' Saving video frames ...')
+ print(" Saving video frames ...")
# save images that will be combined into video
# will either plot with or without boxes
- for ii, col in enumerate(np.linspace(0, spec.shape[1]-1, int(fps*duration*10))):
- if not args_cmd['no_detector']:
+ for ii, col in enumerate(
+ np.linspace(0, spec.shape[1] - 1, int(fps * duration * 10))
+ ):
+ if not args_cmd["no_detector"]:
spec_op = spec_with_boxes.copy()
if ii > 0:
spec_op[:, int(col), :] = 1.0
if reveal_boxes:
- spec_op[:, int(col)+1:, :] = spec_blank[:, int(col)+1:, :]
+ spec_op[:, int(col) + 1 :, :] = spec_blank[
+ :, int(col) + 1 :, :
+ ]
elif ii == 0 and reveal_boxes:
spec_op = spec_blank
- if not args_cmd['disable_axis']:
- plt.close('all')
- fig = plt.figure(ii, figsize=(1.2*(spec_op.shape[1]/dpi), 1.5*(spec_op.shape[0]/dpi)), dpi=dpi)
- plt.xlabel('Time - seconds')
- plt.ylabel('Frequency - kHz')
- plt.imshow(spec_op, vmin=0, vmax=1.0, cmap='plasma', extent=y_extent, aspect='auto')
+ if not args_cmd["disable_axis"]:
+ plt.close("all")
+ fig = plt.figure(
+ ii,
+ figsize=(
+ 1.2 * (spec_op.shape[1] / dpi),
+ 1.5 * (spec_op.shape[0] / dpi),
+ ),
+ dpi=dpi,
+ )
+ plt.xlabel("Time - seconds")
+ plt.ylabel("Frequency - kHz")
+ plt.imshow(
+ spec_op,
+ vmin=0,
+ vmax=1.0,
+ cmap="plasma",
+ extent=y_extent,
+ aspect="auto",
+ )
plt.tight_layout()
- fig.savefig(op_dir_tmp + str(ii).zfill(4) + '.png', dpi=dpi)
+ fig.savefig(op_dir_tmp + str(ii).zfill(4) + ".png", dpi=dpi)
else:
- plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=1.0, cmap='plasma')
+ plt.imsave(
+ op_dir_tmp + str(ii).zfill(4) + ".png",
+ spec_op,
+ vmin=0,
+ vmax=1.0,
+ cmap="plasma",
+ )
else:
spec_op = spec.copy()
if ii > 0:
spec_op[:, int(col)] = max_val
- plt.imsave(op_dir_tmp + str(ii).zfill(4) + '.png', spec_op, vmin=0, vmax=max_val, cmap='plasma')
+ plt.imsave(
+ op_dir_tmp + str(ii).zfill(4) + ".png",
+ spec_op,
+ vmin=0,
+ vmax=max_val,
+ cmap="plasma",
+ )
-
- print(' Creating video ...')
- op_vid_file = os.path.join(op_dir, os.path.basename(audio_file)[:-4] + op_str + '.avi')
- ffmpeg_cmd = 'ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 ' \
- '-crf 25 -pix_fmt yuv420p -acodec copy {}'.format(fps, spec.shape[1], spec.shape[0], op_dir_tmp, op_audio_file, op_vid_file)
+ print(" Creating video ...")
+ op_vid_file = os.path.join(
+ op_dir, os.path.basename(audio_file)[:-4] + op_str + ".avi"
+ )
+ ffmpeg_cmd = (
+ "ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 "
+ "-crf 25 -pix_fmt yuv420p -acodec copy {}".format(
+ fps,
+ spec.shape[1],
+ spec.shape[0],
+ op_dir_tmp,
+ op_audio_file,
+ op_vid_file,
+ )
+ )
ffmpeg_cmd = ffmpeg_path + ffmpeg_cmd
os.system(ffmpeg_cmd)
- print(' Deleting temporary files ...')
+ print(" Deleting temporary files ...")
if os.path.isdir(op_dir_tmp):
- shutil.rmtree(op_dir_tmp)
+ shutil.rmtree(op_dir_tmp)
diff --git a/scripts/viz_helpers.py b/scripts/viz_helpers.py
index 2f55836..667bb9c 100644
--- a/scripts/viz_helpers.py
+++ b/scripts/viz_helpers.py
@@ -1,41 +1,70 @@
-import numpy as np
-import matplotlib.pyplot as plt
-from scipy import ndimage
import os
import sys
-sys.path.append(os.path.join('..'))
+
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy import ndimage
+
+sys.path.append(os.path.join(".."))
import bat_detect.utils.audio_utils as au
-def generate_spectrogram_data(audio, sampling_rate, params, norm_type='log', smooth_spec=False):
- max_freq = round(params['max_freq']*params['fft_win_length'])
- min_freq = round(params['min_freq']*params['fft_win_length'])
+def generate_spectrogram_data(
+ audio, sampling_rate, params, norm_type="log", smooth_spec=False
+):
+ max_freq = round(params["max_freq"] * params["fft_win_length"])
+ min_freq = round(params["min_freq"] * params["fft_win_length"])
# create spectrogram - numpy
- spec = au.gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap'])
- #spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy()
+ spec = au.gen_mag_spectrogram(
+ audio, sampling_rate, params["fft_win_length"], params["fft_overlap"]
+ )
+ # spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy()
if spec.shape[0] < max_freq:
freq_pad = max_freq - spec.shape[0]
- spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec))
- spec = spec[-max_freq:spec.shape[0]-min_freq, :]
+ spec = np.vstack(
+ (np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec)
+ )
+ spec = spec[-max_freq : spec.shape[0] - min_freq, :]
- if norm_type == 'log':
- log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum())
+ if norm_type == "log":
+ log_scaling = (
+ 2.0
+ * (1.0 / sampling_rate)
+ * (
+ 1.0
+ / (
+ np.abs(
+ np.hanning(
+ int(params["fft_win_length"] * sampling_rate)
+ )
+ )
+ ** 2
+ ).sum()
+ )
+ )
##log_scaling = 0.01
- spec = np.log(1.0 + log_scaling*spec).astype(np.float32)
- elif norm_type == 'pcen':
+ spec = np.log(1.0 + log_scaling * spec).astype(np.float32)
+ elif norm_type == "pcen":
spec = au.pcen(spec, sampling_rate)
else:
pass
if smooth_spec:
- spec = ndimage.gaussian_filter(spec, 1)
+ spec = ndimage.gaussian_filter(spec, 1)
return spec
-def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', extract_bg=False):
+def load_data(
+ anns,
+ params,
+ class_names,
+ smooth_spec=False,
+ norm_type="log",
+ extract_bg=False,
+):
specs = []
labels = []
coords = []
@@ -43,67 +72,106 @@ def load_data(anns, params, class_names, smooth_spec=False, norm_type='log', ext
sampling_rates = []
file_names = []
for cur_file in anns:
- sampling_rate, audio_orig = au.load_audio_file(cur_file['file_path'], cur_file['time_exp'],
- params['target_samp_rate'], params['scale_raw_audio'])
+ sampling_rate, audio_orig = au.load_audio_file(
+ cur_file["file_path"],
+ cur_file["time_exp"],
+ params["target_samp_rate"],
+ params["scale_raw_audio"],
+ )
- for ann in cur_file['annotation']:
- if ann['class'] not in params['classes_to_ignore'] and ann['class'] in class_names:
+ for ann in cur_file["annotation"]:
+ if (
+ ann["class"] not in params["classes_to_ignore"]
+ and ann["class"] in class_names
+ ):
# clip out of bounds
- if ann['low_freq'] < params['min_freq']:
- ann['low_freq'] = params['min_freq']
- if ann['high_freq'] > params['max_freq']:
- ann['high_freq'] = params['max_freq']
+ if ann["low_freq"] < params["min_freq"]:
+ ann["low_freq"] = params["min_freq"]
+ if ann["high_freq"] > params["max_freq"]:
+ ann["high_freq"] = params["max_freq"]
# load cropped audio
- start_samp_diff = int(sampling_rate*ann['start_time']) - int(sampling_rate*params['aud_pad'])
+ start_samp_diff = int(sampling_rate * ann["start_time"]) - int(
+ sampling_rate * params["aud_pad"]
+ )
start_samp = np.maximum(0, start_samp_diff)
- end_samp = np.minimum(audio_orig.shape[0], int(sampling_rate*ann['end_time'])*2 + int(sampling_rate*params['aud_pad']))
+ end_samp = np.minimum(
+ audio_orig.shape[0],
+ int(sampling_rate * ann["end_time"]) * 2
+ + int(sampling_rate * params["aud_pad"]),
+ )
audio = audio_orig[start_samp:end_samp]
if start_samp_diff < 0:
# need to pad at start if the call is at the very begining
- audio = np.hstack((np.zeros(-start_samp_diff, dtype=np.float32), audio))
+ audio = np.hstack(
+ (np.zeros(-start_samp_diff, dtype=np.float32), audio)
+ )
- nfft = int(params['fft_win_length']*sampling_rate)
- noverlap = int(params['fft_overlap']*nfft)
- max_samps = params['spec_width']*(nfft - noverlap) + noverlap
+ nfft = int(params["fft_win_length"] * sampling_rate)
+ noverlap = int(params["fft_overlap"] * nfft)
+ max_samps = params["spec_width"] * (nfft - noverlap) + noverlap
if max_samps > audio.shape[0]:
- audio = np.hstack((audio, np.zeros(max_samps - audio.shape[0])))
+ audio = np.hstack(
+ (audio, np.zeros(max_samps - audio.shape[0]))
+ )
audio = audio[:max_samps].astype(np.float32)
- audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
- params['fft_overlap'], params['resize_factor'],
- params['spec_divide_factor'])
+ audio = au.pad_audio(
+ audio,
+ sampling_rate,
+ params["fft_win_length"],
+ params["fft_overlap"],
+ params["resize_factor"],
+ params["spec_divide_factor"],
+ )
# generate spectrogram
- spec = generate_spectrogram_data(audio, sampling_rate, params, norm_type, smooth_spec)[:, :params['spec_width']]
+ spec = generate_spectrogram_data(
+ audio, sampling_rate, params, norm_type, smooth_spec
+ )[:, : params["spec_width"]]
specs.append(spec[np.newaxis, ...])
- labels.append(ann['class'])
+ labels.append(ann["class"])
audios.append(audio)
sampling_rates.append(sampling_rate)
- file_names.append(cur_file['file_path'])
+ file_names.append(cur_file["file_path"])
# position in crop
- x1 = int(au.time_to_x_coords(np.array(params['aud_pad']), sampling_rate, params['fft_win_length'], params['fft_overlap']))
- y1 = (ann['low_freq'] - params['min_freq']) * params['fft_win_length']
+ x1 = int(
+ au.time_to_x_coords(
+ np.array(params["aud_pad"]),
+ sampling_rate,
+ params["fft_win_length"],
+ params["fft_overlap"],
+ )
+ )
+ y1 = (ann["low_freq"] - params["min_freq"]) * params[
+ "fft_win_length"
+ ]
coords.append((y1, x1))
-
_, file_ids = np.unique(file_names, return_inverse=True)
labels = np.array([class_names.index(ll) for ll in labels])
- #return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names
+ # return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names
return np.vstack(specs), labels
-def save_summary_image(specs, labels, species_names, params, op_file_name='plots/all_species.png', order=None):
+def save_summary_image(
+ specs,
+ labels,
+ species_names,
+ params,
+ op_file_name="plots/all_species.png",
+ order=None,
+):
# takes the mean for each class and plots it on a grid
mean_specs = []
max_band = []
for ii in range(len(species_names)):
- inds = np.where(labels==ii)[0]
+ inds = np.where(labels == ii)[0]
mu = specs[inds, :].mean(0)
max_band.append(np.argmax(mu.sum(1)))
mean_specs.append(mu)
@@ -113,11 +181,21 @@ def save_summary_image(specs, labels, species_names, params, op_file_name='plots
order = np.arange(len(species_names))
max_cols = 6
- nrows = int(np.ceil(len(species_names)/max_cols))
+ nrows = int(np.ceil(len(species_names) / max_cols))
ncols = np.minimum(len(species_names), max_cols)
- fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3.3, nrows*6), gridspec_kw = {'wspace':0, 'hspace':0.2})
- spec_min_max = (0, mean_specs[0].shape[1], params['min_freq']/1000, params['max_freq']/1000)
+ fig, ax = plt.subplots(
+ nrows=nrows,
+ ncols=ncols,
+ figsize=(ncols * 3.3, nrows * 6),
+ gridspec_kw={"wspace": 0, "hspace": 0.2},
+ )
+ spec_min_max = (
+ 0,
+ mean_specs[0].shape[1],
+ params["min_freq"] / 1000,
+ params["max_freq"] / 1000,
+ )
ii = 0
for row in ax:
@@ -126,17 +204,24 @@ def save_summary_image(specs, labels, species_names, params, op_file_name='plots
for col in row:
if ii >= len(species_names):
- col.axis('off')
+ col.axis("off")
else:
- inds = np.where(labels==order[ii])[0]
- col.imshow(mean_specs[order[ii]], extent=spec_min_max, cmap='plasma', aspect='equal')
- col.grid(color='w', alpha=0.3, linewidth=0.3)
+ inds = np.where(labels == order[ii])[0]
+ col.imshow(
+ mean_specs[order[ii]],
+ extent=spec_min_max,
+ cmap="plasma",
+ aspect="equal",
+ )
+ col.grid(color="w", alpha=0.3, linewidth=0.3)
col.set_xticks([])
- col.title.set_text(str(ii+1) + ' ' + species_names[order[ii]])
- col.tick_params(axis='both', which='major', labelsize=7)
+ col.title.set_text(
+ str(ii + 1) + " " + species_names[order[ii]]
+ )
+ col.tick_params(axis="both", which="major", labelsize=7)
ii += 1
- #plt.tight_layout()
- #plt.show()
+ # plt.tight_layout()
+ # plt.show()
plt.savefig(op_file_name)
- plt.close('all')
+ plt.close("all")