batdetect2/bat_detect/detector/post_process.py
2023-02-26 20:48:52 +00:00

181 lines
5.7 KiB
Python

"""Post-processing of the output of the model."""
from typing import List, Tuple, Union
import numpy as np
import torch
from torch import nn
from bat_detect.detector.models import ModelOutput
from bat_detect.types import NonMaximumSuppressionConfig, PredictionResults
np.seterr(divide="ignore", invalid="ignore")
def x_coords_to_time(
x_pos: float,
sampling_rate: int,
fft_win_length: float,
fft_overlap: float,
) -> float:
"""Convert x coordinates of spectrogram to time in seconds.
Args:
x_pos: X position of the detection in pixels.
sampling_rate: Sampling rate of the audio in Hz.
fft_win_length: Length of the FFT window in seconds.
fft_overlap: Overlap of the FFT windows in seconds.
Returns:
Time in seconds.
"""
nfft = int(fft_win_length * sampling_rate)
noverlap = int(fft_overlap * nfft)
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
def overall_class_pred(det_prob, class_prob):
weighted_pred = (class_prob * det_prob).sum(1)
return weighted_pred / weighted_pred.sum()
def run_nms(
outputs: ModelOutput,
params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray,
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
"""Run non-maximum suppression on the output of the model.
Model outputs processed are expected to have a batch dimension.
Each element of the batch is processed independently. The
result is a pair of lists, one for the predictions and one for
the features. Each element of the lists corresponds to one
element of the batch.
"""
pred_det, pred_size, pred_class, _, features = outputs
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
-2
]
# NOTE: there will be small differences depending on which sampling rate
# is chosen as we are choosing the same sampling rate for the entire batch
duration = x_coords_to_time(
pred_det.shape[-1],
int(sampling_rate[0].item()),
params["fft_win_length"],
params["fft_overlap"],
)
top_k = int(duration * params["nms_top_k_per_sec"])
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
# loop over batch to save outputs
preds: List[PredictionResults] = []
feats: List[np.ndarray] = []
for num_detection in range(pred_det_nms.shape[0]):
# get valid indices
inds_ord = torch.argsort(x_pos[num_detection, :])
valid_inds = (
scores[num_detection, inds_ord] > params["detection_threshold"]
)
valid_inds = inds_ord[valid_inds]
# create result dictionary
pred = {}
pred["det_probs"] = scores[num_detection, valid_inds]
pred["x_pos"] = x_pos[num_detection, valid_inds]
pred["y_pos"] = y_pos[num_detection, valid_inds]
pred["bb_width"] = pred_size[
num_detection,
0,
pred["y_pos"],
pred["x_pos"],
]
pred["bb_height"] = pred_size[
num_detection,
1,
pred["y_pos"],
pred["x_pos"],
]
pred["start_times"] = x_coords_to_time(
pred["x_pos"].float() / params["resize_factor"],
int(sampling_rate[num_detection].item()),
params["fft_win_length"],
params["fft_overlap"],
)
pred["end_times"] = x_coords_to_time(
(pred["x_pos"].float() + pred["bb_width"])
/ params["resize_factor"],
int(sampling_rate[num_detection].item()),
params["fft_win_length"],
params["fft_overlap"],
)
pred["low_freqs"] = (
pred_size[num_detection].shape[1] - pred["y_pos"].float()
) * freq_rescale + params["min_freq"]
pred["high_freqs"] = (
pred["low_freqs"] + pred["bb_height"] * freq_rescale
)
# extract the per class votes
if pred_class is not None:
pred["class_probs"] = pred_class[
num_detection,
:,
y_pos[num_detection, valid_inds],
x_pos[num_detection, valid_inds],
]
# extract the model features
if features is not None:
feat = features[
num_detection,
:,
y_pos[num_detection, valid_inds],
x_pos[num_detection, valid_inds],
].transpose(0, 1)
feat = feat.detach().numpy().astype(np.float32)
feats.append(feat)
# convert to numpy
for key, value in pred.items():
pred[key] = value.detach().numpy().astype(np.float32)
preds.append(pred) # type: ignore
return preds, feats
def non_max_suppression(
heat: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]],
):
# kernel can be an int or list/tuple
if isinstance(kernel_size, int):
kernel_size_h = kernel_size
kernel_size_w = kernel_size
else:
kernel_size_h, kernel_size_w = kernel_size
pad_h = (kernel_size_h - 1) // 2
pad_w = (kernel_size_w - 1) // 2
hmax = nn.functional.max_pool2d(
heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w)
)
keep = (hmax == heat).float()
return heat * keep
def get_topk_scores(scores, K):
# expects input of size: batch x 1 x height x width
batch, _, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
topk_inds = topk_inds % (height * width)
topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
topk_xs = (topk_inds % width).long()
return topk_scores, topk_ys, topk_xs