diff --git a/batdetect2/post_process.py b/batdetect2/post_process.py deleted file mode 100644 index f266225..0000000 --- a/batdetect2/post_process.py +++ /dev/null @@ -1,398 +0,0 @@ -"""Module for postprocessing model outputs.""" - -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union - -import numpy as np -import torch -from pydantic import Field -from soundevent import data -from torch import nn - -from batdetect2.configs import BaseConfig, load_config -from batdetect2.models.types import ModelOutput - -__all__ = [ - "PostprocessConfig", - "load_postprocess_config", - "postprocess_model_outputs", -] - -NMS_KERNEL_SIZE = 9 -DETECTION_THRESHOLD = 0.01 -TOP_K_PER_SEC = 200 - - -class PostprocessConfig(BaseConfig): - """Configuration for postprocessing model outputs.""" - - nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0) - detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0) - min_freq: int = Field(default=10000, gt=0) - max_freq: int = Field(default=120000, gt=0) - top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) - - -def load_postprocess_config( - path: data.PathLike, - field: Optional[str] = None, -) -> PostprocessConfig: - return load_config(path, schema=PostprocessConfig, field=field) - - -class RawPrediction(NamedTuple): - start_time: float - end_time: float - low_freq: float - high_freq: float - detection_score: float - class_scores: Dict[str, float] - features: np.ndarray - - -def postprocess_model_outputs( - outputs: ModelOutput, - clips: List[data.Clip], - classes: List[str], - decoder: Callable[[str], List[data.Tag]], - config: Optional[PostprocessConfig] = None, -) -> List[data.ClipPrediction]: - """Postprocesses model outputs to generate clip predictions. - - This function takes the output from the model, applies non-maximum suppression, - selects the top-k scores, computes sound events from the outputs, and returns - clip predictions based on these processed outputs. - - Parameters - ---------- - outputs - Output from the model containing detection probabilities, size - predictions, class logits, and features. All tensors are expected - to have a batch dimension. - clips - List of clips for which predictions are made. The number of clips - must match the batch dimension of the model outputs. - config - Configuration for postprocessing model outputs. - - Returns - ------- - predictions: List[data.ClipPrediction] - List of clip predictions containing predicted sound events. - - Raises - ------ - ValueError - If the number of predictions does not match the number of clips. - """ - - config = config or PostprocessConfig() - - num_predictions = len(outputs.detection_probs) - - if num_predictions == 0: - return [] - - if num_predictions != len(clips): - raise ValueError( - "Number of predictions must match the number of clips." - ) - - detection_probs = non_max_suppression( - outputs.detection_probs, - kernel_size=config.nms_kernel_size, - ) - - duration = clips[0].end_time - clips[0].start_time - - scores_batch, y_pos_batch, x_pos_batch = get_topk_scores( - detection_probs, - int(config.top_k_per_sec * duration / 2), - ) - - predictions: List[data.ClipPrediction] = [] - for scores, y_pos, x_pos, size_preds, class_probs, features, clip in zip( - scores_batch, - y_pos_batch, - x_pos_batch, - outputs.size_preds, - outputs.class_probs, - outputs.features, - clips, - ): - sound_events = compute_sound_events_from_outputs( - clip, - scores, - y_pos, - x_pos, - size_preds, - class_probs, - features, - classes=classes, - decoder=decoder, - min_freq=config.min_freq, - max_freq=config.max_freq, - detection_threshold=config.detection_threshold, - ) - - predictions.append( - data.ClipPrediction( - clip=clip, - sound_events=sound_events, - ) - ) - - return predictions - - -def compute_predictions_from_outputs( - start: float, - end: float, - scores: torch.Tensor, - y_pos: torch.Tensor, - x_pos: torch.Tensor, - size_preds: torch.Tensor, - class_probs: torch.Tensor, - features: torch.Tensor, - classes: List[str], - min_freq: int = 10000, - max_freq: int = 120000, - detection_threshold: float = DETECTION_THRESHOLD, -) -> List[RawPrediction]: - _, freq_bins, time_bins = size_preds.shape - - sorted_indices = torch.argsort(x_pos) - valid_indices = sorted_indices[ - scores[sorted_indices] > detection_threshold - ] - - scores = scores[valid_indices] - x_pos = x_pos[valid_indices] - y_pos = y_pos[valid_indices] - - predictions: List[RawPrediction] = [] - for score, x, y in zip(scores, x_pos, y_pos): - width, height = size_preds[:, y, x] - class_prob = class_probs[:, y, x].detach().numpy() - feats = features[:, y, x].detach().numpy() - - start_time = np.interp( - x.item(), - [0, time_bins], - [start, end], - ) - - end_time = np.interp( - x.item() + width.item(), - [0, time_bins], - [start, end], - ) - - low_freq = np.interp( - y.item(), - [0, freq_bins], - [max_freq, min_freq], - ) - - high_freq = np.interp( - y.item() - height.item(), - [0, freq_bins], - [max_freq, min_freq], - ) - - start_time, end_time = sorted([float(start_time), float(end_time)]) - low_freq, high_freq = sorted([float(low_freq), float(high_freq)]) - - predictions.append( - RawPrediction( - start_time=start_time, - end_time=end_time, - low_freq=low_freq, - high_freq=high_freq, - detection_score=score.item(), - features=feats, - class_scores={ - class_name: prob - for class_name, prob in zip(classes, class_prob) - }, - ) - ) - - return predictions - - -def compute_sound_events_from_outputs( - clip: data.Clip, - scores: torch.Tensor, - y_pos: torch.Tensor, - x_pos: torch.Tensor, - size_preds: torch.Tensor, - class_probs: torch.Tensor, - features: torch.Tensor, - classes: List[str], - decoder: Callable[[str], List[data.Tag]], - min_freq: int = 10000, - max_freq: int = 120000, - detection_threshold: float = DETECTION_THRESHOLD, -) -> List[data.SoundEventPrediction]: - _, freq_bins, time_bins = size_preds.shape - - sorted_indices = torch.argsort(x_pos) - valid_indices = sorted_indices[ - scores[sorted_indices] > detection_threshold - ] - - scores = scores[valid_indices] - x_pos = x_pos[valid_indices] - y_pos = y_pos[valid_indices] - - predictions: List[data.SoundEventPrediction] = [] - for score, x, y in zip(scores, x_pos, y_pos): - width, height = size_preds[:, y, x] - class_prob = class_probs[:, y, x] - feature = features[:, y, x] - - start_time = np.interp( - x.item(), - [0, time_bins], - [clip.start_time, clip.end_time], - ) - - end_time = np.interp( - x.item() + width.item(), - [0, time_bins], - [clip.start_time, clip.end_time], - ) - - low_freq = np.interp( - y.item(), - [0, freq_bins], - [max_freq, min_freq], - ) - - high_freq = np.interp( - y.item() - height.item(), - [0, freq_bins], - [max_freq, min_freq], - ) - - predicted_tags: List[data.PredictedTag] = [] - - for label_id, class_score in enumerate(class_prob): - class_name = classes[label_id] - corresponding_tags = decoder(class_name) - predicted_tags.extend( - [ - data.PredictedTag( - tag=tag, - score=max(min(class_score.item(), 1), 0), - ) - for tag in corresponding_tags - ] - ) - - start_time, end_time = sorted([float(start_time), float(end_time)]) - low_freq, high_freq = sorted([float(low_freq), float(high_freq)]) - - sound_event = data.SoundEvent( - recording=clip.recording, - geometry=data.BoundingBox( - coordinates=[ - start_time, - low_freq, - end_time, - high_freq, - ] - ), - features=[ - data.Feature( - term=data.term_from_key(f"batdetect2_{i}"), - value=value.item(), - ) - for i, value in enumerate(feature) - ], - ) - - predictions.append( - data.SoundEventPrediction( - sound_event=sound_event, - score=max(min(score.item(), 1), 0), - tags=predicted_tags, - ) - ) - - return predictions - - -def non_max_suppression( - tensor: torch.Tensor, - kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, -) -> torch.Tensor: - """Run non-maximum suppression on a tensor. - - This function removes values from the input tensor that are not local - maxima in the neighborhood of the given kernel size. - - All non-maximum values are set to zero. - - Parameters - ---------- - tensor : torch.Tensor - Input tensor. - kernel_size : Union[int, Tuple[int, int]], optional - Size of the neighborhood to consider for non-maximum suppression. - If an integer is given, the neighborhood will be a square of the - given size. If a tuple is given, the neighborhood will be a - rectangle with the given height and width. - - Returns - ------- - torch.Tensor - Tensor with non-maximum suppressed values. - """ - if isinstance(kernel_size, int): - kernel_size_h = kernel_size - kernel_size_w = kernel_size - else: - kernel_size_h, kernel_size_w = kernel_size - - pad_h = (kernel_size_h - 1) // 2 - pad_w = (kernel_size_w - 1) // 2 - - hmax = nn.functional.max_pool2d( - tensor, - (kernel_size_h, kernel_size_w), - stride=1, - padding=(pad_h, pad_w), - ) - keep = (hmax == tensor).float() - return tensor * keep - - -def get_topk_scores( - scores: torch.Tensor, - K: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Get the top-k scores and their indices. - - Parameters - ---------- - scores : torch.Tensor - Tensor with scores. Expects input of size: `batch x 1 x height x width`. - K : int - Number of top scores to return. - - Returns - ------- - scores : torch.Tensor - Top-k scores. - ys : torch.Tensor - Y coordinates of the top-k scores. - xs : torch.Tensor - X coordinates of the top-k scores. - """ - batch, _, height, width = scores.size() - topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K) - topk_inds = topk_inds % (height * width) - topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long() - topk_xs = (topk_inds % width).long() - return topk_scores, topk_ys, topk_xs