mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
520 lines
16 KiB
Python
520 lines
16 KiB
Python
import json
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from matplotlib import patches
|
|
from matplotlib.collections import PatchCollection
|
|
from sklearn.metrics import confusion_matrix
|
|
|
|
from . import audio_utils as au
|
|
|
|
|
|
def create_box_image(
|
|
spec,
|
|
fig,
|
|
detections_ip,
|
|
start_time,
|
|
end_time,
|
|
duration,
|
|
params,
|
|
max_val,
|
|
hide_axis=True,
|
|
plot_class_names=False,
|
|
):
|
|
# filter detections
|
|
stop_time = start_time + duration
|
|
detections = []
|
|
for bb in detections_ip:
|
|
if (bb["start_time"] >= start_time) and (
|
|
bb["start_time"] < stop_time - 0.02
|
|
): # (bb['end_time'] < end_time):
|
|
detections.append(bb)
|
|
|
|
# create figure
|
|
freq_scale = 1000 # turn Hz to kHz
|
|
min_freq = params["min_freq"] // freq_scale
|
|
max_freq = params["max_freq"] // freq_scale
|
|
y_extent = [0, duration, min_freq, max_freq]
|
|
|
|
if hide_axis:
|
|
ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
|
|
ax.set_axis_off()
|
|
fig.add_axes(ax)
|
|
else:
|
|
ax = plt.gca()
|
|
|
|
plt.imshow(
|
|
spec,
|
|
aspect="auto",
|
|
cmap="plasma",
|
|
extent=y_extent,
|
|
vmin=0,
|
|
vmax=max_val,
|
|
)
|
|
boxes = plot_bounding_box_patch_ann(detections, freq_scale, start_time)
|
|
ax.add_collection(PatchCollection(boxes, match_original=True))
|
|
plt.grid(False)
|
|
|
|
if plot_class_names:
|
|
for ii, bb in enumerate(boxes):
|
|
txt = " ".join([sp[:3] for sp in detections_ip[ii]["class"].split(" ")])
|
|
font_info = {
|
|
"color": "white",
|
|
"size": 10,
|
|
"weight": "bold",
|
|
"alpha": bb.get_alpha(),
|
|
}
|
|
y_pos = bb.get_xy()[1] + bb.get_height()
|
|
if y_pos > (max_freq - 10):
|
|
y_pos = max_freq - 10
|
|
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
|
|
|
|
|
def save_ann_spec(
|
|
op_path,
|
|
spec,
|
|
min_freq,
|
|
max_freq,
|
|
duration,
|
|
start_time,
|
|
title_text="",
|
|
anns=None,
|
|
):
|
|
# create figure and plot boxes
|
|
freq_scale = 1000 # turn Hz to kHz
|
|
min_freq = min_freq // freq_scale
|
|
max_freq = max_freq // freq_scale
|
|
y_extent = [0, duration, min_freq, max_freq]
|
|
|
|
plt.close("all")
|
|
fig = plt.figure(0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100)
|
|
plt.imshow(
|
|
spec,
|
|
aspect="auto",
|
|
cmap="plasma",
|
|
extent=y_extent,
|
|
vmin=0,
|
|
vmax=spec.max() * 1.1,
|
|
)
|
|
|
|
plt.ylabel("Freq - kHz")
|
|
plt.xlabel("Time - secs")
|
|
if title_text != "":
|
|
plt.title(title_text)
|
|
plt.tight_layout()
|
|
|
|
if anns is not None:
|
|
# drawing bounding boxes and class names
|
|
boxes = plot_bounding_box_patch_ann(anns, freq_scale, start_time)
|
|
plt.gca().add_collection(PatchCollection(boxes, match_original=True))
|
|
for ii, bb in enumerate(boxes):
|
|
txt = " ".join([sp[:3] for sp in anns[ii]["class"].split(" ")])
|
|
font_info = {
|
|
"color": "white",
|
|
"size": 10,
|
|
"weight": "bold",
|
|
"alpha": bb.get_alpha(),
|
|
}
|
|
y_pos = bb.get_xy()[1] + bb.get_height()
|
|
if y_pos > (max_freq - 10):
|
|
y_pos = max_freq - 10
|
|
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
|
|
|
print("Saving figure to:", op_path)
|
|
plt.savefig(op_path)
|
|
|
|
|
|
def plot_pts(fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False):
|
|
plt.figure(fig_id)
|
|
un_class, labels = np.unique(class_names, return_inverse=True)
|
|
un_labels = np.unique(labels)
|
|
if un_labels.shape[0] > len(colors):
|
|
colors = [plt.cm.jet(float(ii) / un_labels.shape[0]) for ii in un_labels]
|
|
|
|
for ii, u in enumerate(un_labels):
|
|
inds = np.where(labels == u)[0]
|
|
plt.scatter(
|
|
feats[inds, 0],
|
|
feats[inds, 1],
|
|
c=colors[ii],
|
|
label=str(un_class[ii]),
|
|
s=marker_size,
|
|
)
|
|
if plot_legend:
|
|
plt.legend()
|
|
plt.xticks([])
|
|
plt.yticks([])
|
|
plt.title("downsampled features")
|
|
|
|
|
|
def plot_bounding_box_patch(pred, freq_scale, ecolor="w"):
|
|
patch_collect = []
|
|
for bb in range(len(pred["start_times"])):
|
|
xx = pred["start_times"][bb]
|
|
ww = pred["end_times"][bb] - pred["start_times"][bb]
|
|
yy = pred["low_freqs"][bb] / freq_scale
|
|
hh = (pred["high_freqs"][bb] - pred["low_freqs"][bb]) / freq_scale
|
|
|
|
if "det_probs" in pred.keys():
|
|
alpha_val = pred["det_probs"][bb]
|
|
else:
|
|
alpha_val = 1.0
|
|
patch_collect.append(
|
|
patches.Rectangle(
|
|
(xx, yy),
|
|
ww,
|
|
hh,
|
|
linewidth=1,
|
|
edgecolor=ecolor,
|
|
facecolor="none",
|
|
alpha=alpha_val,
|
|
)
|
|
)
|
|
return patch_collect
|
|
|
|
|
|
def plot_bounding_box_patch_ann(anns, freq_scale, start_time):
|
|
patch_collect = []
|
|
for aa in range(len(anns)):
|
|
xx = anns[aa]["start_time"] - start_time
|
|
ww = anns[aa]["end_time"] - anns[aa]["start_time"]
|
|
yy = anns[aa]["low_freq"] / freq_scale
|
|
hh = (anns[aa]["high_freq"] - anns[aa]["low_freq"]) / freq_scale
|
|
if "det_prob" in anns[aa]:
|
|
alpha = anns[aa]["det_prob"]
|
|
else:
|
|
alpha = 1.0
|
|
patch_collect.append(
|
|
patches.Rectangle(
|
|
(xx, yy),
|
|
ww,
|
|
hh,
|
|
linewidth=1,
|
|
edgecolor="w",
|
|
facecolor="none",
|
|
alpha=alpha,
|
|
)
|
|
)
|
|
return patch_collect
|
|
|
|
|
|
def plot_spec(
|
|
spec,
|
|
sampling_rate,
|
|
duration,
|
|
gt,
|
|
pred,
|
|
params,
|
|
plot_title,
|
|
op_file_name,
|
|
pred_2d_hm,
|
|
plot_boxes=True,
|
|
fixed_aspect=True,
|
|
):
|
|
|
|
if fixed_aspect:
|
|
# ouptut image will be this width irrespective of the duration of the audio file
|
|
width = 12
|
|
else:
|
|
width = 12 * duration
|
|
|
|
fig = plt.figure(1, figsize=(width, 8))
|
|
ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h
|
|
ax1 = plt.axes([0.05, 0.33, 0.9, 0.30])
|
|
ax2 = plt.axes([0.05, 0.01, 0.9, 0.30])
|
|
|
|
freq_scale = 1000 # turn Hz in kHz
|
|
# duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
|
|
y_extent = [
|
|
0,
|
|
duration,
|
|
params["min_freq"] // freq_scale,
|
|
params["max_freq"] // freq_scale,
|
|
]
|
|
|
|
# plot gt boxes
|
|
ax0.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent)
|
|
ax0.xaxis.set_ticklabels([])
|
|
font_info = {"color": "white", "size": 12, "weight": "bold"}
|
|
ax0.text(0, params["min_freq"] // freq_scale, "Ground Truth", fontdict=font_info)
|
|
|
|
plt.grid(False)
|
|
if plot_boxes:
|
|
boxes = plot_bounding_box_patch(gt, freq_scale)
|
|
ax0.add_collection(PatchCollection(boxes, match_original=True))
|
|
for ii, bb in enumerate(boxes):
|
|
class_id = int(gt["class_ids"][ii])
|
|
if class_id < 0:
|
|
txt = params["generic_class"][0]
|
|
else:
|
|
txt = params["class_names_short"][class_id]
|
|
font_info = {
|
|
"color": "white",
|
|
"size": 10,
|
|
"weight": "bold",
|
|
"alpha": bb.get_alpha(),
|
|
}
|
|
y_pos = bb.get_xy()[1] + bb.get_height()
|
|
ax0.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
|
|
|
# plot predicted boxes
|
|
ax1.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent)
|
|
ax1.xaxis.set_ticklabels([])
|
|
font_info = {"color": "white", "size": 12, "weight": "bold"}
|
|
ax1.text(0, params["min_freq"] // freq_scale, "Prediction", fontdict=font_info)
|
|
|
|
plt.grid(False)
|
|
if plot_boxes:
|
|
boxes = plot_bounding_box_patch(pred, freq_scale)
|
|
ax1.add_collection(PatchCollection(boxes, match_original=True))
|
|
for ii, bb in enumerate(boxes):
|
|
if pred["class_probs"].shape[0] > len(params["class_names_short"]):
|
|
class_id = pred["class_probs"][:-1, ii].argmax()
|
|
else:
|
|
class_id = pred["class_probs"][:, ii].argmax()
|
|
txt = params["class_names_short"][class_id]
|
|
font_info = {
|
|
"color": "white",
|
|
"size": 10,
|
|
"weight": "bold",
|
|
"alpha": bb.get_alpha(),
|
|
}
|
|
y_pos = bb.get_xy()[1] + bb.get_height()
|
|
ax1.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
|
|
|
# plot 2D heatmap
|
|
if pred_2d_hm is not None:
|
|
min_val = 0.0 if pred_2d_hm.min() > 0.0 else pred_2d_hm.min()
|
|
max_val = 1.0 if pred_2d_hm.max() < 1.0 else pred_2d_hm.max()
|
|
|
|
ax2.imshow(
|
|
pred_2d_hm,
|
|
aspect="auto",
|
|
cmap="plasma",
|
|
extent=y_extent,
|
|
clim=[min_val, max_val],
|
|
)
|
|
# ax2.xaxis.set_ticklabels([])
|
|
font_info = {"color": "white", "size": 12, "weight": "bold"}
|
|
ax2.text(0, params["min_freq"] // freq_scale, "Heatmap", fontdict=font_info)
|
|
|
|
plt.grid(False)
|
|
|
|
plt.suptitle(plot_title)
|
|
if op_file_name is not None:
|
|
fig.savefig(op_file_name)
|
|
|
|
plt.close(1)
|
|
|
|
|
|
def plot_pr_curve(
|
|
op_dir, plt_title, file_name, results, file_type="png", title_text=""
|
|
):
|
|
precision = results["precision"]
|
|
recall = results["recall"]
|
|
avg_prec = results["avg_prec"]
|
|
|
|
plt.figure(0, figsize=(10, 8))
|
|
plt.plot(recall, precision)
|
|
plt.ylabel("Precision", fontsize=20)
|
|
plt.xlabel("Recall", fontsize=20)
|
|
if title_text != "":
|
|
plt.title(title_text, fontdict={"fontsize": 28})
|
|
else:
|
|
plt.title(plt_title + " {:.3f}\n".format(avg_prec))
|
|
plt.xlim(0, 1.02)
|
|
plt.ylim(0, 1.02)
|
|
plt.grid(True)
|
|
plt.tight_layout()
|
|
plt.savefig(op_dir + file_name + "." + file_type)
|
|
plt.close(0)
|
|
|
|
|
|
def plot_pr_curve_class(
|
|
op_dir, plt_title, file_name, results, file_type="png", title_text=""
|
|
):
|
|
plt.figure(0, figsize=(10, 8))
|
|
plt.ylabel("Precision", fontsize=20)
|
|
plt.xlabel("Recall", fontsize=20)
|
|
plt.xlim(0, 1.02)
|
|
plt.ylim(0, 1.02)
|
|
plt.grid(True)
|
|
linestyles = ["-", ":", "--"]
|
|
markers = ["o", "v", ">", "^", "<", "s", "P", "X", "*"]
|
|
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
|
|
# plot the PR curves
|
|
for ii, rr in enumerate(results["class_pr"]):
|
|
class_name = " ".join([sp[:3] for sp in rr["name"].split(" ")])
|
|
cur_color = colors[int(ii % 10)]
|
|
plt.plot(
|
|
rr["recall"],
|
|
rr["precision"],
|
|
label=class_name,
|
|
color=cur_color,
|
|
linestyle=linestyles[int(ii // 10)],
|
|
lw=2.5,
|
|
)
|
|
|
|
# print(class_name)
|
|
# plot the location of the confidence threshold values
|
|
for jj, tt in enumerate(rr["thresholds"]):
|
|
ind = rr["thresholds_inds"][jj]
|
|
if ind > -1:
|
|
plt.plot(
|
|
rr["recall"][ind],
|
|
rr["precision"][ind],
|
|
markers[jj],
|
|
color=cur_color,
|
|
ms=10,
|
|
)
|
|
# print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3))
|
|
|
|
if title_text != "":
|
|
plt.title(title_text, fontdict={"fontsize": 28})
|
|
else:
|
|
plt.title(plt_title + " {:.3f}\n".format(results["avg_prec_class"]))
|
|
plt.legend(loc="lower left", prop={"size": 14})
|
|
plt.tight_layout()
|
|
plt.savefig(op_dir + file_name + "." + file_type)
|
|
plt.close(0)
|
|
|
|
|
|
def plot_confusion_matrix(
|
|
op_dir,
|
|
op_file,
|
|
gt,
|
|
pred,
|
|
file_acc,
|
|
class_names_long,
|
|
verbose=False,
|
|
file_type="png",
|
|
title_text="",
|
|
):
|
|
# shorten the class names for plotting
|
|
class_names = []
|
|
for cc in class_names_long:
|
|
class_name_sm = "".join([cc_sm[:3] + " " for cc_sm in cc.split(" ")])[:-1]
|
|
class_names.append(class_name_sm)
|
|
|
|
num_classes = len(class_names)
|
|
cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(np.float32)
|
|
cm_norm = cm.sum(1)
|
|
|
|
valid_inds = np.where(cm_norm > 0)[0]
|
|
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
|
cm[np.where(cm_norm == -0)[0], :] = np.nan
|
|
|
|
if verbose:
|
|
print("Per class accuracy:")
|
|
str_len = np.max([len(cc) for cc in class_names_long]) + 5
|
|
accs = np.diag(cm)
|
|
for ii, cc in enumerate(class_names_long):
|
|
if np.isnan(accs[ii]):
|
|
print(str(ii).ljust(5) + cc.ljust(str_len))
|
|
else:
|
|
print(
|
|
str(ii).ljust(5)
|
|
+ cc.ljust(str_len)
|
|
+ "{:.2f}".format(accs[ii] * 100)
|
|
)
|
|
|
|
plt.figure(0, figsize=(10, 8))
|
|
plt.imshow(cm, vmin=0, vmax=1, cmap="plasma")
|
|
plt.colorbar()
|
|
plt.xticks(np.arange(cm.shape[1]), class_names, rotation="vertical")
|
|
plt.yticks(np.arange(cm.shape[0]), class_names)
|
|
plt.xlabel("Predicted", fontsize=20)
|
|
plt.ylabel("Ground Truth", fontsize=20)
|
|
if title_text != "":
|
|
plt.title(title_text, fontdict={"fontsize": 28})
|
|
else:
|
|
plt.title(op_file + " {:.3f}\n".format(file_acc))
|
|
plt.tight_layout()
|
|
plt.savefig(op_dir + op_file + "." + file_type)
|
|
plt.close("all")
|
|
|
|
|
|
class LossPlotter(object):
|
|
def __init__(
|
|
self,
|
|
op_file_name,
|
|
duration,
|
|
labels,
|
|
ylim,
|
|
class_names,
|
|
axis_labels=None,
|
|
logy=False,
|
|
):
|
|
self.reset()
|
|
self.op_file_name = op_file_name
|
|
self.duration = duration # length of x axis
|
|
self.labels = labels
|
|
self.ylim = ylim
|
|
self.class_names = class_names
|
|
self.axis_labels = axis_labels
|
|
self.logy = logy
|
|
|
|
def reset(self):
|
|
self.epochs = []
|
|
self.vals = []
|
|
|
|
def update_and_save(self, epoch, val, gt=None, pred=None):
|
|
self.epochs.append(epoch)
|
|
self.vals.append(val)
|
|
self.save_plot()
|
|
self.save_json()
|
|
if gt is not None:
|
|
self.save_confusion_matrix(gt, pred)
|
|
|
|
def save_plot(self):
|
|
linestyles = ["-", ":", "--"]
|
|
plt.figure(0, figsize=(8, 5))
|
|
for ii in range(len(self.vals[0])):
|
|
l_vals = [vv[ii] for vv in self.vals]
|
|
plt.plot(
|
|
self.epochs,
|
|
l_vals,
|
|
label=self.labels[ii],
|
|
linestyle=linestyles[int(ii // 10)],
|
|
)
|
|
plt.xlim(0, np.maximum(self.duration, len(self.vals)))
|
|
if self.ylim is not None:
|
|
plt.ylim(self.ylim[0], self.ylim[1])
|
|
if self.axis_labels is not None:
|
|
plt.xlabel(self.axis_labels[0])
|
|
plt.ylabel(self.axis_labels[1])
|
|
if self.logy:
|
|
plt.gca().set_yscale("log")
|
|
plt.grid(True)
|
|
plt.legend(bbox_to_anchor=(1.01, 1), loc="upper left", borderaxespad=0.0)
|
|
plt.tight_layout()
|
|
plt.savefig(self.op_file_name)
|
|
plt.close(0)
|
|
|
|
def save_json(self):
|
|
data = {}
|
|
data["epochs"] = self.epochs
|
|
for ii in range(len(self.vals[0])):
|
|
data[self.labels[ii]] = [round(vv[ii], 4) for vv in self.vals]
|
|
with open(self.op_file_name[:-4] + ".json", "w") as da:
|
|
json.dump(data, da, indent=2)
|
|
|
|
def save_confusion_matrix(self, gt, pred):
|
|
plt.figure(0)
|
|
cm = confusion_matrix(gt, pred, np.arange(len(self.class_names))).astype(
|
|
np.float32
|
|
)
|
|
cm_norm = cm.sum(1)
|
|
valid_inds = np.where(cm_norm > 0)[0]
|
|
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
|
plt.imshow(cm, vmin=0, vmax=1, cmap="plasma")
|
|
plt.colorbar()
|
|
plt.xticks(np.arange(cm.shape[1]), self.class_names, rotation="vertical")
|
|
plt.yticks(np.arange(cm.shape[0]), self.class_names)
|
|
plt.xlabel("Predicted")
|
|
plt.ylabel("Ground Truth")
|
|
plt.tight_layout()
|
|
plt.savefig(self.op_file_name[:-4] + "_cm.png")
|
|
plt.close(0)
|