mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
246 lines
7.2 KiB
Python
246 lines
7.2 KiB
Python
"""Post-processing of the output of the model."""
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from bat_detect.detector.models import ModelOutput
|
|
|
|
try:
|
|
from typing import TypedDict
|
|
except ImportError:
|
|
from typing_extensions import TypedDict
|
|
|
|
np.seterr(divide="ignore", invalid="ignore")
|
|
|
|
|
|
def x_coords_to_time(
|
|
x_pos: float,
|
|
sampling_rate: int,
|
|
fft_win_length: float,
|
|
fft_overlap: float,
|
|
) -> float:
|
|
"""Convert x coordinates of spectrogram to time in seconds.
|
|
|
|
Args:
|
|
x_pos: X position of the detection in pixels.
|
|
sampling_rate: Sampling rate of the audio in Hz.
|
|
fft_win_length: Length of the FFT window in seconds.
|
|
fft_overlap: Overlap of the FFT windows in seconds.
|
|
|
|
Returns:
|
|
Time in seconds.
|
|
"""
|
|
nfft = int(fft_win_length * sampling_rate)
|
|
noverlap = int(fft_overlap * nfft)
|
|
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
|
|
|
|
|
def overall_class_pred(det_prob, class_prob):
|
|
weighted_pred = (class_prob * det_prob).sum(1)
|
|
return weighted_pred / weighted_pred.sum()
|
|
|
|
|
|
class NonMaximumSuppressionConfig(TypedDict):
|
|
"""Configuration for non-maximum suppression."""
|
|
|
|
nms_kernel_size: int
|
|
"""Size of the kernel for non-maximum suppression."""
|
|
|
|
max_freq: int
|
|
"""Maximum frequency to consider in Hz."""
|
|
|
|
min_freq: int
|
|
"""Minimum frequency to consider in Hz."""
|
|
|
|
fft_win_length: float
|
|
"""Length of the FFT window in seconds."""
|
|
|
|
fft_overlap: float
|
|
"""Overlap of the FFT windows in seconds."""
|
|
|
|
resize_factor: float
|
|
"""Factor by which the input was resized."""
|
|
|
|
nms_top_k_per_sec: float
|
|
"""Number of top detections to keep per second."""
|
|
|
|
detection_threshold: float
|
|
"""Threshold for detection probability."""
|
|
|
|
|
|
class PredictionResults(TypedDict):
|
|
"""Results of the prediction.
|
|
|
|
Each key is a list of length `num_detections` containing the
|
|
corresponding values for each detection.
|
|
"""
|
|
|
|
det_probs: np.ndarray
|
|
"""Detection probabilities."""
|
|
|
|
x_pos: np.ndarray
|
|
"""X position of the detection in pixels."""
|
|
|
|
y_pos: np.ndarray
|
|
"""Y position of the detection in pixels."""
|
|
|
|
bb_width: np.ndarray
|
|
"""Width of the detection in pixels."""
|
|
|
|
bb_height: np.ndarray
|
|
"""Height of the detection in pixels."""
|
|
|
|
start_times: np.ndarray
|
|
"""Start times of the detections in seconds."""
|
|
|
|
end_times: np.ndarray
|
|
"""End times of the detections in seconds."""
|
|
|
|
low_freqs: np.ndarray
|
|
"""Low frequencies of the detections in Hz."""
|
|
|
|
high_freqs: np.ndarray
|
|
"""High frequencies of the detections in Hz."""
|
|
|
|
class_probs: Optional[np.ndarray]
|
|
"""Class probabilities."""
|
|
|
|
|
|
def run_nms(
|
|
outputs: ModelOutput,
|
|
params: NonMaximumSuppressionConfig,
|
|
sampling_rate: np.ndarray,
|
|
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
|
"""Run non-maximum suppression on the output of the model.
|
|
|
|
Model outputs processed are expected to have a batch dimension.
|
|
Each element of the batch is processed independently. The
|
|
result is a pair of lists, one for the predictions and one for
|
|
the features. Each element of the lists corresponds to one
|
|
element of the batch.
|
|
"""
|
|
pred_det, pred_size, pred_class, _, features = outputs
|
|
|
|
pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"])
|
|
freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[
|
|
-2
|
|
]
|
|
|
|
# NOTE: there will be small differences depending on which sampling rate is chosen
|
|
# as we are choosing the same sampling rate for the entire batch
|
|
duration = x_coords_to_time(
|
|
pred_det.shape[-1],
|
|
int(sampling_rate[0].item()),
|
|
params["fft_win_length"],
|
|
params["fft_overlap"],
|
|
)
|
|
top_k = int(duration * params["nms_top_k_per_sec"])
|
|
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
|
|
|
# loop over batch to save outputs
|
|
preds: List[PredictionResults] = []
|
|
feats: List[np.ndarray] = []
|
|
for num_detection in range(pred_det_nms.shape[0]):
|
|
# get valid indices
|
|
inds_ord = torch.argsort(x_pos[num_detection, :])
|
|
valid_inds = (
|
|
scores[num_detection, inds_ord] > params["detection_threshold"]
|
|
)
|
|
valid_inds = inds_ord[valid_inds]
|
|
|
|
# create result dictionary
|
|
pred = {}
|
|
pred["det_probs"] = scores[num_detection, valid_inds]
|
|
pred["x_pos"] = x_pos[num_detection, valid_inds]
|
|
pred["y_pos"] = y_pos[num_detection, valid_inds]
|
|
pred["bb_width"] = pred_size[
|
|
num_detection,
|
|
0,
|
|
pred["y_pos"],
|
|
pred["x_pos"],
|
|
]
|
|
pred["bb_height"] = pred_size[
|
|
num_detection,
|
|
1,
|
|
pred["y_pos"],
|
|
pred["x_pos"],
|
|
]
|
|
pred["start_times"] = x_coords_to_time(
|
|
pred["x_pos"].float() / params["resize_factor"],
|
|
int(sampling_rate[num_detection].item()),
|
|
params["fft_win_length"],
|
|
params["fft_overlap"],
|
|
)
|
|
pred["end_times"] = x_coords_to_time(
|
|
(pred["x_pos"].float() + pred["bb_width"])
|
|
/ params["resize_factor"],
|
|
int(sampling_rate[num_detection].item()),
|
|
params["fft_win_length"],
|
|
params["fft_overlap"],
|
|
)
|
|
pred["low_freqs"] = (
|
|
pred_size[num_detection].shape[1] - pred["y_pos"].float()
|
|
) * freq_rescale + params["min_freq"]
|
|
pred["high_freqs"] = (
|
|
pred["low_freqs"] + pred["bb_height"] * freq_rescale
|
|
)
|
|
|
|
# extract the per class votes
|
|
if pred_class is not None:
|
|
pred["class_probs"] = pred_class[
|
|
num_detection,
|
|
:,
|
|
y_pos[num_detection, valid_inds],
|
|
x_pos[num_detection, valid_inds],
|
|
]
|
|
|
|
# extract the model features
|
|
if features is not None:
|
|
feat = features[
|
|
num_detection,
|
|
:,
|
|
y_pos[num_detection, valid_inds],
|
|
x_pos[num_detection, valid_inds],
|
|
].transpose(0, 1)
|
|
feat = feat.cpu().numpy().astype(np.float32)
|
|
feats.append(feat)
|
|
|
|
# convert to numpy
|
|
for key, value in pred.items():
|
|
pred[key] = value.cpu().numpy().astype(np.float32)
|
|
|
|
preds.append(pred)
|
|
|
|
return preds, feats
|
|
|
|
|
|
def non_max_suppression(heat, kernel_size):
|
|
# kernel can be an int or list/tuple
|
|
if isinstance(kernel_size, int):
|
|
kernel_size_h = kernel_size
|
|
kernel_size_w = kernel_size
|
|
|
|
pad_h = (kernel_size_h - 1) // 2
|
|
pad_w = (kernel_size_w - 1) // 2
|
|
|
|
hmax = nn.functional.max_pool2d(
|
|
heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w)
|
|
)
|
|
keep = (hmax == heat).float()
|
|
|
|
return heat * keep
|
|
|
|
|
|
def get_topk_scores(scores, K):
|
|
# expects input of size: batch x 1 x height x width
|
|
batch, _, height, width = scores.size()
|
|
|
|
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
|
|
topk_inds = topk_inds % (height * width)
|
|
topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
|
|
topk_xs = (topk_inds % width).long()
|
|
|
|
return topk_scores, topk_ys, topk_xs
|