mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Cleanup postprocessing
This commit is contained in:
parent
b3af70761e
commit
45ae15eed5
@ -1,6 +0,0 @@
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
class ClipTransform:
|
||||
def __init__(self, clip: ClipDetections):
|
||||
pass
|
||||
@ -7,7 +7,6 @@ from batdetect2.postprocess.config import (
|
||||
)
|
||||
from batdetect2.postprocess.extraction import extract_detection_peaks
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
@ -92,3 +91,22 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
|
||||
def map_detection_to_clip(
|
||||
detections: ClipDetectionsTensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> ClipDetectionsTensor:
|
||||
duration = end_time - start_time
|
||||
bandwidth = max_freq - min_freq
|
||||
return ClipDetectionsTensor(
|
||||
scores=detections.scores,
|
||||
sizes=detections.sizes,
|
||||
features=detections.features,
|
||||
class_scores=detections.class_scores,
|
||||
times=(detections.times * duration + start_time),
|
||||
frequencies=(detections.frequencies * bandwidth + min_freq),
|
||||
)
|
||||
|
||||
@ -1,385 +0,0 @@
|
||||
"""Remaps raw model output tensors to coordinate-aware xarray DataArrays.
|
||||
|
||||
This module provides utility functions to convert the raw numerical outputs
|
||||
(typically PyTorch tensors) from the BatDetect2 DNN model into
|
||||
`xarray.DataArray` objects. This step adds coordinate information
|
||||
(time in seconds, frequency in Hz) back to the model's predictions, making them
|
||||
interpretable in the context of the original audio signal and facilitating
|
||||
subsequent processing steps.
|
||||
|
||||
Functions are provided for common BatDetect2 output types: detection heatmaps,
|
||||
classification probability maps, size prediction maps, and potentially
|
||||
intermediate features.
|
||||
"""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.postprocess.types import ClipDetectionsTensor
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
|
||||
__all__ = [
|
||||
"features_to_xarray",
|
||||
"detection_to_xarray",
|
||||
"classification_to_xarray",
|
||||
"sizes_to_xarray",
|
||||
]
|
||||
|
||||
|
||||
def to_xarray(
|
||||
array: torch.Tensor | np.ndarray,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
name: str = "xarray",
|
||||
extra_dims: List[str] | None = None,
|
||||
extra_coords: Dict[str, np.ndarray] | None = None,
|
||||
) -> xr.DataArray:
|
||||
if isinstance(array, torch.Tensor):
|
||||
array = array.detach().cpu().numpy()
|
||||
|
||||
extra_ndims = array.ndim - 2
|
||||
|
||||
if extra_ndims < 0:
|
||||
raise ValueError(
|
||||
"Input array must have at least 2 dimensions, "
|
||||
f"got shape {array.shape}"
|
||||
)
|
||||
|
||||
width = array.shape[-1]
|
||||
height = array.shape[-2]
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
if extra_dims is None:
|
||||
extra_dims = [f"dim_{i}" for i in range(extra_ndims)]
|
||||
|
||||
if extra_coords is None:
|
||||
extra_coords = {}
|
||||
|
||||
return xr.DataArray(
|
||||
data=array,
|
||||
dims=[
|
||||
*extra_dims,
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
**extra_coords,
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
def map_detection_to_clip(
|
||||
detections: ClipDetectionsTensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> ClipDetectionsTensor:
|
||||
duration = end_time - start_time
|
||||
bandwidth = max_freq - min_freq
|
||||
return ClipDetectionsTensor(
|
||||
scores=detections.scores,
|
||||
sizes=detections.sizes,
|
||||
features=detections.features,
|
||||
class_scores=detections.class_scores,
|
||||
times=(detections.times * duration + start_time),
|
||||
frequencies=(detections.frequencies * bandwidth + min_freq),
|
||||
)
|
||||
|
||||
|
||||
def features_to_xarray(
|
||||
features: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
features_prefix: str = "batdetect2_feature_",
|
||||
):
|
||||
"""Convert a multi-channel feature tensor to a coordinate-aware DataArray.
|
||||
|
||||
Assigns time, frequency, and feature coordinates to a raw feature tensor
|
||||
output by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
The raw feature tensor from the model. Expected shape is
|
||||
(num_features, num_freq_bins, num_time_bins).
|
||||
start_time : float
|
||||
The start time (in seconds) corresponding to the first time bin of
|
||||
the tensor.
|
||||
end_time : float
|
||||
The end time (in seconds) corresponding to the *end* of the last time
|
||||
bin.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
The minimum frequency (in Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
The maximum frequency (in Hz) corresponding to the *end* of the last
|
||||
frequency bin.
|
||||
features_prefix : str, default="batdetect2_feature_"
|
||||
Prefix used to generate names for the feature coordinate dimension
|
||||
(e.g., "batdetect2_feature_0", "batdetect2_feature_1", ...).
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing the feature data with named dimensions
|
||||
('feature', 'frequency', 'time') and calculated coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions.
|
||||
"""
|
||||
if features.ndim != 3:
|
||||
raise ValueError(
|
||||
"Input features tensor must have 3 dimensions (C, T, F), "
|
||||
f"got shape {features.shape}"
|
||||
)
|
||||
|
||||
num_features, height, width = features.shape
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=features.detach().cpu().numpy(),
|
||||
dims=[
|
||||
Dimensions.feature.value,
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
Dimensions.feature.value: [
|
||||
f"{features_prefix}{i}" for i in range(num_features)
|
||||
],
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name="features",
|
||||
)
|
||||
|
||||
|
||||
def detection_to_xarray(
|
||||
detection: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
) -> xr.DataArray:
|
||||
"""Convert a single-channel detection heatmap tensor to a DataArray.
|
||||
|
||||
Assigns time and frequency coordinates to a raw detection heatmap tensor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection : torch.Tensor
|
||||
Raw detection heatmap tensor from the model. Expected shape is
|
||||
(1, num_freq_bins, num_time_bins).
|
||||
start_time : float
|
||||
Start time (seconds) corresponding to the first time bin.
|
||||
end_time : float
|
||||
End time (seconds) corresponding to the end of the last time bin.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||
bin.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing the detection scores with named
|
||||
dimensions ('frequency', 'time') and calculated coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions or if the first
|
||||
dimension size is not 1.
|
||||
"""
|
||||
if detection.ndim != 3:
|
||||
raise ValueError(
|
||||
"Input detection tensor must have 3 dimensions (1, T, F), "
|
||||
f"got shape {detection.shape}"
|
||||
)
|
||||
|
||||
num_channels, height, width = detection.shape
|
||||
|
||||
if num_channels != 1:
|
||||
raise ValueError(
|
||||
"Expected a single channel output, instead got "
|
||||
f"{num_channels} channels"
|
||||
)
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=detection.squeeze(dim=0).detach().cpu().numpy(),
|
||||
dims=[
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name="detection_score",
|
||||
)
|
||||
|
||||
|
||||
def classification_to_xarray(
|
||||
classes: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
class_names: List[str],
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
) -> xr.DataArray:
|
||||
"""Convert multi-channel class probability tensor to a DataArray.
|
||||
|
||||
Assigns category (class name), frequency, and time coordinates to a raw
|
||||
class probability tensor output by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
classes : torch.Tensor
|
||||
Raw class probability tensor. Expected shape is
|
||||
(num_classes, num_freq_bins, num_time_bins).
|
||||
start_time : float
|
||||
Start time (seconds) corresponding to the first time bin.
|
||||
end_time : float
|
||||
End time (seconds) corresponding to the end of the last time bin.
|
||||
class_names : List[str]
|
||||
Ordered list of class names corresponding to the first dimension
|
||||
of the `classes` tensor. The length must match `classes.shape[0]`.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||
bin.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing class probabilities with named
|
||||
dimensions ('category', 'frequency', 'time') and calculated
|
||||
coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions, or if the size of the
|
||||
first dimension does not match the length of `class_names`.
|
||||
"""
|
||||
if classes.ndim != 3:
|
||||
raise ValueError(
|
||||
"Input classes tensor must have 3 dimensions (C, F, T), "
|
||||
f"got shape {classes.shape}"
|
||||
)
|
||||
|
||||
num_classes, height, width = classes.shape
|
||||
|
||||
if num_classes != len(class_names):
|
||||
raise ValueError(
|
||||
"The number of classes does not coincide with the number of "
|
||||
"class names provided: "
|
||||
f"({num_classes = }) != ({len(class_names) = })"
|
||||
)
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=classes.detach().cpu().numpy(),
|
||||
dims=[
|
||||
"category",
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
"category": class_names,
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name="class_scores",
|
||||
)
|
||||
|
||||
|
||||
def sizes_to_xarray(
|
||||
sizes: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
) -> xr.DataArray:
|
||||
"""Convert the 2-channel size prediction tensor to a DataArray.
|
||||
|
||||
Assigns dimension ('width', 'height'), frequency, and time coordinates
|
||||
to the raw size prediction tensor output by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sizes : torch.Tensor
|
||||
Raw size prediction tensor. Expected shape is
|
||||
(2, num_freq_bins, num_time_bins), where the first dimension
|
||||
corresponds to predicted width and height respectively.
|
||||
start_time : float
|
||||
Start time (seconds) corresponding to the first time bin.
|
||||
end_time : float
|
||||
End time (seconds) corresponding to the end of the last time bin.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||
bin.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing predicted sizes with named dimensions
|
||||
('dimension', 'frequency', 'time') and calculated time/frequency
|
||||
coordinates. The 'dimension' coordinate will have values
|
||||
['width', 'height'].
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions or if the first
|
||||
dimension size is not exactly 2.
|
||||
"""
|
||||
num_channels, height, width = sizes.shape
|
||||
|
||||
if num_channels != 2:
|
||||
raise ValueError(
|
||||
"Expected a two-channel output, instead got "
|
||||
f"{num_channels} channels"
|
||||
)
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=sizes.detach().cpu().numpy(),
|
||||
dims=[
|
||||
"dimension",
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
"dimension": ["width", "height"],
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user