"""Module for postprocessing model outputs.""" from typing import Callable, List, Tuple, Union from pydantic import BaseModel, Field import numpy as np import torch from soundevent import data from torch import nn from batdetect2.models.typing import ModelOutput __all__ = [ "postprocess_model_outputs", "PostprocessConfig", ] NMS_KERNEL_SIZE = 9 DETECTION_THRESHOLD = 0.01 TOP_K_PER_SEC = 200 class PostprocessConfig(BaseModel): """Configuration for postprocessing model outputs.""" nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0) detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0) min_freq: int = Field(default=10000, gt=0) max_freq: int = Field(default=120000, gt=0) top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) TagFunction = Callable[[int], List[data.Tag]] def postprocess_model_outputs( outputs: ModelOutput, clips: List[data.Clip], nms_kernel_size: int = NMS_KERNEL_SIZE, detection_threshold: float = DETECTION_THRESHOLD, min_freq: int = 10000, max_freq: int = 120000, top_k_per_sec: int = TOP_K_PER_SEC, ) -> List[data.ClipPrediction]: """Postprocesses model outputs to generate clip predictions. This function takes the output from the model, applies non-maximum suppression, selects the top-k scores, computes sound events from the outputs, and returns clip predictions based on these processed outputs. Parameters ---------- outputs Output from the model containing detection probabilities, size predictions, class logits, and features. All tensors are expected to have a batch dimension. clips List of clips for which predictions are made. The number of clips must match the batch dimension of the model outputs. nms_kernel_size Size of the non-maximum suppression kernel. Default is 9. detection_threshold Detection threshold. Default is 0.01. min_freq Minimum frequency. Default is 10000. max_freq Maximum frequency. Default is 120000. top_k_per_sec Top k per second. Default is 200. Returns ------- predictions: List[data.ClipPrediction] List of clip predictions containing predicted sound events. Raises ------ ValueError If the number of predictions does not match the number of clips. """ num_predictions = len(outputs.detection_probs) if num_predictions == 0: return [] if num_predictions != len(clips): raise ValueError( "Number of predictions must match the number of clips." ) detection_probs = non_max_suppression( outputs.detection_probs, kernel_size=nms_kernel_size, ) duration = clips[0].end_time - clips[0].start_time scores_batch, y_pos_batch, x_pos_batch = get_topk_scores( detection_probs, int(top_k_per_sec * duration / 2), ) predictions: List[data.ClipPrediction] = [] for scores, y_pos, x_pos, size_preds, class_probs, features, clip in zip( scores_batch, y_pos_batch, x_pos_batch, outputs.size_preds, outputs.class_probs, outputs.features, clips, ): sound_events = compute_sound_events_from_outputs( clip, scores, y_pos, x_pos, size_preds, class_probs, features, min_freq=min_freq, max_freq=max_freq, detection_threshold=detection_threshold, ) predictions.append( data.ClipPrediction( clip=clip, sound_events=sound_events, ) ) return predictions def compute_sound_events_from_outputs( clip: data.Clip, scores: torch.Tensor, y_pos: torch.Tensor, x_pos: torch.Tensor, size_preds: torch.Tensor, class_probs: torch.Tensor, features: torch.Tensor, tag_fn: TagFunction = lambda _: [], min_freq: int = 10000, max_freq: int = 120000, detection_threshold: float = DETECTION_THRESHOLD, ) -> List[data.SoundEventPrediction]: _, freq_bins, time_bins = size_preds.shape sorted_indices = torch.argsort(x_pos) valid_indices = sorted_indices[ scores[sorted_indices] > detection_threshold ] scores = scores[valid_indices] x_pos = x_pos[valid_indices] y_pos = y_pos[valid_indices] predictions: List[data.SoundEventPrediction] = [] for score, x, y in zip(scores, x_pos, y_pos): width, height = size_preds[:, y, x] print(width, height) class_prob = class_probs[:, y, x] feature = features[:, y, x] start_time = np.interp( x.item(), [0, time_bins], [clip.start_time, clip.end_time], ) end_time = np.interp( x.item() + width.item(), [0, time_bins], [clip.start_time, clip.end_time], ) low_freq = np.interp( y.item(), [0, freq_bins], [max_freq, min_freq], ) high_freq = np.interp( y.item() - height.item(), [0, freq_bins], [max_freq, min_freq], ) predicted_tags: List[data.PredictedTag] = [] for label_id, class_score in enumerate(class_prob): corresponding_tags = tag_fn(label_id) predicted_tags.extend( [ data.PredictedTag( tag=tag, score=class_score.item(), ) for tag in corresponding_tags ] ) start_time, end_time = sorted([float(start_time), float(end_time)]) low_freq, high_freq = sorted([float(low_freq), float(high_freq)]) sound_event = data.SoundEvent( recording=clip.recording, geometry=data.BoundingBox( coordinates=[ start_time, low_freq, end_time, high_freq, ] ), features=[ data.Feature( name=f"batdetect2_{i}", value=value.item(), ) for i, value in enumerate(feature) ], ) predictions.append( data.SoundEventPrediction( sound_event=sound_event, score=score.item(), tags=predicted_tags, ) ) return predictions def non_max_suppression( tensor: torch.Tensor, kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, ) -> torch.Tensor: """Run non-maximum suppression on a tensor. This function removes values from the input tensor that are not local maxima in the neighborhood of the given kernel size. All non-maximum values are set to zero. Parameters ---------- tensor : torch.Tensor Input tensor. kernel_size : Union[int, Tuple[int, int]], optional Size of the neighborhood to consider for non-maximum suppression. If an integer is given, the neighborhood will be a square of the given size. If a tuple is given, the neighborhood will be a rectangle with the given height and width. Returns ------- torch.Tensor Tensor with non-maximum suppressed values. """ if isinstance(kernel_size, int): kernel_size_h = kernel_size kernel_size_w = kernel_size else: kernel_size_h, kernel_size_w = kernel_size pad_h = (kernel_size_h - 1) // 2 pad_w = (kernel_size_w - 1) // 2 hmax = nn.functional.max_pool2d( tensor, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w), ) keep = (hmax == tensor).float() return tensor * keep def get_topk_scores( scores: torch.Tensor, K: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get the top-k scores and their indices. Parameters ---------- scores : torch.Tensor Tensor with scores. Expects input of size: `batch x 1 x height x width`. K : int Number of top scores to return. Returns ------- scores : torch.Tensor Top-k scores. ys : torch.Tensor Y coordinates of the top-k scores. xs : torch.Tensor X coordinates of the top-k scores. """ batch, _, height, width = scores.size() topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K) topk_inds = topk_inds % (height * width) topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long() topk_xs = (topk_inds % width).long() return topk_scores, topk_ys, topk_xs