mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Updated postprocess module with docstrings
This commit is contained in:
parent
089328a4f0
commit
bcf339c40d
@ -0,0 +1,566 @@
|
|||||||
|
"""Main entry point for the BatDetect2 Postprocessing pipeline.
|
||||||
|
|
||||||
|
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
|
||||||
|
BatDetect2 neural network model and transforms them into meaningful, structured
|
||||||
|
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
|
||||||
|
containing detected sound events with associated class tags and geometry.
|
||||||
|
|
||||||
|
The pipeline involves several configurable steps, implemented in submodules:
|
||||||
|
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||||
|
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
|
||||||
|
coordinates to raw model output arrays.
|
||||||
|
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||||
|
(location and score) based on thresholds and score ranking (top-k).
|
||||||
|
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||||
|
class probabilities, features) at the detected locations.
|
||||||
|
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
|
||||||
|
class predictions into interpretable `soundevent` objects, including
|
||||||
|
recovering geometry (ROIs) and decoding class names back to standard tags.
|
||||||
|
|
||||||
|
This module provides the primary interface:
|
||||||
|
- `PostprocessConfig`: A configuration object for postprocessing parameters
|
||||||
|
(thresholds, NMS kernel size, etc.).
|
||||||
|
- `load_postprocess_config`: Function to load the configuration from a file.
|
||||||
|
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
|
||||||
|
holds the configured pipeline logic.
|
||||||
|
- `build_postprocessor`: A factory function to create a `Postprocessor`
|
||||||
|
instance, linking it to the necessary target definitions (`TargetProtocol`).
|
||||||
|
It also re-exports key components from submodules for convenience.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import xarray as xr
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.models.types import ModelOutput
|
||||||
|
from batdetect2.postprocess.decoding import (
|
||||||
|
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
convert_raw_predictions_to_clip_prediction,
|
||||||
|
convert_xr_dataset_to_raw_prediction,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.detection import (
|
||||||
|
DEFAULT_DETECTION_THRESHOLD,
|
||||||
|
TOP_K_PER_SEC,
|
||||||
|
extract_detections_from_array,
|
||||||
|
get_max_detections,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.extraction import (
|
||||||
|
extract_detection_xr_dataset,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.nms import (
|
||||||
|
NMS_KERNEL_SIZE,
|
||||||
|
non_max_suppression,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.remapping import (
|
||||||
|
classification_to_xarray,
|
||||||
|
detection_to_xarray,
|
||||||
|
features_to_xarray,
|
||||||
|
sizes_to_xarray,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
|
||||||
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
|
"DEFAULT_DETECTION_THRESHOLD",
|
||||||
|
"MAX_FREQ",
|
||||||
|
"MIN_FREQ",
|
||||||
|
"ModelOutput",
|
||||||
|
"NMS_KERNEL_SIZE",
|
||||||
|
"PostprocessConfig",
|
||||||
|
"Postprocessor",
|
||||||
|
"PostprocessorProtocol",
|
||||||
|
"RawPrediction",
|
||||||
|
"TOP_K_PER_SEC",
|
||||||
|
"build_postprocessor",
|
||||||
|
"classification_to_xarray",
|
||||||
|
"convert_raw_predictions_to_clip_prediction",
|
||||||
|
"convert_xr_dataset_to_raw_prediction",
|
||||||
|
"detection_to_xarray",
|
||||||
|
"extract_detection_xr_dataset",
|
||||||
|
"extract_detections_from_array",
|
||||||
|
"features_to_xarray",
|
||||||
|
"get_max_detections",
|
||||||
|
"load_postprocess_config",
|
||||||
|
"non_max_suppression",
|
||||||
|
"sizes_to_xarray",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessConfig(BaseConfig):
|
||||||
|
"""Configuration settings for the postprocessing pipeline.
|
||||||
|
|
||||||
|
Defines tunable parameters that control how raw model outputs are
|
||||||
|
converted into final detections.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||||
|
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||||
|
Used to suppress weaker detections near stronger peaks. Must be
|
||||||
|
positive.
|
||||||
|
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||||
|
Minimum confidence score from the detection heatmap required to
|
||||||
|
consider a point as a potential detection. Must be >= 0.
|
||||||
|
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||||
|
Minimum confidence score for a specific class prediction to be included
|
||||||
|
in the decoded tags for a detection. Must be >= 0.
|
||||||
|
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||||
|
Desired maximum number of detections per second of audio. Used by
|
||||||
|
`get_max_detections` to calculate an absolute limit based on clip
|
||||||
|
duration before applying `extract_detections_from_array`. Must be
|
||||||
|
positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||||
|
detection_threshold: float = Field(
|
||||||
|
default=DEFAULT_DETECTION_THRESHOLD,
|
||||||
|
ge=0,
|
||||||
|
)
|
||||||
|
classification_threshold: float = Field(
|
||||||
|
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
ge=0,
|
||||||
|
)
|
||||||
|
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
def load_postprocess_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> PostprocessConfig:
|
||||||
|
"""Load the postprocessing configuration from a file.
|
||||||
|
|
||||||
|
Reads a configuration file (YAML) and validates it against the
|
||||||
|
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||||
|
field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
postprocessing configuration (e.g., "inference.postprocessing").
|
||||||
|
If None, the entire file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
PostprocessConfig
|
||||||
|
The loaded and validated postprocessing configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
yaml.YAMLError
|
||||||
|
If the file content is not valid YAML.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the loaded configuration data does not conform to the
|
||||||
|
`PostprocessConfig` schema.
|
||||||
|
KeyError, TypeError
|
||||||
|
If `field` specifies an invalid path within the loaded data.
|
||||||
|
"""
|
||||||
|
return load_config(path, schema=PostprocessConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def build_postprocessor(
|
||||||
|
targets: TargetProtocol,
|
||||||
|
config: Optional[PostprocessConfig] = None,
|
||||||
|
max_freq: int = MAX_FREQ,
|
||||||
|
min_freq: int = MIN_FREQ,
|
||||||
|
) -> PostprocessorProtocol:
|
||||||
|
"""Factory function to build the standard postprocessor.
|
||||||
|
|
||||||
|
Creates and initializes the `Postprocessor` instance, providing it with the
|
||||||
|
necessary `targets` object and the `PostprocessConfig`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
targets : TargetProtocol
|
||||||
|
An initialized object conforming to the `TargetProtocol`, providing
|
||||||
|
methods like `.decode()` and `.recover_roi()`, and attributes like
|
||||||
|
`.class_names` and `.generic_class_tags`. This links postprocessing
|
||||||
|
to the defined target semantics and geometry mappings.
|
||||||
|
config : PostprocessConfig, optional
|
||||||
|
Configuration object specifying postprocessing parameters (thresholds,
|
||||||
|
NMS kernel size, etc.). If None, default settings defined in
|
||||||
|
`PostprocessConfig` will be used.
|
||||||
|
min_freq : int, default=MIN_FREQ
|
||||||
|
The minimum frequency (Hz) corresponding to the frequency axis of the
|
||||||
|
model outputs. Required for coordinate remapping. Consider setting via
|
||||||
|
`PostprocessConfig` instead for better encapsulation.
|
||||||
|
max_freq : int, default=MAX_FREQ
|
||||||
|
The maximum frequency (Hz) corresponding to the frequency axis of the
|
||||||
|
model outputs. Required for coordinate remapping. Consider setting via
|
||||||
|
`PostprocessConfig`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
PostprocessorProtocol
|
||||||
|
An initialized `Postprocessor` instance ready to process model outputs.
|
||||||
|
"""
|
||||||
|
return Postprocessor(
|
||||||
|
targets=targets,
|
||||||
|
config=config or PostprocessConfig(),
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Postprocessor(PostprocessorProtocol):
|
||||||
|
"""Standard implementation of the postprocessing pipeline.
|
||||||
|
|
||||||
|
This class orchestrates the steps required to convert raw model outputs
|
||||||
|
into interpretable `soundevent` predictions. It uses configured parameters
|
||||||
|
and leverages functions from the `batdetect2.postprocess` submodules for
|
||||||
|
each stage (NMS, remapping, detection, extraction, decoding).
|
||||||
|
|
||||||
|
It requires a `TargetProtocol` object during initialization to access
|
||||||
|
necessary decoding information (class name to tag mapping,
|
||||||
|
ROI recovery logic) ensuring consistency with the target definitions used
|
||||||
|
during training or specified for inference.
|
||||||
|
|
||||||
|
Instances are typically created using the `build_postprocessor` factory
|
||||||
|
function.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
targets : TargetProtocol
|
||||||
|
The configured target definition object providing decoding and ROI
|
||||||
|
recovery.
|
||||||
|
config : PostprocessConfig
|
||||||
|
Configuration object holding parameters for NMS, thresholds, etc.
|
||||||
|
min_freq : int
|
||||||
|
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
||||||
|
max_freq : int
|
||||||
|
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
targets: TargetProtocol
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
config: PostprocessConfig,
|
||||||
|
min_freq: int = MIN_FREQ,
|
||||||
|
max_freq: int = MAX_FREQ,
|
||||||
|
):
|
||||||
|
"""Initialize the Postprocessor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
targets : TargetProtocol
|
||||||
|
Initialized target definition object.
|
||||||
|
config : PostprocessConfig
|
||||||
|
Configuration for postprocessing parameters.
|
||||||
|
min_freq : int, default=MIN_FREQ
|
||||||
|
Minimum frequency (Hz) for coordinate remapping.
|
||||||
|
max_freq : int, default=MAX_FREQ
|
||||||
|
Maximum frequency (Hz) for coordinate remapping.
|
||||||
|
"""
|
||||||
|
self.targets = targets
|
||||||
|
self.config = config
|
||||||
|
self.min_freq = min_freq
|
||||||
|
self.max_freq = max_freq
|
||||||
|
|
||||||
|
def get_feature_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Extract and remap raw feature tensors for a batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw model output containing `output.features` tensor for the batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of Clip objects corresponding to the batch items.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
List of coordinate-aware feature DataArrays, one per clip.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If batch sizes of `output.features` and `clips` do not match.
|
||||||
|
"""
|
||||||
|
if len(clips) != len(output.features):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of clips and batch size of feature array"
|
||||||
|
"do not match. "
|
||||||
|
f"(clips: {len(clips)}, features: {len(output.features)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
features_to_xarray(
|
||||||
|
feats,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
end_time=clip.end_time,
|
||||||
|
min_freq=self.min_freq,
|
||||||
|
max_freq=self.max_freq,
|
||||||
|
)
|
||||||
|
for feats, clip in zip(output.features, clips)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_detection_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Apply NMS and remap detection heatmaps for a batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw model output containing `output.detection_probs` tensor for the
|
||||||
|
batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of Clip objects corresponding to the batch items.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
List of NMS-applied, coordinate-aware detection heatmaps, one per
|
||||||
|
clip.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If batch sizes of `output.detection_probs` and `clips` do not match.
|
||||||
|
"""
|
||||||
|
detections = output.detection_probs
|
||||||
|
|
||||||
|
if len(clips) != len(output.detection_probs):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of clips and batch size of detection array "
|
||||||
|
"do not match. "
|
||||||
|
f"(clips: {len(clips)}, detection: {len(detections)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
detections = non_max_suppression(
|
||||||
|
detections,
|
||||||
|
kernel_size=self.config.nms_kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
detection_to_xarray(
|
||||||
|
dets,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
end_time=clip.end_time,
|
||||||
|
min_freq=self.min_freq,
|
||||||
|
max_freq=self.max_freq,
|
||||||
|
)
|
||||||
|
for dets, clip in zip(detections, clips)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_classification_arrays(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Extract and remap raw classification tensors for a batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw model output containing `output.class_probs` tensor for the
|
||||||
|
batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of Clip objects corresponding to the batch items.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
List of coordinate-aware class probability maps, one per clip.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If batch sizes of `output.class_probs` and `clips` do not match, or
|
||||||
|
if number of classes mismatches `self.targets.class_names`.
|
||||||
|
"""
|
||||||
|
classifications = output.class_probs
|
||||||
|
|
||||||
|
if len(clips) != len(classifications):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of clips and batch size of classification array "
|
||||||
|
"do not match. "
|
||||||
|
f"(clips: {len(clips)}, classification: {len(classifications)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
classification_to_xarray(
|
||||||
|
class_probs,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
end_time=clip.end_time,
|
||||||
|
class_names=self.targets.class_names,
|
||||||
|
min_freq=self.min_freq,
|
||||||
|
max_freq=self.max_freq,
|
||||||
|
)
|
||||||
|
for class_probs, clip in zip(classifications, clips)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_sizes_arrays(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Extract and remap raw size prediction tensors for a batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw model output containing `output.size_preds` tensor for the
|
||||||
|
batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of Clip objects corresponding to the batch items.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
List of coordinate-aware size prediction maps, one per clip.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If batch sizes of `output.size_preds` and `clips` do not match.
|
||||||
|
"""
|
||||||
|
sizes = output.size_preds
|
||||||
|
|
||||||
|
if len(clips) != len(sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of clips and batch size of sizes array do not match. "
|
||||||
|
f"(clips: {len(clips)}, sizes: {len(sizes)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
sizes_to_xarray(
|
||||||
|
size_preds,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
end_time=clip.end_time,
|
||||||
|
min_freq=self.min_freq,
|
||||||
|
max_freq=self.max_freq,
|
||||||
|
)
|
||||||
|
for size_preds, clip in zip(sizes, clips)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_detection_datasets(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[xr.Dataset]:
|
||||||
|
"""Perform NMS, remapping, detection, and data extraction for a batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.Dataset]
|
||||||
|
List of xarray Datasets (one per clip). Each Dataset contains
|
||||||
|
aligned scores, dimensions, class probabilities, and features for
|
||||||
|
detections found in that clip.
|
||||||
|
"""
|
||||||
|
detection_arrays = self.get_detection_arrays(output, clips)
|
||||||
|
classification_arrays = self.get_classification_arrays(output, clips)
|
||||||
|
size_arrays = self.get_sizes_arrays(output, clips)
|
||||||
|
features_arrays = self.get_feature_arrays(output, clips)
|
||||||
|
|
||||||
|
datasets = []
|
||||||
|
for det_array, class_array, sizes_array, feats_array in zip(
|
||||||
|
detection_arrays,
|
||||||
|
classification_arrays,
|
||||||
|
size_arrays,
|
||||||
|
features_arrays,
|
||||||
|
):
|
||||||
|
max_detections = get_max_detections(
|
||||||
|
det_array,
|
||||||
|
top_k_per_sec=self.config.top_k_per_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
positions = extract_detections_from_array(
|
||||||
|
det_array,
|
||||||
|
max_detections=max_detections,
|
||||||
|
threshold=self.config.detection_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
datasets.append(
|
||||||
|
extract_detection_xr_dataset(
|
||||||
|
positions,
|
||||||
|
sizes_array,
|
||||||
|
class_array,
|
||||||
|
feats_array,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
def get_raw_predictions(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[List[RawPrediction]]:
|
||||||
|
"""Extract intermediate RawPrediction objects for a batch.
|
||||||
|
|
||||||
|
Processes raw model output through remapping, NMS, detection, data
|
||||||
|
extraction, and geometry recovery via the configured
|
||||||
|
`targets.recover_roi`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[List[RawPrediction]]
|
||||||
|
List of lists (one inner list per input clip). Each inner list
|
||||||
|
contains `RawPrediction` objects for detections in that clip.
|
||||||
|
"""
|
||||||
|
detection_datasets = self.get_detection_datasets(output, clips)
|
||||||
|
return [
|
||||||
|
convert_xr_dataset_to_raw_prediction(
|
||||||
|
dataset,
|
||||||
|
self.targets.recover_roi,
|
||||||
|
)
|
||||||
|
for dataset in detection_datasets
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_predictions(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[data.ClipPrediction]:
|
||||||
|
"""Perform the full postprocessing pipeline for a batch.
|
||||||
|
|
||||||
|
Takes raw model output and corresponding clips, applies the entire
|
||||||
|
configured chain (NMS, remapping, extraction, geometry recovery, class
|
||||||
|
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[data.ClipPrediction]
|
||||||
|
List containing one `ClipPrediction` object for each input clip,
|
||||||
|
populated with `SoundEventPrediction` objects.
|
||||||
|
"""
|
||||||
|
raw_predictions = self.get_raw_predictions(output, clips)
|
||||||
|
return [
|
||||||
|
convert_raw_predictions_to_clip_prediction(
|
||||||
|
prediction,
|
||||||
|
clip,
|
||||||
|
sound_event_decoder=self.targets.decode,
|
||||||
|
generic_class_tags=self.targets.generic_class_tags,
|
||||||
|
classification_threshold=self.config.classification_threshold,
|
||||||
|
)
|
||||||
|
for prediction, clip in zip(raw_predictions, clips)
|
||||||
|
]
|
@ -1,73 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import xarray as xr
|
|
||||||
from soundevent.arrays import Dimensions
|
|
||||||
|
|
||||||
from batdetect2.models import ModelOutput
|
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
|
||||||
|
|
||||||
|
|
||||||
def to_xarray(
|
|
||||||
output: ModelOutput,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
class_names: list[str],
|
|
||||||
min_freq: float = MIN_FREQ,
|
|
||||||
max_freq: float = MAX_FREQ,
|
|
||||||
):
|
|
||||||
detection = output.detection_probs
|
|
||||||
size = output.size_preds
|
|
||||||
classes = output.class_probs
|
|
||||||
features = output.features
|
|
||||||
|
|
||||||
if len(detection.shape) == 4:
|
|
||||||
if detection.shape[0] != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Expected a non-batched output or a batch of size 1, instead "
|
|
||||||
f"got an input of shape {detection.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
detection = detection.squeeze(dim=0)
|
|
||||||
size = size.squeeze(dim=0)
|
|
||||||
classes = classes.squeeze(dim=0)
|
|
||||||
features = features.squeeze(dim=0)
|
|
||||||
|
|
||||||
_, width, height = detection.shape
|
|
||||||
|
|
||||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
|
||||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
|
||||||
|
|
||||||
if classes.shape[0] != len(class_names):
|
|
||||||
raise ValueError(
|
|
||||||
f"The number of classes does not coincide with the number of class names provided: ({classes.shape[0] = }) != ({len(class_names) = })"
|
|
||||||
)
|
|
||||||
|
|
||||||
return xr.Dataset(
|
|
||||||
data_vars={
|
|
||||||
"detection": (
|
|
||||||
[Dimensions.time.value, Dimensions.frequency.value],
|
|
||||||
detection.squeeze(dim=0).detach().numpy(),
|
|
||||||
),
|
|
||||||
"size": (
|
|
||||||
[
|
|
||||||
"dimension",
|
|
||||||
Dimensions.time.value,
|
|
||||||
Dimensions.frequency.value,
|
|
||||||
],
|
|
||||||
detection.detach().numpy(),
|
|
||||||
),
|
|
||||||
"classes": (
|
|
||||||
[
|
|
||||||
"category",
|
|
||||||
Dimensions.time.value,
|
|
||||||
Dimensions.frequency.value,
|
|
||||||
],
|
|
||||||
classes.detach().numpy(),
|
|
||||||
),
|
|
||||||
},
|
|
||||||
coords={
|
|
||||||
Dimensions.time.value: times,
|
|
||||||
Dimensions.frequency.value: freqs,
|
|
||||||
"dimension": ["width", "height"],
|
|
||||||
"category": class_names,
|
|
||||||
},
|
|
||||||
)
|
|
@ -1,32 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"PostprocessConfig",
|
|
||||||
"load_postprocess_config",
|
|
||||||
]
|
|
||||||
|
|
||||||
NMS_KERNEL_SIZE = 9
|
|
||||||
DETECTION_THRESHOLD = 0.01
|
|
||||||
TOP_K_PER_SEC = 200
|
|
||||||
|
|
||||||
|
|
||||||
class PostprocessConfig(BaseConfig):
|
|
||||||
"""Configuration for postprocessing model outputs."""
|
|
||||||
|
|
||||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
|
||||||
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
|
|
||||||
min_freq: int = Field(default=10000, gt=0)
|
|
||||||
max_freq: int = Field(default=120000, gt=0)
|
|
||||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
|
||||||
|
|
||||||
|
|
||||||
def load_postprocess_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> PostprocessConfig:
|
|
||||||
return load_config(path, schema=PostprocessConfig, field=field)
|
|
297
batdetect2/postprocess/decoding.py
Normal file
297
batdetect2/postprocess/decoding.py
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
"""Decodes extracted detection data into standard soundevent predictions.
|
||||||
|
|
||||||
|
This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
||||||
|
It takes the structured detection data extracted by the `extraction` module
|
||||||
|
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||||
|
class probabilities, and features for each detection point) and converts it
|
||||||
|
into meaningful, standardized prediction objects based on the `soundevent` data
|
||||||
|
model.
|
||||||
|
|
||||||
|
The process involves:
|
||||||
|
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||||
|
objects, using a configured geometry builder to recover bounding boxes from
|
||||||
|
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
|
||||||
|
2. Converting each `RawPrediction` into a
|
||||||
|
`soundevent.data.SoundEventPrediction`, which involves:
|
||||||
|
- Creating the `soundevent.data.SoundEvent` with geometry and features.
|
||||||
|
- Decoding the predicted class probabilities into representative tags using
|
||||||
|
a configured class decoder (`SoundEventDecoder`).
|
||||||
|
- Applying a classification threshold.
|
||||||
|
- Optionally selecting only the single highest-scoring class (top-1) or
|
||||||
|
including tags for all classes above the threshold (multi-label).
|
||||||
|
- Adding generic class tags as a baseline.
|
||||||
|
- Associating scores with the final prediction and tags.
|
||||||
|
(`convert_raw_prediction_to_sound_event_prediction`)
|
||||||
|
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
|
||||||
|
a `soundevent.data.ClipPrediction`
|
||||||
|
(`convert_raw_predictions_to_clip_prediction`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
||||||
|
from batdetect2.targets.classes import SoundEventDecoder
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"convert_xr_dataset_to_raw_prediction",
|
||||||
|
"convert_raw_predictions_to_clip_prediction",
|
||||||
|
"convert_raw_prediction_to_sound_event_prediction",
|
||||||
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
|
||||||
|
"""Default threshold applied to classification scores.
|
||||||
|
|
||||||
|
Class predictions with scores below this value are typically ignored during
|
||||||
|
decoding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def convert_xr_dataset_to_raw_prediction(
|
||||||
|
detection_dataset: xr.Dataset,
|
||||||
|
geometry_builder: GeometryBuilder,
|
||||||
|
) -> List[RawPrediction]:
|
||||||
|
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||||
|
|
||||||
|
Takes the output of the extraction step (`extract_detection_xr_dataset`)
|
||||||
|
and transforms each detection entry into an intermediate `RawPrediction`
|
||||||
|
object. This involves recovering the geometry (e.g., bounding box) from
|
||||||
|
the predicted position and scaled size dimensions using the provided
|
||||||
|
`geometry_builder` function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detection_dataset : xr.Dataset
|
||||||
|
An xarray Dataset containing aligned detection information, typically
|
||||||
|
output by `extract_detection_xr_dataset`. Expected variables include
|
||||||
|
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||||
|
Must have a 'detection' dimension.
|
||||||
|
geometry_builder : GeometryBuilder
|
||||||
|
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||||
|
of dimensions, and returns the corresponding reconstructed
|
||||||
|
`soundevent.data.Geometry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[RawPrediction]
|
||||||
|
A list of `RawPrediction` objects, each containing the detection score,
|
||||||
|
recovered bounding box coordinates (start/end time, low/high freq),
|
||||||
|
the vector of class scores, and the feature vector for one detection.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
AttributeError, KeyError, ValueError
|
||||||
|
If `detection_dataset` is missing expected variables ('scores',
|
||||||
|
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
|
||||||
|
associated with 'scores'), or if `geometry_builder` fails.
|
||||||
|
"""
|
||||||
|
detections = []
|
||||||
|
|
||||||
|
for det_num in range(detection_dataset.dims["detection"]):
|
||||||
|
det_info = detection_dataset.sel(detection=det_num)
|
||||||
|
|
||||||
|
geom = geometry_builder(
|
||||||
|
(det_info.time, det_info.freq),
|
||||||
|
det_info.dimensions,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(geom)
|
||||||
|
|
||||||
|
classes = det_info.classes
|
||||||
|
features = det_info.features
|
||||||
|
|
||||||
|
detections.append(
|
||||||
|
RawPrediction(
|
||||||
|
detection_score=det_info.score,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
|
class_scores=classes,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return detections
|
||||||
|
|
||||||
|
|
||||||
|
def convert_raw_predictions_to_clip_prediction(
|
||||||
|
raw_predictions: List[RawPrediction],
|
||||||
|
clip: data.Clip,
|
||||||
|
sound_event_decoder: SoundEventDecoder,
|
||||||
|
generic_class_tags: List[data.Tag],
|
||||||
|
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
top_class_only: bool = False,
|
||||||
|
) -> data.ClipPrediction:
|
||||||
|
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
|
||||||
|
|
||||||
|
Iterates through `raw_predictions` (assumed to belong to a single clip),
|
||||||
|
converts each one into a `soundevent.data.SoundEventPrediction` using
|
||||||
|
`convert_raw_prediction_to_sound_event_prediction`, and packages them
|
||||||
|
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw_predictions : List[RawPrediction]
|
||||||
|
List of raw prediction objects for a single clip.
|
||||||
|
clip : data.Clip
|
||||||
|
The original `soundevent.data.Clip` object these predictions belong to.
|
||||||
|
sound_event_decoder : SoundEventDecoder
|
||||||
|
Function to decode class names into representative tags.
|
||||||
|
generic_class_tags : List[data.Tag]
|
||||||
|
List of tags representing the generic class category.
|
||||||
|
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||||
|
Threshold applied to class scores during decoding.
|
||||||
|
top_class_only : bool, default=False
|
||||||
|
If True, only decode tags for the single highest-scoring class above
|
||||||
|
the threshold. If False, decode tags for all classes above threshold.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.ClipPrediction
|
||||||
|
A `ClipPrediction` object containing a list of `SoundEventPrediction`
|
||||||
|
objects corresponding to the input `raw_predictions`.
|
||||||
|
"""
|
||||||
|
return data.ClipPrediction(
|
||||||
|
clip=clip,
|
||||||
|
sound_events=[
|
||||||
|
convert_raw_prediction_to_sound_event_prediction(
|
||||||
|
prediction,
|
||||||
|
recording=clip.recording,
|
||||||
|
sound_event_decoder=sound_event_decoder,
|
||||||
|
generic_class_tags=generic_class_tags,
|
||||||
|
classification_threshold=classification_threshold,
|
||||||
|
top_class_only=top_class_only,
|
||||||
|
)
|
||||||
|
for prediction in raw_predictions
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_raw_prediction_to_sound_event_prediction(
|
||||||
|
raw_prediction: RawPrediction,
|
||||||
|
recording: data.Recording,
|
||||||
|
sound_event_decoder: SoundEventDecoder,
|
||||||
|
generic_class_tags: List[data.Tag],
|
||||||
|
classification_threshold: Optional[
|
||||||
|
float
|
||||||
|
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
top_class_only: bool = False,
|
||||||
|
):
|
||||||
|
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
|
||||||
|
|
||||||
|
This function performs the core decoding steps for a single detected event:
|
||||||
|
1. Creates a `soundevent.data.SoundEvent` containing the geometry
|
||||||
|
(BoundingBox derived from `raw_prediction` bounds) and any associated
|
||||||
|
feature vectors.
|
||||||
|
2. Initializes a list of predicted tags using the provided
|
||||||
|
`generic_class_tags`, assigning the overall `detection_score` from the
|
||||||
|
`raw_prediction` to these generic tags.
|
||||||
|
3. Processes the `class_scores` from the `raw_prediction`:
|
||||||
|
a. Optionally filters out scores below `classification_threshold`
|
||||||
|
(if it's not None).
|
||||||
|
b. Sorts the remaining scores in descending order.
|
||||||
|
c. Iterates through the sorted, thresholded class scores.
|
||||||
|
d. For each class, uses the `sound_event_decoder` to get the
|
||||||
|
representative base tags for that class name.
|
||||||
|
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
|
||||||
|
the specific `score` of that class prediction.
|
||||||
|
f. Appends these specific predicted tags to the list.
|
||||||
|
g. If `top_class_only` is True, stops after processing the first
|
||||||
|
(highest-scoring) class that passed the threshold.
|
||||||
|
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
|
||||||
|
associating the `SoundEvent`, the overall `detection_score`, and the
|
||||||
|
compiled list of `PredictedTag` objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw_prediction : RawPrediction
|
||||||
|
The raw prediction object containing score, bounds, class scores,
|
||||||
|
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
|
||||||
|
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
|
||||||
|
coordinate.
|
||||||
|
recording : data.Recording
|
||||||
|
The recording the sound event belongs to.
|
||||||
|
sound_event_decoder : SoundEventDecoder
|
||||||
|
Configured function mapping class names (str) to lists of base
|
||||||
|
`data.Tag` objects.
|
||||||
|
generic_class_tags : List[data.Tag]
|
||||||
|
List of base tags representing the generic category.
|
||||||
|
classification_threshold : float, optional
|
||||||
|
The minimum score a class prediction must have to be considered
|
||||||
|
significant enough to have its tags decoded and added. If None, no
|
||||||
|
thresholding is applied based on class score (all predicted classes,
|
||||||
|
or the top one if `top_class_only` is True, will be processed).
|
||||||
|
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
|
||||||
|
top_class_only : bool, default=False
|
||||||
|
If True, only includes tags for the single highest-scoring class that
|
||||||
|
exceeds the threshold. If False (default), includes tags for all classes
|
||||||
|
exceeding the threshold.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.SoundEventPrediction
|
||||||
|
The fully formed sound event prediction object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `raw_prediction.features` has unexpected structure or if
|
||||||
|
`data.term_from_key` (if used internally) fails.
|
||||||
|
If `sound_event_decoder` fails for a class name and errors are raised.
|
||||||
|
"""
|
||||||
|
sound_event = data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
raw_prediction.start_time,
|
||||||
|
raw_prediction.low_freq,
|
||||||
|
raw_prediction.end_time,
|
||||||
|
raw_prediction.high_freq,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
features=[
|
||||||
|
data.Feature(term=data.term_from_key(feat_name), value=value)
|
||||||
|
for feat_name, value in raw_prediction.features
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tags = [
|
||||||
|
data.PredictedTag(tag=tag, score=raw_prediction.detection_score)
|
||||||
|
for tag in generic_class_tags
|
||||||
|
]
|
||||||
|
|
||||||
|
class_scores = raw_prediction.class_scores
|
||||||
|
|
||||||
|
if classification_threshold is not None:
|
||||||
|
class_scores = class_scores.where(
|
||||||
|
class_scores > classification_threshold,
|
||||||
|
drop=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for class_name, score in class_scores.sortby(
|
||||||
|
class_scores, ascending=False
|
||||||
|
):
|
||||||
|
class_tags = sound_event_decoder(class_name)
|
||||||
|
|
||||||
|
for tag in class_tags:
|
||||||
|
tags.append(
|
||||||
|
data.PredictedTag(
|
||||||
|
tag=tag,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if top_class_only:
|
||||||
|
break
|
||||||
|
|
||||||
|
return data.SoundEventPrediction(
|
||||||
|
sound_event=sound_event,
|
||||||
|
score=raw_prediction.detection_score,
|
||||||
|
tags=tags,
|
||||||
|
)
|
162
batdetect2/postprocess/detection.py
Normal file
162
batdetect2/postprocess/detection.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
"""Extracts candidate detection points from a model output heatmap.
|
||||||
|
|
||||||
|
This module implements a specific step within the BatDetect2 postprocessing
|
||||||
|
pipeline. Its primary function is to identify potential sound event locations
|
||||||
|
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||||
|
produced by the neural network (usually after Non-Maximum Suppression and
|
||||||
|
coordinate remapping have been applied).
|
||||||
|
|
||||||
|
It provides functionality to:
|
||||||
|
- Identify the locations (time, frequency) of the highest-scoring points.
|
||||||
|
- Filter these points based on a minimum confidence score threshold.
|
||||||
|
- Limit the maximum number of detection points returned (top-k).
|
||||||
|
|
||||||
|
The main output is an `xarray.DataArray` containing the scores and
|
||||||
|
corresponding time/frequency coordinates for the extracted detection points.
|
||||||
|
This output serves as the input for subsequent postprocessing steps, such as
|
||||||
|
extracting predicted class probabilities and bounding box sizes at these
|
||||||
|
specific locations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent.arrays import Dimensions, get_dim_width
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"extract_detections_from_array",
|
||||||
|
"get_max_detections",
|
||||||
|
"DEFAULT_DETECTION_THRESHOLD",
|
||||||
|
"TOP_K_PER_SEC",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||||
|
"""Default confidence score threshold used for filtering detections."""
|
||||||
|
|
||||||
|
TOP_K_PER_SEC = 200
|
||||||
|
"""Default desired maximum number of detections per second of audio."""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_detections_from_array(
|
||||||
|
detection_array: xr.DataArray,
|
||||||
|
max_detections: Optional[int] = None,
|
||||||
|
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Extract detection locations (time, freq) and scores from a heatmap.
|
||||||
|
|
||||||
|
Identifies the pixels with the highest scores in the input detection
|
||||||
|
heatmap, filters them based on an optional score `threshold`, limits the
|
||||||
|
number to an optional `max_detections`, and returns their scores along with
|
||||||
|
their corresponding time and frequency coordinates.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detection_array : xr.DataArray
|
||||||
|
A 2D xarray DataArray representing the detection heatmap. Must have
|
||||||
|
dimensions and coordinates named 'time' and 'frequency'. Higher values
|
||||||
|
are assumed to indicate higher detection confidence.
|
||||||
|
max_detections : int, optional
|
||||||
|
The absolute maximum number of detections to return. If specified, only
|
||||||
|
the top `max_detections` highest-scoring detections (passing the
|
||||||
|
threshold) are returned. If None (default), all detections passing
|
||||||
|
the threshold are returned, sorted by score.
|
||||||
|
threshold : float, optional
|
||||||
|
The minimum confidence score required for a detection peak to be
|
||||||
|
kept. Detections with scores below this value are discarded.
|
||||||
|
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
|
||||||
|
thresholding is applied.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
A 1D xarray DataArray named 'score' with a 'detection' dimension.
|
||||||
|
- The data values are the scores of the extracted detections, sorted
|
||||||
|
in descending order.
|
||||||
|
- It has coordinates 'time' and 'frequency' (also indexed by the
|
||||||
|
'detection' dimension) indicating the location of each detection
|
||||||
|
peak in the original coordinate system.
|
||||||
|
- Returns an empty DataArray if no detections pass the criteria.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `max_detections` is not None and not a positive integer, or if
|
||||||
|
`detection_array` lacks required dimensions/coordinates.
|
||||||
|
"""
|
||||||
|
if max_detections is not None:
|
||||||
|
if max_detections <= 0:
|
||||||
|
raise ValueError("Max detections must be positive")
|
||||||
|
|
||||||
|
values = detection_array.values.flatten()
|
||||||
|
|
||||||
|
if max_detections is not None:
|
||||||
|
top_indices = np.argpartition(-values, max_detections)[:max_detections]
|
||||||
|
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
|
||||||
|
else:
|
||||||
|
top_sorted_indices = np.argsort(-values)
|
||||||
|
|
||||||
|
top_values = values[top_sorted_indices]
|
||||||
|
|
||||||
|
if threshold is not None:
|
||||||
|
mask = top_values > threshold
|
||||||
|
top_values = top_values[mask]
|
||||||
|
top_sorted_indices = top_sorted_indices[mask]
|
||||||
|
|
||||||
|
time_indices, freq_indices = np.unravel_index(
|
||||||
|
top_sorted_indices,
|
||||||
|
detection_array.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
times = detection_array.coords[Dimensions.time.value].values[time_indices]
|
||||||
|
freqs = detection_array.coords[Dimensions.frequency.value].values[
|
||||||
|
freq_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=top_values,
|
||||||
|
coords={
|
||||||
|
Dimensions.frequency.value: ("detection", freqs),
|
||||||
|
Dimensions.time.value: ("detection", times),
|
||||||
|
},
|
||||||
|
dims="detection",
|
||||||
|
name="score",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_detections(
|
||||||
|
detection_array: xr.DataArray,
|
||||||
|
top_k_per_sec: int = TOP_K_PER_SEC,
|
||||||
|
) -> int:
|
||||||
|
"""Calculate max detections allowed based on duration and rate.
|
||||||
|
|
||||||
|
Determines the total maximum number of detections to extract from a
|
||||||
|
heatmap based on its time duration and a desired rate of detections
|
||||||
|
per second.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detection_array : xr.DataArray
|
||||||
|
The detection heatmap, requiring 'time' coordinates from which the
|
||||||
|
total duration can be calculated using
|
||||||
|
`soundevent.arrays.get_dim_width`.
|
||||||
|
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||||
|
The desired maximum number of detections to allow per second of audio.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
The calculated total maximum number of detections allowed for the
|
||||||
|
entire duration of the `detection_array`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the duration cannot be calculated from the `detection_array` (e.g.,
|
||||||
|
missing or invalid 'time' coordinates/dimension).
|
||||||
|
"""
|
||||||
|
if top_k_per_sec < 0:
|
||||||
|
raise ValueError("top_k_per_sec cannot be negative.")
|
||||||
|
|
||||||
|
duration = get_dim_width(detection_array, Dimensions.time.value)
|
||||||
|
return int(duration * top_k_per_sec)
|
122
batdetect2/postprocess/extraction.py
Normal file
122
batdetect2/postprocess/extraction.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
"""Extracts associated data for detected points from model output arrays.
|
||||||
|
|
||||||
|
This module implements a key step (Step 4) in the BatDetect2 postprocessing
|
||||||
|
pipeline. After candidate detection points (time, frequency, score) have been
|
||||||
|
identified, this module extracts the corresponding values from other raw model
|
||||||
|
output arrays, such as:
|
||||||
|
|
||||||
|
- Predicted bounding box sizes (width, height).
|
||||||
|
- Class probability scores for each defined target class.
|
||||||
|
- Intermediate feature vectors.
|
||||||
|
|
||||||
|
It uses coordinate-based indexing provided by `xarray` to ensure that the
|
||||||
|
correct values are retrieved from the original heatmaps/feature maps at the
|
||||||
|
precise time-frequency location of each detection. The final output aggregates
|
||||||
|
all extracted information into a structured `xarray.Dataset`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent.arrays import Dimensions
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"extract_values_at_positions",
|
||||||
|
"extract_detection_xr_dataset",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_values_at_positions(
|
||||||
|
array: xr.DataArray,
|
||||||
|
positions: xr.DataArray,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Extract values from an array at specified time-frequency positions.
|
||||||
|
|
||||||
|
Uses coordinate-based indexing to retrieve values from a source `array`
|
||||||
|
(e.g., class probabilities, size predictions, features) at the time and
|
||||||
|
frequency coordinates defined in the `positions` array.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
array : xr.DataArray
|
||||||
|
The source DataArray from which to extract values. Must have 'time'
|
||||||
|
and 'frequency' dimensions and coordinates matching the space of
|
||||||
|
`positions`.
|
||||||
|
positions : xr.DataArray
|
||||||
|
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
|
||||||
|
locations from which to extract values.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
A DataArray containing the values extracted from `array` at the given
|
||||||
|
positions.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError, IndexError, KeyError
|
||||||
|
If dimensions or coordinates are missing or incompatible between
|
||||||
|
`array` and `positions`, or if selection fails.
|
||||||
|
"""
|
||||||
|
return array.sel(
|
||||||
|
**{
|
||||||
|
Dimensions.frequency.value: positions.coords[
|
||||||
|
Dimensions.frequency.value
|
||||||
|
],
|
||||||
|
Dimensions.time.value: positions.coords[Dimensions.time.value],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_detection_xr_dataset(
|
||||||
|
positions: xr.DataArray,
|
||||||
|
sizes: xr.DataArray,
|
||||||
|
classes: xr.DataArray,
|
||||||
|
features: xr.DataArray,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
"""Combine extracted detection information into a structured xr.Dataset.
|
||||||
|
|
||||||
|
Takes the detection positions/scores and the full model output heatmaps
|
||||||
|
(sizes, classes, optional features), extracts the relevant data at the
|
||||||
|
detection positions, and packages everything into a single `xarray.Dataset`
|
||||||
|
where all variables are indexed by a common 'detection' dimension.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
positions : xr.DataArray
|
||||||
|
Output from `extract_detections_from_array`, containing detection
|
||||||
|
scores as data and 'time', 'frequency' coordinates along the
|
||||||
|
'detection' dimension.
|
||||||
|
sizes : xr.DataArray
|
||||||
|
The full size prediction heatmap from the model, with dimensions like
|
||||||
|
('dimension', 'time', 'frequency').
|
||||||
|
classes : xr.DataArray
|
||||||
|
The full class probability heatmap from the model, with dimensions like
|
||||||
|
('category', 'time', 'frequency').
|
||||||
|
features : xr.DataArray
|
||||||
|
The full feature map from the model, with
|
||||||
|
dimensions like ('feature', 'time', 'frequency').
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
An xarray Dataset containing aligned information for each detection:
|
||||||
|
- 'scores': DataArray from `positions` (score data, time/freq coords).
|
||||||
|
- 'dimensions': DataArray with extracted size values
|
||||||
|
(dims: 'detection', 'dimension').
|
||||||
|
- 'classes': DataArray with extracted class probabilities
|
||||||
|
(dims: 'detection', 'category').
|
||||||
|
- 'features': DataArray with extracted feature vectors
|
||||||
|
(dims: 'detection', 'feature'), if `features` was provided. All
|
||||||
|
DataArrays share the 'detection' dimension and associated
|
||||||
|
time/frequency coordinates.
|
||||||
|
"""
|
||||||
|
sizes = extract_values_at_positions(sizes, positions).T
|
||||||
|
classes = extract_values_at_positions(classes, positions).T
|
||||||
|
features = extract_values_at_positions(features, positions).T
|
||||||
|
return xr.Dataset(
|
||||||
|
{
|
||||||
|
"scores": positions,
|
||||||
|
"dimensions": sizes,
|
||||||
|
"classes": classes,
|
||||||
|
"features": features,
|
||||||
|
}
|
||||||
|
)
|
96
batdetect2/postprocess/nms.py
Normal file
96
batdetect2/postprocess/nms.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
"""Performs Non-Maximum Suppression (NMS) on detection heatmaps.
|
||||||
|
|
||||||
|
This module provides functionality to apply Non-Maximum Suppression, a common
|
||||||
|
technique used after model inference, particularly in object detection and peak
|
||||||
|
detection tasks.
|
||||||
|
|
||||||
|
In the context of BatDetect2 postprocessing, NMS is applied
|
||||||
|
to the raw detection heatmap output by the neural network. Its purpose is to
|
||||||
|
isolate distinct detection peaks by suppressing (setting to zero) nearby heatmap
|
||||||
|
activations that have lower scores than a local maximum. This helps prevent
|
||||||
|
multiple, overlapping detections originating from the same sound event.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
NMS_KERNEL_SIZE = 9
|
||||||
|
"""Default kernel size (pixels) for Non-Maximum Suppression.
|
||||||
|
|
||||||
|
Specifies the side length of the square neighborhood used by default in
|
||||||
|
`non_max_suppression` to find local maxima. A 9x9 neighborhood is often
|
||||||
|
a reasonable starting point for typical spectrogram resolutions used in
|
||||||
|
BatDetect2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def non_max_suppression(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
|
||||||
|
|
||||||
|
This function identifies local maxima within a defined neighborhood for
|
||||||
|
each point in the input tensor. Values that are *not* the maximum within
|
||||||
|
their neighborhood are suppressed (set to zero). This is commonly used on
|
||||||
|
detection probability heatmaps to isolate distinct peaks corresponding to
|
||||||
|
individual detections and remove redundant lower scores nearby.
|
||||||
|
|
||||||
|
The implementation uses efficient 2D max pooling to find the maximum value
|
||||||
|
in the neighborhood of each point.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tensor : torch.Tensor
|
||||||
|
Input tensor, typically representing a detection heatmap. Must be a
|
||||||
|
3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying
|
||||||
|
`torch.nn.functional.max_pool2d` operation.
|
||||||
|
kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE
|
||||||
|
Size of the sliding window neighborhood used to find local maxima.
|
||||||
|
If an integer `k` is provided, a square kernel of size `(k, k)` is used.
|
||||||
|
If a tuple `(h, w)` is provided, a rectangular kernel of height `h`
|
||||||
|
and width `w` is used. The kernel size should typically be odd to
|
||||||
|
have a well-defined center.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
A tensor of the same shape as the input, where only local maxima within
|
||||||
|
their respective neighborhoods (defined by `kernel_size`) retain their
|
||||||
|
original values. All other values are set to zero.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
TypeError
|
||||||
|
If `kernel_size` is not an int or a tuple of two ints.
|
||||||
|
RuntimeError
|
||||||
|
If the input `tensor` does not have 3 or 4 dimensions (as required
|
||||||
|
by `max_pool2d`).
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
- The function assumes higher values in the tensor indicate stronger peaks.
|
||||||
|
- Choosing an appropriate `kernel_size` is important. It should be large
|
||||||
|
enough to cover the typical "footprint" of a single detection peak plus
|
||||||
|
some surrounding context, effectively preventing multiple detections for
|
||||||
|
the same event. A size that is too large might suppress nearby distinct
|
||||||
|
events.
|
||||||
|
"""
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size_h = kernel_size
|
||||||
|
kernel_size_w = kernel_size
|
||||||
|
else:
|
||||||
|
kernel_size_h, kernel_size_w = kernel_size
|
||||||
|
|
||||||
|
pad_h = (kernel_size_h - 1) // 2
|
||||||
|
pad_w = (kernel_size_w - 1) // 2
|
||||||
|
|
||||||
|
hmax = torch.nn.functional.max_pool2d(
|
||||||
|
tensor,
|
||||||
|
(kernel_size_h, kernel_size_w),
|
||||||
|
stride=1,
|
||||||
|
padding=(pad_h, pad_w),
|
||||||
|
)
|
||||||
|
keep = (hmax == tensor).float()
|
||||||
|
return tensor * keep
|
@ -1,50 +0,0 @@
|
|||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
NMS_KERNEL_SIZE = 9
|
|
||||||
|
|
||||||
|
|
||||||
def non_max_suppression(
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Run non-maximum suppression on a tensor.
|
|
||||||
|
|
||||||
This function removes values from the input tensor that are not local
|
|
||||||
maxima in the neighborhood of the given kernel size.
|
|
||||||
|
|
||||||
All non-maximum values are set to zero.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
tensor : torch.Tensor
|
|
||||||
Input tensor.
|
|
||||||
kernel_size : Union[int, Tuple[int, int]], optional
|
|
||||||
Size of the neighborhood to consider for non-maximum suppression.
|
|
||||||
If an integer is given, the neighborhood will be a square of the
|
|
||||||
given size. If a tuple is given, the neighborhood will be a
|
|
||||||
rectangle with the given height and width.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Tensor with non-maximum suppressed values.
|
|
||||||
"""
|
|
||||||
if isinstance(kernel_size, int):
|
|
||||||
kernel_size_h = kernel_size
|
|
||||||
kernel_size_w = kernel_size
|
|
||||||
else:
|
|
||||||
kernel_size_h, kernel_size_w = kernel_size
|
|
||||||
|
|
||||||
pad_h = (kernel_size_h - 1) // 2
|
|
||||||
pad_w = (kernel_size_w - 1) // 2
|
|
||||||
|
|
||||||
hmax = torch.nn.functional.max_pool2d(
|
|
||||||
tensor,
|
|
||||||
(kernel_size_h, kernel_size_w),
|
|
||||||
stride=1,
|
|
||||||
padding=(pad_h, pad_w),
|
|
||||||
)
|
|
||||||
keep = (hmax == tensor).float()
|
|
||||||
return tensor * keep
|
|
316
batdetect2/postprocess/remapping.py
Normal file
316
batdetect2/postprocess/remapping.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
"""Remaps raw model output tensors to coordinate-aware xarray DataArrays.
|
||||||
|
|
||||||
|
This module provides utility functions to convert the raw numerical outputs
|
||||||
|
(typically PyTorch tensors) from the BatDetect2 DNN model into
|
||||||
|
`xarray.DataArray` objects. This step adds coordinate information
|
||||||
|
(time in seconds, frequency in Hz) back to the model's predictions, making them
|
||||||
|
interpretable in the context of the original audio signal and facilitating
|
||||||
|
subsequent processing steps.
|
||||||
|
|
||||||
|
Functions are provided for common BatDetect2 output types: detection heatmaps,
|
||||||
|
classification probability maps, size prediction maps, and potentially
|
||||||
|
intermediate features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent.arrays import Dimensions
|
||||||
|
|
||||||
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"features_to_xarray",
|
||||||
|
"detection_to_xarray",
|
||||||
|
"classification_to_xarray",
|
||||||
|
"sizes_to_xarray",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def features_to_xarray(
|
||||||
|
features: torch.Tensor,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
min_freq: float = MIN_FREQ,
|
||||||
|
max_freq: float = MAX_FREQ,
|
||||||
|
features_prefix: str = "batdetect2_feature_",
|
||||||
|
):
|
||||||
|
"""Convert a multi-channel feature tensor to a coordinate-aware DataArray.
|
||||||
|
|
||||||
|
Assigns time, frequency, and feature coordinates to a raw feature tensor
|
||||||
|
output by the model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
features : torch.Tensor
|
||||||
|
The raw feature tensor from the model. Expected shape is
|
||||||
|
(num_features, num_freq_bins, num_time_bins).
|
||||||
|
start_time : float
|
||||||
|
The start time (in seconds) corresponding to the first time bin of
|
||||||
|
the tensor.
|
||||||
|
end_time : float
|
||||||
|
The end time (in seconds) corresponding to the *end* of the last time
|
||||||
|
bin.
|
||||||
|
min_freq : float, default=MIN_FREQ
|
||||||
|
The minimum frequency (in Hz) corresponding to the first frequency bin.
|
||||||
|
max_freq : float, default=MAX_FREQ
|
||||||
|
The maximum frequency (in Hz) corresponding to the *end* of the last
|
||||||
|
frequency bin.
|
||||||
|
features_prefix : str, default="batdetect2_feature_"
|
||||||
|
Prefix used to generate names for the feature coordinate dimension
|
||||||
|
(e.g., "batdetect2_feature_0", "batdetect2_feature_1", ...).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
An xarray DataArray containing the feature data with named dimensions
|
||||||
|
('feature', 'frequency', 'time') and calculated coordinates.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the input tensor does not have 3 dimensions.
|
||||||
|
"""
|
||||||
|
if features.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"Input features tensor must have 3 dimensions (C, T, F), "
|
||||||
|
f"got shape {features.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_features, height, width = features.shape
|
||||||
|
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||||
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=features.detach().numpy(),
|
||||||
|
dims=[
|
||||||
|
Dimensions.feature.value,
|
||||||
|
Dimensions.frequency.value,
|
||||||
|
Dimensions.time.value,
|
||||||
|
],
|
||||||
|
coords={
|
||||||
|
Dimensions.feature.value: [
|
||||||
|
f"{features_prefix}{i}" for i in range(num_features)
|
||||||
|
],
|
||||||
|
Dimensions.frequency.value: freqs,
|
||||||
|
Dimensions.time.value: times,
|
||||||
|
},
|
||||||
|
name="features",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def detection_to_xarray(
|
||||||
|
detection: torch.Tensor,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
min_freq: float = MIN_FREQ,
|
||||||
|
max_freq: float = MAX_FREQ,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Convert a single-channel detection heatmap tensor to a DataArray.
|
||||||
|
|
||||||
|
Assigns time and frequency coordinates to a raw detection heatmap tensor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detection : torch.Tensor
|
||||||
|
Raw detection heatmap tensor from the model. Expected shape is
|
||||||
|
(1, num_freq_bins, num_time_bins).
|
||||||
|
start_time : float
|
||||||
|
Start time (seconds) corresponding to the first time bin.
|
||||||
|
end_time : float
|
||||||
|
End time (seconds) corresponding to the end of the last time bin.
|
||||||
|
min_freq : float, default=MIN_FREQ
|
||||||
|
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||||
|
max_freq : float, default=MAX_FREQ
|
||||||
|
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||||
|
bin.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
An xarray DataArray containing the detection scores with named
|
||||||
|
dimensions ('frequency', 'time') and calculated coordinates.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the input tensor does not have 3 dimensions or if the first
|
||||||
|
dimension size is not 1.
|
||||||
|
"""
|
||||||
|
if detection.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"Input detection tensor must have 3 dimensions (1, T, F), "
|
||||||
|
f"got shape {detection.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_channels, height, width = detection.shape
|
||||||
|
|
||||||
|
if num_channels != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected a single channel output, instead got "
|
||||||
|
f"{num_channels} channels"
|
||||||
|
)
|
||||||
|
|
||||||
|
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||||
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=detection.squeeze(dim=0).detach().numpy(),
|
||||||
|
dims=[
|
||||||
|
Dimensions.frequency.value,
|
||||||
|
Dimensions.time.value,
|
||||||
|
],
|
||||||
|
coords={
|
||||||
|
Dimensions.frequency.value: freqs,
|
||||||
|
Dimensions.time.value: times,
|
||||||
|
},
|
||||||
|
name="detection_score",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def classification_to_xarray(
|
||||||
|
classes: torch.Tensor,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
class_names: List[str],
|
||||||
|
min_freq: float = MIN_FREQ,
|
||||||
|
max_freq: float = MAX_FREQ,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Convert multi-channel class probability tensor to a DataArray.
|
||||||
|
|
||||||
|
Assigns category (class name), frequency, and time coordinates to a raw
|
||||||
|
class probability tensor output by the model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
classes : torch.Tensor
|
||||||
|
Raw class probability tensor. Expected shape is
|
||||||
|
(num_classes, num_freq_bins, num_time_bins).
|
||||||
|
start_time : float
|
||||||
|
Start time (seconds) corresponding to the first time bin.
|
||||||
|
end_time : float
|
||||||
|
End time (seconds) corresponding to the end of the last time bin.
|
||||||
|
class_names : List[str]
|
||||||
|
Ordered list of class names corresponding to the first dimension
|
||||||
|
of the `classes` tensor. The length must match `classes.shape[0]`.
|
||||||
|
min_freq : float, default=MIN_FREQ
|
||||||
|
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||||
|
max_freq : float, default=MAX_FREQ
|
||||||
|
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||||
|
bin.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
An xarray DataArray containing class probabilities with named
|
||||||
|
dimensions ('category', 'frequency', 'time') and calculated
|
||||||
|
coordinates.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the input tensor does not have 3 dimensions, or if the size of the
|
||||||
|
first dimension does not match the length of `class_names`.
|
||||||
|
"""
|
||||||
|
if classes.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"Input classes tensor must have 3 dimensions (C, F, T), "
|
||||||
|
f"got shape {classes.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_classes, height, width = classes.shape
|
||||||
|
|
||||||
|
if num_classes != len(class_names):
|
||||||
|
raise ValueError(
|
||||||
|
"The number of classes does not coincide with the number of "
|
||||||
|
"class names provided: "
|
||||||
|
f"({num_classes = }) != ({len(class_names) = })"
|
||||||
|
)
|
||||||
|
|
||||||
|
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||||
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=classes.detach().numpy(),
|
||||||
|
dims=[
|
||||||
|
"category",
|
||||||
|
Dimensions.frequency.value,
|
||||||
|
Dimensions.time.value,
|
||||||
|
],
|
||||||
|
coords={
|
||||||
|
"category": class_names,
|
||||||
|
Dimensions.frequency.value: freqs,
|
||||||
|
Dimensions.time.value: times,
|
||||||
|
},
|
||||||
|
name="class_scores",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sizes_to_xarray(
|
||||||
|
sizes: torch.Tensor,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
min_freq: float = MIN_FREQ,
|
||||||
|
max_freq: float = MAX_FREQ,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Convert the 2-channel size prediction tensor to a DataArray.
|
||||||
|
|
||||||
|
Assigns dimension ('width', 'height'), frequency, and time coordinates
|
||||||
|
to the raw size prediction tensor output by the model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sizes : torch.Tensor
|
||||||
|
Raw size prediction tensor. Expected shape is
|
||||||
|
(2, num_freq_bins, num_time_bins), where the first dimension
|
||||||
|
corresponds to predicted width and height respectively.
|
||||||
|
start_time : float
|
||||||
|
Start time (seconds) corresponding to the first time bin.
|
||||||
|
end_time : float
|
||||||
|
End time (seconds) corresponding to the end of the last time bin.
|
||||||
|
min_freq : float, default=MIN_FREQ
|
||||||
|
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||||
|
max_freq : float, default=MAX_FREQ
|
||||||
|
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||||
|
bin.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
An xarray DataArray containing predicted sizes with named dimensions
|
||||||
|
('dimension', 'frequency', 'time') and calculated time/frequency
|
||||||
|
coordinates. The 'dimension' coordinate will have values
|
||||||
|
['width', 'height'].
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the input tensor does not have 3 dimensions or if the first
|
||||||
|
dimension size is not exactly 2.
|
||||||
|
"""
|
||||||
|
num_channels, height, width = sizes.shape
|
||||||
|
|
||||||
|
if num_channels != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected a two-channel output, instead got "
|
||||||
|
f"{num_channels} channels"
|
||||||
|
)
|
||||||
|
|
||||||
|
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||||
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=sizes.detach().numpy(),
|
||||||
|
dims=[
|
||||||
|
"dimension",
|
||||||
|
Dimensions.frequency.value,
|
||||||
|
Dimensions.time.value,
|
||||||
|
],
|
||||||
|
coords={
|
||||||
|
"dimension": ["width", "height"],
|
||||||
|
Dimensions.frequency.value: freqs,
|
||||||
|
Dimensions.time.value: times,
|
||||||
|
},
|
||||||
|
)
|
@ -1,21 +1,284 @@
|
|||||||
from typing import Dict, NamedTuple, Protocol
|
"""Defines shared interfaces and data structures for postprocessing.
|
||||||
|
|
||||||
|
This module centralizes the Protocol definitions and common data structures
|
||||||
|
used throughout the `batdetect2.postprocess` module.
|
||||||
|
|
||||||
|
The main component is the `PostprocessorProtocol`, which outlines the standard
|
||||||
|
interface for an object responsible for executing the entire postprocessing
|
||||||
|
pipeline. This pipeline transforms raw neural network outputs into interpretable
|
||||||
|
detections represented as `soundevent` objects. Using protocols ensures
|
||||||
|
modularity and consistent interaction between different parts of the BatDetect2
|
||||||
|
system that deal with model predictions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Callable, List, NamedTuple, Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.models.types import ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BatDetect2Prediction",
|
"RawPrediction",
|
||||||
|
"PostprocessorProtocol",
|
||||||
|
"GeometryBuilder",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2Prediction(NamedTuple):
|
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
|
||||||
|
"""Type alias for a function that recovers geometry from position and size.
|
||||||
|
|
||||||
|
This callable takes:
|
||||||
|
1. A position tuple `(time, frequency)`.
|
||||||
|
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||||
|
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||||
|
`BoundingBox`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class RawPrediction(NamedTuple):
|
||||||
|
"""Intermediate representation of a single detected sound event.
|
||||||
|
|
||||||
|
Holds extracted information about a detection after initial processing
|
||||||
|
(like peak finding, coordinate remapping, geometry recovery) but before
|
||||||
|
final class decoding and conversion into a `SoundEventPrediction`. This
|
||||||
|
can be useful for evaluation or simpler data handling formats.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
start_time : float
|
||||||
|
Start time of the recovered bounding box in seconds.
|
||||||
|
end_time : float
|
||||||
|
End time of the recovered bounding box in seconds.
|
||||||
|
low_freq : float
|
||||||
|
Lowest frequency of the recovered bounding box in Hz.
|
||||||
|
high_freq : float
|
||||||
|
Highest frequency of the recovered bounding box in Hz.
|
||||||
|
detection_score : float
|
||||||
|
The confidence score associated with this detection, typically from
|
||||||
|
the detection heatmap peak.
|
||||||
|
class_scores : xr.DataArray
|
||||||
|
An xarray DataArray containing the predicted probabilities or scores
|
||||||
|
for each target class at the detection location. Indexed by a
|
||||||
|
'category' coordinate containing class names.
|
||||||
|
features : xr.DataArray
|
||||||
|
An xarray DataArray containing extracted feature vectors at the
|
||||||
|
detection location. Indexed by a 'feature' coordinate.
|
||||||
|
"""
|
||||||
|
|
||||||
start_time: float
|
start_time: float
|
||||||
end_time: float
|
end_time: float
|
||||||
low_freq: float
|
low_freq: float
|
||||||
high_freq: float
|
high_freq: float
|
||||||
detection_score: float
|
detection_score: float
|
||||||
class_scores: Dict[str, float]
|
class_scores: xr.DataArray
|
||||||
features: np.ndarray
|
features: xr.DataArray
|
||||||
|
|
||||||
|
|
||||||
class PostprocessorProtocol(Protocol):
|
class PostprocessorProtocol(Protocol):
|
||||||
pass
|
"""Protocol defining the interface for the full postprocessing pipeline.
|
||||||
|
|
||||||
|
This protocol outlines the standard methods for an object that takes raw
|
||||||
|
output from a BatDetect2 model and the corresponding input clip metadata,
|
||||||
|
and processes it through various stages (e.g., coordinate remapping, NMS,
|
||||||
|
detection extraction, data extraction, decoding) to produce interpretable
|
||||||
|
results at different levels of completion.
|
||||||
|
|
||||||
|
Implementations manage the configured logic for all postprocessing steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_feature_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Remap feature tensors to coordinate-aware DataArrays.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch, expected
|
||||||
|
to contain the necessary feature tensors.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects, one for each item in the
|
||||||
|
processed batch. This list provides the timing, recording, and
|
||||||
|
other metadata context needed to calculate real-world coordinates
|
||||||
|
(seconds, Hz) for the output arrays. The length of this list must
|
||||||
|
correspond to the batch size of the `output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
A list of xarray DataArrays, one for each input clip in the batch,
|
||||||
|
in the same order. Each DataArray contains the feature vectors
|
||||||
|
with dimensions like ('feature', 'time', 'frequency') and
|
||||||
|
corresponding real-world coordinates.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_detection_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Remap detection tensors to coordinate-aware DataArrays.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch,
|
||||||
|
containing detection heatmaps.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing coordinate context. Must match the batch size of
|
||||||
|
`output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
A list of 2D xarray DataArrays (one per input clip, in order),
|
||||||
|
representing the detection heatmap with 'time' and 'frequency'
|
||||||
|
coordinates. Values typically indicate detection confidence.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_classification_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Remap classification tensors to coordinate-aware DataArrays.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch,
|
||||||
|
containing class probability tensors.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing coordinate context. Must match the batch size of
|
||||||
|
`output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||||
|
representing class probabilities with 'category', 'time', and
|
||||||
|
'frequency' dimensions and coordinates.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_sizes_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Remap size prediction tensors to coordinate-aware DataArrays.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch,
|
||||||
|
containing predicted size tensors (e.g., width and height).
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing coordinate context. Must match the batch size of
|
||||||
|
`output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||||
|
representing predicted sizes with 'dimension'
|
||||||
|
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
|
||||||
|
coordinates. Values represent estimated detection sizes.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_detection_datasets(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.Dataset]:
|
||||||
|
"""Perform remapping, NMS, detection, and data extraction for a batch.
|
||||||
|
|
||||||
|
Processes the raw model output for a batch to identify detection peaks
|
||||||
|
and extract all associated information (score, position, size, class
|
||||||
|
probs, features) at those peak locations, returning a structured
|
||||||
|
dataset for each input clip in the batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing context. Must match the batch size of `output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.Dataset]
|
||||||
|
A list of xarray Datasets (one per input clip, in order). Each
|
||||||
|
Dataset contains multiple DataArrays ('scores', 'dimensions',
|
||||||
|
'classes', 'features') sharing a common 'detection' dimension,
|
||||||
|
providing aligned data for each detected event in that clip.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_raw_predictions(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[List[RawPrediction]]:
|
||||||
|
"""Extract intermediate RawPrediction objects for a batch.
|
||||||
|
|
||||||
|
Processes the raw model output for a batch through remapping, NMS,
|
||||||
|
detection, data extraction, and geometry recovery to produce a list of
|
||||||
|
`RawPrediction` objects for each corresponding input clip. This provides
|
||||||
|
a simplified, intermediate representation before final tag decoding.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing context. Must match the batch size of `output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[List[RawPrediction]]
|
||||||
|
A list of lists (one inner list per input clip, in order). Each
|
||||||
|
inner list contains the `RawPrediction` objects extracted for the
|
||||||
|
corresponding input clip.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_predictions(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[data.ClipPrediction]:
|
||||||
|
"""Perform the full postprocessing pipeline for a batch.
|
||||||
|
|
||||||
|
Takes raw model output for a batch and corresponding clips, applies the
|
||||||
|
entire postprocessing chain, and returns the final, interpretable
|
||||||
|
predictions as a list of `soundevent.data.ClipPrediction` objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing context. Must match the batch size of `output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[data.ClipPrediction]
|
||||||
|
A list containing one `ClipPrediction` object for each input clip
|
||||||
|
(in the same order), populated with `SoundEventPrediction` objects
|
||||||
|
representing the final detections with decoded tags and geometry.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
Loading…
Reference in New Issue
Block a user