diff --git a/batdetect2/postprocess/__init__.py b/batdetect2/postprocess/__init__.py index e69de29..b9050ad 100644 --- a/batdetect2/postprocess/__init__.py +++ b/batdetect2/postprocess/__init__.py @@ -0,0 +1,566 @@ +"""Main entry point for the BatDetect2 Postprocessing pipeline. + +This package (`batdetect2.postprocess`) takes the raw outputs from a trained +BatDetect2 neural network model and transforms them into meaningful, structured +predictions, typically in the form of `soundevent.data.ClipPrediction` objects +containing detected sound events with associated class tags and geometry. + +The pipeline involves several configurable steps, implemented in submodules: +1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks. +2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency + coordinates to raw model output arrays. +3. Detection Extraction (`.detection`): Identifies candidate detection points + (location and score) based on thresholds and score ranking (top-k). +4. Data Extraction (`.extraction`): Gathers associated model outputs (size, + class probabilities, features) at the detected locations. +5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and + class predictions into interpretable `soundevent` objects, including + recovering geometry (ROIs) and decoding class names back to standard tags. + +This module provides the primary interface: +- `PostprocessConfig`: A configuration object for postprocessing parameters + (thresholds, NMS kernel size, etc.). +- `load_postprocess_config`: Function to load the configuration from a file. +- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that + holds the configured pipeline logic. +- `build_postprocessor`: A factory function to create a `Postprocessor` + instance, linking it to the necessary target definitions (`TargetProtocol`). +It also re-exports key components from submodules for convenience. +""" + +from typing import List, Optional + +import xarray as xr +from pydantic import Field +from soundevent import data + +from batdetect2.configs import BaseConfig, load_config +from batdetect2.models.types import ModelOutput +from batdetect2.postprocess.decoding import ( + DEFAULT_CLASSIFICATION_THRESHOLD, + convert_raw_predictions_to_clip_prediction, + convert_xr_dataset_to_raw_prediction, +) +from batdetect2.postprocess.detection import ( + DEFAULT_DETECTION_THRESHOLD, + TOP_K_PER_SEC, + extract_detections_from_array, + get_max_detections, +) +from batdetect2.postprocess.extraction import ( + extract_detection_xr_dataset, +) +from batdetect2.postprocess.nms import ( + NMS_KERNEL_SIZE, + non_max_suppression, +) +from batdetect2.postprocess.remapping import ( + classification_to_xarray, + detection_to_xarray, + features_to_xarray, + sizes_to_xarray, +) +from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction +from batdetect2.preprocess import MAX_FREQ, MIN_FREQ +from batdetect2.targets.types import TargetProtocol + +__all__ = [ + "DEFAULT_CLASSIFICATION_THRESHOLD", + "DEFAULT_DETECTION_THRESHOLD", + "MAX_FREQ", + "MIN_FREQ", + "ModelOutput", + "NMS_KERNEL_SIZE", + "PostprocessConfig", + "Postprocessor", + "PostprocessorProtocol", + "RawPrediction", + "TOP_K_PER_SEC", + "build_postprocessor", + "classification_to_xarray", + "convert_raw_predictions_to_clip_prediction", + "convert_xr_dataset_to_raw_prediction", + "detection_to_xarray", + "extract_detection_xr_dataset", + "extract_detections_from_array", + "features_to_xarray", + "get_max_detections", + "load_postprocess_config", + "non_max_suppression", + "sizes_to_xarray", +] + + +class PostprocessConfig(BaseConfig): + """Configuration settings for the postprocessing pipeline. + + Defines tunable parameters that control how raw model outputs are + converted into final detections. + + Attributes + ---------- + nms_kernel_size : int, default=NMS_KERNEL_SIZE + Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression. + Used to suppress weaker detections near stronger peaks. Must be + positive. + detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD + Minimum confidence score from the detection heatmap required to + consider a point as a potential detection. Must be >= 0. + classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD + Minimum confidence score for a specific class prediction to be included + in the decoded tags for a detection. Must be >= 0. + top_k_per_sec : int, default=TOP_K_PER_SEC + Desired maximum number of detections per second of audio. Used by + `get_max_detections` to calculate an absolute limit based on clip + duration before applying `extract_detections_from_array`. Must be + positive. + """ + + nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0) + detection_threshold: float = Field( + default=DEFAULT_DETECTION_THRESHOLD, + ge=0, + ) + classification_threshold: float = Field( + default=DEFAULT_CLASSIFICATION_THRESHOLD, + ge=0, + ) + top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) + + +def load_postprocess_config( + path: data.PathLike, + field: Optional[str] = None, +) -> PostprocessConfig: + """Load the postprocessing configuration from a file. + + Reads a configuration file (YAML) and validates it against the + `PostprocessConfig` schema, potentially extracting data from a nested + field. + + Parameters + ---------- + path : PathLike + Path to the configuration file. + field : str, optional + Dot-separated path to a nested section within the file containing the + postprocessing configuration (e.g., "inference.postprocessing"). + If None, the entire file content is used. + + Returns + ------- + PostprocessConfig + The loaded and validated postprocessing configuration object. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + yaml.YAMLError + If the file content is not valid YAML. + pydantic.ValidationError + If the loaded configuration data does not conform to the + `PostprocessConfig` schema. + KeyError, TypeError + If `field` specifies an invalid path within the loaded data. + """ + return load_config(path, schema=PostprocessConfig, field=field) + + +def build_postprocessor( + targets: TargetProtocol, + config: Optional[PostprocessConfig] = None, + max_freq: int = MAX_FREQ, + min_freq: int = MIN_FREQ, +) -> PostprocessorProtocol: + """Factory function to build the standard postprocessor. + + Creates and initializes the `Postprocessor` instance, providing it with the + necessary `targets` object and the `PostprocessConfig`. + + Parameters + ---------- + targets : TargetProtocol + An initialized object conforming to the `TargetProtocol`, providing + methods like `.decode()` and `.recover_roi()`, and attributes like + `.class_names` and `.generic_class_tags`. This links postprocessing + to the defined target semantics and geometry mappings. + config : PostprocessConfig, optional + Configuration object specifying postprocessing parameters (thresholds, + NMS kernel size, etc.). If None, default settings defined in + `PostprocessConfig` will be used. + min_freq : int, default=MIN_FREQ + The minimum frequency (Hz) corresponding to the frequency axis of the + model outputs. Required for coordinate remapping. Consider setting via + `PostprocessConfig` instead for better encapsulation. + max_freq : int, default=MAX_FREQ + The maximum frequency (Hz) corresponding to the frequency axis of the + model outputs. Required for coordinate remapping. Consider setting via + `PostprocessConfig`. + + Returns + ------- + PostprocessorProtocol + An initialized `Postprocessor` instance ready to process model outputs. + """ + return Postprocessor( + targets=targets, + config=config or PostprocessConfig(), + min_freq=min_freq, + max_freq=max_freq, + ) + + +class Postprocessor(PostprocessorProtocol): + """Standard implementation of the postprocessing pipeline. + + This class orchestrates the steps required to convert raw model outputs + into interpretable `soundevent` predictions. It uses configured parameters + and leverages functions from the `batdetect2.postprocess` submodules for + each stage (NMS, remapping, detection, extraction, decoding). + + It requires a `TargetProtocol` object during initialization to access + necessary decoding information (class name to tag mapping, + ROI recovery logic) ensuring consistency with the target definitions used + during training or specified for inference. + + Instances are typically created using the `build_postprocessor` factory + function. + + Attributes + ---------- + targets : TargetProtocol + The configured target definition object providing decoding and ROI + recovery. + config : PostprocessConfig + Configuration object holding parameters for NMS, thresholds, etc. + min_freq : int + Minimum frequency (Hz) assumed for the model output's frequency axis. + max_freq : int + Maximum frequency (Hz) assumed for the model output's frequency axis. + """ + + targets: TargetProtocol + + def __init__( + self, + targets: TargetProtocol, + config: PostprocessConfig, + min_freq: int = MIN_FREQ, + max_freq: int = MAX_FREQ, + ): + """Initialize the Postprocessor. + + Parameters + ---------- + targets : TargetProtocol + Initialized target definition object. + config : PostprocessConfig + Configuration for postprocessing parameters. + min_freq : int, default=MIN_FREQ + Minimum frequency (Hz) for coordinate remapping. + max_freq : int, default=MAX_FREQ + Maximum frequency (Hz) for coordinate remapping. + """ + self.targets = targets + self.config = config + self.min_freq = min_freq + self.max_freq = max_freq + + def get_feature_arrays( + self, + output: ModelOutput, + clips: List[data.Clip], + ) -> List[xr.DataArray]: + """Extract and remap raw feature tensors for a batch. + + Parameters + ---------- + output : ModelOutput + Raw model output containing `output.features` tensor for the batch. + clips : List[data.Clip] + List of Clip objects corresponding to the batch items. + + Returns + ------- + List[xr.DataArray] + List of coordinate-aware feature DataArrays, one per clip. + + Raises + ------ + ValueError + If batch sizes of `output.features` and `clips` do not match. + """ + if len(clips) != len(output.features): + raise ValueError( + "Number of clips and batch size of feature array" + "do not match. " + f"(clips: {len(clips)}, features: {len(output.features)})" + ) + + return [ + features_to_xarray( + feats, + start_time=clip.start_time, + end_time=clip.end_time, + min_freq=self.min_freq, + max_freq=self.max_freq, + ) + for feats, clip in zip(output.features, clips) + ] + + def get_detection_arrays( + self, + output: ModelOutput, + clips: List[data.Clip], + ) -> List[xr.DataArray]: + """Apply NMS and remap detection heatmaps for a batch. + + Parameters + ---------- + output : ModelOutput + Raw model output containing `output.detection_probs` tensor for the + batch. + clips : List[data.Clip] + List of Clip objects corresponding to the batch items. + + Returns + ------- + List[xr.DataArray] + List of NMS-applied, coordinate-aware detection heatmaps, one per + clip. + + Raises + ------ + ValueError + If batch sizes of `output.detection_probs` and `clips` do not match. + """ + detections = output.detection_probs + + if len(clips) != len(output.detection_probs): + raise ValueError( + "Number of clips and batch size of detection array " + "do not match. " + f"(clips: {len(clips)}, detection: {len(detections)})" + ) + + detections = non_max_suppression( + detections, + kernel_size=self.config.nms_kernel_size, + ) + + return [ + detection_to_xarray( + dets, + start_time=clip.start_time, + end_time=clip.end_time, + min_freq=self.min_freq, + max_freq=self.max_freq, + ) + for dets, clip in zip(detections, clips) + ] + + def get_classification_arrays( + self, output: ModelOutput, clips: List[data.Clip] + ) -> List[xr.DataArray]: + """Extract and remap raw classification tensors for a batch. + + Parameters + ---------- + output : ModelOutput + Raw model output containing `output.class_probs` tensor for the + batch. + clips : List[data.Clip] + List of Clip objects corresponding to the batch items. + + Returns + ------- + List[xr.DataArray] + List of coordinate-aware class probability maps, one per clip. + + Raises + ------ + ValueError + If batch sizes of `output.class_probs` and `clips` do not match, or + if number of classes mismatches `self.targets.class_names`. + """ + classifications = output.class_probs + + if len(clips) != len(classifications): + raise ValueError( + "Number of clips and batch size of classification array " + "do not match. " + f"(clips: {len(clips)}, classification: {len(classifications)})" + ) + + return [ + classification_to_xarray( + class_probs, + start_time=clip.start_time, + end_time=clip.end_time, + class_names=self.targets.class_names, + min_freq=self.min_freq, + max_freq=self.max_freq, + ) + for class_probs, clip in zip(classifications, clips) + ] + + def get_sizes_arrays( + self, output: ModelOutput, clips: List[data.Clip] + ) -> List[xr.DataArray]: + """Extract and remap raw size prediction tensors for a batch. + + Parameters + ---------- + output : ModelOutput + Raw model output containing `output.size_preds` tensor for the + batch. + clips : List[data.Clip] + List of Clip objects corresponding to the batch items. + + Returns + ------- + List[xr.DataArray] + List of coordinate-aware size prediction maps, one per clip. + + Raises + ------ + ValueError + If batch sizes of `output.size_preds` and `clips` do not match. + """ + sizes = output.size_preds + + if len(clips) != len(sizes): + raise ValueError( + "Number of clips and batch size of sizes array do not match. " + f"(clips: {len(clips)}, sizes: {len(sizes)})" + ) + + return [ + sizes_to_xarray( + size_preds, + start_time=clip.start_time, + end_time=clip.end_time, + min_freq=self.min_freq, + max_freq=self.max_freq, + ) + for size_preds, clip in zip(sizes, clips) + ] + + def get_detection_datasets( + self, output: ModelOutput, clips: List[data.Clip] + ) -> List[xr.Dataset]: + """Perform NMS, remapping, detection, and data extraction for a batch. + + Parameters + ---------- + output : ModelOutput + Raw output from the neural network model for a batch. + clips : List[data.Clip] + List of `soundevent.data.Clip` objects corresponding to the batch. + + Returns + ------- + List[xr.Dataset] + List of xarray Datasets (one per clip). Each Dataset contains + aligned scores, dimensions, class probabilities, and features for + detections found in that clip. + """ + detection_arrays = self.get_detection_arrays(output, clips) + classification_arrays = self.get_classification_arrays(output, clips) + size_arrays = self.get_sizes_arrays(output, clips) + features_arrays = self.get_feature_arrays(output, clips) + + datasets = [] + for det_array, class_array, sizes_array, feats_array in zip( + detection_arrays, + classification_arrays, + size_arrays, + features_arrays, + ): + max_detections = get_max_detections( + det_array, + top_k_per_sec=self.config.top_k_per_sec, + ) + + positions = extract_detections_from_array( + det_array, + max_detections=max_detections, + threshold=self.config.detection_threshold, + ) + + datasets.append( + extract_detection_xr_dataset( + positions, + sizes_array, + class_array, + feats_array, + ) + ) + + return datasets + + def get_raw_predictions( + self, output: ModelOutput, clips: List[data.Clip] + ) -> List[List[RawPrediction]]: + """Extract intermediate RawPrediction objects for a batch. + + Processes raw model output through remapping, NMS, detection, data + extraction, and geometry recovery via the configured + `targets.recover_roi`. + + Parameters + ---------- + output : ModelOutput + Raw output from the neural network model for a batch. + clips : List[data.Clip] + List of `soundevent.data.Clip` objects corresponding to the batch. + + Returns + ------- + List[List[RawPrediction]] + List of lists (one inner list per input clip). Each inner list + contains `RawPrediction` objects for detections in that clip. + """ + detection_datasets = self.get_detection_datasets(output, clips) + return [ + convert_xr_dataset_to_raw_prediction( + dataset, + self.targets.recover_roi, + ) + for dataset in detection_datasets + ] + + def get_predictions( + self, output: ModelOutput, clips: List[data.Clip] + ) -> List[data.ClipPrediction]: + """Perform the full postprocessing pipeline for a batch. + + Takes raw model output and corresponding clips, applies the entire + configured chain (NMS, remapping, extraction, geometry recovery, class + decoding), producing final `soundevent.data.ClipPrediction` objects. + + Parameters + ---------- + output : ModelOutput + Raw output from the neural network model for a batch. + clips : List[data.Clip] + List of `soundevent.data.Clip` objects corresponding to the batch. + + Returns + ------- + List[data.ClipPrediction] + List containing one `ClipPrediction` object for each input clip, + populated with `SoundEventPrediction` objects. + """ + raw_predictions = self.get_raw_predictions(output, clips) + return [ + convert_raw_predictions_to_clip_prediction( + prediction, + clip, + sound_event_decoder=self.targets.decode, + generic_class_tags=self.targets.generic_class_tags, + classification_threshold=self.config.classification_threshold, + ) + for prediction, clip in zip(raw_predictions, clips) + ] diff --git a/batdetect2/postprocess/arrays.py b/batdetect2/postprocess/arrays.py deleted file mode 100644 index f657e28..0000000 --- a/batdetect2/postprocess/arrays.py +++ /dev/null @@ -1,73 +0,0 @@ -import numpy as np -import xarray as xr -from soundevent.arrays import Dimensions - -from batdetect2.models import ModelOutput -from batdetect2.preprocess import MAX_FREQ, MIN_FREQ - - -def to_xarray( - output: ModelOutput, - start_time: float, - end_time: float, - class_names: list[str], - min_freq: float = MIN_FREQ, - max_freq: float = MAX_FREQ, -): - detection = output.detection_probs - size = output.size_preds - classes = output.class_probs - features = output.features - - if len(detection.shape) == 4: - if detection.shape[0] != 1: - raise ValueError( - "Expected a non-batched output or a batch of size 1, instead " - f"got an input of shape {detection.shape}" - ) - - detection = detection.squeeze(dim=0) - size = size.squeeze(dim=0) - classes = classes.squeeze(dim=0) - features = features.squeeze(dim=0) - - _, width, height = detection.shape - - times = np.linspace(start_time, end_time, width, endpoint=False) - freqs = np.linspace(min_freq, max_freq, height, endpoint=False) - - if classes.shape[0] != len(class_names): - raise ValueError( - f"The number of classes does not coincide with the number of class names provided: ({classes.shape[0] = }) != ({len(class_names) = })" - ) - - return xr.Dataset( - data_vars={ - "detection": ( - [Dimensions.time.value, Dimensions.frequency.value], - detection.squeeze(dim=0).detach().numpy(), - ), - "size": ( - [ - "dimension", - Dimensions.time.value, - Dimensions.frequency.value, - ], - detection.detach().numpy(), - ), - "classes": ( - [ - "category", - Dimensions.time.value, - Dimensions.frequency.value, - ], - classes.detach().numpy(), - ), - }, - coords={ - Dimensions.time.value: times, - Dimensions.frequency.value: freqs, - "dimension": ["width", "height"], - "category": class_names, - }, - ) diff --git a/batdetect2/postprocess/config.py b/batdetect2/postprocess/config.py deleted file mode 100644 index 3c4bada..0000000 --- a/batdetect2/postprocess/config.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Optional - -from pydantic import Field -from soundevent import data - -from batdetect2.configs import BaseConfig, load_config - -__all__ = [ - "PostprocessConfig", - "load_postprocess_config", -] - -NMS_KERNEL_SIZE = 9 -DETECTION_THRESHOLD = 0.01 -TOP_K_PER_SEC = 200 - - -class PostprocessConfig(BaseConfig): - """Configuration for postprocessing model outputs.""" - - nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0) - detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0) - min_freq: int = Field(default=10000, gt=0) - max_freq: int = Field(default=120000, gt=0) - top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) - - -def load_postprocess_config( - path: data.PathLike, - field: Optional[str] = None, -) -> PostprocessConfig: - return load_config(path, schema=PostprocessConfig, field=field) diff --git a/batdetect2/postprocess/decoding.py b/batdetect2/postprocess/decoding.py new file mode 100644 index 0000000..09537f7 --- /dev/null +++ b/batdetect2/postprocess/decoding.py @@ -0,0 +1,297 @@ +"""Decodes extracted detection data into standard soundevent predictions. + +This module handles the final stages of the BatDetect2 postprocessing pipeline. +It takes the structured detection data extracted by the `extraction` module +(typically an `xarray.Dataset` containing scores, positions, predicted sizes, + class probabilities, and features for each detection point) and converts it +into meaningful, standardized prediction objects based on the `soundevent` data +model. + +The process involves: +1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction` + objects, using a configured geometry builder to recover bounding boxes from + predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`). +2. Converting each `RawPrediction` into a + `soundevent.data.SoundEventPrediction`, which involves: + - Creating the `soundevent.data.SoundEvent` with geometry and features. + - Decoding the predicted class probabilities into representative tags using + a configured class decoder (`SoundEventDecoder`). + - Applying a classification threshold. + - Optionally selecting only the single highest-scoring class (top-1) or + including tags for all classes above the threshold (multi-label). + - Adding generic class tags as a baseline. + - Associating scores with the final prediction and tags. + (`convert_raw_prediction_to_sound_event_prediction`) +3. Grouping the `SoundEventPrediction` objects for a given audio segment into + a `soundevent.data.ClipPrediction` + (`convert_raw_predictions_to_clip_prediction`). +""" + +from typing import List, Optional + +import xarray as xr +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.postprocess.types import GeometryBuilder, RawPrediction +from batdetect2.targets.classes import SoundEventDecoder + +__all__ = [ + "convert_xr_dataset_to_raw_prediction", + "convert_raw_predictions_to_clip_prediction", + "convert_raw_prediction_to_sound_event_prediction", + "DEFAULT_CLASSIFICATION_THRESHOLD", +] + + +DEFAULT_CLASSIFICATION_THRESHOLD = 0.1 +"""Default threshold applied to classification scores. + +Class predictions with scores below this value are typically ignored during +decoding. +""" + + +def convert_xr_dataset_to_raw_prediction( + detection_dataset: xr.Dataset, + geometry_builder: GeometryBuilder, +) -> List[RawPrediction]: + """Convert an xarray.Dataset of detections to RawPrediction objects. + + Takes the output of the extraction step (`extract_detection_xr_dataset`) + and transforms each detection entry into an intermediate `RawPrediction` + object. This involves recovering the geometry (e.g., bounding box) from + the predicted position and scaled size dimensions using the provided + `geometry_builder` function. + + Parameters + ---------- + detection_dataset : xr.Dataset + An xarray Dataset containing aligned detection information, typically + output by `extract_detection_xr_dataset`. Expected variables include + 'scores' (with time/freq coords), 'dimensions', 'classes', 'features'. + Must have a 'detection' dimension. + geometry_builder : GeometryBuilder + A function that takes a position tuple `(time, freq)` and a NumPy array + of dimensions, and returns the corresponding reconstructed + `soundevent.data.Geometry`. + + Returns + ------- + List[RawPrediction] + A list of `RawPrediction` objects, each containing the detection score, + recovered bounding box coordinates (start/end time, low/high freq), + the vector of class scores, and the feature vector for one detection. + + Raises + ------ + AttributeError, KeyError, ValueError + If `detection_dataset` is missing expected variables ('scores', + 'dimensions', 'classes', 'features') or coordinates ('time', 'freq' + associated with 'scores'), or if `geometry_builder` fails. + """ + detections = [] + + for det_num in range(detection_dataset.dims["detection"]): + det_info = detection_dataset.sel(detection=det_num) + + geom = geometry_builder( + (det_info.time, det_info.freq), + det_info.dimensions, + ) + + start_time, low_freq, end_time, high_freq = compute_bounds(geom) + + classes = det_info.classes + features = det_info.features + + detections.append( + RawPrediction( + detection_score=det_info.score, + start_time=start_time, + end_time=end_time, + low_freq=low_freq, + high_freq=high_freq, + class_scores=classes, + features=features, + ) + ) + + return detections + + +def convert_raw_predictions_to_clip_prediction( + raw_predictions: List[RawPrediction], + clip: data.Clip, + sound_event_decoder: SoundEventDecoder, + generic_class_tags: List[data.Tag], + classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, + top_class_only: bool = False, +) -> data.ClipPrediction: + """Convert a list of RawPredictions into a soundevent ClipPrediction. + + Iterates through `raw_predictions` (assumed to belong to a single clip), + converts each one into a `soundevent.data.SoundEventPrediction` using + `convert_raw_prediction_to_sound_event_prediction`, and packages them + into a `soundevent.data.ClipPrediction` associated with the original `clip`. + + Parameters + ---------- + raw_predictions : List[RawPrediction] + List of raw prediction objects for a single clip. + clip : data.Clip + The original `soundevent.data.Clip` object these predictions belong to. + sound_event_decoder : SoundEventDecoder + Function to decode class names into representative tags. + generic_class_tags : List[data.Tag] + List of tags representing the generic class category. + classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD + Threshold applied to class scores during decoding. + top_class_only : bool, default=False + If True, only decode tags for the single highest-scoring class above + the threshold. If False, decode tags for all classes above threshold. + + Returns + ------- + data.ClipPrediction + A `ClipPrediction` object containing a list of `SoundEventPrediction` + objects corresponding to the input `raw_predictions`. + """ + return data.ClipPrediction( + clip=clip, + sound_events=[ + convert_raw_prediction_to_sound_event_prediction( + prediction, + recording=clip.recording, + sound_event_decoder=sound_event_decoder, + generic_class_tags=generic_class_tags, + classification_threshold=classification_threshold, + top_class_only=top_class_only, + ) + for prediction in raw_predictions + ], + ) + + +def convert_raw_prediction_to_sound_event_prediction( + raw_prediction: RawPrediction, + recording: data.Recording, + sound_event_decoder: SoundEventDecoder, + generic_class_tags: List[data.Tag], + classification_threshold: Optional[ + float + ] = DEFAULT_CLASSIFICATION_THRESHOLD, + top_class_only: bool = False, +): + """Convert a single RawPrediction into a soundevent SoundEventPrediction. + + This function performs the core decoding steps for a single detected event: + 1. Creates a `soundevent.data.SoundEvent` containing the geometry + (BoundingBox derived from `raw_prediction` bounds) and any associated + feature vectors. + 2. Initializes a list of predicted tags using the provided + `generic_class_tags`, assigning the overall `detection_score` from the + `raw_prediction` to these generic tags. + 3. Processes the `class_scores` from the `raw_prediction`: + a. Optionally filters out scores below `classification_threshold` + (if it's not None). + b. Sorts the remaining scores in descending order. + c. Iterates through the sorted, thresholded class scores. + d. For each class, uses the `sound_event_decoder` to get the + representative base tags for that class name. + e. Wraps these base tags in `soundevent.data.PredictedTag`, associating + the specific `score` of that class prediction. + f. Appends these specific predicted tags to the list. + g. If `top_class_only` is True, stops after processing the first + (highest-scoring) class that passed the threshold. + 4. Creates and returns the final `soundevent.data.SoundEventPrediction`, + associating the `SoundEvent`, the overall `detection_score`, and the + compiled list of `PredictedTag` objects. + + Parameters + ---------- + raw_prediction : RawPrediction + The raw prediction object containing score, bounds, class scores, + features. Assumes `class_scores` is an `xr.DataArray` with a 'category' + coordinate. Assumes `features` is an `xr.DataArray` with a 'feature' + coordinate. + recording : data.Recording + The recording the sound event belongs to. + sound_event_decoder : SoundEventDecoder + Configured function mapping class names (str) to lists of base + `data.Tag` objects. + generic_class_tags : List[data.Tag] + List of base tags representing the generic category. + classification_threshold : float, optional + The minimum score a class prediction must have to be considered + significant enough to have its tags decoded and added. If None, no + thresholding is applied based on class score (all predicted classes, + or the top one if `top_class_only` is True, will be processed). + Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`. + top_class_only : bool, default=False + If True, only includes tags for the single highest-scoring class that + exceeds the threshold. If False (default), includes tags for all classes + exceeding the threshold. + + Returns + ------- + data.SoundEventPrediction + The fully formed sound event prediction object. + + Raises + ------ + ValueError + If `raw_prediction.features` has unexpected structure or if + `data.term_from_key` (if used internally) fails. + If `sound_event_decoder` fails for a class name and errors are raised. + """ + sound_event = data.SoundEvent( + recording=recording, + geometry=data.BoundingBox( + coordinates=[ + raw_prediction.start_time, + raw_prediction.low_freq, + raw_prediction.end_time, + raw_prediction.high_freq, + ] + ), + features=[ + data.Feature(term=data.term_from_key(feat_name), value=value) + for feat_name, value in raw_prediction.features + ], + ) + + tags = [ + data.PredictedTag(tag=tag, score=raw_prediction.detection_score) + for tag in generic_class_tags + ] + + class_scores = raw_prediction.class_scores + + if classification_threshold is not None: + class_scores = class_scores.where( + class_scores > classification_threshold, + drop=True, + ) + + for class_name, score in class_scores.sortby( + class_scores, ascending=False + ): + class_tags = sound_event_decoder(class_name) + + for tag in class_tags: + tags.append( + data.PredictedTag( + tag=tag, + score=score, + ) + ) + + if top_class_only: + break + + return data.SoundEventPrediction( + sound_event=sound_event, + score=raw_prediction.detection_score, + tags=tags, + ) diff --git a/batdetect2/postprocess/detection.py b/batdetect2/postprocess/detection.py new file mode 100644 index 0000000..96ac430 --- /dev/null +++ b/batdetect2/postprocess/detection.py @@ -0,0 +1,162 @@ +"""Extracts candidate detection points from a model output heatmap. + +This module implements a specific step within the BatDetect2 postprocessing +pipeline. Its primary function is to identify potential sound event locations +by finding peaks (local maxima or high-scoring points) in the detection heatmap +produced by the neural network (usually after Non-Maximum Suppression and +coordinate remapping have been applied). + +It provides functionality to: +- Identify the locations (time, frequency) of the highest-scoring points. +- Filter these points based on a minimum confidence score threshold. +- Limit the maximum number of detection points returned (top-k). + +The main output is an `xarray.DataArray` containing the scores and +corresponding time/frequency coordinates for the extracted detection points. +This output serves as the input for subsequent postprocessing steps, such as +extracting predicted class probabilities and bounding box sizes at these +specific locations. +""" + +from typing import Optional + +import numpy as np +import xarray as xr +from soundevent.arrays import Dimensions, get_dim_width + +__all__ = [ + "extract_detections_from_array", + "get_max_detections", + "DEFAULT_DETECTION_THRESHOLD", + "TOP_K_PER_SEC", +] + +DEFAULT_DETECTION_THRESHOLD = 0.01 +"""Default confidence score threshold used for filtering detections.""" + +TOP_K_PER_SEC = 200 +"""Default desired maximum number of detections per second of audio.""" + + +def extract_detections_from_array( + detection_array: xr.DataArray, + max_detections: Optional[int] = None, + threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD, +) -> xr.DataArray: + """Extract detection locations (time, freq) and scores from a heatmap. + + Identifies the pixels with the highest scores in the input detection + heatmap, filters them based on an optional score `threshold`, limits the + number to an optional `max_detections`, and returns their scores along with + their corresponding time and frequency coordinates. + + Parameters + ---------- + detection_array : xr.DataArray + A 2D xarray DataArray representing the detection heatmap. Must have + dimensions and coordinates named 'time' and 'frequency'. Higher values + are assumed to indicate higher detection confidence. + max_detections : int, optional + The absolute maximum number of detections to return. If specified, only + the top `max_detections` highest-scoring detections (passing the + threshold) are returned. If None (default), all detections passing + the threshold are returned, sorted by score. + threshold : float, optional + The minimum confidence score required for a detection peak to be + kept. Detections with scores below this value are discarded. + Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no + thresholding is applied. + + Returns + ------- + xr.DataArray + A 1D xarray DataArray named 'score' with a 'detection' dimension. + - The data values are the scores of the extracted detections, sorted + in descending order. + - It has coordinates 'time' and 'frequency' (also indexed by the + 'detection' dimension) indicating the location of each detection + peak in the original coordinate system. + - Returns an empty DataArray if no detections pass the criteria. + + Raises + ------ + ValueError + If `max_detections` is not None and not a positive integer, or if + `detection_array` lacks required dimensions/coordinates. + """ + if max_detections is not None: + if max_detections <= 0: + raise ValueError("Max detections must be positive") + + values = detection_array.values.flatten() + + if max_detections is not None: + top_indices = np.argpartition(-values, max_detections)[:max_detections] + top_sorted_indices = top_indices[np.argsort(-values[top_indices])] + else: + top_sorted_indices = np.argsort(-values) + + top_values = values[top_sorted_indices] + + if threshold is not None: + mask = top_values > threshold + top_values = top_values[mask] + top_sorted_indices = top_sorted_indices[mask] + + time_indices, freq_indices = np.unravel_index( + top_sorted_indices, + detection_array.shape, + ) + + times = detection_array.coords[Dimensions.time.value].values[time_indices] + freqs = detection_array.coords[Dimensions.frequency.value].values[ + freq_indices + ] + + return xr.DataArray( + data=top_values, + coords={ + Dimensions.frequency.value: ("detection", freqs), + Dimensions.time.value: ("detection", times), + }, + dims="detection", + name="score", + ) + + +def get_max_detections( + detection_array: xr.DataArray, + top_k_per_sec: int = TOP_K_PER_SEC, +) -> int: + """Calculate max detections allowed based on duration and rate. + + Determines the total maximum number of detections to extract from a + heatmap based on its time duration and a desired rate of detections + per second. + + Parameters + ---------- + detection_array : xr.DataArray + The detection heatmap, requiring 'time' coordinates from which the + total duration can be calculated using + `soundevent.arrays.get_dim_width`. + top_k_per_sec : int, default=TOP_K_PER_SEC + The desired maximum number of detections to allow per second of audio. + + Returns + ------- + int + The calculated total maximum number of detections allowed for the + entire duration of the `detection_array`. + + Raises + ------ + ValueError + If the duration cannot be calculated from the `detection_array` (e.g., + missing or invalid 'time' coordinates/dimension). + """ + if top_k_per_sec < 0: + raise ValueError("top_k_per_sec cannot be negative.") + + duration = get_dim_width(detection_array, Dimensions.time.value) + return int(duration * top_k_per_sec) diff --git a/batdetect2/postprocess/extraction.py b/batdetect2/postprocess/extraction.py new file mode 100644 index 0000000..2592e59 --- /dev/null +++ b/batdetect2/postprocess/extraction.py @@ -0,0 +1,122 @@ +"""Extracts associated data for detected points from model output arrays. + +This module implements a key step (Step 4) in the BatDetect2 postprocessing +pipeline. After candidate detection points (time, frequency, score) have been +identified, this module extracts the corresponding values from other raw model +output arrays, such as: + +- Predicted bounding box sizes (width, height). +- Class probability scores for each defined target class. +- Intermediate feature vectors. + +It uses coordinate-based indexing provided by `xarray` to ensure that the +correct values are retrieved from the original heatmaps/feature maps at the +precise time-frequency location of each detection. The final output aggregates +all extracted information into a structured `xarray.Dataset`. +""" + +import xarray as xr +from soundevent.arrays import Dimensions + +__all__ = [ + "extract_values_at_positions", + "extract_detection_xr_dataset", +] + + +def extract_values_at_positions( + array: xr.DataArray, + positions: xr.DataArray, +) -> xr.DataArray: + """Extract values from an array at specified time-frequency positions. + + Uses coordinate-based indexing to retrieve values from a source `array` + (e.g., class probabilities, size predictions, features) at the time and + frequency coordinates defined in the `positions` array. + + Parameters + ---------- + array : xr.DataArray + The source DataArray from which to extract values. Must have 'time' + and 'frequency' dimensions and coordinates matching the space of + `positions`. + positions : xr.DataArray + A 1D DataArray whose 'time' and 'frequency' coordinates specify the + locations from which to extract values. + + Returns + ------- + xr.DataArray + A DataArray containing the values extracted from `array` at the given + positions. + + Raises + ------ + ValueError, IndexError, KeyError + If dimensions or coordinates are missing or incompatible between + `array` and `positions`, or if selection fails. + """ + return array.sel( + **{ + Dimensions.frequency.value: positions.coords[ + Dimensions.frequency.value + ], + Dimensions.time.value: positions.coords[Dimensions.time.value], + } + ) + + +def extract_detection_xr_dataset( + positions: xr.DataArray, + sizes: xr.DataArray, + classes: xr.DataArray, + features: xr.DataArray, +) -> xr.Dataset: + """Combine extracted detection information into a structured xr.Dataset. + + Takes the detection positions/scores and the full model output heatmaps + (sizes, classes, optional features), extracts the relevant data at the + detection positions, and packages everything into a single `xarray.Dataset` + where all variables are indexed by a common 'detection' dimension. + + Parameters + ---------- + positions : xr.DataArray + Output from `extract_detections_from_array`, containing detection + scores as data and 'time', 'frequency' coordinates along the + 'detection' dimension. + sizes : xr.DataArray + The full size prediction heatmap from the model, with dimensions like + ('dimension', 'time', 'frequency'). + classes : xr.DataArray + The full class probability heatmap from the model, with dimensions like + ('category', 'time', 'frequency'). + features : xr.DataArray + The full feature map from the model, with + dimensions like ('feature', 'time', 'frequency'). + + Returns + ------- + xr.Dataset + An xarray Dataset containing aligned information for each detection: + - 'scores': DataArray from `positions` (score data, time/freq coords). + - 'dimensions': DataArray with extracted size values + (dims: 'detection', 'dimension'). + - 'classes': DataArray with extracted class probabilities + (dims: 'detection', 'category'). + - 'features': DataArray with extracted feature vectors + (dims: 'detection', 'feature'), if `features` was provided. All + DataArrays share the 'detection' dimension and associated + time/frequency coordinates. + """ + sizes = extract_values_at_positions(sizes, positions).T + classes = extract_values_at_positions(classes, positions).T + features = extract_values_at_positions(features, positions).T + return xr.Dataset( + { + "scores": positions, + "dimensions": sizes, + "classes": classes, + "features": features, + } + ) diff --git a/batdetect2/postprocess/nms.py b/batdetect2/postprocess/nms.py new file mode 100644 index 0000000..f92800d --- /dev/null +++ b/batdetect2/postprocess/nms.py @@ -0,0 +1,96 @@ +"""Performs Non-Maximum Suppression (NMS) on detection heatmaps. + +This module provides functionality to apply Non-Maximum Suppression, a common +technique used after model inference, particularly in object detection and peak +detection tasks. + +In the context of BatDetect2 postprocessing, NMS is applied +to the raw detection heatmap output by the neural network. Its purpose is to +isolate distinct detection peaks by suppressing (setting to zero) nearby heatmap +activations that have lower scores than a local maximum. This helps prevent +multiple, overlapping detections originating from the same sound event. +""" + +from typing import Tuple, Union + +import torch + +NMS_KERNEL_SIZE = 9 +"""Default kernel size (pixels) for Non-Maximum Suppression. + +Specifies the side length of the square neighborhood used by default in +`non_max_suppression` to find local maxima. A 9x9 neighborhood is often +a reasonable starting point for typical spectrogram resolutions used in +BatDetect2. +""" + + +def non_max_suppression( + tensor: torch.Tensor, + kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, +) -> torch.Tensor: + """Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap. + + This function identifies local maxima within a defined neighborhood for + each point in the input tensor. Values that are *not* the maximum within + their neighborhood are suppressed (set to zero). This is commonly used on + detection probability heatmaps to isolate distinct peaks corresponding to + individual detections and remove redundant lower scores nearby. + + The implementation uses efficient 2D max pooling to find the maximum value + in the neighborhood of each point. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor, typically representing a detection heatmap. Must be a + 3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying + `torch.nn.functional.max_pool2d` operation. + kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE + Size of the sliding window neighborhood used to find local maxima. + If an integer `k` is provided, a square kernel of size `(k, k)` is used. + If a tuple `(h, w)` is provided, a rectangular kernel of height `h` + and width `w` is used. The kernel size should typically be odd to + have a well-defined center. + + Returns + ------- + torch.Tensor + A tensor of the same shape as the input, where only local maxima within + their respective neighborhoods (defined by `kernel_size`) retain their + original values. All other values are set to zero. + + Raises + ------ + TypeError + If `kernel_size` is not an int or a tuple of two ints. + RuntimeError + If the input `tensor` does not have 3 or 4 dimensions (as required + by `max_pool2d`). + + Notes + ----- + - The function assumes higher values in the tensor indicate stronger peaks. + - Choosing an appropriate `kernel_size` is important. It should be large + enough to cover the typical "footprint" of a single detection peak plus + some surrounding context, effectively preventing multiple detections for + the same event. A size that is too large might suppress nearby distinct + events. + """ + if isinstance(kernel_size, int): + kernel_size_h = kernel_size + kernel_size_w = kernel_size + else: + kernel_size_h, kernel_size_w = kernel_size + + pad_h = (kernel_size_h - 1) // 2 + pad_w = (kernel_size_w - 1) // 2 + + hmax = torch.nn.functional.max_pool2d( + tensor, + (kernel_size_h, kernel_size_w), + stride=1, + padding=(pad_h, pad_w), + ) + keep = (hmax == tensor).float() + return tensor * keep diff --git a/batdetect2/postprocess/non_max_supression.py b/batdetect2/postprocess/non_max_supression.py deleted file mode 100644 index 1893a12..0000000 --- a/batdetect2/postprocess/non_max_supression.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Tuple, Union - -import torch - -NMS_KERNEL_SIZE = 9 - - -def non_max_suppression( - tensor: torch.Tensor, - kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, -) -> torch.Tensor: - """Run non-maximum suppression on a tensor. - - This function removes values from the input tensor that are not local - maxima in the neighborhood of the given kernel size. - - All non-maximum values are set to zero. - - Parameters - ---------- - tensor : torch.Tensor - Input tensor. - kernel_size : Union[int, Tuple[int, int]], optional - Size of the neighborhood to consider for non-maximum suppression. - If an integer is given, the neighborhood will be a square of the - given size. If a tuple is given, the neighborhood will be a - rectangle with the given height and width. - - Returns - ------- - torch.Tensor - Tensor with non-maximum suppressed values. - """ - if isinstance(kernel_size, int): - kernel_size_h = kernel_size - kernel_size_w = kernel_size - else: - kernel_size_h, kernel_size_w = kernel_size - - pad_h = (kernel_size_h - 1) // 2 - pad_w = (kernel_size_w - 1) // 2 - - hmax = torch.nn.functional.max_pool2d( - tensor, - (kernel_size_h, kernel_size_w), - stride=1, - padding=(pad_h, pad_w), - ) - keep = (hmax == tensor).float() - return tensor * keep diff --git a/batdetect2/postprocess/remapping.py b/batdetect2/postprocess/remapping.py new file mode 100644 index 0000000..51560ea --- /dev/null +++ b/batdetect2/postprocess/remapping.py @@ -0,0 +1,316 @@ +"""Remaps raw model output tensors to coordinate-aware xarray DataArrays. + +This module provides utility functions to convert the raw numerical outputs +(typically PyTorch tensors) from the BatDetect2 DNN model into +`xarray.DataArray` objects. This step adds coordinate information +(time in seconds, frequency in Hz) back to the model's predictions, making them +interpretable in the context of the original audio signal and facilitating +subsequent processing steps. + +Functions are provided for common BatDetect2 output types: detection heatmaps, +classification probability maps, size prediction maps, and potentially +intermediate features. +""" + +from typing import List + +import numpy as np +import torch +import xarray as xr +from soundevent.arrays import Dimensions + +from batdetect2.preprocess import MAX_FREQ, MIN_FREQ + +__all__ = [ + "features_to_xarray", + "detection_to_xarray", + "classification_to_xarray", + "sizes_to_xarray", +] + + +def features_to_xarray( + features: torch.Tensor, + start_time: float, + end_time: float, + min_freq: float = MIN_FREQ, + max_freq: float = MAX_FREQ, + features_prefix: str = "batdetect2_feature_", +): + """Convert a multi-channel feature tensor to a coordinate-aware DataArray. + + Assigns time, frequency, and feature coordinates to a raw feature tensor + output by the model. + + Parameters + ---------- + features : torch.Tensor + The raw feature tensor from the model. Expected shape is + (num_features, num_freq_bins, num_time_bins). + start_time : float + The start time (in seconds) corresponding to the first time bin of + the tensor. + end_time : float + The end time (in seconds) corresponding to the *end* of the last time + bin. + min_freq : float, default=MIN_FREQ + The minimum frequency (in Hz) corresponding to the first frequency bin. + max_freq : float, default=MAX_FREQ + The maximum frequency (in Hz) corresponding to the *end* of the last + frequency bin. + features_prefix : str, default="batdetect2_feature_" + Prefix used to generate names for the feature coordinate dimension + (e.g., "batdetect2_feature_0", "batdetect2_feature_1", ...). + + Returns + ------- + xr.DataArray + An xarray DataArray containing the feature data with named dimensions + ('feature', 'frequency', 'time') and calculated coordinates. + + Raises + ------ + ValueError + If the input tensor does not have 3 dimensions. + """ + if features.ndim != 3: + raise ValueError( + "Input features tensor must have 3 dimensions (C, T, F), " + f"got shape {features.shape}" + ) + + num_features, height, width = features.shape + times = np.linspace(start_time, end_time, width, endpoint=False) + freqs = np.linspace(min_freq, max_freq, height, endpoint=False) + + return xr.DataArray( + data=features.detach().numpy(), + dims=[ + Dimensions.feature.value, + Dimensions.frequency.value, + Dimensions.time.value, + ], + coords={ + Dimensions.feature.value: [ + f"{features_prefix}{i}" for i in range(num_features) + ], + Dimensions.frequency.value: freqs, + Dimensions.time.value: times, + }, + name="features", + ) + + +def detection_to_xarray( + detection: torch.Tensor, + start_time: float, + end_time: float, + min_freq: float = MIN_FREQ, + max_freq: float = MAX_FREQ, +) -> xr.DataArray: + """Convert a single-channel detection heatmap tensor to a DataArray. + + Assigns time and frequency coordinates to a raw detection heatmap tensor. + + Parameters + ---------- + detection : torch.Tensor + Raw detection heatmap tensor from the model. Expected shape is + (1, num_freq_bins, num_time_bins). + start_time : float + Start time (seconds) corresponding to the first time bin. + end_time : float + End time (seconds) corresponding to the end of the last time bin. + min_freq : float, default=MIN_FREQ + Minimum frequency (Hz) corresponding to the first frequency bin. + max_freq : float, default=MAX_FREQ + Maximum frequency (Hz) corresponding to the end of the last frequency + bin. + + Returns + ------- + xr.DataArray + An xarray DataArray containing the detection scores with named + dimensions ('frequency', 'time') and calculated coordinates. + + Raises + ------ + ValueError + If the input tensor does not have 3 dimensions or if the first + dimension size is not 1. + """ + if detection.ndim != 3: + raise ValueError( + "Input detection tensor must have 3 dimensions (1, T, F), " + f"got shape {detection.shape}" + ) + + num_channels, height, width = detection.shape + + if num_channels != 1: + raise ValueError( + "Expected a single channel output, instead got " + f"{num_channels} channels" + ) + + times = np.linspace(start_time, end_time, width, endpoint=False) + freqs = np.linspace(min_freq, max_freq, height, endpoint=False) + + return xr.DataArray( + data=detection.squeeze(dim=0).detach().numpy(), + dims=[ + Dimensions.frequency.value, + Dimensions.time.value, + ], + coords={ + Dimensions.frequency.value: freqs, + Dimensions.time.value: times, + }, + name="detection_score", + ) + + +def classification_to_xarray( + classes: torch.Tensor, + start_time: float, + end_time: float, + class_names: List[str], + min_freq: float = MIN_FREQ, + max_freq: float = MAX_FREQ, +) -> xr.DataArray: + """Convert multi-channel class probability tensor to a DataArray. + + Assigns category (class name), frequency, and time coordinates to a raw + class probability tensor output by the model. + + Parameters + ---------- + classes : torch.Tensor + Raw class probability tensor. Expected shape is + (num_classes, num_freq_bins, num_time_bins). + start_time : float + Start time (seconds) corresponding to the first time bin. + end_time : float + End time (seconds) corresponding to the end of the last time bin. + class_names : List[str] + Ordered list of class names corresponding to the first dimension + of the `classes` tensor. The length must match `classes.shape[0]`. + min_freq : float, default=MIN_FREQ + Minimum frequency (Hz) corresponding to the first frequency bin. + max_freq : float, default=MAX_FREQ + Maximum frequency (Hz) corresponding to the end of the last frequency + bin. + + Returns + ------- + xr.DataArray + An xarray DataArray containing class probabilities with named + dimensions ('category', 'frequency', 'time') and calculated + coordinates. + + Raises + ------ + ValueError + If the input tensor does not have 3 dimensions, or if the size of the + first dimension does not match the length of `class_names`. + """ + if classes.ndim != 3: + raise ValueError( + "Input classes tensor must have 3 dimensions (C, F, T), " + f"got shape {classes.shape}" + ) + + num_classes, height, width = classes.shape + + if num_classes != len(class_names): + raise ValueError( + "The number of classes does not coincide with the number of " + "class names provided: " + f"({num_classes = }) != ({len(class_names) = })" + ) + + times = np.linspace(start_time, end_time, width, endpoint=False) + freqs = np.linspace(min_freq, max_freq, height, endpoint=False) + + return xr.DataArray( + data=classes.detach().numpy(), + dims=[ + "category", + Dimensions.frequency.value, + Dimensions.time.value, + ], + coords={ + "category": class_names, + Dimensions.frequency.value: freqs, + Dimensions.time.value: times, + }, + name="class_scores", + ) + + +def sizes_to_xarray( + sizes: torch.Tensor, + start_time: float, + end_time: float, + min_freq: float = MIN_FREQ, + max_freq: float = MAX_FREQ, +) -> xr.DataArray: + """Convert the 2-channel size prediction tensor to a DataArray. + + Assigns dimension ('width', 'height'), frequency, and time coordinates + to the raw size prediction tensor output by the model. + + Parameters + ---------- + sizes : torch.Tensor + Raw size prediction tensor. Expected shape is + (2, num_freq_bins, num_time_bins), where the first dimension + corresponds to predicted width and height respectively. + start_time : float + Start time (seconds) corresponding to the first time bin. + end_time : float + End time (seconds) corresponding to the end of the last time bin. + min_freq : float, default=MIN_FREQ + Minimum frequency (Hz) corresponding to the first frequency bin. + max_freq : float, default=MAX_FREQ + Maximum frequency (Hz) corresponding to the end of the last frequency + bin. + + Returns + ------- + xr.DataArray + An xarray DataArray containing predicted sizes with named dimensions + ('dimension', 'frequency', 'time') and calculated time/frequency + coordinates. The 'dimension' coordinate will have values + ['width', 'height']. + + Raises + ------ + ValueError + If the input tensor does not have 3 dimensions or if the first + dimension size is not exactly 2. + """ + num_channels, height, width = sizes.shape + + if num_channels != 2: + raise ValueError( + "Expected a two-channel output, instead got " + f"{num_channels} channels" + ) + + times = np.linspace(start_time, end_time, width, endpoint=False) + freqs = np.linspace(min_freq, max_freq, height, endpoint=False) + + return xr.DataArray( + data=sizes.detach().numpy(), + dims=[ + "dimension", + Dimensions.frequency.value, + Dimensions.time.value, + ], + coords={ + "dimension": ["width", "height"], + Dimensions.frequency.value: freqs, + Dimensions.time.value: times, + }, + ) diff --git a/batdetect2/postprocess/types.py b/batdetect2/postprocess/types.py index aa16d06..1269a21 100644 --- a/batdetect2/postprocess/types.py +++ b/batdetect2/postprocess/types.py @@ -1,21 +1,284 @@ -from typing import Dict, NamedTuple, Protocol +"""Defines shared interfaces and data structures for postprocessing. + +This module centralizes the Protocol definitions and common data structures +used throughout the `batdetect2.postprocess` module. + +The main component is the `PostprocessorProtocol`, which outlines the standard +interface for an object responsible for executing the entire postprocessing +pipeline. This pipeline transforms raw neural network outputs into interpretable +detections represented as `soundevent` objects. Using protocols ensures +modularity and consistent interaction between different parts of the BatDetect2 +system that deal with model predictions. +""" + +from typing import Callable, List, NamedTuple, Protocol import numpy as np +import xarray as xr +from soundevent import data + +from batdetect2.models.types import ModelOutput __all__ = [ - "BatDetect2Prediction", + "RawPrediction", + "PostprocessorProtocol", + "GeometryBuilder", ] -class BatDetect2Prediction(NamedTuple): +GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry] +"""Type alias for a function that recovers geometry from position and size. + +This callable takes: +1. A position tuple `(time, frequency)`. +2. A NumPy array of size dimensions (e.g., `[width, height]`). +It should return the reconstructed `soundevent.data.Geometry` (typically a +`BoundingBox`). +""" + + +class RawPrediction(NamedTuple): + """Intermediate representation of a single detected sound event. + + Holds extracted information about a detection after initial processing + (like peak finding, coordinate remapping, geometry recovery) but before + final class decoding and conversion into a `SoundEventPrediction`. This + can be useful for evaluation or simpler data handling formats. + + Attributes + ---------- + start_time : float + Start time of the recovered bounding box in seconds. + end_time : float + End time of the recovered bounding box in seconds. + low_freq : float + Lowest frequency of the recovered bounding box in Hz. + high_freq : float + Highest frequency of the recovered bounding box in Hz. + detection_score : float + The confidence score associated with this detection, typically from + the detection heatmap peak. + class_scores : xr.DataArray + An xarray DataArray containing the predicted probabilities or scores + for each target class at the detection location. Indexed by a + 'category' coordinate containing class names. + features : xr.DataArray + An xarray DataArray containing extracted feature vectors at the + detection location. Indexed by a 'feature' coordinate. + """ + start_time: float end_time: float low_freq: float high_freq: float detection_score: float - class_scores: Dict[str, float] - features: np.ndarray + class_scores: xr.DataArray + features: xr.DataArray class PostprocessorProtocol(Protocol): - pass + """Protocol defining the interface for the full postprocessing pipeline. + + This protocol outlines the standard methods for an object that takes raw + output from a BatDetect2 model and the corresponding input clip metadata, + and processes it through various stages (e.g., coordinate remapping, NMS, + detection extraction, data extraction, decoding) to produce interpretable + results at different levels of completion. + + Implementations manage the configured logic for all postprocessing steps. + """ + + def get_feature_arrays( + self, + output: ModelOutput, + clips: List[data.Clip], + ) -> List[xr.DataArray]: + """Remap feature tensors to coordinate-aware DataArrays. + + Parameters + ---------- + output : ModelOutput + The raw output from the neural network model for a batch, expected + to contain the necessary feature tensors. + clips : List[data.Clip] + A list of `soundevent.data.Clip` objects, one for each item in the + processed batch. This list provides the timing, recording, and + other metadata context needed to calculate real-world coordinates + (seconds, Hz) for the output arrays. The length of this list must + correspond to the batch size of the `output`. + + Returns + ------- + List[xr.DataArray] + A list of xarray DataArrays, one for each input clip in the batch, + in the same order. Each DataArray contains the feature vectors + with dimensions like ('feature', 'time', 'frequency') and + corresponding real-world coordinates. + """ + ... + + def get_detection_arrays( + self, + output: ModelOutput, + clips: List[data.Clip], + ) -> List[xr.DataArray]: + """Remap detection tensors to coordinate-aware DataArrays. + + Parameters + ---------- + output : ModelOutput + The raw output from the neural network model for a batch, + containing detection heatmaps. + clips : List[data.Clip] + A list of `soundevent.data.Clip` objects corresponding to the batch + items, providing coordinate context. Must match the batch size of + `output`. + + Returns + ------- + List[xr.DataArray] + A list of 2D xarray DataArrays (one per input clip, in order), + representing the detection heatmap with 'time' and 'frequency' + coordinates. Values typically indicate detection confidence. + """ + ... + + def get_classification_arrays( + self, + output: ModelOutput, + clips: List[data.Clip], + ) -> List[xr.DataArray]: + """Remap classification tensors to coordinate-aware DataArrays. + + Parameters + ---------- + output : ModelOutput + The raw output from the neural network model for a batch, + containing class probability tensors. + clips : List[data.Clip] + A list of `soundevent.data.Clip` objects corresponding to the batch + items, providing coordinate context. Must match the batch size of + `output`. + + Returns + ------- + List[xr.DataArray] + A list of 3D xarray DataArrays (one per input clip, in order), + representing class probabilities with 'category', 'time', and + 'frequency' dimensions and coordinates. + """ + ... + + def get_sizes_arrays( + self, + output: ModelOutput, + clips: List[data.Clip], + ) -> List[xr.DataArray]: + """Remap size prediction tensors to coordinate-aware DataArrays. + + Parameters + ---------- + output : ModelOutput + The raw output from the neural network model for a batch, + containing predicted size tensors (e.g., width and height). + clips : List[data.Clip] + A list of `soundevent.data.Clip` objects corresponding to the batch + items, providing coordinate context. Must match the batch size of + `output`. + + Returns + ------- + List[xr.DataArray] + A list of 3D xarray DataArrays (one per input clip, in order), + representing predicted sizes with 'dimension' + (e.g., ['width', 'height']), 'time', and 'frequency' dimensions and + coordinates. Values represent estimated detection sizes. + """ + ... + + def get_detection_datasets( + self, + output: ModelOutput, + clips: List[data.Clip], + ) -> List[xr.Dataset]: + """Perform remapping, NMS, detection, and data extraction for a batch. + + Processes the raw model output for a batch to identify detection peaks + and extract all associated information (score, position, size, class + probs, features) at those peak locations, returning a structured + dataset for each input clip in the batch. + + Parameters + ---------- + output : ModelOutput + The raw output from the neural network model for a batch. + clips : List[data.Clip] + A list of `soundevent.data.Clip` objects corresponding to the batch + items, providing context. Must match the batch size of `output`. + + Returns + ------- + List[xr.Dataset] + A list of xarray Datasets (one per input clip, in order). Each + Dataset contains multiple DataArrays ('scores', 'dimensions', + 'classes', 'features') sharing a common 'detection' dimension, + providing aligned data for each detected event in that clip. + """ + ... + + def get_raw_predictions( + self, + output: ModelOutput, + clips: List[data.Clip], + ) -> List[List[RawPrediction]]: + """Extract intermediate RawPrediction objects for a batch. + + Processes the raw model output for a batch through remapping, NMS, + detection, data extraction, and geometry recovery to produce a list of + `RawPrediction` objects for each corresponding input clip. This provides + a simplified, intermediate representation before final tag decoding. + + Parameters + ---------- + output : ModelOutput + The raw output from the neural network model for a batch. + clips : List[data.Clip] + A list of `soundevent.data.Clip` objects corresponding to the batch + items, providing context. Must match the batch size of `output`. + + Returns + ------- + List[List[RawPrediction]] + A list of lists (one inner list per input clip, in order). Each + inner list contains the `RawPrediction` objects extracted for the + corresponding input clip. + """ + ... + + def get_predictions( + self, + output: ModelOutput, + clips: List[data.Clip], + ) -> List[data.ClipPrediction]: + """Perform the full postprocessing pipeline for a batch. + + Takes raw model output for a batch and corresponding clips, applies the + entire postprocessing chain, and returns the final, interpretable + predictions as a list of `soundevent.data.ClipPrediction` objects. + + Parameters + ---------- + output : ModelOutput + The raw output from the neural network model for a batch. + clips : List[data.Clip] + A list of `soundevent.data.Clip` objects corresponding to the batch + items, providing context. Must match the batch size of `output`. + + Returns + ------- + List[data.ClipPrediction] + A list containing one `ClipPrediction` object for each input clip + (in the same order), populated with `SoundEventPrediction` objects + representing the final detections with decoded tags and geometry. + """ + ...