Updated postprocess module with docstrings

This commit is contained in:
mbsantiago 2025-04-20 13:56:18 +01:00
parent 089328a4f0
commit bcf339c40d
10 changed files with 1828 additions and 161 deletions

View File

@ -0,0 +1,566 @@
"""Main entry point for the BatDetect2 Postprocessing pipeline.
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
BatDetect2 neural network model and transforms them into meaningful, structured
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
containing detected sound events with associated class tags and geometry.
The pipeline involves several configurable steps, implemented in submodules:
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
coordinates to raw model output arrays.
3. Detection Extraction (`.detection`): Identifies candidate detection points
(location and score) based on thresholds and score ranking (top-k).
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
class probabilities, features) at the detected locations.
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
class predictions into interpretable `soundevent` objects, including
recovering geometry (ROIs) and decoding class names back to standard tags.
This module provides the primary interface:
- `PostprocessConfig`: A configuration object for postprocessing parameters
(thresholds, NMS kernel size, etc.).
- `load_postprocess_config`: Function to load the configuration from a file.
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
holds the configured pipeline logic.
- `build_postprocessor`: A factory function to create a `Postprocessor`
instance, linking it to the necessary target definitions (`TargetProtocol`).
It also re-exports key components from submodules for convenience.
"""
from typing import List, Optional
import xarray as xr
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction,
)
from batdetect2.postprocess.detection import (
DEFAULT_DETECTION_THRESHOLD,
TOP_K_PER_SEC,
extract_detections_from_array,
get_max_detections,
)
from batdetect2.postprocess.extraction import (
extract_detection_xr_dataset,
)
from batdetect2.postprocess.nms import (
NMS_KERNEL_SIZE,
non_max_suppression,
)
from batdetect2.postprocess.remapping import (
classification_to_xarray,
detection_to_xarray,
features_to_xarray,
sizes_to_xarray,
)
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets.types import TargetProtocol
__all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD",
"DEFAULT_DETECTION_THRESHOLD",
"MAX_FREQ",
"MIN_FREQ",
"ModelOutput",
"NMS_KERNEL_SIZE",
"PostprocessConfig",
"Postprocessor",
"PostprocessorProtocol",
"RawPrediction",
"TOP_K_PER_SEC",
"build_postprocessor",
"classification_to_xarray",
"convert_raw_predictions_to_clip_prediction",
"convert_xr_dataset_to_raw_prediction",
"detection_to_xarray",
"extract_detection_xr_dataset",
"extract_detections_from_array",
"features_to_xarray",
"get_max_detections",
"load_postprocess_config",
"non_max_suppression",
"sizes_to_xarray",
]
class PostprocessConfig(BaseConfig):
"""Configuration settings for the postprocessing pipeline.
Defines tunable parameters that control how raw model outputs are
converted into final detections.
Attributes
----------
nms_kernel_size : int, default=NMS_KERNEL_SIZE
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
Used to suppress weaker detections near stronger peaks. Must be
positive.
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
Minimum confidence score from the detection heatmap required to
consider a point as a potential detection. Must be >= 0.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Minimum confidence score for a specific class prediction to be included
in the decoded tags for a detection. Must be >= 0.
top_k_per_sec : int, default=TOP_K_PER_SEC
Desired maximum number of detections per second of audio. Used by
`get_max_detections` to calculate an absolute limit based on clip
duration before applying `extract_detections_from_array`. Must be
positive.
"""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(
default=DEFAULT_DETECTION_THRESHOLD,
ge=0,
)
classification_threshold: float = Field(
default=DEFAULT_CLASSIFICATION_THRESHOLD,
ge=0,
)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
"""Load the postprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PostprocessConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
postprocessing configuration (e.g., "inference.postprocessing").
If None, the entire file content is used.
Returns
-------
PostprocessConfig
The loaded and validated postprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`PostprocessConfig` schema.
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path, schema=PostprocessConfig, field=field)
def build_postprocessor(
targets: TargetProtocol,
config: Optional[PostprocessConfig] = None,
max_freq: int = MAX_FREQ,
min_freq: int = MIN_FREQ,
) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor.
Creates and initializes the `Postprocessor` instance, providing it with the
necessary `targets` object and the `PostprocessConfig`.
Parameters
----------
targets : TargetProtocol
An initialized object conforming to the `TargetProtocol`, providing
methods like `.decode()` and `.recover_roi()`, and attributes like
`.class_names` and `.generic_class_tags`. This links postprocessing
to the defined target semantics and geometry mappings.
config : PostprocessConfig, optional
Configuration object specifying postprocessing parameters (thresholds,
NMS kernel size, etc.). If None, default settings defined in
`PostprocessConfig` will be used.
min_freq : int, default=MIN_FREQ
The minimum frequency (Hz) corresponding to the frequency axis of the
model outputs. Required for coordinate remapping. Consider setting via
`PostprocessConfig` instead for better encapsulation.
max_freq : int, default=MAX_FREQ
The maximum frequency (Hz) corresponding to the frequency axis of the
model outputs. Required for coordinate remapping. Consider setting via
`PostprocessConfig`.
Returns
-------
PostprocessorProtocol
An initialized `Postprocessor` instance ready to process model outputs.
"""
return Postprocessor(
targets=targets,
config=config or PostprocessConfig(),
min_freq=min_freq,
max_freq=max_freq,
)
class Postprocessor(PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline.
This class orchestrates the steps required to convert raw model outputs
into interpretable `soundevent` predictions. It uses configured parameters
and leverages functions from the `batdetect2.postprocess` submodules for
each stage (NMS, remapping, detection, extraction, decoding).
It requires a `TargetProtocol` object during initialization to access
necessary decoding information (class name to tag mapping,
ROI recovery logic) ensuring consistency with the target definitions used
during training or specified for inference.
Instances are typically created using the `build_postprocessor` factory
function.
Attributes
----------
targets : TargetProtocol
The configured target definition object providing decoding and ROI
recovery.
config : PostprocessConfig
Configuration object holding parameters for NMS, thresholds, etc.
min_freq : int
Minimum frequency (Hz) assumed for the model output's frequency axis.
max_freq : int
Maximum frequency (Hz) assumed for the model output's frequency axis.
"""
targets: TargetProtocol
def __init__(
self,
targets: TargetProtocol,
config: PostprocessConfig,
min_freq: int = MIN_FREQ,
max_freq: int = MAX_FREQ,
):
"""Initialize the Postprocessor.
Parameters
----------
targets : TargetProtocol
Initialized target definition object.
config : PostprocessConfig
Configuration for postprocessing parameters.
min_freq : int, default=MIN_FREQ
Minimum frequency (Hz) for coordinate remapping.
max_freq : int, default=MAX_FREQ
Maximum frequency (Hz) for coordinate remapping.
"""
self.targets = targets
self.config = config
self.min_freq = min_freq
self.max_freq = max_freq
def get_feature_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Extract and remap raw feature tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.features` tensor for the batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware feature DataArrays, one per clip.
Raises
------
ValueError
If batch sizes of `output.features` and `clips` do not match.
"""
if len(clips) != len(output.features):
raise ValueError(
"Number of clips and batch size of feature array"
"do not match. "
f"(clips: {len(clips)}, features: {len(output.features)})"
)
return [
features_to_xarray(
feats,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for feats, clip in zip(output.features, clips)
]
def get_detection_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Apply NMS and remap detection heatmaps for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.detection_probs` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of NMS-applied, coordinate-aware detection heatmaps, one per
clip.
Raises
------
ValueError
If batch sizes of `output.detection_probs` and `clips` do not match.
"""
detections = output.detection_probs
if len(clips) != len(output.detection_probs):
raise ValueError(
"Number of clips and batch size of detection array "
"do not match. "
f"(clips: {len(clips)}, detection: {len(detections)})"
)
detections = non_max_suppression(
detections,
kernel_size=self.config.nms_kernel_size,
)
return [
detection_to_xarray(
dets,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for dets, clip in zip(detections, clips)
]
def get_classification_arrays(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.DataArray]:
"""Extract and remap raw classification tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.class_probs` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware class probability maps, one per clip.
Raises
------
ValueError
If batch sizes of `output.class_probs` and `clips` do not match, or
if number of classes mismatches `self.targets.class_names`.
"""
classifications = output.class_probs
if len(clips) != len(classifications):
raise ValueError(
"Number of clips and batch size of classification array "
"do not match. "
f"(clips: {len(clips)}, classification: {len(classifications)})"
)
return [
classification_to_xarray(
class_probs,
start_time=clip.start_time,
end_time=clip.end_time,
class_names=self.targets.class_names,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for class_probs, clip in zip(classifications, clips)
]
def get_sizes_arrays(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.DataArray]:
"""Extract and remap raw size prediction tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.size_preds` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware size prediction maps, one per clip.
Raises
------
ValueError
If batch sizes of `output.size_preds` and `clips` do not match.
"""
sizes = output.size_preds
if len(clips) != len(sizes):
raise ValueError(
"Number of clips and batch size of sizes array do not match. "
f"(clips: {len(clips)}, sizes: {len(sizes)})"
)
return [
sizes_to_xarray(
size_preds,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for size_preds, clip in zip(sizes, clips)
]
def get_detection_datasets(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.Dataset]:
"""Perform NMS, remapping, detection, and data extraction for a batch.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[xr.Dataset]
List of xarray Datasets (one per clip). Each Dataset contains
aligned scores, dimensions, class probabilities, and features for
detections found in that clip.
"""
detection_arrays = self.get_detection_arrays(output, clips)
classification_arrays = self.get_classification_arrays(output, clips)
size_arrays = self.get_sizes_arrays(output, clips)
features_arrays = self.get_feature_arrays(output, clips)
datasets = []
for det_array, class_array, sizes_array, feats_array in zip(
detection_arrays,
classification_arrays,
size_arrays,
features_arrays,
):
max_detections = get_max_detections(
det_array,
top_k_per_sec=self.config.top_k_per_sec,
)
positions = extract_detections_from_array(
det_array,
max_detections=max_detections,
threshold=self.config.detection_threshold,
)
datasets.append(
extract_detection_xr_dataset(
positions,
sizes_array,
class_array,
feats_array,
)
)
return datasets
def get_raw_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes raw model output through remapping, NMS, detection, data
extraction, and geometry recovery via the configured
`targets.recover_roi`.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[List[RawPrediction]]
List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip.
"""
detection_datasets = self.get_detection_datasets(output, clips)
return [
convert_xr_dataset_to_raw_prediction(
dataset,
self.targets.recover_roi,
)
for dataset in detection_datasets
]
def get_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = self.get_raw_predictions(output, clips)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
sound_event_decoder=self.targets.decode,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -1,73 +0,0 @@
import numpy as np
import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.models import ModelOutput
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
def to_xarray(
output: ModelOutput,
start_time: float,
end_time: float,
class_names: list[str],
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
):
detection = output.detection_probs
size = output.size_preds
classes = output.class_probs
features = output.features
if len(detection.shape) == 4:
if detection.shape[0] != 1:
raise ValueError(
"Expected a non-batched output or a batch of size 1, instead "
f"got an input of shape {detection.shape}"
)
detection = detection.squeeze(dim=0)
size = size.squeeze(dim=0)
classes = classes.squeeze(dim=0)
features = features.squeeze(dim=0)
_, width, height = detection.shape
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
if classes.shape[0] != len(class_names):
raise ValueError(
f"The number of classes does not coincide with the number of class names provided: ({classes.shape[0] = }) != ({len(class_names) = })"
)
return xr.Dataset(
data_vars={
"detection": (
[Dimensions.time.value, Dimensions.frequency.value],
detection.squeeze(dim=0).detach().numpy(),
),
"size": (
[
"dimension",
Dimensions.time.value,
Dimensions.frequency.value,
],
detection.detach().numpy(),
),
"classes": (
[
"category",
Dimensions.time.value,
Dimensions.frequency.value,
],
classes.detach().numpy(),
),
},
coords={
Dimensions.time.value: times,
Dimensions.frequency.value: freqs,
"dimension": ["width", "height"],
"category": class_names,
},
)

View File

@ -1,32 +0,0 @@
from typing import Optional
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
__all__ = [
"PostprocessConfig",
"load_postprocess_config",
]
NMS_KERNEL_SIZE = 9
DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 200
class PostprocessConfig(BaseConfig):
"""Configuration for postprocessing model outputs."""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
min_freq: int = Field(default=10000, gt=0)
max_freq: int = Field(default=120000, gt=0)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
return load_config(path, schema=PostprocessConfig, field=field)

View File

@ -0,0 +1,297 @@
"""Decodes extracted detection data into standard soundevent predictions.
This module handles the final stages of the BatDetect2 postprocessing pipeline.
It takes the structured detection data extracted by the `extraction` module
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
class probabilities, and features for each detection point) and converts it
into meaningful, standardized prediction objects based on the `soundevent` data
model.
The process involves:
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
objects, using a configured geometry builder to recover bounding boxes from
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
2. Converting each `RawPrediction` into a
`soundevent.data.SoundEventPrediction`, which involves:
- Creating the `soundevent.data.SoundEvent` with geometry and features.
- Decoding the predicted class probabilities into representative tags using
a configured class decoder (`SoundEventDecoder`).
- Applying a classification threshold.
- Optionally selecting only the single highest-scoring class (top-1) or
including tags for all classes above the threshold (multi-label).
- Adding generic class tags as a baseline.
- Associating scores with the final prediction and tags.
(`convert_raw_prediction_to_sound_event_prediction`)
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
a `soundevent.data.ClipPrediction`
(`convert_raw_predictions_to_clip_prediction`).
"""
from typing import List, Optional
import xarray as xr
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder
__all__ = [
"convert_xr_dataset_to_raw_prediction",
"convert_raw_predictions_to_clip_prediction",
"convert_raw_prediction_to_sound_event_prediction",
"DEFAULT_CLASSIFICATION_THRESHOLD",
]
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
"""Default threshold applied to classification scores.
Class predictions with scores below this value are typically ignored during
decoding.
"""
def convert_xr_dataset_to_raw_prediction(
detection_dataset: xr.Dataset,
geometry_builder: GeometryBuilder,
) -> List[RawPrediction]:
"""Convert an xarray.Dataset of detections to RawPrediction objects.
Takes the output of the extraction step (`extract_detection_xr_dataset`)
and transforms each detection entry into an intermediate `RawPrediction`
object. This involves recovering the geometry (e.g., bounding box) from
the predicted position and scaled size dimensions using the provided
`geometry_builder` function.
Parameters
----------
detection_dataset : xr.Dataset
An xarray Dataset containing aligned detection information, typically
output by `extract_detection_xr_dataset`. Expected variables include
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
Must have a 'detection' dimension.
geometry_builder : GeometryBuilder
A function that takes a position tuple `(time, freq)` and a NumPy array
of dimensions, and returns the corresponding reconstructed
`soundevent.data.Geometry`.
Returns
-------
List[RawPrediction]
A list of `RawPrediction` objects, each containing the detection score,
recovered bounding box coordinates (start/end time, low/high freq),
the vector of class scores, and the feature vector for one detection.
Raises
------
AttributeError, KeyError, ValueError
If `detection_dataset` is missing expected variables ('scores',
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
associated with 'scores'), or if `geometry_builder` fails.
"""
detections = []
for det_num in range(detection_dataset.dims["detection"]):
det_info = detection_dataset.sel(detection=det_num)
geom = geometry_builder(
(det_info.time, det_info.freq),
det_info.dimensions,
)
start_time, low_freq, end_time, high_freq = compute_bounds(geom)
classes = det_info.classes
features = det_info.features
detections.append(
RawPrediction(
detection_score=det_info.score,
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
class_scores=classes,
features=features,
)
)
return detections
def convert_raw_predictions_to_clip_prediction(
raw_predictions: List[RawPrediction],
clip: data.Clip,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
) -> data.ClipPrediction:
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
Iterates through `raw_predictions` (assumed to belong to a single clip),
converts each one into a `soundevent.data.SoundEventPrediction` using
`convert_raw_prediction_to_sound_event_prediction`, and packages them
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
Parameters
----------
raw_predictions : List[RawPrediction]
List of raw prediction objects for a single clip.
clip : data.Clip
The original `soundevent.data.Clip` object these predictions belong to.
sound_event_decoder : SoundEventDecoder
Function to decode class names into representative tags.
generic_class_tags : List[data.Tag]
List of tags representing the generic class category.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Threshold applied to class scores during decoding.
top_class_only : bool, default=False
If True, only decode tags for the single highest-scoring class above
the threshold. If False, decode tags for all classes above threshold.
Returns
-------
data.ClipPrediction
A `ClipPrediction` object containing a list of `SoundEventPrediction`
objects corresponding to the input `raw_predictions`.
"""
return data.ClipPrediction(
clip=clip,
sound_events=[
convert_raw_prediction_to_sound_event_prediction(
prediction,
recording=clip.recording,
sound_event_decoder=sound_event_decoder,
generic_class_tags=generic_class_tags,
classification_threshold=classification_threshold,
top_class_only=top_class_only,
)
for prediction in raw_predictions
],
)
def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction,
recording: data.Recording,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
classification_threshold: Optional[
float
] = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
):
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
This function performs the core decoding steps for a single detected event:
1. Creates a `soundevent.data.SoundEvent` containing the geometry
(BoundingBox derived from `raw_prediction` bounds) and any associated
feature vectors.
2. Initializes a list of predicted tags using the provided
`generic_class_tags`, assigning the overall `detection_score` from the
`raw_prediction` to these generic tags.
3. Processes the `class_scores` from the `raw_prediction`:
a. Optionally filters out scores below `classification_threshold`
(if it's not None).
b. Sorts the remaining scores in descending order.
c. Iterates through the sorted, thresholded class scores.
d. For each class, uses the `sound_event_decoder` to get the
representative base tags for that class name.
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
the specific `score` of that class prediction.
f. Appends these specific predicted tags to the list.
g. If `top_class_only` is True, stops after processing the first
(highest-scoring) class that passed the threshold.
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
associating the `SoundEvent`, the overall `detection_score`, and the
compiled list of `PredictedTag` objects.
Parameters
----------
raw_prediction : RawPrediction
The raw prediction object containing score, bounds, class scores,
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
coordinate.
recording : data.Recording
The recording the sound event belongs to.
sound_event_decoder : SoundEventDecoder
Configured function mapping class names (str) to lists of base
`data.Tag` objects.
generic_class_tags : List[data.Tag]
List of base tags representing the generic category.
classification_threshold : float, optional
The minimum score a class prediction must have to be considered
significant enough to have its tags decoded and added. If None, no
thresholding is applied based on class score (all predicted classes,
or the top one if `top_class_only` is True, will be processed).
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
top_class_only : bool, default=False
If True, only includes tags for the single highest-scoring class that
exceeds the threshold. If False (default), includes tags for all classes
exceeding the threshold.
Returns
-------
data.SoundEventPrediction
The fully formed sound event prediction object.
Raises
------
ValueError
If `raw_prediction.features` has unexpected structure or if
`data.term_from_key` (if used internally) fails.
If `sound_event_decoder` fails for a class name and errors are raised.
"""
sound_event = data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(
coordinates=[
raw_prediction.start_time,
raw_prediction.low_freq,
raw_prediction.end_time,
raw_prediction.high_freq,
]
),
features=[
data.Feature(term=data.term_from_key(feat_name), value=value)
for feat_name, value in raw_prediction.features
],
)
tags = [
data.PredictedTag(tag=tag, score=raw_prediction.detection_score)
for tag in generic_class_tags
]
class_scores = raw_prediction.class_scores
if classification_threshold is not None:
class_scores = class_scores.where(
class_scores > classification_threshold,
drop=True,
)
for class_name, score in class_scores.sortby(
class_scores, ascending=False
):
class_tags = sound_event_decoder(class_name)
for tag in class_tags:
tags.append(
data.PredictedTag(
tag=tag,
score=score,
)
)
if top_class_only:
break
return data.SoundEventPrediction(
sound_event=sound_event,
score=raw_prediction.detection_score,
tags=tags,
)

View File

@ -0,0 +1,162 @@
"""Extracts candidate detection points from a model output heatmap.
This module implements a specific step within the BatDetect2 postprocessing
pipeline. Its primary function is to identify potential sound event locations
by finding peaks (local maxima or high-scoring points) in the detection heatmap
produced by the neural network (usually after Non-Maximum Suppression and
coordinate remapping have been applied).
It provides functionality to:
- Identify the locations (time, frequency) of the highest-scoring points.
- Filter these points based on a minimum confidence score threshold.
- Limit the maximum number of detection points returned (top-k).
The main output is an `xarray.DataArray` containing the scores and
corresponding time/frequency coordinates for the extracted detection points.
This output serves as the input for subsequent postprocessing steps, such as
extracting predicted class probabilities and bounding box sizes at these
specific locations.
"""
from typing import Optional
import numpy as np
import xarray as xr
from soundevent.arrays import Dimensions, get_dim_width
__all__ = [
"extract_detections_from_array",
"get_max_detections",
"DEFAULT_DETECTION_THRESHOLD",
"TOP_K_PER_SEC",
]
DEFAULT_DETECTION_THRESHOLD = 0.01
"""Default confidence score threshold used for filtering detections."""
TOP_K_PER_SEC = 200
"""Default desired maximum number of detections per second of audio."""
def extract_detections_from_array(
detection_array: xr.DataArray,
max_detections: Optional[int] = None,
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
) -> xr.DataArray:
"""Extract detection locations (time, freq) and scores from a heatmap.
Identifies the pixels with the highest scores in the input detection
heatmap, filters them based on an optional score `threshold`, limits the
number to an optional `max_detections`, and returns their scores along with
their corresponding time and frequency coordinates.
Parameters
----------
detection_array : xr.DataArray
A 2D xarray DataArray representing the detection heatmap. Must have
dimensions and coordinates named 'time' and 'frequency'. Higher values
are assumed to indicate higher detection confidence.
max_detections : int, optional
The absolute maximum number of detections to return. If specified, only
the top `max_detections` highest-scoring detections (passing the
threshold) are returned. If None (default), all detections passing
the threshold are returned, sorted by score.
threshold : float, optional
The minimum confidence score required for a detection peak to be
kept. Detections with scores below this value are discarded.
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
thresholding is applied.
Returns
-------
xr.DataArray
A 1D xarray DataArray named 'score' with a 'detection' dimension.
- The data values are the scores of the extracted detections, sorted
in descending order.
- It has coordinates 'time' and 'frequency' (also indexed by the
'detection' dimension) indicating the location of each detection
peak in the original coordinate system.
- Returns an empty DataArray if no detections pass the criteria.
Raises
------
ValueError
If `max_detections` is not None and not a positive integer, or if
`detection_array` lacks required dimensions/coordinates.
"""
if max_detections is not None:
if max_detections <= 0:
raise ValueError("Max detections must be positive")
values = detection_array.values.flatten()
if max_detections is not None:
top_indices = np.argpartition(-values, max_detections)[:max_detections]
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
else:
top_sorted_indices = np.argsort(-values)
top_values = values[top_sorted_indices]
if threshold is not None:
mask = top_values > threshold
top_values = top_values[mask]
top_sorted_indices = top_sorted_indices[mask]
time_indices, freq_indices = np.unravel_index(
top_sorted_indices,
detection_array.shape,
)
times = detection_array.coords[Dimensions.time.value].values[time_indices]
freqs = detection_array.coords[Dimensions.frequency.value].values[
freq_indices
]
return xr.DataArray(
data=top_values,
coords={
Dimensions.frequency.value: ("detection", freqs),
Dimensions.time.value: ("detection", times),
},
dims="detection",
name="score",
)
def get_max_detections(
detection_array: xr.DataArray,
top_k_per_sec: int = TOP_K_PER_SEC,
) -> int:
"""Calculate max detections allowed based on duration and rate.
Determines the total maximum number of detections to extract from a
heatmap based on its time duration and a desired rate of detections
per second.
Parameters
----------
detection_array : xr.DataArray
The detection heatmap, requiring 'time' coordinates from which the
total duration can be calculated using
`soundevent.arrays.get_dim_width`.
top_k_per_sec : int, default=TOP_K_PER_SEC
The desired maximum number of detections to allow per second of audio.
Returns
-------
int
The calculated total maximum number of detections allowed for the
entire duration of the `detection_array`.
Raises
------
ValueError
If the duration cannot be calculated from the `detection_array` (e.g.,
missing or invalid 'time' coordinates/dimension).
"""
if top_k_per_sec < 0:
raise ValueError("top_k_per_sec cannot be negative.")
duration = get_dim_width(detection_array, Dimensions.time.value)
return int(duration * top_k_per_sec)

View File

@ -0,0 +1,122 @@
"""Extracts associated data for detected points from model output arrays.
This module implements a key step (Step 4) in the BatDetect2 postprocessing
pipeline. After candidate detection points (time, frequency, score) have been
identified, this module extracts the corresponding values from other raw model
output arrays, such as:
- Predicted bounding box sizes (width, height).
- Class probability scores for each defined target class.
- Intermediate feature vectors.
It uses coordinate-based indexing provided by `xarray` to ensure that the
correct values are retrieved from the original heatmaps/feature maps at the
precise time-frequency location of each detection. The final output aggregates
all extracted information into a structured `xarray.Dataset`.
"""
import xarray as xr
from soundevent.arrays import Dimensions
__all__ = [
"extract_values_at_positions",
"extract_detection_xr_dataset",
]
def extract_values_at_positions(
array: xr.DataArray,
positions: xr.DataArray,
) -> xr.DataArray:
"""Extract values from an array at specified time-frequency positions.
Uses coordinate-based indexing to retrieve values from a source `array`
(e.g., class probabilities, size predictions, features) at the time and
frequency coordinates defined in the `positions` array.
Parameters
----------
array : xr.DataArray
The source DataArray from which to extract values. Must have 'time'
and 'frequency' dimensions and coordinates matching the space of
`positions`.
positions : xr.DataArray
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
locations from which to extract values.
Returns
-------
xr.DataArray
A DataArray containing the values extracted from `array` at the given
positions.
Raises
------
ValueError, IndexError, KeyError
If dimensions or coordinates are missing or incompatible between
`array` and `positions`, or if selection fails.
"""
return array.sel(
**{
Dimensions.frequency.value: positions.coords[
Dimensions.frequency.value
],
Dimensions.time.value: positions.coords[Dimensions.time.value],
}
)
def extract_detection_xr_dataset(
positions: xr.DataArray,
sizes: xr.DataArray,
classes: xr.DataArray,
features: xr.DataArray,
) -> xr.Dataset:
"""Combine extracted detection information into a structured xr.Dataset.
Takes the detection positions/scores and the full model output heatmaps
(sizes, classes, optional features), extracts the relevant data at the
detection positions, and packages everything into a single `xarray.Dataset`
where all variables are indexed by a common 'detection' dimension.
Parameters
----------
positions : xr.DataArray
Output from `extract_detections_from_array`, containing detection
scores as data and 'time', 'frequency' coordinates along the
'detection' dimension.
sizes : xr.DataArray
The full size prediction heatmap from the model, with dimensions like
('dimension', 'time', 'frequency').
classes : xr.DataArray
The full class probability heatmap from the model, with dimensions like
('category', 'time', 'frequency').
features : xr.DataArray
The full feature map from the model, with
dimensions like ('feature', 'time', 'frequency').
Returns
-------
xr.Dataset
An xarray Dataset containing aligned information for each detection:
- 'scores': DataArray from `positions` (score data, time/freq coords).
- 'dimensions': DataArray with extracted size values
(dims: 'detection', 'dimension').
- 'classes': DataArray with extracted class probabilities
(dims: 'detection', 'category').
- 'features': DataArray with extracted feature vectors
(dims: 'detection', 'feature'), if `features` was provided. All
DataArrays share the 'detection' dimension and associated
time/frequency coordinates.
"""
sizes = extract_values_at_positions(sizes, positions).T
classes = extract_values_at_positions(classes, positions).T
features = extract_values_at_positions(features, positions).T
return xr.Dataset(
{
"scores": positions,
"dimensions": sizes,
"classes": classes,
"features": features,
}
)

View File

@ -0,0 +1,96 @@
"""Performs Non-Maximum Suppression (NMS) on detection heatmaps.
This module provides functionality to apply Non-Maximum Suppression, a common
technique used after model inference, particularly in object detection and peak
detection tasks.
In the context of BatDetect2 postprocessing, NMS is applied
to the raw detection heatmap output by the neural network. Its purpose is to
isolate distinct detection peaks by suppressing (setting to zero) nearby heatmap
activations that have lower scores than a local maximum. This helps prevent
multiple, overlapping detections originating from the same sound event.
"""
from typing import Tuple, Union
import torch
NMS_KERNEL_SIZE = 9
"""Default kernel size (pixels) for Non-Maximum Suppression.
Specifies the side length of the square neighborhood used by default in
`non_max_suppression` to find local maxima. A 9x9 neighborhood is often
a reasonable starting point for typical spectrogram resolutions used in
BatDetect2.
"""
def non_max_suppression(
tensor: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> torch.Tensor:
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
This function identifies local maxima within a defined neighborhood for
each point in the input tensor. Values that are *not* the maximum within
their neighborhood are suppressed (set to zero). This is commonly used on
detection probability heatmaps to isolate distinct peaks corresponding to
individual detections and remove redundant lower scores nearby.
The implementation uses efficient 2D max pooling to find the maximum value
in the neighborhood of each point.
Parameters
----------
tensor : torch.Tensor
Input tensor, typically representing a detection heatmap. Must be a
3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying
`torch.nn.functional.max_pool2d` operation.
kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE
Size of the sliding window neighborhood used to find local maxima.
If an integer `k` is provided, a square kernel of size `(k, k)` is used.
If a tuple `(h, w)` is provided, a rectangular kernel of height `h`
and width `w` is used. The kernel size should typically be odd to
have a well-defined center.
Returns
-------
torch.Tensor
A tensor of the same shape as the input, where only local maxima within
their respective neighborhoods (defined by `kernel_size`) retain their
original values. All other values are set to zero.
Raises
------
TypeError
If `kernel_size` is not an int or a tuple of two ints.
RuntimeError
If the input `tensor` does not have 3 or 4 dimensions (as required
by `max_pool2d`).
Notes
-----
- The function assumes higher values in the tensor indicate stronger peaks.
- Choosing an appropriate `kernel_size` is important. It should be large
enough to cover the typical "footprint" of a single detection peak plus
some surrounding context, effectively preventing multiple detections for
the same event. A size that is too large might suppress nearby distinct
events.
"""
if isinstance(kernel_size, int):
kernel_size_h = kernel_size
kernel_size_w = kernel_size
else:
kernel_size_h, kernel_size_w = kernel_size
pad_h = (kernel_size_h - 1) // 2
pad_w = (kernel_size_w - 1) // 2
hmax = torch.nn.functional.max_pool2d(
tensor,
(kernel_size_h, kernel_size_w),
stride=1,
padding=(pad_h, pad_w),
)
keep = (hmax == tensor).float()
return tensor * keep

View File

@ -1,50 +0,0 @@
from typing import Tuple, Union
import torch
NMS_KERNEL_SIZE = 9
def non_max_suppression(
tensor: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> torch.Tensor:
"""Run non-maximum suppression on a tensor.
This function removes values from the input tensor that are not local
maxima in the neighborhood of the given kernel size.
All non-maximum values are set to zero.
Parameters
----------
tensor : torch.Tensor
Input tensor.
kernel_size : Union[int, Tuple[int, int]], optional
Size of the neighborhood to consider for non-maximum suppression.
If an integer is given, the neighborhood will be a square of the
given size. If a tuple is given, the neighborhood will be a
rectangle with the given height and width.
Returns
-------
torch.Tensor
Tensor with non-maximum suppressed values.
"""
if isinstance(kernel_size, int):
kernel_size_h = kernel_size
kernel_size_w = kernel_size
else:
kernel_size_h, kernel_size_w = kernel_size
pad_h = (kernel_size_h - 1) // 2
pad_w = (kernel_size_w - 1) // 2
hmax = torch.nn.functional.max_pool2d(
tensor,
(kernel_size_h, kernel_size_w),
stride=1,
padding=(pad_h, pad_w),
)
keep = (hmax == tensor).float()
return tensor * keep

View File

@ -0,0 +1,316 @@
"""Remaps raw model output tensors to coordinate-aware xarray DataArrays.
This module provides utility functions to convert the raw numerical outputs
(typically PyTorch tensors) from the BatDetect2 DNN model into
`xarray.DataArray` objects. This step adds coordinate information
(time in seconds, frequency in Hz) back to the model's predictions, making them
interpretable in the context of the original audio signal and facilitating
subsequent processing steps.
Functions are provided for common BatDetect2 output types: detection heatmaps,
classification probability maps, size prediction maps, and potentially
intermediate features.
"""
from typing import List
import numpy as np
import torch
import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
__all__ = [
"features_to_xarray",
"detection_to_xarray",
"classification_to_xarray",
"sizes_to_xarray",
]
def features_to_xarray(
features: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
features_prefix: str = "batdetect2_feature_",
):
"""Convert a multi-channel feature tensor to a coordinate-aware DataArray.
Assigns time, frequency, and feature coordinates to a raw feature tensor
output by the model.
Parameters
----------
features : torch.Tensor
The raw feature tensor from the model. Expected shape is
(num_features, num_freq_bins, num_time_bins).
start_time : float
The start time (in seconds) corresponding to the first time bin of
the tensor.
end_time : float
The end time (in seconds) corresponding to the *end* of the last time
bin.
min_freq : float, default=MIN_FREQ
The minimum frequency (in Hz) corresponding to the first frequency bin.
max_freq : float, default=MAX_FREQ
The maximum frequency (in Hz) corresponding to the *end* of the last
frequency bin.
features_prefix : str, default="batdetect2_feature_"
Prefix used to generate names for the feature coordinate dimension
(e.g., "batdetect2_feature_0", "batdetect2_feature_1", ...).
Returns
-------
xr.DataArray
An xarray DataArray containing the feature data with named dimensions
('feature', 'frequency', 'time') and calculated coordinates.
Raises
------
ValueError
If the input tensor does not have 3 dimensions.
"""
if features.ndim != 3:
raise ValueError(
"Input features tensor must have 3 dimensions (C, T, F), "
f"got shape {features.shape}"
)
num_features, height, width = features.shape
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=features.detach().numpy(),
dims=[
Dimensions.feature.value,
Dimensions.frequency.value,
Dimensions.time.value,
],
coords={
Dimensions.feature.value: [
f"{features_prefix}{i}" for i in range(num_features)
],
Dimensions.frequency.value: freqs,
Dimensions.time.value: times,
},
name="features",
)
def detection_to_xarray(
detection: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
) -> xr.DataArray:
"""Convert a single-channel detection heatmap tensor to a DataArray.
Assigns time and frequency coordinates to a raw detection heatmap tensor.
Parameters
----------
detection : torch.Tensor
Raw detection heatmap tensor from the model. Expected shape is
(1, num_freq_bins, num_time_bins).
start_time : float
Start time (seconds) corresponding to the first time bin.
end_time : float
End time (seconds) corresponding to the end of the last time bin.
min_freq : float, default=MIN_FREQ
Minimum frequency (Hz) corresponding to the first frequency bin.
max_freq : float, default=MAX_FREQ
Maximum frequency (Hz) corresponding to the end of the last frequency
bin.
Returns
-------
xr.DataArray
An xarray DataArray containing the detection scores with named
dimensions ('frequency', 'time') and calculated coordinates.
Raises
------
ValueError
If the input tensor does not have 3 dimensions or if the first
dimension size is not 1.
"""
if detection.ndim != 3:
raise ValueError(
"Input detection tensor must have 3 dimensions (1, T, F), "
f"got shape {detection.shape}"
)
num_channels, height, width = detection.shape
if num_channels != 1:
raise ValueError(
"Expected a single channel output, instead got "
f"{num_channels} channels"
)
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=detection.squeeze(dim=0).detach().numpy(),
dims=[
Dimensions.frequency.value,
Dimensions.time.value,
],
coords={
Dimensions.frequency.value: freqs,
Dimensions.time.value: times,
},
name="detection_score",
)
def classification_to_xarray(
classes: torch.Tensor,
start_time: float,
end_time: float,
class_names: List[str],
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
) -> xr.DataArray:
"""Convert multi-channel class probability tensor to a DataArray.
Assigns category (class name), frequency, and time coordinates to a raw
class probability tensor output by the model.
Parameters
----------
classes : torch.Tensor
Raw class probability tensor. Expected shape is
(num_classes, num_freq_bins, num_time_bins).
start_time : float
Start time (seconds) corresponding to the first time bin.
end_time : float
End time (seconds) corresponding to the end of the last time bin.
class_names : List[str]
Ordered list of class names corresponding to the first dimension
of the `classes` tensor. The length must match `classes.shape[0]`.
min_freq : float, default=MIN_FREQ
Minimum frequency (Hz) corresponding to the first frequency bin.
max_freq : float, default=MAX_FREQ
Maximum frequency (Hz) corresponding to the end of the last frequency
bin.
Returns
-------
xr.DataArray
An xarray DataArray containing class probabilities with named
dimensions ('category', 'frequency', 'time') and calculated
coordinates.
Raises
------
ValueError
If the input tensor does not have 3 dimensions, or if the size of the
first dimension does not match the length of `class_names`.
"""
if classes.ndim != 3:
raise ValueError(
"Input classes tensor must have 3 dimensions (C, F, T), "
f"got shape {classes.shape}"
)
num_classes, height, width = classes.shape
if num_classes != len(class_names):
raise ValueError(
"The number of classes does not coincide with the number of "
"class names provided: "
f"({num_classes = }) != ({len(class_names) = })"
)
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=classes.detach().numpy(),
dims=[
"category",
Dimensions.frequency.value,
Dimensions.time.value,
],
coords={
"category": class_names,
Dimensions.frequency.value: freqs,
Dimensions.time.value: times,
},
name="class_scores",
)
def sizes_to_xarray(
sizes: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
) -> xr.DataArray:
"""Convert the 2-channel size prediction tensor to a DataArray.
Assigns dimension ('width', 'height'), frequency, and time coordinates
to the raw size prediction tensor output by the model.
Parameters
----------
sizes : torch.Tensor
Raw size prediction tensor. Expected shape is
(2, num_freq_bins, num_time_bins), where the first dimension
corresponds to predicted width and height respectively.
start_time : float
Start time (seconds) corresponding to the first time bin.
end_time : float
End time (seconds) corresponding to the end of the last time bin.
min_freq : float, default=MIN_FREQ
Minimum frequency (Hz) corresponding to the first frequency bin.
max_freq : float, default=MAX_FREQ
Maximum frequency (Hz) corresponding to the end of the last frequency
bin.
Returns
-------
xr.DataArray
An xarray DataArray containing predicted sizes with named dimensions
('dimension', 'frequency', 'time') and calculated time/frequency
coordinates. The 'dimension' coordinate will have values
['width', 'height'].
Raises
------
ValueError
If the input tensor does not have 3 dimensions or if the first
dimension size is not exactly 2.
"""
num_channels, height, width = sizes.shape
if num_channels != 2:
raise ValueError(
"Expected a two-channel output, instead got "
f"{num_channels} channels"
)
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=sizes.detach().numpy(),
dims=[
"dimension",
Dimensions.frequency.value,
Dimensions.time.value,
],
coords={
"dimension": ["width", "height"],
Dimensions.frequency.value: freqs,
Dimensions.time.value: times,
},
)

View File

@ -1,21 +1,284 @@
from typing import Dict, NamedTuple, Protocol """Defines shared interfaces and data structures for postprocessing.
This module centralizes the Protocol definitions and common data structures
used throughout the `batdetect2.postprocess` module.
The main component is the `PostprocessorProtocol`, which outlines the standard
interface for an object responsible for executing the entire postprocessing
pipeline. This pipeline transforms raw neural network outputs into interpretable
detections represented as `soundevent` objects. Using protocols ensures
modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions.
"""
from typing import Callable, List, NamedTuple, Protocol
import numpy as np import numpy as np
import xarray as xr
from soundevent import data
from batdetect2.models.types import ModelOutput
__all__ = [ __all__ = [
"BatDetect2Prediction", "RawPrediction",
"PostprocessorProtocol",
"GeometryBuilder",
] ]
class BatDetect2Prediction(NamedTuple): GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
"""Type alias for a function that recovers geometry from position and size.
This callable takes:
1. A position tuple `(time, frequency)`.
2. A NumPy array of size dimensions (e.g., `[width, height]`).
It should return the reconstructed `soundevent.data.Geometry` (typically a
`BoundingBox`).
"""
class RawPrediction(NamedTuple):
"""Intermediate representation of a single detected sound event.
Holds extracted information about a detection after initial processing
(like peak finding, coordinate remapping, geometry recovery) but before
final class decoding and conversion into a `SoundEventPrediction`. This
can be useful for evaluation or simpler data handling formats.
Attributes
----------
start_time : float
Start time of the recovered bounding box in seconds.
end_time : float
End time of the recovered bounding box in seconds.
low_freq : float
Lowest frequency of the recovered bounding box in Hz.
high_freq : float
Highest frequency of the recovered bounding box in Hz.
detection_score : float
The confidence score associated with this detection, typically from
the detection heatmap peak.
class_scores : xr.DataArray
An xarray DataArray containing the predicted probabilities or scores
for each target class at the detection location. Indexed by a
'category' coordinate containing class names.
features : xr.DataArray
An xarray DataArray containing extracted feature vectors at the
detection location. Indexed by a 'feature' coordinate.
"""
start_time: float start_time: float
end_time: float end_time: float
low_freq: float low_freq: float
high_freq: float high_freq: float
detection_score: float detection_score: float
class_scores: Dict[str, float] class_scores: xr.DataArray
features: np.ndarray features: xr.DataArray
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):
pass """Protocol defining the interface for the full postprocessing pipeline.
This protocol outlines the standard methods for an object that takes raw
output from a BatDetect2 model and the corresponding input clip metadata,
and processes it through various stages (e.g., coordinate remapping, NMS,
detection extraction, data extraction, decoding) to produce interpretable
results at different levels of completion.
Implementations manage the configured logic for all postprocessing steps.
"""
def get_feature_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap feature tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch, expected
to contain the necessary feature tensors.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects, one for each item in the
processed batch. This list provides the timing, recording, and
other metadata context needed to calculate real-world coordinates
(seconds, Hz) for the output arrays. The length of this list must
correspond to the batch size of the `output`.
Returns
-------
List[xr.DataArray]
A list of xarray DataArrays, one for each input clip in the batch,
in the same order. Each DataArray contains the feature vectors
with dimensions like ('feature', 'time', 'frequency') and
corresponding real-world coordinates.
"""
...
def get_detection_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap detection tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing detection heatmaps.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 2D xarray DataArrays (one per input clip, in order),
representing the detection heatmap with 'time' and 'frequency'
coordinates. Values typically indicate detection confidence.
"""
...
def get_classification_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap classification tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing class probability tensors.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 3D xarray DataArrays (one per input clip, in order),
representing class probabilities with 'category', 'time', and
'frequency' dimensions and coordinates.
"""
...
def get_sizes_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap size prediction tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing predicted size tensors (e.g., width and height).
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 3D xarray DataArrays (one per input clip, in order),
representing predicted sizes with 'dimension'
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
coordinates. Values represent estimated detection sizes.
"""
...
def get_detection_datasets(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.Dataset]:
"""Perform remapping, NMS, detection, and data extraction for a batch.
Processes the raw model output for a batch to identify detection peaks
and extract all associated information (score, position, size, class
probs, features) at those peak locations, returning a structured
dataset for each input clip in the batch.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[xr.Dataset]
A list of xarray Datasets (one per input clip, in order). Each
Dataset contains multiple DataArrays ('scores', 'dimensions',
'classes', 'features') sharing a common 'detection' dimension,
providing aligned data for each detected event in that clip.
"""
...
def get_raw_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes the raw model output for a batch through remapping, NMS,
detection, data extraction, and geometry recovery to produce a list of
`RawPrediction` objects for each corresponding input clip. This provides
a simplified, intermediate representation before final tag decoding.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[List[RawPrediction]]
A list of lists (one inner list per input clip, in order). Each
inner list contains the `RawPrediction` objects extracted for the
corresponding input clip.
"""
...
def get_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output for a batch and corresponding clips, applies the
entire postprocessing chain, and returns the final, interpretable
predictions as a list of `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[data.ClipPrediction]
A list containing one `ClipPrediction` object for each input clip
(in the same order), populated with `SoundEventPrediction` objects
representing the final detections with decoded tags and geometry.
"""
...