From 45ae15eed56d0d46429642ab3a8a763700855071 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 18 Mar 2026 09:07:13 +0000 Subject: [PATCH] Cleanup postprocessing --- src/batdetect2/postprocess/clips.py | 6 - src/batdetect2/postprocess/postprocessor.py | 20 +- src/batdetect2/postprocess/remapping.py | 385 -------------------- 3 files changed, 19 insertions(+), 392 deletions(-) delete mode 100644 src/batdetect2/postprocess/clips.py delete mode 100644 src/batdetect2/postprocess/remapping.py diff --git a/src/batdetect2/postprocess/clips.py b/src/batdetect2/postprocess/clips.py deleted file mode 100644 index dcd17d5..0000000 --- a/src/batdetect2/postprocess/clips.py +++ /dev/null @@ -1,6 +0,0 @@ -from batdetect2.postprocess.types import ClipDetections - - -class ClipTransform: - def __init__(self, clip: ClipDetections): - pass diff --git a/src/batdetect2/postprocess/postprocessor.py b/src/batdetect2/postprocess/postprocessor.py index 7d5d3a4..d30e6d1 100644 --- a/src/batdetect2/postprocess/postprocessor.py +++ b/src/batdetect2/postprocess/postprocessor.py @@ -7,7 +7,6 @@ from batdetect2.postprocess.config import ( ) from batdetect2.postprocess.extraction import extract_detection_peaks from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression -from batdetect2.postprocess.remapping import map_detection_to_clip from batdetect2.postprocess.types import ( ClipDetectionsTensor, PostprocessorProtocol, @@ -92,3 +91,22 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): ) for detection in detections ] + + +def map_detection_to_clip( + detections: ClipDetectionsTensor, + start_time: float, + end_time: float, + min_freq: float, + max_freq: float, +) -> ClipDetectionsTensor: + duration = end_time - start_time + bandwidth = max_freq - min_freq + return ClipDetectionsTensor( + scores=detections.scores, + sizes=detections.sizes, + features=detections.features, + class_scores=detections.class_scores, + times=(detections.times * duration + start_time), + frequencies=(detections.frequencies * bandwidth + min_freq), + ) diff --git a/src/batdetect2/postprocess/remapping.py b/src/batdetect2/postprocess/remapping.py deleted file mode 100644 index d1e95c8..0000000 --- a/src/batdetect2/postprocess/remapping.py +++ /dev/null @@ -1,385 +0,0 @@ -"""Remaps raw model output tensors to coordinate-aware xarray DataArrays. - -This module provides utility functions to convert the raw numerical outputs -(typically PyTorch tensors) from the BatDetect2 DNN model into -`xarray.DataArray` objects. This step adds coordinate information -(time in seconds, frequency in Hz) back to the model's predictions, making them -interpretable in the context of the original audio signal and facilitating -subsequent processing steps. - -Functions are provided for common BatDetect2 output types: detection heatmaps, -classification probability maps, size prediction maps, and potentially -intermediate features. -""" - -from typing import Dict, List - -import numpy as np -import torch -import xarray as xr -from soundevent.arrays import Dimensions - -from batdetect2.postprocess.types import ClipDetectionsTensor -from batdetect2.preprocess import MAX_FREQ, MIN_FREQ - -__all__ = [ - "features_to_xarray", - "detection_to_xarray", - "classification_to_xarray", - "sizes_to_xarray", -] - - -def to_xarray( - array: torch.Tensor | np.ndarray, - start_time: float, - end_time: float, - min_freq: float = MIN_FREQ, - max_freq: float = MAX_FREQ, - name: str = "xarray", - extra_dims: List[str] | None = None, - extra_coords: Dict[str, np.ndarray] | None = None, -) -> xr.DataArray: - if isinstance(array, torch.Tensor): - array = array.detach().cpu().numpy() - - extra_ndims = array.ndim - 2 - - if extra_ndims < 0: - raise ValueError( - "Input array must have at least 2 dimensions, " - f"got shape {array.shape}" - ) - - width = array.shape[-1] - height = array.shape[-2] - - times = np.linspace(start_time, end_time, width, endpoint=False) - freqs = np.linspace(min_freq, max_freq, height, endpoint=False) - - if extra_dims is None: - extra_dims = [f"dim_{i}" for i in range(extra_ndims)] - - if extra_coords is None: - extra_coords = {} - - return xr.DataArray( - data=array, - dims=[ - *extra_dims, - Dimensions.frequency.value, - Dimensions.time.value, - ], - coords={ - **extra_coords, - Dimensions.frequency.value: freqs, - Dimensions.time.value: times, - }, - name=name, - ) - - -def map_detection_to_clip( - detections: ClipDetectionsTensor, - start_time: float, - end_time: float, - min_freq: float, - max_freq: float, -) -> ClipDetectionsTensor: - duration = end_time - start_time - bandwidth = max_freq - min_freq - return ClipDetectionsTensor( - scores=detections.scores, - sizes=detections.sizes, - features=detections.features, - class_scores=detections.class_scores, - times=(detections.times * duration + start_time), - frequencies=(detections.frequencies * bandwidth + min_freq), - ) - - -def features_to_xarray( - features: torch.Tensor, - start_time: float, - end_time: float, - min_freq: float = MIN_FREQ, - max_freq: float = MAX_FREQ, - features_prefix: str = "batdetect2_feature_", -): - """Convert a multi-channel feature tensor to a coordinate-aware DataArray. - - Assigns time, frequency, and feature coordinates to a raw feature tensor - output by the model. - - Parameters - ---------- - features : torch.Tensor - The raw feature tensor from the model. Expected shape is - (num_features, num_freq_bins, num_time_bins). - start_time : float - The start time (in seconds) corresponding to the first time bin of - the tensor. - end_time : float - The end time (in seconds) corresponding to the *end* of the last time - bin. - min_freq : float, default=MIN_FREQ - The minimum frequency (in Hz) corresponding to the first frequency bin. - max_freq : float, default=MAX_FREQ - The maximum frequency (in Hz) corresponding to the *end* of the last - frequency bin. - features_prefix : str, default="batdetect2_feature_" - Prefix used to generate names for the feature coordinate dimension - (e.g., "batdetect2_feature_0", "batdetect2_feature_1", ...). - - Returns - ------- - xr.DataArray - An xarray DataArray containing the feature data with named dimensions - ('feature', 'frequency', 'time') and calculated coordinates. - - Raises - ------ - ValueError - If the input tensor does not have 3 dimensions. - """ - if features.ndim != 3: - raise ValueError( - "Input features tensor must have 3 dimensions (C, T, F), " - f"got shape {features.shape}" - ) - - num_features, height, width = features.shape - times = np.linspace(start_time, end_time, width, endpoint=False) - freqs = np.linspace(min_freq, max_freq, height, endpoint=False) - - return xr.DataArray( - data=features.detach().cpu().numpy(), - dims=[ - Dimensions.feature.value, - Dimensions.frequency.value, - Dimensions.time.value, - ], - coords={ - Dimensions.feature.value: [ - f"{features_prefix}{i}" for i in range(num_features) - ], - Dimensions.frequency.value: freqs, - Dimensions.time.value: times, - }, - name="features", - ) - - -def detection_to_xarray( - detection: torch.Tensor, - start_time: float, - end_time: float, - min_freq: float = MIN_FREQ, - max_freq: float = MAX_FREQ, -) -> xr.DataArray: - """Convert a single-channel detection heatmap tensor to a DataArray. - - Assigns time and frequency coordinates to a raw detection heatmap tensor. - - Parameters - ---------- - detection : torch.Tensor - Raw detection heatmap tensor from the model. Expected shape is - (1, num_freq_bins, num_time_bins). - start_time : float - Start time (seconds) corresponding to the first time bin. - end_time : float - End time (seconds) corresponding to the end of the last time bin. - min_freq : float, default=MIN_FREQ - Minimum frequency (Hz) corresponding to the first frequency bin. - max_freq : float, default=MAX_FREQ - Maximum frequency (Hz) corresponding to the end of the last frequency - bin. - - Returns - ------- - xr.DataArray - An xarray DataArray containing the detection scores with named - dimensions ('frequency', 'time') and calculated coordinates. - - Raises - ------ - ValueError - If the input tensor does not have 3 dimensions or if the first - dimension size is not 1. - """ - if detection.ndim != 3: - raise ValueError( - "Input detection tensor must have 3 dimensions (1, T, F), " - f"got shape {detection.shape}" - ) - - num_channels, height, width = detection.shape - - if num_channels != 1: - raise ValueError( - "Expected a single channel output, instead got " - f"{num_channels} channels" - ) - - times = np.linspace(start_time, end_time, width, endpoint=False) - freqs = np.linspace(min_freq, max_freq, height, endpoint=False) - - return xr.DataArray( - data=detection.squeeze(dim=0).detach().cpu().numpy(), - dims=[ - Dimensions.frequency.value, - Dimensions.time.value, - ], - coords={ - Dimensions.frequency.value: freqs, - Dimensions.time.value: times, - }, - name="detection_score", - ) - - -def classification_to_xarray( - classes: torch.Tensor, - start_time: float, - end_time: float, - class_names: List[str], - min_freq: float = MIN_FREQ, - max_freq: float = MAX_FREQ, -) -> xr.DataArray: - """Convert multi-channel class probability tensor to a DataArray. - - Assigns category (class name), frequency, and time coordinates to a raw - class probability tensor output by the model. - - Parameters - ---------- - classes : torch.Tensor - Raw class probability tensor. Expected shape is - (num_classes, num_freq_bins, num_time_bins). - start_time : float - Start time (seconds) corresponding to the first time bin. - end_time : float - End time (seconds) corresponding to the end of the last time bin. - class_names : List[str] - Ordered list of class names corresponding to the first dimension - of the `classes` tensor. The length must match `classes.shape[0]`. - min_freq : float, default=MIN_FREQ - Minimum frequency (Hz) corresponding to the first frequency bin. - max_freq : float, default=MAX_FREQ - Maximum frequency (Hz) corresponding to the end of the last frequency - bin. - - Returns - ------- - xr.DataArray - An xarray DataArray containing class probabilities with named - dimensions ('category', 'frequency', 'time') and calculated - coordinates. - - Raises - ------ - ValueError - If the input tensor does not have 3 dimensions, or if the size of the - first dimension does not match the length of `class_names`. - """ - if classes.ndim != 3: - raise ValueError( - "Input classes tensor must have 3 dimensions (C, F, T), " - f"got shape {classes.shape}" - ) - - num_classes, height, width = classes.shape - - if num_classes != len(class_names): - raise ValueError( - "The number of classes does not coincide with the number of " - "class names provided: " - f"({num_classes = }) != ({len(class_names) = })" - ) - - times = np.linspace(start_time, end_time, width, endpoint=False) - freqs = np.linspace(min_freq, max_freq, height, endpoint=False) - - return xr.DataArray( - data=classes.detach().cpu().numpy(), - dims=[ - "category", - Dimensions.frequency.value, - Dimensions.time.value, - ], - coords={ - "category": class_names, - Dimensions.frequency.value: freqs, - Dimensions.time.value: times, - }, - name="class_scores", - ) - - -def sizes_to_xarray( - sizes: torch.Tensor, - start_time: float, - end_time: float, - min_freq: float = MIN_FREQ, - max_freq: float = MAX_FREQ, -) -> xr.DataArray: - """Convert the 2-channel size prediction tensor to a DataArray. - - Assigns dimension ('width', 'height'), frequency, and time coordinates - to the raw size prediction tensor output by the model. - - Parameters - ---------- - sizes : torch.Tensor - Raw size prediction tensor. Expected shape is - (2, num_freq_bins, num_time_bins), where the first dimension - corresponds to predicted width and height respectively. - start_time : float - Start time (seconds) corresponding to the first time bin. - end_time : float - End time (seconds) corresponding to the end of the last time bin. - min_freq : float, default=MIN_FREQ - Minimum frequency (Hz) corresponding to the first frequency bin. - max_freq : float, default=MAX_FREQ - Maximum frequency (Hz) corresponding to the end of the last frequency - bin. - - Returns - ------- - xr.DataArray - An xarray DataArray containing predicted sizes with named dimensions - ('dimension', 'frequency', 'time') and calculated time/frequency - coordinates. The 'dimension' coordinate will have values - ['width', 'height']. - - Raises - ------ - ValueError - If the input tensor does not have 3 dimensions or if the first - dimension size is not exactly 2. - """ - num_channels, height, width = sizes.shape - - if num_channels != 2: - raise ValueError( - "Expected a two-channel output, instead got " - f"{num_channels} channels" - ) - - times = np.linspace(start_time, end_time, width, endpoint=False) - freqs = np.linspace(min_freq, max_freq, height, endpoint=False) - - return xr.DataArray( - data=sizes.detach().cpu().numpy(), - dims=[ - "dimension", - Dimensions.frequency.value, - Dimensions.time.value, - ], - coords={ - "dimension": ["width", "height"], - Dimensions.frequency.value: freqs, - Dimensions.time.value: times, - }, - )