batdetect2/src/batdetect2/postprocess/decoding.py
2025-08-08 13:06:10 +01:00

396 lines
15 KiB
Python

"""Decodes extracted detection data into standard soundevent predictions.
This module handles the final stages of the BatDetect2 postprocessing pipeline.
It takes the structured detection data extracted by the `extraction` module
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
class probabilities, and features for each detection point) and converts it
into standardized prediction objects based on the `soundevent` data model.
The process involves:
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
objects, using a configured geometry builder to recover bounding boxes from
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
2. Converting each `RawPrediction` into a
`soundevent.data.SoundEventPrediction`, which involves:
- Creating the `soundevent.data.SoundEvent` with geometry and features.
- Decoding the predicted class probabilities into representative tags using
a configured class decoder (`SoundEventDecoder`).
- Applying a classification threshold.
- Optionally selecting only the single highest-scoring class (top-1) or
including tags for all classes above the threshold (multi-label).
- Adding generic class tags as a baseline.
- Associating scores with the final prediction and tags.
(`convert_raw_prediction_to_sound_event_prediction`)
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
a `soundevent.data.ClipPrediction`
(`convert_raw_predictions_to_clip_prediction`).
"""
from typing import List, Optional
import numpy as np
import xarray as xr
from soundevent import data
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder
from batdetect2.utils.arrays import iterate_over_array
__all__ = [
"convert_xr_dataset_to_raw_prediction",
"convert_raw_predictions_to_clip_prediction",
"convert_raw_prediction_to_sound_event_prediction",
"DEFAULT_CLASSIFICATION_THRESHOLD",
]
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
"""Default threshold applied to classification scores.
Class predictions with scores below this value are typically ignored during
decoding.
"""
def convert_xr_dataset_to_raw_prediction(
detection_dataset: xr.Dataset,
geometry_decoder: GeometryDecoder,
) -> List[RawPrediction]:
"""Convert an xarray.Dataset of detections to RawPrediction objects.
Takes the output of the extraction step (`extract_detection_xr_dataset`)
and transforms each detection entry into an intermediate `RawPrediction`
object. This involves recovering the geometry (e.g., bounding box) from
the predicted position and scaled size dimensions using the provided
`geometry_builder` function.
Parameters
----------
detection_dataset : xr.Dataset
An xarray Dataset containing aligned detection information, typically
output by `extract_detection_xr_dataset`. Expected variables include
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
Must have a 'detection' dimension.
geometry_decoder : GeometryDecoder
A function that takes a position tuple `(time, freq)` and a NumPy array
of dimensions, and returns the corresponding reconstructed
`soundevent.data.Geometry`.
Returns
-------
List[RawPrediction]
A list of `RawPrediction` objects, each containing the detection score,
recovered bounding box coordinates (start/end time, low/high freq),
the vector of class scores, and the feature vector for one detection.
Raises
------
AttributeError, KeyError, ValueError
If `detection_dataset` is missing expected variables ('scores',
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
associated with 'scores'), or if `geometry_builder` fails.
"""
detections = []
for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num)
highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax()
].item()
geom = geometry_decoder(
(det_info.time, det_info.frequency),
det_info.dimensions,
class_name=highest_scoring_class,
)
detections.append(
RawPrediction(
detection_score=det_info.scores,
geometry=geom,
class_scores=det_info.classes,
features=det_info.features,
)
)
return detections
def convert_raw_predictions_to_clip_prediction(
raw_predictions: List[RawPrediction],
clip: data.Clip,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
) -> data.ClipPrediction:
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
Iterates through `raw_predictions` (assumed to belong to a single clip),
converts each one into a `soundevent.data.SoundEventPrediction` using
`convert_raw_prediction_to_sound_event_prediction`, and packages them
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
Parameters
----------
raw_predictions : List[RawPrediction]
List of raw prediction objects for a single clip.
clip : data.Clip
The original `soundevent.data.Clip` object these predictions belong to.
sound_event_decoder : SoundEventDecoder
Function to decode class names into representative tags.
generic_class_tags : List[data.Tag]
List of tags representing the generic class category.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Threshold applied to class scores during decoding.
top_class_only : bool, default=False
If True, only decode tags for the single highest-scoring class above
the threshold. If False, decode tags for all classes above threshold.
Returns
-------
data.ClipPrediction
A `ClipPrediction` object containing a list of `SoundEventPrediction`
objects corresponding to the input `raw_predictions`.
"""
return data.ClipPrediction(
clip=clip,
sound_events=[
convert_raw_prediction_to_sound_event_prediction(
prediction,
recording=clip.recording,
sound_event_decoder=sound_event_decoder,
generic_class_tags=generic_class_tags,
classification_threshold=classification_threshold,
top_class_only=top_class_only,
)
for prediction in raw_predictions
],
)
def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction,
recording: data.Recording,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
classification_threshold: Optional[
float
] = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
):
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
This function performs the core decoding steps for a single detected event:
1. Creates a `soundevent.data.SoundEvent` containing the geometry
(BoundingBox derived from `raw_prediction` bounds) and any associated
feature vectors.
2. Initializes a list of predicted tags using the provided
`generic_class_tags`, assigning the overall `detection_score` from the
`raw_prediction` to these generic tags.
3. Processes the `class_scores` from the `raw_prediction`:
a. Optionally filters out scores below `classification_threshold`
(if it's not None).
b. Sorts the remaining scores in descending order.
c. Iterates through the sorted, thresholded class scores.
d. For each class, uses the `sound_event_decoder` to get the
representative base tags for that class name.
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
the specific `score` of that class prediction.
f. Appends these specific predicted tags to the list.
g. If `top_class_only` is True, stops after processing the first
(highest-scoring) class that passed the threshold.
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
associating the `SoundEvent`, the overall `detection_score`, and the
compiled list of `PredictedTag` objects.
Parameters
----------
raw_prediction : RawPrediction
The raw prediction object containing score, bounds, class scores,
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
coordinate.
recording : data.Recording
The recording the sound event belongs to.
sound_event_decoder : SoundEventDecoder
Configured function mapping class names (str) to lists of base
`data.Tag` objects.
generic_class_tags : List[data.Tag]
List of base tags representing the generic category.
classification_threshold : float, optional
The minimum score a class prediction must have to be considered
significant enough to have its tags decoded and added. If None, no
thresholding is applied based on class score (all predicted classes,
or the top one if `top_class_only` is True, will be processed).
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
top_class_only : bool, default=False
If True, only includes tags for the single highest-scoring class that
exceeds the threshold. If False (default), includes tags for all classes
exceeding the threshold.
Returns
-------
data.SoundEventPrediction
The fully formed sound event prediction object.
Raises
------
ValueError
If `raw_prediction.features` has unexpected structure or if
`data.term_from_key` (if used internally) fails.
If `sound_event_decoder` fails for a class name and errors are raised.
"""
sound_event = data.SoundEvent(
recording=recording,
geometry=raw_prediction.geometry,
features=get_prediction_features(raw_prediction.features),
)
tags = [
*get_generic_tags(
raw_prediction.detection_score,
generic_class_tags=generic_class_tags,
),
*get_class_tags(
raw_prediction.class_scores,
sound_event_decoder,
top_class_only=top_class_only,
threshold=classification_threshold,
),
]
return data.SoundEventPrediction(
sound_event=sound_event,
score=raw_prediction.detection_score,
tags=tags,
)
def get_generic_tags(
detection_score: float,
generic_class_tags: List[data.Tag],
) -> List[data.PredictedTag]:
"""Create PredictedTag objects for the generic category.
Takes the base list of generic tags and assigns the overall detection
score to each one, wrapping them in `PredictedTag` objects.
Parameters
----------
detection_score : float
The overall confidence score of the detection event.
generic_class_tags : List[data.Tag]
The list of base `soundevent.data.Tag` objects that define the
generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']).
Returns
-------
List[data.PredictedTag]
A list of `PredictedTag` objects for the generic category, each
assigned the `detection_score`.
"""
return [
data.PredictedTag(tag=tag, score=detection_score)
for tag in generic_class_tags
]
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features.
Parameters
----------
features : xr.DataArray
A 1D xarray DataArray containing feature values, indexed by a coordinate
named 'feature' which holds the feature names (e.g., output of selecting
features for one detection from `extract_detection_xr_dataset`).
Returns
-------
List[data.Feature]
A list of `soundevent.data.Feature` objects.
Notes
-----
- This function creates basic `Term` objects using the feature coordinate
names with a "batdetect2:" prefix.
"""
return [
data.Feature(
term=data.Term(
name=f"batdetect2:{feat_name}",
label=feat_name,
definition="Automatically extracted features by BatDetect2",
),
value=value,
)
for feat_name, value in iterate_over_array(features)
]
def get_class_tags(
class_scores: xr.DataArray,
sound_event_decoder: SoundEventDecoder,
top_class_only: bool = False,
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]:
"""Generate specific PredictedTags based on class scores and decoder.
Filters class scores by the threshold, sorts remaining scores descending,
decodes the class name(s) into base tags using the `sound_event_decoder`,
and creates `PredictedTag` objects associating the class score. Stops after
the first (top) class if `top_class_only` is True.
Parameters
----------
class_scores : xr.DataArray
A 1D xarray DataArray containing class probabilities/scores, indexed
by a 'category' coordinate holding the class names.
sound_event_decoder : SoundEventDecoder
Function to map a class name string to a list of base `data.Tag`
objects.
top_class_only : bool, default=False
If True, only generate tags for the single highest-scoring class above
the threshold.
threshold : float, optional
Minimum score for a class to be considered. If None, all classes are
processed (or top-1 if `top_class_only` is True). Defaults to
`DEFAULT_CLASSIFICATION_THRESHOLD`.
Returns
-------
List[data.PredictedTag]
A list of `PredictedTag` objects for the class(es) that passed the
threshold, ordered by score if `top_class_only` is False.
"""
tags = []
if threshold is not None:
class_scores = class_scores.where(class_scores > threshold, drop=True)
for class_name, score in _iterate_sorted(class_scores):
class_tags = sound_event_decoder(class_name)
for tag in class_tags:
tags.append(
data.PredictedTag(
tag=tag,
score=score,
)
)
if top_class_only:
break
return tags
def _iterate_sorted(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name].values
indices = np.argsort(-array.values)
for index in indices:
yield str(coords[index]), float(array.values[index])