mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Remove old post_process module
This commit is contained in:
parent
3abebc9c17
commit
4aa2e6905c
@ -1,398 +0,0 @@
|
|||||||
"""Module for postprocessing model outputs."""
|
|
||||||
|
|
||||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.models.types import ModelOutput
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"PostprocessConfig",
|
|
||||||
"load_postprocess_config",
|
|
||||||
"postprocess_model_outputs",
|
|
||||||
]
|
|
||||||
|
|
||||||
NMS_KERNEL_SIZE = 9
|
|
||||||
DETECTION_THRESHOLD = 0.01
|
|
||||||
TOP_K_PER_SEC = 200
|
|
||||||
|
|
||||||
|
|
||||||
class PostprocessConfig(BaseConfig):
|
|
||||||
"""Configuration for postprocessing model outputs."""
|
|
||||||
|
|
||||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
|
||||||
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
|
|
||||||
min_freq: int = Field(default=10000, gt=0)
|
|
||||||
max_freq: int = Field(default=120000, gt=0)
|
|
||||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
|
||||||
|
|
||||||
|
|
||||||
def load_postprocess_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> PostprocessConfig:
|
|
||||||
return load_config(path, schema=PostprocessConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
class RawPrediction(NamedTuple):
|
|
||||||
start_time: float
|
|
||||||
end_time: float
|
|
||||||
low_freq: float
|
|
||||||
high_freq: float
|
|
||||||
detection_score: float
|
|
||||||
class_scores: Dict[str, float]
|
|
||||||
features: np.ndarray
|
|
||||||
|
|
||||||
|
|
||||||
def postprocess_model_outputs(
|
|
||||||
outputs: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
classes: List[str],
|
|
||||||
decoder: Callable[[str], List[data.Tag]],
|
|
||||||
config: Optional[PostprocessConfig] = None,
|
|
||||||
) -> List[data.ClipPrediction]:
|
|
||||||
"""Postprocesses model outputs to generate clip predictions.
|
|
||||||
|
|
||||||
This function takes the output from the model, applies non-maximum suppression,
|
|
||||||
selects the top-k scores, computes sound events from the outputs, and returns
|
|
||||||
clip predictions based on these processed outputs.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
outputs
|
|
||||||
Output from the model containing detection probabilities, size
|
|
||||||
predictions, class logits, and features. All tensors are expected
|
|
||||||
to have a batch dimension.
|
|
||||||
clips
|
|
||||||
List of clips for which predictions are made. The number of clips
|
|
||||||
must match the batch dimension of the model outputs.
|
|
||||||
config
|
|
||||||
Configuration for postprocessing model outputs.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
predictions: List[data.ClipPrediction]
|
|
||||||
List of clip predictions containing predicted sound events.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the number of predictions does not match the number of clips.
|
|
||||||
"""
|
|
||||||
|
|
||||||
config = config or PostprocessConfig()
|
|
||||||
|
|
||||||
num_predictions = len(outputs.detection_probs)
|
|
||||||
|
|
||||||
if num_predictions == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if num_predictions != len(clips):
|
|
||||||
raise ValueError(
|
|
||||||
"Number of predictions must match the number of clips."
|
|
||||||
)
|
|
||||||
|
|
||||||
detection_probs = non_max_suppression(
|
|
||||||
outputs.detection_probs,
|
|
||||||
kernel_size=config.nms_kernel_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = clips[0].end_time - clips[0].start_time
|
|
||||||
|
|
||||||
scores_batch, y_pos_batch, x_pos_batch = get_topk_scores(
|
|
||||||
detection_probs,
|
|
||||||
int(config.top_k_per_sec * duration / 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
predictions: List[data.ClipPrediction] = []
|
|
||||||
for scores, y_pos, x_pos, size_preds, class_probs, features, clip in zip(
|
|
||||||
scores_batch,
|
|
||||||
y_pos_batch,
|
|
||||||
x_pos_batch,
|
|
||||||
outputs.size_preds,
|
|
||||||
outputs.class_probs,
|
|
||||||
outputs.features,
|
|
||||||
clips,
|
|
||||||
):
|
|
||||||
sound_events = compute_sound_events_from_outputs(
|
|
||||||
clip,
|
|
||||||
scores,
|
|
||||||
y_pos,
|
|
||||||
x_pos,
|
|
||||||
size_preds,
|
|
||||||
class_probs,
|
|
||||||
features,
|
|
||||||
classes=classes,
|
|
||||||
decoder=decoder,
|
|
||||||
min_freq=config.min_freq,
|
|
||||||
max_freq=config.max_freq,
|
|
||||||
detection_threshold=config.detection_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
predictions.append(
|
|
||||||
data.ClipPrediction(
|
|
||||||
clip=clip,
|
|
||||||
sound_events=sound_events,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return predictions
|
|
||||||
|
|
||||||
|
|
||||||
def compute_predictions_from_outputs(
|
|
||||||
start: float,
|
|
||||||
end: float,
|
|
||||||
scores: torch.Tensor,
|
|
||||||
y_pos: torch.Tensor,
|
|
||||||
x_pos: torch.Tensor,
|
|
||||||
size_preds: torch.Tensor,
|
|
||||||
class_probs: torch.Tensor,
|
|
||||||
features: torch.Tensor,
|
|
||||||
classes: List[str],
|
|
||||||
min_freq: int = 10000,
|
|
||||||
max_freq: int = 120000,
|
|
||||||
detection_threshold: float = DETECTION_THRESHOLD,
|
|
||||||
) -> List[RawPrediction]:
|
|
||||||
_, freq_bins, time_bins = size_preds.shape
|
|
||||||
|
|
||||||
sorted_indices = torch.argsort(x_pos)
|
|
||||||
valid_indices = sorted_indices[
|
|
||||||
scores[sorted_indices] > detection_threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
scores = scores[valid_indices]
|
|
||||||
x_pos = x_pos[valid_indices]
|
|
||||||
y_pos = y_pos[valid_indices]
|
|
||||||
|
|
||||||
predictions: List[RawPrediction] = []
|
|
||||||
for score, x, y in zip(scores, x_pos, y_pos):
|
|
||||||
width, height = size_preds[:, y, x]
|
|
||||||
class_prob = class_probs[:, y, x].detach().numpy()
|
|
||||||
feats = features[:, y, x].detach().numpy()
|
|
||||||
|
|
||||||
start_time = np.interp(
|
|
||||||
x.item(),
|
|
||||||
[0, time_bins],
|
|
||||||
[start, end],
|
|
||||||
)
|
|
||||||
|
|
||||||
end_time = np.interp(
|
|
||||||
x.item() + width.item(),
|
|
||||||
[0, time_bins],
|
|
||||||
[start, end],
|
|
||||||
)
|
|
||||||
|
|
||||||
low_freq = np.interp(
|
|
||||||
y.item(),
|
|
||||||
[0, freq_bins],
|
|
||||||
[max_freq, min_freq],
|
|
||||||
)
|
|
||||||
|
|
||||||
high_freq = np.interp(
|
|
||||||
y.item() - height.item(),
|
|
||||||
[0, freq_bins],
|
|
||||||
[max_freq, min_freq],
|
|
||||||
)
|
|
||||||
|
|
||||||
start_time, end_time = sorted([float(start_time), float(end_time)])
|
|
||||||
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
|
||||||
|
|
||||||
predictions.append(
|
|
||||||
RawPrediction(
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
low_freq=low_freq,
|
|
||||||
high_freq=high_freq,
|
|
||||||
detection_score=score.item(),
|
|
||||||
features=feats,
|
|
||||||
class_scores={
|
|
||||||
class_name: prob
|
|
||||||
for class_name, prob in zip(classes, class_prob)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return predictions
|
|
||||||
|
|
||||||
|
|
||||||
def compute_sound_events_from_outputs(
|
|
||||||
clip: data.Clip,
|
|
||||||
scores: torch.Tensor,
|
|
||||||
y_pos: torch.Tensor,
|
|
||||||
x_pos: torch.Tensor,
|
|
||||||
size_preds: torch.Tensor,
|
|
||||||
class_probs: torch.Tensor,
|
|
||||||
features: torch.Tensor,
|
|
||||||
classes: List[str],
|
|
||||||
decoder: Callable[[str], List[data.Tag]],
|
|
||||||
min_freq: int = 10000,
|
|
||||||
max_freq: int = 120000,
|
|
||||||
detection_threshold: float = DETECTION_THRESHOLD,
|
|
||||||
) -> List[data.SoundEventPrediction]:
|
|
||||||
_, freq_bins, time_bins = size_preds.shape
|
|
||||||
|
|
||||||
sorted_indices = torch.argsort(x_pos)
|
|
||||||
valid_indices = sorted_indices[
|
|
||||||
scores[sorted_indices] > detection_threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
scores = scores[valid_indices]
|
|
||||||
x_pos = x_pos[valid_indices]
|
|
||||||
y_pos = y_pos[valid_indices]
|
|
||||||
|
|
||||||
predictions: List[data.SoundEventPrediction] = []
|
|
||||||
for score, x, y in zip(scores, x_pos, y_pos):
|
|
||||||
width, height = size_preds[:, y, x]
|
|
||||||
class_prob = class_probs[:, y, x]
|
|
||||||
feature = features[:, y, x]
|
|
||||||
|
|
||||||
start_time = np.interp(
|
|
||||||
x.item(),
|
|
||||||
[0, time_bins],
|
|
||||||
[clip.start_time, clip.end_time],
|
|
||||||
)
|
|
||||||
|
|
||||||
end_time = np.interp(
|
|
||||||
x.item() + width.item(),
|
|
||||||
[0, time_bins],
|
|
||||||
[clip.start_time, clip.end_time],
|
|
||||||
)
|
|
||||||
|
|
||||||
low_freq = np.interp(
|
|
||||||
y.item(),
|
|
||||||
[0, freq_bins],
|
|
||||||
[max_freq, min_freq],
|
|
||||||
)
|
|
||||||
|
|
||||||
high_freq = np.interp(
|
|
||||||
y.item() - height.item(),
|
|
||||||
[0, freq_bins],
|
|
||||||
[max_freq, min_freq],
|
|
||||||
)
|
|
||||||
|
|
||||||
predicted_tags: List[data.PredictedTag] = []
|
|
||||||
|
|
||||||
for label_id, class_score in enumerate(class_prob):
|
|
||||||
class_name = classes[label_id]
|
|
||||||
corresponding_tags = decoder(class_name)
|
|
||||||
predicted_tags.extend(
|
|
||||||
[
|
|
||||||
data.PredictedTag(
|
|
||||||
tag=tag,
|
|
||||||
score=max(min(class_score.item(), 1), 0),
|
|
||||||
)
|
|
||||||
for tag in corresponding_tags
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
start_time, end_time = sorted([float(start_time), float(end_time)])
|
|
||||||
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
|
||||||
|
|
||||||
sound_event = data.SoundEvent(
|
|
||||||
recording=clip.recording,
|
|
||||||
geometry=data.BoundingBox(
|
|
||||||
coordinates=[
|
|
||||||
start_time,
|
|
||||||
low_freq,
|
|
||||||
end_time,
|
|
||||||
high_freq,
|
|
||||||
]
|
|
||||||
),
|
|
||||||
features=[
|
|
||||||
data.Feature(
|
|
||||||
term=data.term_from_key(f"batdetect2_{i}"),
|
|
||||||
value=value.item(),
|
|
||||||
)
|
|
||||||
for i, value in enumerate(feature)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
predictions.append(
|
|
||||||
data.SoundEventPrediction(
|
|
||||||
sound_event=sound_event,
|
|
||||||
score=max(min(score.item(), 1), 0),
|
|
||||||
tags=predicted_tags,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return predictions
|
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Run non-maximum suppression on a tensor.
|
|
||||||
|
|
||||||
This function removes values from the input tensor that are not local
|
|
||||||
maxima in the neighborhood of the given kernel size.
|
|
||||||
|
|
||||||
All non-maximum values are set to zero.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
tensor : torch.Tensor
|
|
||||||
Input tensor.
|
|
||||||
kernel_size : Union[int, Tuple[int, int]], optional
|
|
||||||
Size of the neighborhood to consider for non-maximum suppression.
|
|
||||||
If an integer is given, the neighborhood will be a square of the
|
|
||||||
given size. If a tuple is given, the neighborhood will be a
|
|
||||||
rectangle with the given height and width.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Tensor with non-maximum suppressed values.
|
|
||||||
"""
|
|
||||||
if isinstance(kernel_size, int):
|
|
||||||
kernel_size_h = kernel_size
|
|
||||||
kernel_size_w = kernel_size
|
|
||||||
else:
|
|
||||||
kernel_size_h, kernel_size_w = kernel_size
|
|
||||||
|
|
||||||
pad_h = (kernel_size_h - 1) // 2
|
|
||||||
pad_w = (kernel_size_w - 1) // 2
|
|
||||||
|
|
||||||
hmax = nn.functional.max_pool2d(
|
|
||||||
tensor,
|
|
||||||
(kernel_size_h, kernel_size_w),
|
|
||||||
stride=1,
|
|
||||||
padding=(pad_h, pad_w),
|
|
||||||
)
|
|
||||||
keep = (hmax == tensor).float()
|
|
||||||
return tensor * keep
|
|
||||||
|
|
||||||
|
|
||||||
def get_topk_scores(
|
|
||||||
scores: torch.Tensor,
|
|
||||||
K: int,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""Get the top-k scores and their indices.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
scores : torch.Tensor
|
|
||||||
Tensor with scores. Expects input of size: `batch x 1 x height x width`.
|
|
||||||
K : int
|
|
||||||
Number of top scores to return.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
scores : torch.Tensor
|
|
||||||
Top-k scores.
|
|
||||||
ys : torch.Tensor
|
|
||||||
Y coordinates of the top-k scores.
|
|
||||||
xs : torch.Tensor
|
|
||||||
X coordinates of the top-k scores.
|
|
||||||
"""
|
|
||||||
batch, _, height, width = scores.size()
|
|
||||||
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
|
|
||||||
topk_inds = topk_inds % (height * width)
|
|
||||||
topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
|
|
||||||
topk_xs = (topk_inds % width).long()
|
|
||||||
return topk_scores, topk_ys, topk_xs
|
|
Loading…
Reference in New Issue
Block a user