import numpy as np import torch import torch.nn as nn import torch.nn.functional as F np.seterr(divide="ignore", invalid="ignore") def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): nfft = int(fft_win_length * sampling_rate) noverlap = int(fft_overlap * nfft) return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window def overall_class_pred(det_prob, class_prob): weighted_pred = (class_prob * det_prob).sum(1) return weighted_pred / weighted_pred.sum() def run_nms(outputs, params, sampling_rate): pred_det = outputs["pred_det"] # probability of box pred_size = outputs["pred_size"] # box size pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"]) freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[ -2 ] # NOTE there will be small differences depending on which sampling rate is chosen # as we are choosing the same sampling rate for the entire batch duration = x_coords_to_time( pred_det.shape[-1], sampling_rate[0].item(), params["fft_win_length"], params["fft_overlap"], ) top_k = int(duration * params["nms_top_k_per_sec"]) scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) # loop over batch to save outputs preds = [] feats = [] for ii in range(pred_det_nms.shape[0]): # get valid indices inds_ord = torch.argsort(x_pos[ii, :]) valid_inds = scores[ii, inds_ord] > params["detection_threshold"] valid_inds = inds_ord[valid_inds] # create result dictionary pred = {} pred["det_probs"] = scores[ii, valid_inds] pred["x_pos"] = x_pos[ii, valid_inds] pred["y_pos"] = y_pos[ii, valid_inds] pred["bb_width"] = pred_size[ii, 0, pred["y_pos"], pred["x_pos"]] pred["bb_height"] = pred_size[ii, 1, pred["y_pos"], pred["x_pos"]] pred["start_times"] = x_coords_to_time( pred["x_pos"].float() / params["resize_factor"], sampling_rate[ii].item(), params["fft_win_length"], params["fft_overlap"], ) pred["end_times"] = x_coords_to_time( (pred["x_pos"].float() + pred["bb_width"]) / params["resize_factor"], sampling_rate[ii].item(), params["fft_win_length"], params["fft_overlap"], ) pred["low_freqs"] = ( pred_size[ii].shape[1] - pred["y_pos"].float() ) * freq_rescale + params["min_freq"] pred["high_freqs"] = ( pred["low_freqs"] + pred["bb_height"] * freq_rescale ) # extract the per class votes if "pred_class" in outputs: pred["class_probs"] = outputs["pred_class"][ ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds] ] # extract the model features if "features" in outputs: feat = outputs["features"][ ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds] ].transpose(0, 1) feat = feat.cpu().numpy().astype(np.float32) feats.append(feat) # convert to numpy for kk in pred.keys(): pred[kk] = pred[kk].cpu().numpy().astype(np.float32) preds.append(pred) return preds, feats def non_max_suppression(heat, kernel_size): # kernel can be an int or list/tuple if type(kernel_size) is int: kernel_size_h = kernel_size kernel_size_w = kernel_size pad_h = (kernel_size_h - 1) // 2 pad_w = (kernel_size_w - 1) // 2 hmax = nn.functional.max_pool2d( heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w) ) keep = (hmax == heat).float() return heat * keep def get_topk_scores(scores, K): # expects input of size: batch x 1 x height x width batch, _, height, width = scores.size() topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K) topk_inds = topk_inds % (height * width) topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long() topk_xs = (topk_inds % width).long() return topk_scores, topk_ys, topk_xs