mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
334 lines
12 KiB
Python
Executable File
334 lines
12 KiB
Python
Executable File
import numpy as np
|
|
from sklearn.metrics import roc_curve, auc
|
|
from sklearn.metrics import accuracy_score, balanced_accuracy_score
|
|
|
|
|
|
def compute_error_auc(op_str, gt, pred, prob):
|
|
|
|
# classification error
|
|
pred_int = (pred > prob).astype(np.int)
|
|
class_acc = (pred_int == gt).mean() * 100.0
|
|
|
|
# ROC - area under curve
|
|
fpr, tpr, thresholds = roc_curve(gt, pred)
|
|
roc_auc = auc(fpr, tpr)
|
|
|
|
print(op_str + ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc))
|
|
#return class_acc, roc_auc
|
|
|
|
|
|
def calc_average_precision(recall, precision):
|
|
|
|
precision[np.isnan(precision)] = 0
|
|
recall[np.isnan(recall)] = 0
|
|
|
|
# pascal 12 way
|
|
mprec = np.hstack((0, precision, 0))
|
|
mrec = np.hstack((0, recall, 1))
|
|
for ii in range(mprec.shape[0]-2, -1,-1):
|
|
mprec[ii] = np.maximum(mprec[ii], mprec[ii+1])
|
|
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0]+1
|
|
ave_prec = ((mrec[inds] - mrec[inds-1])*mprec[inds]).sum()
|
|
|
|
return float(ave_prec)
|
|
|
|
|
|
def calc_recall_at_x(recall, precision, x=0.95):
|
|
precision[np.isnan(precision)] = 0
|
|
recall[np.isnan(recall)] = 0
|
|
|
|
inds = np.where(precision[::-1]>x)[0]
|
|
if len(inds) > 0:
|
|
return float(recall[::-1][inds[0]])
|
|
else:
|
|
return 0.0
|
|
|
|
|
|
def compute_affinity_1d(pred_box, gt_boxes, threshold):
|
|
# first entry is start time
|
|
score = np.abs(pred_box[0] - gt_boxes[:, 0])
|
|
valid_detection = np.min(score) <= threshold
|
|
return valid_detection, np.argmin(score)
|
|
|
|
|
|
def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, threshold, ignore_start_end):
|
|
"""
|
|
Computes precision and recall. Assumes that each file has been exhaustively
|
|
annotated. Will not count predicted detection with a start time that is within
|
|
ignore_start_end miliseconds of the start or end of the file.
|
|
|
|
eval_mode == 'detection'
|
|
Returns overall detection results (not per class)
|
|
|
|
eval_mode == 'per_class'
|
|
Filters ground truth based on class of interest. This will ignore predictions
|
|
assigned to gt with unknown class.
|
|
|
|
eval_mode = 'top_class'
|
|
Turns the problem into a binary one and selects the top predicted class
|
|
for each predicted detection
|
|
|
|
"""
|
|
|
|
# get predictions and put in array
|
|
pred_boxes = []
|
|
confidence = []
|
|
pred_class = []
|
|
file_ids = []
|
|
for pid, pp in enumerate(preds):
|
|
|
|
# filter predicted calls that are too near the start or end of the file
|
|
file_dur = gts[pid]['duration']
|
|
valid_inds = (pp['start_times'] >= ignore_start_end) & (pp['start_times'] <= (file_dur - ignore_start_end))
|
|
|
|
pred_boxes.append(np.vstack((pp['start_times'][valid_inds], pp['end_times'][valid_inds],
|
|
pp['low_freqs'][valid_inds], pp['high_freqs'][valid_inds])).T)
|
|
|
|
if eval_mode == 'detection':
|
|
# overall detection
|
|
confidence.append(pp['det_probs'][valid_inds])
|
|
elif eval_mode == 'per_class':
|
|
# per class
|
|
confidence.append(pp['class_probs'].T[valid_inds, class_of_interest])
|
|
elif eval_mode == 'top_class':
|
|
# per class - note that sometimes 'class_probs' can be num_classes+1 in size
|
|
top_class = np.argmax(pp['class_probs'].T[valid_inds, :num_classes], 1)
|
|
confidence.append(pp['class_probs'].T[valid_inds, top_class])
|
|
pred_class.append(top_class)
|
|
|
|
# be careful, assuming the order in the list is same as GT
|
|
file_ids.append([pid]*valid_inds.sum())
|
|
|
|
confidence = np.hstack(confidence)
|
|
file_ids = np.hstack(file_ids).astype(np.int)
|
|
pred_boxes = np.vstack(pred_boxes)
|
|
if len(pred_class) > 0:
|
|
pred_class = np.hstack(pred_class)
|
|
|
|
|
|
# extract relevant ground truth boxes
|
|
gt_boxes = []
|
|
gt_assigned = []
|
|
gt_class = []
|
|
gt_generic_class = []
|
|
num_positives = 0
|
|
for gg in gts:
|
|
|
|
# filter ground truth calls that are too near the start or end of the file
|
|
file_dur = gg['duration']
|
|
valid_inds = (gg['start_times'] >= ignore_start_end) & (gg['start_times'] <= (file_dur - ignore_start_end))
|
|
|
|
# note, files with the incorrect duration will cause a problem
|
|
if (gg['start_times'] > file_dur).sum() > 0:
|
|
print('Error: file duration incorrect for', gg['id'])
|
|
assert(False)
|
|
|
|
boxes = np.vstack((gg['start_times'][valid_inds], gg['end_times'][valid_inds],
|
|
gg['low_freqs'][valid_inds], gg['high_freqs'][valid_inds])).T
|
|
gen_class = gg['class_ids'][valid_inds] == -1
|
|
class_ids = gg['class_ids'][valid_inds]
|
|
|
|
# keep track of the number of relevant ground truth calls
|
|
if eval_mode == 'detection':
|
|
# all valid ones
|
|
num_positives += len(gg['start_times'][valid_inds])
|
|
elif eval_mode == 'per_class':
|
|
# all valid ones with class of interest
|
|
num_positives += (gg['class_ids'][valid_inds] == class_of_interest).sum()
|
|
elif eval_mode == 'top_class':
|
|
# all valid ones with non generic class
|
|
num_positives += (gg['class_ids'][valid_inds] > -1).sum()
|
|
|
|
# find relevant classes (i.e. class_of_interest) and events without known class (i.e. generic class, -1)
|
|
if eval_mode == 'per_class':
|
|
class_inds = (class_ids == class_of_interest) | (class_ids == -1)
|
|
boxes = boxes[class_inds, :]
|
|
gen_class = gen_class[class_inds]
|
|
class_ids = class_ids[class_inds]
|
|
|
|
gt_assigned.append(np.zeros(boxes.shape[0]))
|
|
gt_boxes.append(boxes)
|
|
gt_generic_class.append(gen_class)
|
|
gt_class.append(class_ids)
|
|
|
|
|
|
# loop through detections and keep track of those that have been assigned
|
|
true_pos = np.zeros(confidence.shape[0])
|
|
valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True
|
|
sorted_inds = np.argsort(confidence)[::-1] # sort high to low
|
|
for ii, ind in enumerate(sorted_inds):
|
|
gt_id = file_ids[ind]
|
|
|
|
valid_det = False
|
|
if gt_boxes[gt_id].shape[0] > 0:
|
|
# compute overlap
|
|
valid_det, det_ind = compute_affinity_1d(pred_boxes[ind], gt_boxes[gt_id],
|
|
threshold)
|
|
|
|
# valid detection that has not already been assigned
|
|
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
|
|
|
count_as_true_pos = True
|
|
if eval_mode == 'top_class' and (gt_class[gt_id][det_ind] != pred_class[ind]):
|
|
# needs to be the same class
|
|
count_as_true_pos = False
|
|
|
|
if count_as_true_pos:
|
|
true_pos[ii] = 1
|
|
|
|
gt_assigned[gt_id][det_ind] = 1
|
|
|
|
# if event is generic class (i.e. gt_generic_class[gt_id][det_ind] is True)
|
|
# and eval_mode != 'detection', then ignore it
|
|
if gt_generic_class[gt_id][det_ind]:
|
|
if eval_mode == 'per_class' or eval_mode == 'top_class':
|
|
valid_inds[ii] = False
|
|
|
|
|
|
# store threshold values - used for plotting
|
|
conf_sorted = np.sort(confidence)[::-1][valid_inds]
|
|
thresholds = np.linspace(0.1, 0.9, 9)
|
|
thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
|
|
for ii, tt in enumerate(thresholds):
|
|
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
|
|
thresholds_inds[thresholds_inds==0] = -1
|
|
|
|
# compute precision and recall
|
|
true_pos = true_pos[valid_inds]
|
|
false_pos_c = np.cumsum(1-true_pos)
|
|
true_pos_c = np.cumsum(true_pos)
|
|
|
|
recall = true_pos_c / num_positives
|
|
precision = true_pos_c / np.maximum(true_pos_c + false_pos_c, np.finfo(np.float64).eps)
|
|
|
|
results = {}
|
|
results['recall'] = recall
|
|
results['precision'] = precision
|
|
results['num_gt'] = num_positives
|
|
|
|
results['thresholds'] = thresholds
|
|
results['thresholds_inds'] = thresholds_inds
|
|
|
|
if num_positives == 0:
|
|
results['avg_prec'] = np.nan
|
|
results['rec_at_x'] = np.nan
|
|
else:
|
|
results['avg_prec'] = np.round(calc_average_precision(recall, precision), 5)
|
|
results['rec_at_x'] = np.round(calc_recall_at_x(recall, precision), 5)
|
|
|
|
return results
|
|
|
|
|
|
def compute_file_accuracy_simple(gts, preds, num_classes):
|
|
"""
|
|
Evaluates the prediction accuracy at a file level.
|
|
Does not include files that have more than one class (or the generic class).
|
|
|
|
Simply chooses the class per file that has the highest probability overall.
|
|
"""
|
|
|
|
gt_valid = []
|
|
pred_valid = []
|
|
for ii in range(len(gts)):
|
|
gt_class = np.unique(gts[ii]['class_ids'])
|
|
if len(gt_class) == 1 and gt_class[0] != -1:
|
|
gt_valid.append(gt_class[0])
|
|
pred = preds[ii]['class_probs'][:num_classes, :].T
|
|
pred_valid.append(np.argmax(pred.mean(0)))
|
|
acc = (np.array(gt_valid) == np.array(pred_valid)).mean()
|
|
|
|
res = {}
|
|
res['num_valid_files'] = len(gt_valid)
|
|
res['num_total_files'] = len(gts)
|
|
res['gt_valid_file'] = gt_valid
|
|
res['pred_valid_file'] = pred_valid
|
|
res['file_acc'] = np.round(acc, 5)
|
|
return res
|
|
|
|
|
|
def compute_file_accuracy(gts, preds, num_classes):
|
|
"""
|
|
Evaluates the prediction accuracy at a file level.
|
|
Does not include files that have more than one class (or the unknown class).
|
|
|
|
Tries several different detection thresholds and picks the best one.
|
|
"""
|
|
|
|
# compute min and max scoring range - then threshold
|
|
min_val = 0
|
|
mins = [pp['class_probs'].min() for pp in preds if pp['class_probs'].shape[1] > 0]
|
|
if len(mins) > 0:
|
|
min_val = np.min(mins)
|
|
|
|
max_val = 1.0
|
|
maxes = [pp['class_probs'].max() for pp in preds if pp['class_probs'].shape[1] > 0]
|
|
if len(maxes) > 0:
|
|
max_val = np.max(maxes)
|
|
|
|
thresh = np.linspace(min_val, max_val, 11)[:10]
|
|
|
|
# loop over the files and store the accuracy at different prediction thresholds
|
|
# only include gt files that have one valid species
|
|
gt_valid = []
|
|
pred_valid_all = []
|
|
for ii in range(len(gts)):
|
|
gt_class = np.unique(gts[ii]['class_ids'])
|
|
if len(gt_class) == 1 and gt_class[0] != -1:
|
|
gt_valid.append(gt_class[0])
|
|
pred = preds[ii]['class_probs'][:num_classes, :].T
|
|
p_class = np.zeros(len(thresh))
|
|
for tt in range(len(thresh)):
|
|
p_class[tt] = (pred*(pred>=thresh[tt])).sum(0).argmax()
|
|
pred_valid_all.append(p_class)
|
|
|
|
# pick the result corresponding to the overall best threshold
|
|
pred_valid_all = np.vstack(pred_valid_all)
|
|
acc_per_thresh = (np.array(gt_valid)[..., np.newaxis] == pred_valid_all).mean(0)
|
|
best_thresh = np.argmax(acc_per_thresh)
|
|
best_acc = acc_per_thresh[best_thresh]
|
|
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
|
|
|
|
res = {}
|
|
res['num_valid_files'] = len(gt_valid)
|
|
res['num_total_files'] = len(gts)
|
|
res['gt_valid_file'] = gt_valid
|
|
res['pred_valid_file'] = pred_valid
|
|
res['file_acc'] = np.round(best_acc, 5)
|
|
|
|
return res
|
|
|
|
|
|
def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_start_end=0.0):
|
|
"""
|
|
Computes metrics derived from the precision and recall.
|
|
Assumes that gts and preds are both lists of the same lengths, with ground
|
|
truth and predictions contained within.
|
|
|
|
Returns the overall detection results, and per class results
|
|
"""
|
|
|
|
assert(len(gts) == len(preds))
|
|
num_classes = len(class_names)
|
|
|
|
# evaluate detection on its own i.e. ignoring class
|
|
det_results = compute_pre_rec(gts, preds, 'detection', None, num_classes, detection_overlap, ignore_start_end)
|
|
top_class = compute_pre_rec(gts, preds, 'top_class', None, num_classes, detection_overlap, ignore_start_end)
|
|
det_results['top_class'] = top_class
|
|
|
|
# per class evaluation
|
|
det_results['class_pr'] = []
|
|
for cc in range(num_classes):
|
|
res = compute_pre_rec(gts, preds, 'per_class', cc, num_classes, detection_overlap, ignore_start_end)
|
|
res['name'] = class_names[cc]
|
|
det_results['class_pr'].append(res)
|
|
|
|
# ignores classes that are not present in the test set
|
|
det_results['avg_prec_class'] = np.mean([rs['avg_prec'] for rs in det_results['class_pr'] if rs['num_gt'] > 0])
|
|
det_results['avg_prec_class'] = np.round(det_results['avg_prec_class'], 5)
|
|
|
|
# file level evaluation
|
|
res_file = compute_file_accuracy(gts, preds, num_classes)
|
|
det_results.update(res_file)
|
|
|
|
return det_results
|