import json import matplotlib.pyplot as plt import numpy as np from matplotlib import patches from matplotlib.collections import PatchCollection from sklearn.metrics import confusion_matrix from . import audio_utils as au def create_box_image( spec, fig, detections_ip, start_time, end_time, duration, params, max_val, hide_axis=True, plot_class_names=False, ): # filter detections stop_time = start_time + duration detections = [] for bb in detections_ip: if (bb["start_time"] >= start_time) and ( bb["start_time"] < stop_time - 0.02 ): # (bb['end_time'] < end_time): detections.append(bb) # create figure freq_scale = 1000 # turn Hz to kHz min_freq = params["min_freq"] // freq_scale max_freq = params["max_freq"] // freq_scale y_extent = [0, duration, min_freq, max_freq] if hide_axis: ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() fig.add_axes(ax) else: ax = plt.gca() plt.imshow( spec, aspect="auto", cmap="plasma", extent=y_extent, vmin=0, vmax=max_val, ) boxes = plot_bounding_box_patch_ann(detections, freq_scale, start_time) ax.add_collection(PatchCollection(boxes, match_original=True)) plt.grid(False) if plot_class_names: for ii, bb in enumerate(boxes): txt = " ".join([sp[:3] for sp in detections_ip[ii]["class"].split(" ")]) font_info = { "color": "white", "size": 10, "weight": "bold", "alpha": bb.get_alpha(), } y_pos = bb.get_xy()[1] + bb.get_height() if y_pos > (max_freq - 10): y_pos = max_freq - 10 plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) def save_ann_spec( op_path, spec, min_freq, max_freq, duration, start_time, title_text="", anns=None, ): # create figure and plot boxes freq_scale = 1000 # turn Hz to kHz min_freq = min_freq // freq_scale max_freq = max_freq // freq_scale y_extent = [0, duration, min_freq, max_freq] plt.close("all") fig = plt.figure(0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100) plt.imshow( spec, aspect="auto", cmap="plasma", extent=y_extent, vmin=0, vmax=spec.max() * 1.1, ) plt.ylabel("Freq - kHz") plt.xlabel("Time - secs") if title_text != "": plt.title(title_text) plt.tight_layout() if anns is not None: # drawing bounding boxes and class names boxes = plot_bounding_box_patch_ann(anns, freq_scale, start_time) plt.gca().add_collection(PatchCollection(boxes, match_original=True)) for ii, bb in enumerate(boxes): txt = " ".join([sp[:3] for sp in anns[ii]["class"].split(" ")]) font_info = { "color": "white", "size": 10, "weight": "bold", "alpha": bb.get_alpha(), } y_pos = bb.get_xy()[1] + bb.get_height() if y_pos > (max_freq - 10): y_pos = max_freq - 10 plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) print("Saving figure to:", op_path) plt.savefig(op_path) def plot_pts(fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False): plt.figure(fig_id) un_class, labels = np.unique(class_names, return_inverse=True) un_labels = np.unique(labels) if un_labels.shape[0] > len(colors): colors = [plt.cm.jet(float(ii) / un_labels.shape[0]) for ii in un_labels] for ii, u in enumerate(un_labels): inds = np.where(labels == u)[0] plt.scatter( feats[inds, 0], feats[inds, 1], c=colors[ii], label=str(un_class[ii]), s=marker_size, ) if plot_legend: plt.legend() plt.xticks([]) plt.yticks([]) plt.title("downsampled features") def plot_bounding_box_patch(pred, freq_scale, ecolor="w"): patch_collect = [] for bb in range(len(pred["start_times"])): xx = pred["start_times"][bb] ww = pred["end_times"][bb] - pred["start_times"][bb] yy = pred["low_freqs"][bb] / freq_scale hh = (pred["high_freqs"][bb] - pred["low_freqs"][bb]) / freq_scale if "det_probs" in pred.keys(): alpha_val = pred["det_probs"][bb] else: alpha_val = 1.0 patch_collect.append( patches.Rectangle( (xx, yy), ww, hh, linewidth=1, edgecolor=ecolor, facecolor="none", alpha=alpha_val, ) ) return patch_collect def plot_bounding_box_patch_ann(anns, freq_scale, start_time): patch_collect = [] for aa in range(len(anns)): xx = anns[aa]["start_time"] - start_time ww = anns[aa]["end_time"] - anns[aa]["start_time"] yy = anns[aa]["low_freq"] / freq_scale hh = (anns[aa]["high_freq"] - anns[aa]["low_freq"]) / freq_scale if "det_prob" in anns[aa]: alpha = anns[aa]["det_prob"] else: alpha = 1.0 patch_collect.append( patches.Rectangle( (xx, yy), ww, hh, linewidth=1, edgecolor="w", facecolor="none", alpha=alpha, ) ) return patch_collect def plot_spec( spec, sampling_rate, duration, gt, pred, params, plot_title, op_file_name, pred_2d_hm, plot_boxes=True, fixed_aspect=True, ): if fixed_aspect: # ouptut image will be this width irrespective of the duration of the audio file width = 12 else: width = 12 * duration fig = plt.figure(1, figsize=(width, 8)) ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h ax1 = plt.axes([0.05, 0.33, 0.9, 0.30]) ax2 = plt.axes([0.05, 0.01, 0.9, 0.30]) freq_scale = 1000 # turn Hz in kHz # duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) y_extent = [ 0, duration, params["min_freq"] // freq_scale, params["max_freq"] // freq_scale, ] # plot gt boxes ax0.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent) ax0.xaxis.set_ticklabels([]) font_info = {"color": "white", "size": 12, "weight": "bold"} ax0.text(0, params["min_freq"] // freq_scale, "Ground Truth", fontdict=font_info) plt.grid(False) if plot_boxes: boxes = plot_bounding_box_patch(gt, freq_scale) ax0.add_collection(PatchCollection(boxes, match_original=True)) for ii, bb in enumerate(boxes): class_id = int(gt["class_ids"][ii]) if class_id < 0: txt = params["generic_class"][0] else: txt = params["class_names_short"][class_id] font_info = { "color": "white", "size": 10, "weight": "bold", "alpha": bb.get_alpha(), } y_pos = bb.get_xy()[1] + bb.get_height() ax0.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) # plot predicted boxes ax1.imshow(spec, aspect="auto", cmap="plasma", extent=y_extent) ax1.xaxis.set_ticklabels([]) font_info = {"color": "white", "size": 12, "weight": "bold"} ax1.text(0, params["min_freq"] // freq_scale, "Prediction", fontdict=font_info) plt.grid(False) if plot_boxes: boxes = plot_bounding_box_patch(pred, freq_scale) ax1.add_collection(PatchCollection(boxes, match_original=True)) for ii, bb in enumerate(boxes): if pred["class_probs"].shape[0] > len(params["class_names_short"]): class_id = pred["class_probs"][:-1, ii].argmax() else: class_id = pred["class_probs"][:, ii].argmax() txt = params["class_names_short"][class_id] font_info = { "color": "white", "size": 10, "weight": "bold", "alpha": bb.get_alpha(), } y_pos = bb.get_xy()[1] + bb.get_height() ax1.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info) # plot 2D heatmap if pred_2d_hm is not None: min_val = 0.0 if pred_2d_hm.min() > 0.0 else pred_2d_hm.min() max_val = 1.0 if pred_2d_hm.max() < 1.0 else pred_2d_hm.max() ax2.imshow( pred_2d_hm, aspect="auto", cmap="plasma", extent=y_extent, clim=[min_val, max_val], ) # ax2.xaxis.set_ticklabels([]) font_info = {"color": "white", "size": 12, "weight": "bold"} ax2.text(0, params["min_freq"] // freq_scale, "Heatmap", fontdict=font_info) plt.grid(False) plt.suptitle(plot_title) if op_file_name is not None: fig.savefig(op_file_name) plt.close(1) def plot_pr_curve( op_dir, plt_title, file_name, results, file_type="png", title_text="" ): precision = results["precision"] recall = results["recall"] avg_prec = results["avg_prec"] plt.figure(0, figsize=(10, 8)) plt.plot(recall, precision) plt.ylabel("Precision", fontsize=20) plt.xlabel("Recall", fontsize=20) if title_text != "": plt.title(title_text, fontdict={"fontsize": 28}) else: plt.title(plt_title + " {:.3f}\n".format(avg_prec)) plt.xlim(0, 1.02) plt.ylim(0, 1.02) plt.grid(True) plt.tight_layout() plt.savefig(op_dir + file_name + "." + file_type) plt.close(0) def plot_pr_curve_class( op_dir, plt_title, file_name, results, file_type="png", title_text="" ): plt.figure(0, figsize=(10, 8)) plt.ylabel("Precision", fontsize=20) plt.xlabel("Recall", fontsize=20) plt.xlim(0, 1.02) plt.ylim(0, 1.02) plt.grid(True) linestyles = ["-", ":", "--"] markers = ["o", "v", ">", "^", "<", "s", "P", "X", "*"] colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] # plot the PR curves for ii, rr in enumerate(results["class_pr"]): class_name = " ".join([sp[:3] for sp in rr["name"].split(" ")]) cur_color = colors[int(ii % 10)] plt.plot( rr["recall"], rr["precision"], label=class_name, color=cur_color, linestyle=linestyles[int(ii // 10)], lw=2.5, ) # print(class_name) # plot the location of the confidence threshold values for jj, tt in enumerate(rr["thresholds"]): ind = rr["thresholds_inds"][jj] if ind > -1: plt.plot( rr["recall"][ind], rr["precision"][ind], markers[jj], color=cur_color, ms=10, ) # print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3)) if title_text != "": plt.title(title_text, fontdict={"fontsize": 28}) else: plt.title(plt_title + " {:.3f}\n".format(results["avg_prec_class"])) plt.legend(loc="lower left", prop={"size": 14}) plt.tight_layout() plt.savefig(op_dir + file_name + "." + file_type) plt.close(0) def plot_confusion_matrix( op_dir, op_file, gt, pred, file_acc, class_names_long, verbose=False, file_type="png", title_text="", ): # shorten the class names for plotting class_names = [] for cc in class_names_long: class_name_sm = "".join([cc_sm[:3] + " " for cc_sm in cc.split(" ")])[:-1] class_names.append(class_name_sm) num_classes = len(class_names) cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(np.float32) cm_norm = cm.sum(1) valid_inds = np.where(cm_norm > 0)[0] cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] cm[np.where(cm_norm == -0)[0], :] = np.nan if verbose: print("Per class accuracy:") str_len = np.max([len(cc) for cc in class_names_long]) + 5 accs = np.diag(cm) for ii, cc in enumerate(class_names_long): if np.isnan(accs[ii]): print(str(ii).ljust(5) + cc.ljust(str_len)) else: print( str(ii).ljust(5) + cc.ljust(str_len) + "{:.2f}".format(accs[ii] * 100) ) plt.figure(0, figsize=(10, 8)) plt.imshow(cm, vmin=0, vmax=1, cmap="plasma") plt.colorbar() plt.xticks(np.arange(cm.shape[1]), class_names, rotation="vertical") plt.yticks(np.arange(cm.shape[0]), class_names) plt.xlabel("Predicted", fontsize=20) plt.ylabel("Ground Truth", fontsize=20) if title_text != "": plt.title(title_text, fontdict={"fontsize": 28}) else: plt.title(op_file + " {:.3f}\n".format(file_acc)) plt.tight_layout() plt.savefig(op_dir + op_file + "." + file_type) plt.close("all") class LossPlotter(object): def __init__( self, op_file_name, duration, labels, ylim, class_names, axis_labels=None, logy=False, ): self.reset() self.op_file_name = op_file_name self.duration = duration # length of x axis self.labels = labels self.ylim = ylim self.class_names = class_names self.axis_labels = axis_labels self.logy = logy def reset(self): self.epochs = [] self.vals = [] def update_and_save(self, epoch, val, gt=None, pred=None): self.epochs.append(epoch) self.vals.append(val) self.save_plot() self.save_json() if gt is not None: self.save_confusion_matrix(gt, pred) def save_plot(self): linestyles = ["-", ":", "--"] plt.figure(0, figsize=(8, 5)) for ii in range(len(self.vals[0])): l_vals = [vv[ii] for vv in self.vals] plt.plot( self.epochs, l_vals, label=self.labels[ii], linestyle=linestyles[int(ii // 10)], ) plt.xlim(0, np.maximum(self.duration, len(self.vals))) if self.ylim is not None: plt.ylim(self.ylim[0], self.ylim[1]) if self.axis_labels is not None: plt.xlabel(self.axis_labels[0]) plt.ylabel(self.axis_labels[1]) if self.logy: plt.gca().set_yscale("log") plt.grid(True) plt.legend(bbox_to_anchor=(1.01, 1), loc="upper left", borderaxespad=0.0) plt.tight_layout() plt.savefig(self.op_file_name) plt.close(0) def save_json(self): data = {} data["epochs"] = self.epochs for ii in range(len(self.vals[0])): data[self.labels[ii]] = [round(vv[ii], 4) for vv in self.vals] with open(self.op_file_name[:-4] + ".json", "w") as da: json.dump(data, da, indent=2) def save_confusion_matrix(self, gt, pred): plt.figure(0) cm = confusion_matrix(gt, pred, np.arange(len(self.class_names))).astype( np.float32 ) cm_norm = cm.sum(1) valid_inds = np.where(cm_norm > 0)[0] cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis] plt.imshow(cm, vmin=0, vmax=1, cmap="plasma") plt.colorbar() plt.xticks(np.arange(cm.shape[1]), self.class_names, rotation="vertical") plt.yticks(np.arange(cm.shape[0]), self.class_names) plt.xlabel("Predicted") plt.ylabel("Ground Truth") plt.tight_layout() plt.savefig(self.op_file_name[:-4] + "_cm.png") plt.close(0)