mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Updated postprocess module with docstrings
This commit is contained in:
parent
089328a4f0
commit
bcf339c40d
@ -0,0 +1,566 @@
|
||||
"""Main entry point for the BatDetect2 Postprocessing pipeline.
|
||||
|
||||
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
|
||||
BatDetect2 neural network model and transforms them into meaningful, structured
|
||||
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
|
||||
containing detected sound events with associated class tags and geometry.
|
||||
|
||||
The pipeline involves several configurable steps, implemented in submodules:
|
||||
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
|
||||
coordinates to raw model output arrays.
|
||||
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||
(location and score) based on thresholds and score ranking (top-k).
|
||||
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||
class probabilities, features) at the detected locations.
|
||||
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
|
||||
class predictions into interpretable `soundevent` objects, including
|
||||
recovering geometry (ROIs) and decoding class names back to standard tags.
|
||||
|
||||
This module provides the primary interface:
|
||||
- `PostprocessConfig`: A configuration object for postprocessing parameters
|
||||
(thresholds, NMS kernel size, etc.).
|
||||
- `load_postprocess_config`: Function to load the configuration from a file.
|
||||
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
|
||||
holds the configured pipeline logic.
|
||||
- `build_postprocessor`: A factory function to create a `Postprocessor`
|
||||
instance, linking it to the necessary target definitions (`TargetProtocol`).
|
||||
It also re-exports key components from submodules for convenience.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
convert_xr_dataset_to_raw_prediction,
|
||||
)
|
||||
from batdetect2.postprocess.detection import (
|
||||
DEFAULT_DETECTION_THRESHOLD,
|
||||
TOP_K_PER_SEC,
|
||||
extract_detections_from_array,
|
||||
get_max_detections,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import (
|
||||
extract_detection_xr_dataset,
|
||||
)
|
||||
from batdetect2.postprocess.nms import (
|
||||
NMS_KERNEL_SIZE,
|
||||
non_max_suppression,
|
||||
)
|
||||
from batdetect2.postprocess.remapping import (
|
||||
classification_to_xarray,
|
||||
detection_to_xarray,
|
||||
features_to_xarray,
|
||||
sizes_to_xarray,
|
||||
)
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
"DEFAULT_DETECTION_THRESHOLD",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"ModelOutput",
|
||||
"NMS_KERNEL_SIZE",
|
||||
"PostprocessConfig",
|
||||
"Postprocessor",
|
||||
"PostprocessorProtocol",
|
||||
"RawPrediction",
|
||||
"TOP_K_PER_SEC",
|
||||
"build_postprocessor",
|
||||
"classification_to_xarray",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
"detection_to_xarray",
|
||||
"extract_detection_xr_dataset",
|
||||
"extract_detections_from_array",
|
||||
"features_to_xarray",
|
||||
"get_max_detections",
|
||||
"load_postprocess_config",
|
||||
"non_max_suppression",
|
||||
"sizes_to_xarray",
|
||||
]
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration settings for the postprocessing pipeline.
|
||||
|
||||
Defines tunable parameters that control how raw model outputs are
|
||||
converted into final detections.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||
Used to suppress weaker detections near stronger peaks. Must be
|
||||
positive.
|
||||
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||
Minimum confidence score from the detection heatmap required to
|
||||
consider a point as a potential detection. Must be >= 0.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Minimum confidence score for a specific class prediction to be included
|
||||
in the decoded tags for a detection. Must be >= 0.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
Desired maximum number of detections per second of audio. Used by
|
||||
`get_max_detections` to calculate an absolute limit based on clip
|
||||
duration before applying `extract_detections_from_array`. Must be
|
||||
positive.
|
||||
"""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(
|
||||
default=DEFAULT_DETECTION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
classification_threshold: float = Field(
|
||||
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
"""Load the postprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
postprocessing configuration (e.g., "inference.postprocessing").
|
||||
If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessConfig
|
||||
The loaded and validated postprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`PostprocessConfig` schema.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
|
||||
|
||||
def build_postprocessor(
|
||||
targets: TargetProtocol,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
max_freq: int = MAX_FREQ,
|
||||
min_freq: int = MIN_FREQ,
|
||||
) -> PostprocessorProtocol:
|
||||
"""Factory function to build the standard postprocessor.
|
||||
|
||||
Creates and initializes the `Postprocessor` instance, providing it with the
|
||||
necessary `targets` object and the `PostprocessConfig`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
An initialized object conforming to the `TargetProtocol`, providing
|
||||
methods like `.decode()` and `.recover_roi()`, and attributes like
|
||||
`.class_names` and `.generic_class_tags`. This links postprocessing
|
||||
to the defined target semantics and geometry mappings.
|
||||
config : PostprocessConfig, optional
|
||||
Configuration object specifying postprocessing parameters (thresholds,
|
||||
NMS kernel size, etc.). If None, default settings defined in
|
||||
`PostprocessConfig` will be used.
|
||||
min_freq : int, default=MIN_FREQ
|
||||
The minimum frequency (Hz) corresponding to the frequency axis of the
|
||||
model outputs. Required for coordinate remapping. Consider setting via
|
||||
`PostprocessConfig` instead for better encapsulation.
|
||||
max_freq : int, default=MAX_FREQ
|
||||
The maximum frequency (Hz) corresponding to the frequency axis of the
|
||||
model outputs. Required for coordinate remapping. Consider setting via
|
||||
`PostprocessConfig`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessorProtocol
|
||||
An initialized `Postprocessor` instance ready to process model outputs.
|
||||
"""
|
||||
return Postprocessor(
|
||||
targets=targets,
|
||||
config=config or PostprocessConfig(),
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
|
||||
class Postprocessor(PostprocessorProtocol):
|
||||
"""Standard implementation of the postprocessing pipeline.
|
||||
|
||||
This class orchestrates the steps required to convert raw model outputs
|
||||
into interpretable `soundevent` predictions. It uses configured parameters
|
||||
and leverages functions from the `batdetect2.postprocess` submodules for
|
||||
each stage (NMS, remapping, detection, extraction, decoding).
|
||||
|
||||
It requires a `TargetProtocol` object during initialization to access
|
||||
necessary decoding information (class name to tag mapping,
|
||||
ROI recovery logic) ensuring consistency with the target definitions used
|
||||
during training or specified for inference.
|
||||
|
||||
Instances are typically created using the `build_postprocessor` factory
|
||||
function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
The configured target definition object providing decoding and ROI
|
||||
recovery.
|
||||
config : PostprocessConfig
|
||||
Configuration object holding parameters for NMS, thresholds, etc.
|
||||
min_freq : int
|
||||
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
||||
max_freq : int
|
||||
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
||||
"""
|
||||
|
||||
targets: TargetProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
config: PostprocessConfig,
|
||||
min_freq: int = MIN_FREQ,
|
||||
max_freq: int = MAX_FREQ,
|
||||
):
|
||||
"""Initialize the Postprocessor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
Initialized target definition object.
|
||||
config : PostprocessConfig
|
||||
Configuration for postprocessing parameters.
|
||||
min_freq : int, default=MIN_FREQ
|
||||
Minimum frequency (Hz) for coordinate remapping.
|
||||
max_freq : int, default=MAX_FREQ
|
||||
Maximum frequency (Hz) for coordinate remapping.
|
||||
"""
|
||||
self.targets = targets
|
||||
self.config = config
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
|
||||
def get_feature_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw feature tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.features` tensor for the batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware feature DataArrays, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.features` and `clips` do not match.
|
||||
"""
|
||||
if len(clips) != len(output.features):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of feature array"
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, features: {len(output.features)})"
|
||||
)
|
||||
|
||||
return [
|
||||
features_to_xarray(
|
||||
feats,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for feats, clip in zip(output.features, clips)
|
||||
]
|
||||
|
||||
def get_detection_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Apply NMS and remap detection heatmaps for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.detection_probs` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of NMS-applied, coordinate-aware detection heatmaps, one per
|
||||
clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.detection_probs` and `clips` do not match.
|
||||
"""
|
||||
detections = output.detection_probs
|
||||
|
||||
if len(clips) != len(output.detection_probs):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of detection array "
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, detection: {len(detections)})"
|
||||
)
|
||||
|
||||
detections = non_max_suppression(
|
||||
detections,
|
||||
kernel_size=self.config.nms_kernel_size,
|
||||
)
|
||||
|
||||
return [
|
||||
detection_to_xarray(
|
||||
dets,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for dets, clip in zip(detections, clips)
|
||||
]
|
||||
|
||||
def get_classification_arrays(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw classification tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.class_probs` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware class probability maps, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.class_probs` and `clips` do not match, or
|
||||
if number of classes mismatches `self.targets.class_names`.
|
||||
"""
|
||||
classifications = output.class_probs
|
||||
|
||||
if len(clips) != len(classifications):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of classification array "
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, classification: {len(classifications)})"
|
||||
)
|
||||
|
||||
return [
|
||||
classification_to_xarray(
|
||||
class_probs,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
class_names=self.targets.class_names,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for class_probs, clip in zip(classifications, clips)
|
||||
]
|
||||
|
||||
def get_sizes_arrays(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw size prediction tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.size_preds` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware size prediction maps, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.size_preds` and `clips` do not match.
|
||||
"""
|
||||
sizes = output.size_preds
|
||||
|
||||
if len(clips) != len(sizes):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of sizes array do not match. "
|
||||
f"(clips: {len(clips)}, sizes: {len(sizes)})"
|
||||
)
|
||||
|
||||
return [
|
||||
sizes_to_xarray(
|
||||
size_preds,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for size_preds, clip in zip(sizes, clips)
|
||||
]
|
||||
|
||||
def get_detection_datasets(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.Dataset]:
|
||||
"""Perform NMS, remapping, detection, and data extraction for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.Dataset]
|
||||
List of xarray Datasets (one per clip). Each Dataset contains
|
||||
aligned scores, dimensions, class probabilities, and features for
|
||||
detections found in that clip.
|
||||
"""
|
||||
detection_arrays = self.get_detection_arrays(output, clips)
|
||||
classification_arrays = self.get_classification_arrays(output, clips)
|
||||
size_arrays = self.get_sizes_arrays(output, clips)
|
||||
features_arrays = self.get_feature_arrays(output, clips)
|
||||
|
||||
datasets = []
|
||||
for det_array, class_array, sizes_array, feats_array in zip(
|
||||
detection_arrays,
|
||||
classification_arrays,
|
||||
size_arrays,
|
||||
features_arrays,
|
||||
):
|
||||
max_detections = get_max_detections(
|
||||
det_array,
|
||||
top_k_per_sec=self.config.top_k_per_sec,
|
||||
)
|
||||
|
||||
positions = extract_detections_from_array(
|
||||
det_array,
|
||||
max_detections=max_detections,
|
||||
threshold=self.config.detection_threshold,
|
||||
)
|
||||
|
||||
datasets.append(
|
||||
extract_detection_xr_dataset(
|
||||
positions,
|
||||
sizes_array,
|
||||
class_array,
|
||||
feats_array,
|
||||
)
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
def get_raw_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch.
|
||||
|
||||
Processes raw model output through remapping, NMS, detection, data
|
||||
extraction, and geometry recovery via the configured
|
||||
`targets.recover_roi`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[List[RawPrediction]]
|
||||
List of lists (one inner list per input clip). Each inner list
|
||||
contains `RawPrediction` objects for detections in that clip.
|
||||
"""
|
||||
detection_datasets = self.get_detection_datasets(output, clips)
|
||||
return [
|
||||
convert_xr_dataset_to_raw_prediction(
|
||||
dataset,
|
||||
self.targets.recover_roi,
|
||||
)
|
||||
for dataset in detection_datasets
|
||||
]
|
||||
|
||||
def get_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Perform the full postprocessing pipeline for a batch.
|
||||
|
||||
Takes raw model output and corresponding clips, applies the entire
|
||||
configured chain (NMS, remapping, extraction, geometry recovery, class
|
||||
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.ClipPrediction]
|
||||
List containing one `ClipPrediction` object for each input clip,
|
||||
populated with `SoundEventPrediction` objects.
|
||||
"""
|
||||
raw_predictions = self.get_raw_predictions(output, clips)
|
||||
return [
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
sound_event_decoder=self.targets.decode,
|
||||
generic_class_tags=self.targets.generic_class_tags,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
)
|
||||
for prediction, clip in zip(raw_predictions, clips)
|
||||
]
|
@ -1,73 +0,0 @@
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.models import ModelOutput
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
|
||||
|
||||
def to_xarray(
|
||||
output: ModelOutput,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
class_names: list[str],
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
):
|
||||
detection = output.detection_probs
|
||||
size = output.size_preds
|
||||
classes = output.class_probs
|
||||
features = output.features
|
||||
|
||||
if len(detection.shape) == 4:
|
||||
if detection.shape[0] != 1:
|
||||
raise ValueError(
|
||||
"Expected a non-batched output or a batch of size 1, instead "
|
||||
f"got an input of shape {detection.shape}"
|
||||
)
|
||||
|
||||
detection = detection.squeeze(dim=0)
|
||||
size = size.squeeze(dim=0)
|
||||
classes = classes.squeeze(dim=0)
|
||||
features = features.squeeze(dim=0)
|
||||
|
||||
_, width, height = detection.shape
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
if classes.shape[0] != len(class_names):
|
||||
raise ValueError(
|
||||
f"The number of classes does not coincide with the number of class names provided: ({classes.shape[0] = }) != ({len(class_names) = })"
|
||||
)
|
||||
|
||||
return xr.Dataset(
|
||||
data_vars={
|
||||
"detection": (
|
||||
[Dimensions.time.value, Dimensions.frequency.value],
|
||||
detection.squeeze(dim=0).detach().numpy(),
|
||||
),
|
||||
"size": (
|
||||
[
|
||||
"dimension",
|
||||
Dimensions.time.value,
|
||||
Dimensions.frequency.value,
|
||||
],
|
||||
detection.detach().numpy(),
|
||||
),
|
||||
"classes": (
|
||||
[
|
||||
"category",
|
||||
Dimensions.time.value,
|
||||
Dimensions.frequency.value,
|
||||
],
|
||||
classes.detach().numpy(),
|
||||
),
|
||||
},
|
||||
coords={
|
||||
Dimensions.time.value: times,
|
||||
Dimensions.frequency.value: freqs,
|
||||
"dimension": ["width", "height"],
|
||||
"category": class_names,
|
||||
},
|
||||
)
|
@ -1,32 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
|
||||
__all__ = [
|
||||
"PostprocessConfig",
|
||||
"load_postprocess_config",
|
||||
]
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
DETECTION_THRESHOLD = 0.01
|
||||
TOP_K_PER_SEC = 200
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration for postprocessing model outputs."""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
|
||||
min_freq: int = Field(default=10000, gt=0)
|
||||
max_freq: int = Field(default=120000, gt=0)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
297
batdetect2/postprocess/decoding.py
Normal file
297
batdetect2/postprocess/decoding.py
Normal file
@ -0,0 +1,297 @@
|
||||
"""Decodes extracted detection data into standard soundevent predictions.
|
||||
|
||||
This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
||||
It takes the structured detection data extracted by the `extraction` module
|
||||
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||
class probabilities, and features for each detection point) and converts it
|
||||
into meaningful, standardized prediction objects based on the `soundevent` data
|
||||
model.
|
||||
|
||||
The process involves:
|
||||
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||
objects, using a configured geometry builder to recover bounding boxes from
|
||||
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
|
||||
2. Converting each `RawPrediction` into a
|
||||
`soundevent.data.SoundEventPrediction`, which involves:
|
||||
- Creating the `soundevent.data.SoundEvent` with geometry and features.
|
||||
- Decoding the predicted class probabilities into representative tags using
|
||||
a configured class decoder (`SoundEventDecoder`).
|
||||
- Applying a classification threshold.
|
||||
- Optionally selecting only the single highest-scoring class (top-1) or
|
||||
including tags for all classes above the threshold (multi-label).
|
||||
- Adding generic class tags as a baseline.
|
||||
- Associating scores with the final prediction and tags.
|
||||
(`convert_raw_prediction_to_sound_event_prediction`)
|
||||
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
|
||||
a `soundevent.data.ClipPrediction`
|
||||
(`convert_raw_predictions_to_clip_prediction`).
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
||||
from batdetect2.targets.classes import SoundEventDecoder
|
||||
|
||||
__all__ = [
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"convert_raw_prediction_to_sound_event_prediction",
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
|
||||
"""Default threshold applied to classification scores.
|
||||
|
||||
Class predictions with scores below this value are typically ignored during
|
||||
decoding.
|
||||
"""
|
||||
|
||||
|
||||
def convert_xr_dataset_to_raw_prediction(
|
||||
detection_dataset: xr.Dataset,
|
||||
geometry_builder: GeometryBuilder,
|
||||
) -> List[RawPrediction]:
|
||||
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||
|
||||
Takes the output of the extraction step (`extract_detection_xr_dataset`)
|
||||
and transforms each detection entry into an intermediate `RawPrediction`
|
||||
object. This involves recovering the geometry (e.g., bounding box) from
|
||||
the predicted position and scaled size dimensions using the provided
|
||||
`geometry_builder` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_dataset : xr.Dataset
|
||||
An xarray Dataset containing aligned detection information, typically
|
||||
output by `extract_detection_xr_dataset`. Expected variables include
|
||||
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||
Must have a 'detection' dimension.
|
||||
geometry_builder : GeometryBuilder
|
||||
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||
of dimensions, and returns the corresponding reconstructed
|
||||
`soundevent.data.Geometry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[RawPrediction]
|
||||
A list of `RawPrediction` objects, each containing the detection score,
|
||||
recovered bounding box coordinates (start/end time, low/high freq),
|
||||
the vector of class scores, and the feature vector for one detection.
|
||||
|
||||
Raises
|
||||
------
|
||||
AttributeError, KeyError, ValueError
|
||||
If `detection_dataset` is missing expected variables ('scores',
|
||||
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
|
||||
associated with 'scores'), or if `geometry_builder` fails.
|
||||
"""
|
||||
detections = []
|
||||
|
||||
for det_num in range(detection_dataset.dims["detection"]):
|
||||
det_info = detection_dataset.sel(detection=det_num)
|
||||
|
||||
geom = geometry_builder(
|
||||
(det_info.time, det_info.freq),
|
||||
det_info.dimensions,
|
||||
)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geom)
|
||||
|
||||
classes = det_info.classes
|
||||
features = det_info.features
|
||||
|
||||
detections.append(
|
||||
RawPrediction(
|
||||
detection_score=det_info.score,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
class_scores=classes,
|
||||
features=features,
|
||||
)
|
||||
)
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
def convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions: List[RawPrediction],
|
||||
clip: data.Clip,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
generic_class_tags: List[data.Tag],
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
) -> data.ClipPrediction:
|
||||
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
|
||||
|
||||
Iterates through `raw_predictions` (assumed to belong to a single clip),
|
||||
converts each one into a `soundevent.data.SoundEventPrediction` using
|
||||
`convert_raw_prediction_to_sound_event_prediction`, and packages them
|
||||
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_predictions : List[RawPrediction]
|
||||
List of raw prediction objects for a single clip.
|
||||
clip : data.Clip
|
||||
The original `soundevent.data.Clip` object these predictions belong to.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Function to decode class names into representative tags.
|
||||
generic_class_tags : List[data.Tag]
|
||||
List of tags representing the generic class category.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Threshold applied to class scores during decoding.
|
||||
top_class_only : bool, default=False
|
||||
If True, only decode tags for the single highest-scoring class above
|
||||
the threshold. If False, decode tags for all classes above threshold.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.ClipPrediction
|
||||
A `ClipPrediction` object containing a list of `SoundEventPrediction`
|
||||
objects corresponding to the input `raw_predictions`.
|
||||
"""
|
||||
return data.ClipPrediction(
|
||||
clip=clip,
|
||||
sound_events=[
|
||||
convert_raw_prediction_to_sound_event_prediction(
|
||||
prediction,
|
||||
recording=clip.recording,
|
||||
sound_event_decoder=sound_event_decoder,
|
||||
generic_class_tags=generic_class_tags,
|
||||
classification_threshold=classification_threshold,
|
||||
top_class_only=top_class_only,
|
||||
)
|
||||
for prediction in raw_predictions
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction: RawPrediction,
|
||||
recording: data.Recording,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
generic_class_tags: List[data.Tag],
|
||||
classification_threshold: Optional[
|
||||
float
|
||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
):
|
||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
|
||||
|
||||
This function performs the core decoding steps for a single detected event:
|
||||
1. Creates a `soundevent.data.SoundEvent` containing the geometry
|
||||
(BoundingBox derived from `raw_prediction` bounds) and any associated
|
||||
feature vectors.
|
||||
2. Initializes a list of predicted tags using the provided
|
||||
`generic_class_tags`, assigning the overall `detection_score` from the
|
||||
`raw_prediction` to these generic tags.
|
||||
3. Processes the `class_scores` from the `raw_prediction`:
|
||||
a. Optionally filters out scores below `classification_threshold`
|
||||
(if it's not None).
|
||||
b. Sorts the remaining scores in descending order.
|
||||
c. Iterates through the sorted, thresholded class scores.
|
||||
d. For each class, uses the `sound_event_decoder` to get the
|
||||
representative base tags for that class name.
|
||||
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
|
||||
the specific `score` of that class prediction.
|
||||
f. Appends these specific predicted tags to the list.
|
||||
g. If `top_class_only` is True, stops after processing the first
|
||||
(highest-scoring) class that passed the threshold.
|
||||
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
|
||||
associating the `SoundEvent`, the overall `detection_score`, and the
|
||||
compiled list of `PredictedTag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_prediction : RawPrediction
|
||||
The raw prediction object containing score, bounds, class scores,
|
||||
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
|
||||
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
|
||||
coordinate.
|
||||
recording : data.Recording
|
||||
The recording the sound event belongs to.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Configured function mapping class names (str) to lists of base
|
||||
`data.Tag` objects.
|
||||
generic_class_tags : List[data.Tag]
|
||||
List of base tags representing the generic category.
|
||||
classification_threshold : float, optional
|
||||
The minimum score a class prediction must have to be considered
|
||||
significant enough to have its tags decoded and added. If None, no
|
||||
thresholding is applied based on class score (all predicted classes,
|
||||
or the top one if `top_class_only` is True, will be processed).
|
||||
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
|
||||
top_class_only : bool, default=False
|
||||
If True, only includes tags for the single highest-scoring class that
|
||||
exceeds the threshold. If False (default), includes tags for all classes
|
||||
exceeding the threshold.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventPrediction
|
||||
The fully formed sound event prediction object.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `raw_prediction.features` has unexpected structure or if
|
||||
`data.term_from_key` (if used internally) fails.
|
||||
If `sound_event_decoder` fails for a class name and errors are raised.
|
||||
"""
|
||||
sound_event = data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
raw_prediction.start_time,
|
||||
raw_prediction.low_freq,
|
||||
raw_prediction.end_time,
|
||||
raw_prediction.high_freq,
|
||||
]
|
||||
),
|
||||
features=[
|
||||
data.Feature(term=data.term_from_key(feat_name), value=value)
|
||||
for feat_name, value in raw_prediction.features
|
||||
],
|
||||
)
|
||||
|
||||
tags = [
|
||||
data.PredictedTag(tag=tag, score=raw_prediction.detection_score)
|
||||
for tag in generic_class_tags
|
||||
]
|
||||
|
||||
class_scores = raw_prediction.class_scores
|
||||
|
||||
if classification_threshold is not None:
|
||||
class_scores = class_scores.where(
|
||||
class_scores > classification_threshold,
|
||||
drop=True,
|
||||
)
|
||||
|
||||
for class_name, score in class_scores.sortby(
|
||||
class_scores, ascending=False
|
||||
):
|
||||
class_tags = sound_event_decoder(class_name)
|
||||
|
||||
for tag in class_tags:
|
||||
tags.append(
|
||||
data.PredictedTag(
|
||||
tag=tag,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
if top_class_only:
|
||||
break
|
||||
|
||||
return data.SoundEventPrediction(
|
||||
sound_event=sound_event,
|
||||
score=raw_prediction.detection_score,
|
||||
tags=tags,
|
||||
)
|
162
batdetect2/postprocess/detection.py
Normal file
162
batdetect2/postprocess/detection.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""Extracts candidate detection points from a model output heatmap.
|
||||
|
||||
This module implements a specific step within the BatDetect2 postprocessing
|
||||
pipeline. Its primary function is to identify potential sound event locations
|
||||
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||
produced by the neural network (usually after Non-Maximum Suppression and
|
||||
coordinate remapping have been applied).
|
||||
|
||||
It provides functionality to:
|
||||
- Identify the locations (time, frequency) of the highest-scoring points.
|
||||
- Filter these points based on a minimum confidence score threshold.
|
||||
- Limit the maximum number of detection points returned (top-k).
|
||||
|
||||
The main output is an `xarray.DataArray` containing the scores and
|
||||
corresponding time/frequency coordinates for the extracted detection points.
|
||||
This output serves as the input for subsequent postprocessing steps, such as
|
||||
extracting predicted class probabilities and bounding box sizes at these
|
||||
specific locations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions, get_dim_width
|
||||
|
||||
__all__ = [
|
||||
"extract_detections_from_array",
|
||||
"get_max_detections",
|
||||
"DEFAULT_DETECTION_THRESHOLD",
|
||||
"TOP_K_PER_SEC",
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
"""Default confidence score threshold used for filtering detections."""
|
||||
|
||||
TOP_K_PER_SEC = 200
|
||||
"""Default desired maximum number of detections per second of audio."""
|
||||
|
||||
|
||||
def extract_detections_from_array(
|
||||
detection_array: xr.DataArray,
|
||||
max_detections: Optional[int] = None,
|
||||
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
|
||||
) -> xr.DataArray:
|
||||
"""Extract detection locations (time, freq) and scores from a heatmap.
|
||||
|
||||
Identifies the pixels with the highest scores in the input detection
|
||||
heatmap, filters them based on an optional score `threshold`, limits the
|
||||
number to an optional `max_detections`, and returns their scores along with
|
||||
their corresponding time and frequency coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_array : xr.DataArray
|
||||
A 2D xarray DataArray representing the detection heatmap. Must have
|
||||
dimensions and coordinates named 'time' and 'frequency'. Higher values
|
||||
are assumed to indicate higher detection confidence.
|
||||
max_detections : int, optional
|
||||
The absolute maximum number of detections to return. If specified, only
|
||||
the top `max_detections` highest-scoring detections (passing the
|
||||
threshold) are returned. If None (default), all detections passing
|
||||
the threshold are returned, sorted by score.
|
||||
threshold : float, optional
|
||||
The minimum confidence score required for a detection peak to be
|
||||
kept. Detections with scores below this value are discarded.
|
||||
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
|
||||
thresholding is applied.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
A 1D xarray DataArray named 'score' with a 'detection' dimension.
|
||||
- The data values are the scores of the extracted detections, sorted
|
||||
in descending order.
|
||||
- It has coordinates 'time' and 'frequency' (also indexed by the
|
||||
'detection' dimension) indicating the location of each detection
|
||||
peak in the original coordinate system.
|
||||
- Returns an empty DataArray if no detections pass the criteria.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `max_detections` is not None and not a positive integer, or if
|
||||
`detection_array` lacks required dimensions/coordinates.
|
||||
"""
|
||||
if max_detections is not None:
|
||||
if max_detections <= 0:
|
||||
raise ValueError("Max detections must be positive")
|
||||
|
||||
values = detection_array.values.flatten()
|
||||
|
||||
if max_detections is not None:
|
||||
top_indices = np.argpartition(-values, max_detections)[:max_detections]
|
||||
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
|
||||
else:
|
||||
top_sorted_indices = np.argsort(-values)
|
||||
|
||||
top_values = values[top_sorted_indices]
|
||||
|
||||
if threshold is not None:
|
||||
mask = top_values > threshold
|
||||
top_values = top_values[mask]
|
||||
top_sorted_indices = top_sorted_indices[mask]
|
||||
|
||||
time_indices, freq_indices = np.unravel_index(
|
||||
top_sorted_indices,
|
||||
detection_array.shape,
|
||||
)
|
||||
|
||||
times = detection_array.coords[Dimensions.time.value].values[time_indices]
|
||||
freqs = detection_array.coords[Dimensions.frequency.value].values[
|
||||
freq_indices
|
||||
]
|
||||
|
||||
return xr.DataArray(
|
||||
data=top_values,
|
||||
coords={
|
||||
Dimensions.frequency.value: ("detection", freqs),
|
||||
Dimensions.time.value: ("detection", times),
|
||||
},
|
||||
dims="detection",
|
||||
name="score",
|
||||
)
|
||||
|
||||
|
||||
def get_max_detections(
|
||||
detection_array: xr.DataArray,
|
||||
top_k_per_sec: int = TOP_K_PER_SEC,
|
||||
) -> int:
|
||||
"""Calculate max detections allowed based on duration and rate.
|
||||
|
||||
Determines the total maximum number of detections to extract from a
|
||||
heatmap based on its time duration and a desired rate of detections
|
||||
per second.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_array : xr.DataArray
|
||||
The detection heatmap, requiring 'time' coordinates from which the
|
||||
total duration can be calculated using
|
||||
`soundevent.arrays.get_dim_width`.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
The desired maximum number of detections to allow per second of audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The calculated total maximum number of detections allowed for the
|
||||
entire duration of the `detection_array`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the duration cannot be calculated from the `detection_array` (e.g.,
|
||||
missing or invalid 'time' coordinates/dimension).
|
||||
"""
|
||||
if top_k_per_sec < 0:
|
||||
raise ValueError("top_k_per_sec cannot be negative.")
|
||||
|
||||
duration = get_dim_width(detection_array, Dimensions.time.value)
|
||||
return int(duration * top_k_per_sec)
|
122
batdetect2/postprocess/extraction.py
Normal file
122
batdetect2/postprocess/extraction.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""Extracts associated data for detected points from model output arrays.
|
||||
|
||||
This module implements a key step (Step 4) in the BatDetect2 postprocessing
|
||||
pipeline. After candidate detection points (time, frequency, score) have been
|
||||
identified, this module extracts the corresponding values from other raw model
|
||||
output arrays, such as:
|
||||
|
||||
- Predicted bounding box sizes (width, height).
|
||||
- Class probability scores for each defined target class.
|
||||
- Intermediate feature vectors.
|
||||
|
||||
It uses coordinate-based indexing provided by `xarray` to ensure that the
|
||||
correct values are retrieved from the original heatmaps/feature maps at the
|
||||
precise time-frequency location of each detection. The final output aggregates
|
||||
all extracted information into a structured `xarray.Dataset`.
|
||||
"""
|
||||
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
__all__ = [
|
||||
"extract_values_at_positions",
|
||||
"extract_detection_xr_dataset",
|
||||
]
|
||||
|
||||
|
||||
def extract_values_at_positions(
|
||||
array: xr.DataArray,
|
||||
positions: xr.DataArray,
|
||||
) -> xr.DataArray:
|
||||
"""Extract values from an array at specified time-frequency positions.
|
||||
|
||||
Uses coordinate-based indexing to retrieve values from a source `array`
|
||||
(e.g., class probabilities, size predictions, features) at the time and
|
||||
frequency coordinates defined in the `positions` array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : xr.DataArray
|
||||
The source DataArray from which to extract values. Must have 'time'
|
||||
and 'frequency' dimensions and coordinates matching the space of
|
||||
`positions`.
|
||||
positions : xr.DataArray
|
||||
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
|
||||
locations from which to extract values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
A DataArray containing the values extracted from `array` at the given
|
||||
positions.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError, IndexError, KeyError
|
||||
If dimensions or coordinates are missing or incompatible between
|
||||
`array` and `positions`, or if selection fails.
|
||||
"""
|
||||
return array.sel(
|
||||
**{
|
||||
Dimensions.frequency.value: positions.coords[
|
||||
Dimensions.frequency.value
|
||||
],
|
||||
Dimensions.time.value: positions.coords[Dimensions.time.value],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def extract_detection_xr_dataset(
|
||||
positions: xr.DataArray,
|
||||
sizes: xr.DataArray,
|
||||
classes: xr.DataArray,
|
||||
features: xr.DataArray,
|
||||
) -> xr.Dataset:
|
||||
"""Combine extracted detection information into a structured xr.Dataset.
|
||||
|
||||
Takes the detection positions/scores and the full model output heatmaps
|
||||
(sizes, classes, optional features), extracts the relevant data at the
|
||||
detection positions, and packages everything into a single `xarray.Dataset`
|
||||
where all variables are indexed by a common 'detection' dimension.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : xr.DataArray
|
||||
Output from `extract_detections_from_array`, containing detection
|
||||
scores as data and 'time', 'frequency' coordinates along the
|
||||
'detection' dimension.
|
||||
sizes : xr.DataArray
|
||||
The full size prediction heatmap from the model, with dimensions like
|
||||
('dimension', 'time', 'frequency').
|
||||
classes : xr.DataArray
|
||||
The full class probability heatmap from the model, with dimensions like
|
||||
('category', 'time', 'frequency').
|
||||
features : xr.DataArray
|
||||
The full feature map from the model, with
|
||||
dimensions like ('feature', 'time', 'frequency').
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
An xarray Dataset containing aligned information for each detection:
|
||||
- 'scores': DataArray from `positions` (score data, time/freq coords).
|
||||
- 'dimensions': DataArray with extracted size values
|
||||
(dims: 'detection', 'dimension').
|
||||
- 'classes': DataArray with extracted class probabilities
|
||||
(dims: 'detection', 'category').
|
||||
- 'features': DataArray with extracted feature vectors
|
||||
(dims: 'detection', 'feature'), if `features` was provided. All
|
||||
DataArrays share the 'detection' dimension and associated
|
||||
time/frequency coordinates.
|
||||
"""
|
||||
sizes = extract_values_at_positions(sizes, positions).T
|
||||
classes = extract_values_at_positions(classes, positions).T
|
||||
features = extract_values_at_positions(features, positions).T
|
||||
return xr.Dataset(
|
||||
{
|
||||
"scores": positions,
|
||||
"dimensions": sizes,
|
||||
"classes": classes,
|
||||
"features": features,
|
||||
}
|
||||
)
|
96
batdetect2/postprocess/nms.py
Normal file
96
batdetect2/postprocess/nms.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""Performs Non-Maximum Suppression (NMS) on detection heatmaps.
|
||||
|
||||
This module provides functionality to apply Non-Maximum Suppression, a common
|
||||
technique used after model inference, particularly in object detection and peak
|
||||
detection tasks.
|
||||
|
||||
In the context of BatDetect2 postprocessing, NMS is applied
|
||||
to the raw detection heatmap output by the neural network. Its purpose is to
|
||||
isolate distinct detection peaks by suppressing (setting to zero) nearby heatmap
|
||||
activations that have lower scores than a local maximum. This helps prevent
|
||||
multiple, overlapping detections originating from the same sound event.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
"""Default kernel size (pixels) for Non-Maximum Suppression.
|
||||
|
||||
Specifies the side length of the square neighborhood used by default in
|
||||
`non_max_suppression` to find local maxima. A 9x9 neighborhood is often
|
||||
a reasonable starting point for typical spectrogram resolutions used in
|
||||
BatDetect2.
|
||||
"""
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
tensor: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
) -> torch.Tensor:
|
||||
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
|
||||
|
||||
This function identifies local maxima within a defined neighborhood for
|
||||
each point in the input tensor. Values that are *not* the maximum within
|
||||
their neighborhood are suppressed (set to zero). This is commonly used on
|
||||
detection probability heatmaps to isolate distinct peaks corresponding to
|
||||
individual detections and remove redundant lower scores nearby.
|
||||
|
||||
The implementation uses efficient 2D max pooling to find the maximum value
|
||||
in the neighborhood of each point.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor, typically representing a detection heatmap. Must be a
|
||||
3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying
|
||||
`torch.nn.functional.max_pool2d` operation.
|
||||
kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE
|
||||
Size of the sliding window neighborhood used to find local maxima.
|
||||
If an integer `k` is provided, a square kernel of size `(k, k)` is used.
|
||||
If a tuple `(h, w)` is provided, a rectangular kernel of height `h`
|
||||
and width `w` is used. The kernel size should typically be odd to
|
||||
have a well-defined center.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
A tensor of the same shape as the input, where only local maxima within
|
||||
their respective neighborhoods (defined by `kernel_size`) retain their
|
||||
original values. All other values are set to zero.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
If `kernel_size` is not an int or a tuple of two ints.
|
||||
RuntimeError
|
||||
If the input `tensor` does not have 3 or 4 dimensions (as required
|
||||
by `max_pool2d`).
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function assumes higher values in the tensor indicate stronger peaks.
|
||||
- Choosing an appropriate `kernel_size` is important. It should be large
|
||||
enough to cover the typical "footprint" of a single detection peak plus
|
||||
some surrounding context, effectively preventing multiple detections for
|
||||
the same event. A size that is too large might suppress nearby distinct
|
||||
events.
|
||||
"""
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size_h = kernel_size
|
||||
kernel_size_w = kernel_size
|
||||
else:
|
||||
kernel_size_h, kernel_size_w = kernel_size
|
||||
|
||||
pad_h = (kernel_size_h - 1) // 2
|
||||
pad_w = (kernel_size_w - 1) // 2
|
||||
|
||||
hmax = torch.nn.functional.max_pool2d(
|
||||
tensor,
|
||||
(kernel_size_h, kernel_size_w),
|
||||
stride=1,
|
||||
padding=(pad_h, pad_w),
|
||||
)
|
||||
keep = (hmax == tensor).float()
|
||||
return tensor * keep
|
@ -1,50 +0,0 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
tensor: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
) -> torch.Tensor:
|
||||
"""Run non-maximum suppression on a tensor.
|
||||
|
||||
This function removes values from the input tensor that are not local
|
||||
maxima in the neighborhood of the given kernel size.
|
||||
|
||||
All non-maximum values are set to zero.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor.
|
||||
kernel_size : Union[int, Tuple[int, int]], optional
|
||||
Size of the neighborhood to consider for non-maximum suppression.
|
||||
If an integer is given, the neighborhood will be a square of the
|
||||
given size. If a tuple is given, the neighborhood will be a
|
||||
rectangle with the given height and width.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Tensor with non-maximum suppressed values.
|
||||
"""
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size_h = kernel_size
|
||||
kernel_size_w = kernel_size
|
||||
else:
|
||||
kernel_size_h, kernel_size_w = kernel_size
|
||||
|
||||
pad_h = (kernel_size_h - 1) // 2
|
||||
pad_w = (kernel_size_w - 1) // 2
|
||||
|
||||
hmax = torch.nn.functional.max_pool2d(
|
||||
tensor,
|
||||
(kernel_size_h, kernel_size_w),
|
||||
stride=1,
|
||||
padding=(pad_h, pad_w),
|
||||
)
|
||||
keep = (hmax == tensor).float()
|
||||
return tensor * keep
|
316
batdetect2/postprocess/remapping.py
Normal file
316
batdetect2/postprocess/remapping.py
Normal file
@ -0,0 +1,316 @@
|
||||
"""Remaps raw model output tensors to coordinate-aware xarray DataArrays.
|
||||
|
||||
This module provides utility functions to convert the raw numerical outputs
|
||||
(typically PyTorch tensors) from the BatDetect2 DNN model into
|
||||
`xarray.DataArray` objects. This step adds coordinate information
|
||||
(time in seconds, frequency in Hz) back to the model's predictions, making them
|
||||
interpretable in the context of the original audio signal and facilitating
|
||||
subsequent processing steps.
|
||||
|
||||
Functions are provided for common BatDetect2 output types: detection heatmaps,
|
||||
classification probability maps, size prediction maps, and potentially
|
||||
intermediate features.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
|
||||
__all__ = [
|
||||
"features_to_xarray",
|
||||
"detection_to_xarray",
|
||||
"classification_to_xarray",
|
||||
"sizes_to_xarray",
|
||||
]
|
||||
|
||||
|
||||
def features_to_xarray(
|
||||
features: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
features_prefix: str = "batdetect2_feature_",
|
||||
):
|
||||
"""Convert a multi-channel feature tensor to a coordinate-aware DataArray.
|
||||
|
||||
Assigns time, frequency, and feature coordinates to a raw feature tensor
|
||||
output by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
The raw feature tensor from the model. Expected shape is
|
||||
(num_features, num_freq_bins, num_time_bins).
|
||||
start_time : float
|
||||
The start time (in seconds) corresponding to the first time bin of
|
||||
the tensor.
|
||||
end_time : float
|
||||
The end time (in seconds) corresponding to the *end* of the last time
|
||||
bin.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
The minimum frequency (in Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
The maximum frequency (in Hz) corresponding to the *end* of the last
|
||||
frequency bin.
|
||||
features_prefix : str, default="batdetect2_feature_"
|
||||
Prefix used to generate names for the feature coordinate dimension
|
||||
(e.g., "batdetect2_feature_0", "batdetect2_feature_1", ...).
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing the feature data with named dimensions
|
||||
('feature', 'frequency', 'time') and calculated coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions.
|
||||
"""
|
||||
if features.ndim != 3:
|
||||
raise ValueError(
|
||||
"Input features tensor must have 3 dimensions (C, T, F), "
|
||||
f"got shape {features.shape}"
|
||||
)
|
||||
|
||||
num_features, height, width = features.shape
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=features.detach().numpy(),
|
||||
dims=[
|
||||
Dimensions.feature.value,
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
Dimensions.feature.value: [
|
||||
f"{features_prefix}{i}" for i in range(num_features)
|
||||
],
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name="features",
|
||||
)
|
||||
|
||||
|
||||
def detection_to_xarray(
|
||||
detection: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
) -> xr.DataArray:
|
||||
"""Convert a single-channel detection heatmap tensor to a DataArray.
|
||||
|
||||
Assigns time and frequency coordinates to a raw detection heatmap tensor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection : torch.Tensor
|
||||
Raw detection heatmap tensor from the model. Expected shape is
|
||||
(1, num_freq_bins, num_time_bins).
|
||||
start_time : float
|
||||
Start time (seconds) corresponding to the first time bin.
|
||||
end_time : float
|
||||
End time (seconds) corresponding to the end of the last time bin.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||
bin.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing the detection scores with named
|
||||
dimensions ('frequency', 'time') and calculated coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions or if the first
|
||||
dimension size is not 1.
|
||||
"""
|
||||
if detection.ndim != 3:
|
||||
raise ValueError(
|
||||
"Input detection tensor must have 3 dimensions (1, T, F), "
|
||||
f"got shape {detection.shape}"
|
||||
)
|
||||
|
||||
num_channels, height, width = detection.shape
|
||||
|
||||
if num_channels != 1:
|
||||
raise ValueError(
|
||||
"Expected a single channel output, instead got "
|
||||
f"{num_channels} channels"
|
||||
)
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=detection.squeeze(dim=0).detach().numpy(),
|
||||
dims=[
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name="detection_score",
|
||||
)
|
||||
|
||||
|
||||
def classification_to_xarray(
|
||||
classes: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
class_names: List[str],
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
) -> xr.DataArray:
|
||||
"""Convert multi-channel class probability tensor to a DataArray.
|
||||
|
||||
Assigns category (class name), frequency, and time coordinates to a raw
|
||||
class probability tensor output by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
classes : torch.Tensor
|
||||
Raw class probability tensor. Expected shape is
|
||||
(num_classes, num_freq_bins, num_time_bins).
|
||||
start_time : float
|
||||
Start time (seconds) corresponding to the first time bin.
|
||||
end_time : float
|
||||
End time (seconds) corresponding to the end of the last time bin.
|
||||
class_names : List[str]
|
||||
Ordered list of class names corresponding to the first dimension
|
||||
of the `classes` tensor. The length must match `classes.shape[0]`.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||
bin.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing class probabilities with named
|
||||
dimensions ('category', 'frequency', 'time') and calculated
|
||||
coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions, or if the size of the
|
||||
first dimension does not match the length of `class_names`.
|
||||
"""
|
||||
if classes.ndim != 3:
|
||||
raise ValueError(
|
||||
"Input classes tensor must have 3 dimensions (C, F, T), "
|
||||
f"got shape {classes.shape}"
|
||||
)
|
||||
|
||||
num_classes, height, width = classes.shape
|
||||
|
||||
if num_classes != len(class_names):
|
||||
raise ValueError(
|
||||
"The number of classes does not coincide with the number of "
|
||||
"class names provided: "
|
||||
f"({num_classes = }) != ({len(class_names) = })"
|
||||
)
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=classes.detach().numpy(),
|
||||
dims=[
|
||||
"category",
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
"category": class_names,
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name="class_scores",
|
||||
)
|
||||
|
||||
|
||||
def sizes_to_xarray(
|
||||
sizes: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
) -> xr.DataArray:
|
||||
"""Convert the 2-channel size prediction tensor to a DataArray.
|
||||
|
||||
Assigns dimension ('width', 'height'), frequency, and time coordinates
|
||||
to the raw size prediction tensor output by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sizes : torch.Tensor
|
||||
Raw size prediction tensor. Expected shape is
|
||||
(2, num_freq_bins, num_time_bins), where the first dimension
|
||||
corresponds to predicted width and height respectively.
|
||||
start_time : float
|
||||
Start time (seconds) corresponding to the first time bin.
|
||||
end_time : float
|
||||
End time (seconds) corresponding to the end of the last time bin.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||
bin.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing predicted sizes with named dimensions
|
||||
('dimension', 'frequency', 'time') and calculated time/frequency
|
||||
coordinates. The 'dimension' coordinate will have values
|
||||
['width', 'height'].
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions or if the first
|
||||
dimension size is not exactly 2.
|
||||
"""
|
||||
num_channels, height, width = sizes.shape
|
||||
|
||||
if num_channels != 2:
|
||||
raise ValueError(
|
||||
"Expected a two-channel output, instead got "
|
||||
f"{num_channels} channels"
|
||||
)
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=sizes.detach().numpy(),
|
||||
dims=[
|
||||
"dimension",
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
"dimension": ["width", "height"],
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
)
|
@ -1,21 +1,284 @@
|
||||
from typing import Dict, NamedTuple, Protocol
|
||||
"""Defines shared interfaces and data structures for postprocessing.
|
||||
|
||||
This module centralizes the Protocol definitions and common data structures
|
||||
used throughout the `batdetect2.postprocess` module.
|
||||
|
||||
The main component is the `PostprocessorProtocol`, which outlines the standard
|
||||
interface for an object responsible for executing the entire postprocessing
|
||||
pipeline. This pipeline transforms raw neural network outputs into interpretable
|
||||
detections represented as `soundevent` objects. Using protocols ensures
|
||||
modularity and consistent interaction between different parts of the BatDetect2
|
||||
system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from typing import Callable, List, NamedTuple, Protocol
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2Prediction",
|
||||
"RawPrediction",
|
||||
"PostprocessorProtocol",
|
||||
"GeometryBuilder",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2Prediction(NamedTuple):
|
||||
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
|
||||
"""Type alias for a function that recovers geometry from position and size.
|
||||
|
||||
This callable takes:
|
||||
1. A position tuple `(time, frequency)`.
|
||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||
`BoundingBox`).
|
||||
"""
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
"""Intermediate representation of a single detected sound event.
|
||||
|
||||
Holds extracted information about a detection after initial processing
|
||||
(like peak finding, coordinate remapping, geometry recovery) but before
|
||||
final class decoding and conversion into a `SoundEventPrediction`. This
|
||||
can be useful for evaluation or simpler data handling formats.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
start_time : float
|
||||
Start time of the recovered bounding box in seconds.
|
||||
end_time : float
|
||||
End time of the recovered bounding box in seconds.
|
||||
low_freq : float
|
||||
Lowest frequency of the recovered bounding box in Hz.
|
||||
high_freq : float
|
||||
Highest frequency of the recovered bounding box in Hz.
|
||||
detection_score : float
|
||||
The confidence score associated with this detection, typically from
|
||||
the detection heatmap peak.
|
||||
class_scores : xr.DataArray
|
||||
An xarray DataArray containing the predicted probabilities or scores
|
||||
for each target class at the detection location. Indexed by a
|
||||
'category' coordinate containing class names.
|
||||
features : xr.DataArray
|
||||
An xarray DataArray containing extracted feature vectors at the
|
||||
detection location. Indexed by a 'feature' coordinate.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
detection_score: float
|
||||
class_scores: Dict[str, float]
|
||||
features: np.ndarray
|
||||
class_scores: xr.DataArray
|
||||
features: xr.DataArray
|
||||
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
pass
|
||||
"""Protocol defining the interface for the full postprocessing pipeline.
|
||||
|
||||
This protocol outlines the standard methods for an object that takes raw
|
||||
output from a BatDetect2 model and the corresponding input clip metadata,
|
||||
and processes it through various stages (e.g., coordinate remapping, NMS,
|
||||
detection extraction, data extraction, decoding) to produce interpretable
|
||||
results at different levels of completion.
|
||||
|
||||
Implementations manage the configured logic for all postprocessing steps.
|
||||
"""
|
||||
|
||||
def get_feature_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap feature tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch, expected
|
||||
to contain the necessary feature tensors.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects, one for each item in the
|
||||
processed batch. This list provides the timing, recording, and
|
||||
other metadata context needed to calculate real-world coordinates
|
||||
(seconds, Hz) for the output arrays. The length of this list must
|
||||
correspond to the batch size of the `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of xarray DataArrays, one for each input clip in the batch,
|
||||
in the same order. Each DataArray contains the feature vectors
|
||||
with dimensions like ('feature', 'time', 'frequency') and
|
||||
corresponding real-world coordinates.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_detection_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap detection tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing detection heatmaps.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 2D xarray DataArrays (one per input clip, in order),
|
||||
representing the detection heatmap with 'time' and 'frequency'
|
||||
coordinates. Values typically indicate detection confidence.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_classification_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap classification tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing class probability tensors.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||
representing class probabilities with 'category', 'time', and
|
||||
'frequency' dimensions and coordinates.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_sizes_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap size prediction tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing predicted size tensors (e.g., width and height).
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||
representing predicted sizes with 'dimension'
|
||||
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
|
||||
coordinates. Values represent estimated detection sizes.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_detection_datasets(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.Dataset]:
|
||||
"""Perform remapping, NMS, detection, and data extraction for a batch.
|
||||
|
||||
Processes the raw model output for a batch to identify detection peaks
|
||||
and extract all associated information (score, position, size, class
|
||||
probs, features) at those peak locations, returning a structured
|
||||
dataset for each input clip in the batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing context. Must match the batch size of `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.Dataset]
|
||||
A list of xarray Datasets (one per input clip, in order). Each
|
||||
Dataset contains multiple DataArrays ('scores', 'dimensions',
|
||||
'classes', 'features') sharing a common 'detection' dimension,
|
||||
providing aligned data for each detected event in that clip.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_raw_predictions(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch.
|
||||
|
||||
Processes the raw model output for a batch through remapping, NMS,
|
||||
detection, data extraction, and geometry recovery to produce a list of
|
||||
`RawPrediction` objects for each corresponding input clip. This provides
|
||||
a simplified, intermediate representation before final tag decoding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing context. Must match the batch size of `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[List[RawPrediction]]
|
||||
A list of lists (one inner list per input clip, in order). Each
|
||||
inner list contains the `RawPrediction` objects extracted for the
|
||||
corresponding input clip.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_predictions(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Perform the full postprocessing pipeline for a batch.
|
||||
|
||||
Takes raw model output for a batch and corresponding clips, applies the
|
||||
entire postprocessing chain, and returns the final, interpretable
|
||||
predictions as a list of `soundevent.data.ClipPrediction` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing context. Must match the batch size of `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.ClipPrediction]
|
||||
A list containing one `ClipPrediction` object for each input clip
|
||||
(in the same order), populated with `SoundEventPrediction` objects
|
||||
representing the final detections with decoded tags and geometry.
|
||||
"""
|
||||
...
|
||||
|
Loading…
Reference in New Issue
Block a user