mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
235 lines
7.6 KiB
Python
235 lines
7.6 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from matplotlib import patches
|
|
from matplotlib.axes._axes import _log as matplotlib_axes_logger
|
|
from sklearn.svm import LinearSVC
|
|
|
|
matplotlib_axes_logger.setLevel("ERROR")
|
|
|
|
|
|
colors = [
|
|
"#e6194B",
|
|
"#3cb44b",
|
|
"#ffe119",
|
|
"#4363d8",
|
|
"#f58231",
|
|
"#911eb4",
|
|
"#42d4f4",
|
|
"#f032e6",
|
|
"#bfef45",
|
|
"#fabebe",
|
|
"#469990",
|
|
"#e6beff",
|
|
"#9A6324",
|
|
"#fffac8",
|
|
"#800000",
|
|
"#aaffc3",
|
|
"#808000",
|
|
"#ffd8b1",
|
|
"#000075",
|
|
"#a9a9a9",
|
|
]
|
|
|
|
|
|
class InteractivePlotter:
|
|
def __init__(
|
|
self,
|
|
feats_ds,
|
|
feats,
|
|
spec_slices,
|
|
call_info,
|
|
freq_lims,
|
|
allow_training,
|
|
):
|
|
"""
|
|
Plots 2D low dimensional features on left and corresponding spectgrams on
|
|
the right.
|
|
"""
|
|
self.feats_ds = feats_ds
|
|
self.feats = feats
|
|
self.clf = None
|
|
|
|
self.spec_slices = spec_slices
|
|
self.call_info = call_info
|
|
# _, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True)
|
|
self.labels = np.zeros(len(call_info), dtype=np.int)
|
|
self.annotated = np.zeros(
|
|
self.labels.shape[0], dtype=np.int
|
|
) # can populate this with 1's where we have labels
|
|
self.labels_cols = [colors[self.labels[ii]] for ii in range(len(self.labels))]
|
|
self.freq_lims = freq_lims
|
|
|
|
self.allow_training = allow_training
|
|
self.pt_size = 5.0
|
|
self.spec_pad = 0.2 # this much padding has been applied to the spec slices
|
|
self.fig_width = 12
|
|
self.fig_height = 8
|
|
|
|
self.current_id = 0
|
|
max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices])
|
|
self.max_width = self.spec_slices[max_ind].shape[1]
|
|
self.blank_spec = np.zeros((self.spec_slices[0].shape[0], self.max_width))
|
|
|
|
def plot(self, fig_id):
|
|
self.fig, self.ax = plt.subplots(
|
|
nrows=1,
|
|
ncols=2,
|
|
num=fig_id,
|
|
figsize=(self.fig_width, self.fig_height),
|
|
gridspec_kw={"width_ratios": [2, 1]},
|
|
)
|
|
plt.tight_layout()
|
|
|
|
# plot 2D TNSE features
|
|
self.low_dim_plt = self.ax[0].scatter(
|
|
self.feats_ds[:, 0],
|
|
self.feats_ds[:, 1],
|
|
c=self.labels_cols,
|
|
s=self.pt_size,
|
|
picker=5,
|
|
)
|
|
self.ax[0].set_title("TSNE of Call Features")
|
|
self.ax[0].set_xticks([])
|
|
self.ax[0].set_yticks([])
|
|
|
|
# plot clip from spectrogram
|
|
spec_min_max = (
|
|
0,
|
|
self.blank_spec.shape[1],
|
|
self.freq_lims[0],
|
|
self.freq_lims[1],
|
|
)
|
|
self.ax[1].imshow(
|
|
self.blank_spec, extent=spec_min_max, cmap="plasma", aspect="auto"
|
|
)
|
|
self.spec_im = self.ax[1].get_images()[0]
|
|
self.ax[1].set_title("Spectrogram")
|
|
self.ax[1].grid(color="w", linewidth=0.5)
|
|
self.ax[1].set_xticks([])
|
|
self.ax[1].set_ylabel("kHz")
|
|
|
|
bbox_orig = patches.Rectangle(
|
|
(0, 0), 0, 0, edgecolor="w", linewidth=0, fill=False
|
|
)
|
|
self.ax[1].add_patch(bbox_orig)
|
|
|
|
self.annot = self.ax[0].annotate(
|
|
"",
|
|
xy=(0, 0),
|
|
xytext=(20, 20),
|
|
textcoords="offset points",
|
|
bbox=dict(boxstyle="round", fc="w"),
|
|
arrowprops=dict(arrowstyle="->"),
|
|
)
|
|
self.annot.set_visible(False)
|
|
|
|
self.fig.canvas.mpl_connect("motion_notify_event", self.mouse_hover)
|
|
self.fig.canvas.mpl_connect("key_press_event", self.key_press)
|
|
|
|
def mouse_hover(self, event):
|
|
vis = self.annot.get_visible()
|
|
if event.inaxes == self.ax[0]:
|
|
cont, ind = self.low_dim_plt.contains(event)
|
|
if cont:
|
|
self.current_id = ind["ind"][0]
|
|
|
|
# copy spec into full window - probably a better way of doing this
|
|
new_spec = self.blank_spec.copy()
|
|
w_diff = (
|
|
self.blank_spec.shape[1]
|
|
- self.spec_slices[self.current_id].shape[1]
|
|
) // 2
|
|
new_spec[
|
|
:,
|
|
w_diff : self.spec_slices[self.current_id].shape[1] + w_diff,
|
|
] = self.spec_slices[self.current_id]
|
|
self.spec_im.set_data(new_spec)
|
|
self.spec_im.set_clim(vmin=0, vmax=new_spec.max())
|
|
|
|
# draw bounding box around call
|
|
self.ax[1].patches[0].remove()
|
|
spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
|
|
1.0 + 2.0 * self.spec_pad
|
|
)
|
|
xx = w_diff + self.spec_pad * spec_width_orig
|
|
ww = spec_width_orig
|
|
yy = self.call_info[self.current_id]["low_freq"] / 1000
|
|
hh = (
|
|
self.call_info[self.current_id]["high_freq"]
|
|
- self.call_info[self.current_id]["low_freq"]
|
|
) / 1000
|
|
bbox = patches.Rectangle(
|
|
(xx, yy), ww, hh, edgecolor="r", linewidth=0.5, fill=False
|
|
)
|
|
self.ax[1].add_patch(bbox)
|
|
|
|
# update annotation arrow
|
|
pos = self.low_dim_plt.get_offsets()[self.current_id]
|
|
self.annot.xy = pos
|
|
self.annot.set_visible(True)
|
|
|
|
# write call info
|
|
info_str = (
|
|
self.call_info[self.current_id]["file_name"]
|
|
+ ", time="
|
|
+ str(round(self.call_info[self.current_id]["start_time"], 3))
|
|
+ ", prob="
|
|
+ str(round(self.call_info[self.current_id]["det_prob"], 3))
|
|
)
|
|
self.ax[0].set_xlabel(info_str)
|
|
|
|
# redraw
|
|
self.fig.canvas.draw_idle()
|
|
|
|
def key_press(self, event):
|
|
if event.key.isdigit():
|
|
self.labels_cols[self.current_id] = colors[int(event.key)]
|
|
self.labels[self.current_id] = int(event.key)
|
|
self.annotated[self.current_id] = 1
|
|
elif event.key == "enter" and self.allow_training:
|
|
self.train_classifier()
|
|
elif event.key == "x" and self.allow_training:
|
|
self.get_classifier_params()
|
|
|
|
self.ax[0].scatter(
|
|
self.feats_ds[:, 0],
|
|
self.feats_ds[:, 1],
|
|
c=self.labels_cols,
|
|
s=self.pt_size,
|
|
)
|
|
self.fig.canvas.draw_idle()
|
|
|
|
def train_classifier(self):
|
|
# TODO maybe it's better to classify in 2D space - but then can't be linear ...
|
|
inds = np.where(self.annotated == 1)[0]
|
|
labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True)
|
|
|
|
if labs_un.shape[0] > 1: # needs at least 2 classes
|
|
self.clf = LinearSVC(
|
|
C=1.0,
|
|
penalty="l2",
|
|
loss="squared_hinge",
|
|
tol=0.0001,
|
|
intercept_scaling=1.0,
|
|
max_iter=2000,
|
|
)
|
|
|
|
self.clf.fit(self.feats[inds, :], self.labels[inds])
|
|
|
|
# update labels
|
|
inds_unlab = np.where(self.annotated == 0)[0]
|
|
self.labels[inds_unlab] = self.clf.predict(self.feats[inds_unlab])
|
|
for ii in inds_unlab:
|
|
self.labels_cols[ii] = colors[self.labels[ii]]
|
|
else:
|
|
print("Not enough data - please label more classes.")
|
|
|
|
def get_classifier_params(self):
|
|
res = {}
|
|
if self.clf is None:
|
|
print("Model not trained!")
|
|
else:
|
|
res["weights"] = self.clf.coef_.astype(np.float32)
|
|
res["biases"] = self.clf.intercept_.astype(np.float32)
|
|
return res
|