mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
396 lines
11 KiB
Python
396 lines
11 KiB
Python
"""Module for postprocessing model outputs."""
|
|
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from pydantic import BaseModel, Field
|
|
from soundevent import data
|
|
from torch import nn
|
|
|
|
from batdetect2.models.typing import ModelOutput
|
|
|
|
__all__ = [
|
|
"postprocess_model_outputs",
|
|
"PostprocessConfig",
|
|
]
|
|
|
|
NMS_KERNEL_SIZE = 9
|
|
DETECTION_THRESHOLD = 0.01
|
|
TOP_K_PER_SEC = 200
|
|
|
|
|
|
class PostprocessConfig(BaseModel):
|
|
"""Configuration for postprocessing model outputs."""
|
|
|
|
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
|
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
|
|
min_freq: int = Field(default=10000, gt=0)
|
|
max_freq: int = Field(default=120000, gt=0)
|
|
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
|
|
|
|
|
class RawPrediction(BaseModel):
|
|
start_time: float
|
|
end_time: float
|
|
low_freq: float
|
|
high_freq: float
|
|
detection_score: float
|
|
class_scores: Dict[str, float]
|
|
features: np.ndarray
|
|
|
|
|
|
def postprocess_model_outputs(
|
|
outputs: ModelOutput,
|
|
clips: List[data.Clip],
|
|
classes: List[str],
|
|
decoder: Callable[[str], List[data.Tag]],
|
|
config: Optional[PostprocessConfig] = None,
|
|
) -> List[data.ClipPrediction]:
|
|
"""Postprocesses model outputs to generate clip predictions.
|
|
|
|
This function takes the output from the model, applies non-maximum suppression,
|
|
selects the top-k scores, computes sound events from the outputs, and returns
|
|
clip predictions based on these processed outputs.
|
|
|
|
Parameters
|
|
----------
|
|
outputs
|
|
Output from the model containing detection probabilities, size
|
|
predictions, class logits, and features. All tensors are expected
|
|
to have a batch dimension.
|
|
clips
|
|
List of clips for which predictions are made. The number of clips
|
|
must match the batch dimension of the model outputs.
|
|
config
|
|
Configuration for postprocessing model outputs.
|
|
|
|
Returns
|
|
-------
|
|
predictions: List[data.ClipPrediction]
|
|
List of clip predictions containing predicted sound events.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If the number of predictions does not match the number of clips.
|
|
"""
|
|
|
|
config = config or PostprocessConfig()
|
|
|
|
num_predictions = len(outputs.detection_probs)
|
|
|
|
if num_predictions == 0:
|
|
return []
|
|
|
|
if num_predictions != len(clips):
|
|
raise ValueError(
|
|
"Number of predictions must match the number of clips."
|
|
)
|
|
|
|
detection_probs = non_max_suppression(
|
|
outputs.detection_probs,
|
|
kernel_size=config.nms_kernel_size,
|
|
)
|
|
|
|
duration = clips[0].end_time - clips[0].start_time
|
|
|
|
scores_batch, y_pos_batch, x_pos_batch = get_topk_scores(
|
|
detection_probs,
|
|
int(config.top_k_per_sec * duration / 2),
|
|
)
|
|
|
|
predictions: List[data.ClipPrediction] = []
|
|
for scores, y_pos, x_pos, size_preds, class_probs, features, clip in zip(
|
|
scores_batch,
|
|
y_pos_batch,
|
|
x_pos_batch,
|
|
outputs.size_preds,
|
|
outputs.class_probs,
|
|
outputs.features,
|
|
clips,
|
|
):
|
|
sound_events = compute_sound_events_from_outputs(
|
|
clip,
|
|
scores,
|
|
y_pos,
|
|
x_pos,
|
|
size_preds,
|
|
class_probs,
|
|
features,
|
|
classes=classes,
|
|
decoder=decoder,
|
|
min_freq=config.min_freq,
|
|
max_freq=config.max_freq,
|
|
detection_threshold=config.detection_threshold,
|
|
)
|
|
|
|
predictions.append(
|
|
data.ClipPrediction(
|
|
clip=clip,
|
|
sound_events=sound_events,
|
|
)
|
|
)
|
|
|
|
return predictions
|
|
|
|
|
|
def compute_predictions_from_outputs(
|
|
start: float,
|
|
end: float,
|
|
scores: torch.Tensor,
|
|
y_pos: torch.Tensor,
|
|
x_pos: torch.Tensor,
|
|
size_preds: torch.Tensor,
|
|
class_probs: torch.Tensor,
|
|
features: torch.Tensor,
|
|
classes: List[str],
|
|
min_freq: int = 10000,
|
|
max_freq: int = 120000,
|
|
detection_threshold: float = DETECTION_THRESHOLD,
|
|
) -> List[RawPrediction]:
|
|
_, freq_bins, time_bins = size_preds.shape
|
|
|
|
sorted_indices = torch.argsort(x_pos)
|
|
valid_indices = sorted_indices[
|
|
scores[sorted_indices] > detection_threshold
|
|
]
|
|
|
|
scores = scores[valid_indices]
|
|
x_pos = x_pos[valid_indices]
|
|
y_pos = y_pos[valid_indices]
|
|
|
|
predictions: List[RawPrediction] = []
|
|
for score, x, y in zip(scores, x_pos, y_pos):
|
|
width, height = size_preds[:, y, x]
|
|
class_prob = class_probs[:, y, x].detach().numpy()
|
|
feats = features[:, y, x].detach().numpy()
|
|
|
|
start_time = np.interp(
|
|
x.item(),
|
|
[0, time_bins],
|
|
[start, end],
|
|
)
|
|
|
|
end_time = np.interp(
|
|
x.item() + width.item(),
|
|
[0, time_bins],
|
|
[start, end],
|
|
)
|
|
|
|
low_freq = np.interp(
|
|
y.item(),
|
|
[0, freq_bins],
|
|
[max_freq, min_freq],
|
|
)
|
|
|
|
high_freq = np.interp(
|
|
y.item() - height.item(),
|
|
[0, freq_bins],
|
|
[max_freq, min_freq],
|
|
)
|
|
|
|
start_time, end_time = sorted([float(start_time), float(end_time)])
|
|
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
|
|
|
predictions.append(
|
|
RawPrediction(
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
low_freq=low_freq,
|
|
high_freq=high_freq,
|
|
detection_score=score.item(),
|
|
features=feats,
|
|
class_scores={
|
|
class_name: prob
|
|
for class_name, prob in zip(classes, class_prob)
|
|
},
|
|
)
|
|
)
|
|
|
|
return predictions
|
|
|
|
|
|
def convert_raw_prediction_to_soundevent(
|
|
prediction: RawPrediction,
|
|
) -> data.SoundEventPrediction:
|
|
pass
|
|
|
|
|
|
def compute_sound_events_from_outputs(
|
|
clip: data.Clip,
|
|
scores: torch.Tensor,
|
|
y_pos: torch.Tensor,
|
|
x_pos: torch.Tensor,
|
|
size_preds: torch.Tensor,
|
|
class_probs: torch.Tensor,
|
|
features: torch.Tensor,
|
|
classes: List[str],
|
|
decoder: Callable[[str], List[data.Tag]],
|
|
min_freq: int = 10000,
|
|
max_freq: int = 120000,
|
|
detection_threshold: float = DETECTION_THRESHOLD,
|
|
) -> List[data.SoundEventPrediction]:
|
|
_, freq_bins, time_bins = size_preds.shape
|
|
|
|
sorted_indices = torch.argsort(x_pos)
|
|
valid_indices = sorted_indices[
|
|
scores[sorted_indices] > detection_threshold
|
|
]
|
|
|
|
scores = scores[valid_indices]
|
|
x_pos = x_pos[valid_indices]
|
|
y_pos = y_pos[valid_indices]
|
|
|
|
predictions: List[data.SoundEventPrediction] = []
|
|
for score, x, y in zip(scores, x_pos, y_pos):
|
|
width, height = size_preds[:, y, x]
|
|
class_prob = class_probs[:, y, x]
|
|
feature = features[:, y, x]
|
|
|
|
start_time = np.interp(
|
|
x.item(),
|
|
[0, time_bins],
|
|
[clip.start_time, clip.end_time],
|
|
)
|
|
|
|
end_time = np.interp(
|
|
x.item() + width.item(),
|
|
[0, time_bins],
|
|
[clip.start_time, clip.end_time],
|
|
)
|
|
|
|
low_freq = np.interp(
|
|
y.item(),
|
|
[0, freq_bins],
|
|
[max_freq, min_freq],
|
|
)
|
|
|
|
high_freq = np.interp(
|
|
y.item() - height.item(),
|
|
[0, freq_bins],
|
|
[max_freq, min_freq],
|
|
)
|
|
|
|
predicted_tags: List[data.PredictedTag] = []
|
|
|
|
for label_id, class_score in enumerate(class_prob):
|
|
class_name = classes[label_id]
|
|
corresponding_tags = decoder(class_name)
|
|
predicted_tags.extend(
|
|
[
|
|
data.PredictedTag(
|
|
tag=tag,
|
|
score=max(min(class_score.item(), 1), 0),
|
|
)
|
|
for tag in corresponding_tags
|
|
]
|
|
)
|
|
|
|
start_time, end_time = sorted([float(start_time), float(end_time)])
|
|
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
|
|
|
sound_event = data.SoundEvent(
|
|
recording=clip.recording,
|
|
geometry=data.BoundingBox(
|
|
coordinates=[
|
|
start_time,
|
|
low_freq,
|
|
end_time,
|
|
high_freq,
|
|
]
|
|
),
|
|
features=[
|
|
data.Feature(
|
|
term=data.term_from_key(f"batdetect2_{i}"),
|
|
value=value.item(),
|
|
)
|
|
for i, value in enumerate(feature)
|
|
],
|
|
)
|
|
|
|
predictions.append(
|
|
data.SoundEventPrediction(
|
|
sound_event=sound_event,
|
|
score=max(min(score.item(), 1), 0),
|
|
tags=predicted_tags,
|
|
)
|
|
)
|
|
|
|
return predictions
|
|
|
|
|
|
def non_max_suppression(
|
|
tensor: torch.Tensor,
|
|
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
|
) -> torch.Tensor:
|
|
"""Run non-maximum suppression on a tensor.
|
|
|
|
This function removes values from the input tensor that are not local
|
|
maxima in the neighborhood of the given kernel size.
|
|
|
|
All non-maximum values are set to zero.
|
|
|
|
Parameters
|
|
----------
|
|
tensor : torch.Tensor
|
|
Input tensor.
|
|
kernel_size : Union[int, Tuple[int, int]], optional
|
|
Size of the neighborhood to consider for non-maximum suppression.
|
|
If an integer is given, the neighborhood will be a square of the
|
|
given size. If a tuple is given, the neighborhood will be a
|
|
rectangle with the given height and width.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
Tensor with non-maximum suppressed values.
|
|
"""
|
|
if isinstance(kernel_size, int):
|
|
kernel_size_h = kernel_size
|
|
kernel_size_w = kernel_size
|
|
else:
|
|
kernel_size_h, kernel_size_w = kernel_size
|
|
|
|
pad_h = (kernel_size_h - 1) // 2
|
|
pad_w = (kernel_size_w - 1) // 2
|
|
|
|
hmax = nn.functional.max_pool2d(
|
|
tensor,
|
|
(kernel_size_h, kernel_size_w),
|
|
stride=1,
|
|
padding=(pad_h, pad_w),
|
|
)
|
|
keep = (hmax == tensor).float()
|
|
return tensor * keep
|
|
|
|
|
|
def get_topk_scores(
|
|
scores: torch.Tensor,
|
|
K: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Get the top-k scores and their indices.
|
|
|
|
Parameters
|
|
----------
|
|
scores : torch.Tensor
|
|
Tensor with scores. Expects input of size: `batch x 1 x height x width`.
|
|
K : int
|
|
Number of top scores to return.
|
|
|
|
Returns
|
|
-------
|
|
scores : torch.Tensor
|
|
Top-k scores.
|
|
ys : torch.Tensor
|
|
Y coordinates of the top-k scores.
|
|
xs : torch.Tensor
|
|
X coordinates of the top-k scores.
|
|
"""
|
|
batch, _, height, width = scores.size()
|
|
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
|
|
topk_inds = topk_inds % (height * width)
|
|
topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
|
|
topk_xs = (topk_inds % width).long()
|
|
return topk_scores, topk_ys, topk_xs
|