Remove xr from postprocess

This commit is contained in:
mbsantiago 2025-08-25 22:46:21 +01:00
parent cc9e47b022
commit 281c4dcb8a
21 changed files with 231 additions and 1692 deletions

View File

@ -140,9 +140,8 @@ def build_model(config: Optional[ModelConfig] = None):
preprocessor = build_preprocessor(config=config.preprocess) preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
targets=targets, targets=targets,
preprocessor=preprocessor,
config=config.postprocess, config=config.postprocess,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
) )
detector = build_detector( detector = build_detector(
num_classes=len(targets.class_names), num_classes=len(targets.class_names),

View File

@ -6,7 +6,7 @@ from matplotlib.axes import Axes
from soundevent import data from soundevent import data
from batdetect2.plotting.common import plot_spectrogram from batdetect2.plotting.common import plot_spectrogram
from batdetect2.preprocess import build_audio_loader, get_default_preprocessor from batdetect2.preprocess import build_audio_loader, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol from batdetect2.typing import AudioLoader, PreprocessorProtocol
__all__ = [ __all__ = [
@ -27,7 +27,7 @@ def plot_clip(
_, ax = plt.subplots(figsize=figsize) _, ax = plt.subplots(figsize=figsize)
if preprocessor is None: if preprocessor is None:
preprocessor = get_default_preprocessor() preprocessor = build_preprocessor()
if audio_loader is None: if audio_loader is None:
audio_loader = build_audio_loader() audio_loader = build_audio_loader()

View File

@ -8,10 +8,7 @@ from soundevent.plot.tags import TagColorMapper
from batdetect2.plotting.clip_predictions import plot_prediction from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import ( from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
PreprocessorProtocol,
get_default_preprocessor,
)
from batdetect2.typing.evaluate import MatchEvaluation from batdetect2.typing.evaluate import MatchEvaluation
__all__ = [ __all__ = [
@ -50,7 +47,7 @@ def plot_matches(
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes: ) -> Axes:
if preprocessor is None: if preprocessor is None:
preprocessor = get_default_preprocessor() preprocessor = build_preprocessor()
ax = plot_clip( ax = plot_clip(
clip, clip,

View File

@ -1,36 +1,7 @@
"""Main entry point for the BatDetect2 Postprocessing pipeline. """Main entry point for the BatDetect2 Postprocessing pipeline."""
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
BatDetect2 neural network model and transforms them into meaningful, structured
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
containing detected sound events with associated class tags and geometry.
The pipeline involves several configurable steps, implemented in submodules:
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
model output arrays.
3. Detection Extraction (`.detection`): Identifies candidate detection points
(location and score) based on thresholds and score ranking (top-k).
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
class probabilities, features) at the detected locations.
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
class predictions into interpretable `soundevent` objects, including
recovering geometry (ROIs) and decoding class names back to standard tags.
This module provides the primary interface:
- `PostprocessConfig`: A configuration object for postprocessing parameters
(thresholds, NMS kernel size, etc.).
- `load_postprocess_config`: Function to load the configuration from a file.
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
holds the configured pipeline logic.
- `build_postprocessor`: A factory function to create a `Postprocessor`
instance, linking it to the necessary target definitions (`TargetProtocol`).
It also re-exports key components from submodules for convenience.
"""
from typing import List, Optional from typing import List, Optional
import xarray as xr
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -38,37 +9,24 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.postprocess.decoding import ( from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD, DEFAULT_CLASSIFICATION_THRESHOLD,
convert_detections_to_raw_predictions,
convert_raw_prediction_to_sound_event_prediction, convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction, convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction,
)
from batdetect2.postprocess.detection import (
DEFAULT_DETECTION_THRESHOLD,
TOP_K_PER_SEC,
extract_detections_from_array,
get_max_detections,
)
from batdetect2.postprocess.extraction import (
extract_detection_xr_dataset,
) )
from batdetect2.postprocess.extraction import extract_prediction_tensor
from batdetect2.postprocess.nms import ( from batdetect2.postprocess.nms import (
NMS_KERNEL_SIZE, NMS_KERNEL_SIZE,
non_max_suppression, non_max_suppression,
) )
from batdetect2.postprocess.remapping import ( from batdetect2.postprocess.remapping import map_detection_to_clip
classification_to_xarray,
detection_to_xarray,
features_to_xarray,
sizes_to_xarray,
)
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.models import ModelOutput from batdetect2.typing import ModelOutput, PreprocessorProtocol, TargetProtocol
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
BatDetect2Prediction, BatDetect2Prediction,
Detections,
PostprocessorProtocol, PostprocessorProtocol,
RawPrediction, RawPrediction,
) )
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD", "DEFAULT_CLASSIFICATION_THRESHOLD",
@ -81,19 +39,17 @@ __all__ = [
"Postprocessor", "Postprocessor",
"TOP_K_PER_SEC", "TOP_K_PER_SEC",
"build_postprocessor", "build_postprocessor",
"classification_to_xarray",
"convert_raw_predictions_to_clip_prediction", "convert_raw_predictions_to_clip_prediction",
"convert_xr_dataset_to_raw_prediction", "convert_detections_to_raw_predictions",
"detection_to_xarray",
"extract_detection_xr_dataset",
"extract_detections_from_array",
"features_to_xarray",
"get_max_detections",
"load_postprocess_config", "load_postprocess_config",
"non_max_suppression", "non_max_suppression",
"sizes_to_xarray",
] ]
DEFAULT_DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 200
class PostprocessConfig(BaseConfig): class PostprocessConfig(BaseConfig):
"""Configuration settings for the postprocessing pipeline. """Configuration settings for the postprocessing pipeline.
@ -173,40 +129,10 @@ def load_postprocess_config(
def build_postprocessor( def build_postprocessor(
targets: TargetProtocol, targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None, config: Optional[PostprocessConfig] = None,
max_freq: float = MAX_FREQ,
min_freq: float = MIN_FREQ,
) -> PostprocessorProtocol: ) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor. """Factory function to build the standard postprocessor."""
Creates and initializes the `Postprocessor` instance, providing it with the
necessary `targets` object and the `PostprocessConfig`.
Parameters
----------
targets : TargetProtocol
An initialized object conforming to the `TargetProtocol`, providing
methods like `.decode()` and `.recover_roi()`, and attributes like
`.class_names` and `.generic_class_tags`. This links postprocessing
to the defined target semantics and geometry mappings.
config : PostprocessConfig, optional
Configuration object specifying postprocessing parameters (thresholds,
NMS kernel size, etc.). If None, default settings defined in
`PostprocessConfig` will be used.
min_freq : int, default=MIN_FREQ
The minimum frequency (Hz) corresponding to the frequency axis of the
model outputs. Required for coordinate remapping. Consider setting via
`PostprocessConfig` instead for better encapsulation.
max_freq : int, default=MAX_FREQ
The maximum frequency (Hz) corresponding to the frequency axis of the
model outputs. Required for coordinate remapping. Consider setting via
`PostprocessConfig`.
Returns
-------
PostprocessorProtocol
An initialized `Postprocessor` instance ready to process model outputs.
"""
config = config or PostprocessConfig() config = config or PostprocessConfig()
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}", "Building postprocessor with config: \n{}",
@ -214,303 +140,62 @@ def build_postprocessor(
) )
return Postprocessor( return Postprocessor(
targets=targets, targets=targets,
preprocessor=preprocessor,
config=config, config=config,
min_freq=min_freq,
max_freq=max_freq,
) )
class Postprocessor(PostprocessorProtocol): class Postprocessor(PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline. """Standard implementation of the postprocessing pipeline."""
This class orchestrates the steps required to convert raw model outputs
into interpretable `soundevent` predictions. It uses configured parameters
and leverages functions from the `batdetect2.postprocess` submodules for
each stage (NMS, remapping, detection, extraction, decoding).
It requires a `TargetProtocol` object during initialization to access
necessary decoding information (class name to tag mapping,
ROI recovery logic) ensuring consistency with the target definitions used
during training or specified for inference.
Instances are typically created using the `build_postprocessor` factory
function.
Attributes
----------
targets : TargetProtocol
The configured target definition object providing decoding and ROI
recovery.
config : PostprocessConfig
Configuration object holding parameters for NMS, thresholds, etc.
min_freq : float
Minimum frequency (Hz) assumed for the model output's frequency axis.
max_freq : float
Maximum frequency (Hz) assumed for the model output's frequency axis.
"""
targets: TargetProtocol targets: TargetProtocol
preprocessor: PreprocessorProtocol
def __init__( def __init__(
self, self,
targets: TargetProtocol, targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: PostprocessConfig, config: PostprocessConfig,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
): ):
"""Initialize the Postprocessor. """Initialize the Postprocessor."""
Parameters
----------
targets : TargetProtocol
Initialized target definition object.
config : PostprocessConfig
Configuration for postprocessing parameters.
min_freq : int, default=MIN_FREQ
Minimum frequency (Hz) for coordinate remapping.
max_freq : int, default=MAX_FREQ
Maximum frequency (Hz) for coordinate remapping.
"""
self.targets = targets self.targets = targets
self.preprocessor = preprocessor
self.config = config self.config = config
self.min_freq = min_freq
self.max_freq = max_freq
def get_feature_arrays( def get_detections(
self, self,
output: ModelOutput, output: ModelOutput,
clips: List[data.Clip], clips: Optional[List[data.Clip]] = None,
) -> List[xr.DataArray]: ) -> List[Detections]:
"""Extract and remap raw feature tensors for a batch. width = output.detection_probs.shape[-1]
duration = width / self.preprocessor.output_samplerate
max_detections = int(self.config.top_k_per_sec * duration)
Parameters detections = extract_prediction_tensor(
---------- output,
output : ModelOutput max_detections=max_detections,
Raw model output containing `output.features` tensor for the batch. threshold=self.config.detection_threshold,
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware feature DataArrays, one per clip.
Raises
------
ValueError
If batch sizes of `output.features` and `clips` do not match.
"""
if len(clips) != len(output.features):
raise ValueError(
"Number of clips and batch size of feature array"
"do not match. "
f"(clips: {len(clips)}, features: {len(output.features)})"
)
return [
features_to_xarray(
feats,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for feats, clip in zip(output.features, clips)
]
def get_detection_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Apply NMS and remap detection heatmaps for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.detection_probs` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of NMS-applied, coordinate-aware detection heatmaps, one per
clip.
Raises
------
ValueError
If batch sizes of `output.detection_probs` and `clips` do not match.
"""
detections = output.detection_probs
if len(clips) != len(output.detection_probs):
raise ValueError(
"Number of clips and batch size of detection array "
"do not match. "
f"(clips: {len(clips)}, detection: {len(detections)})"
)
detections = non_max_suppression(
detections,
kernel_size=self.config.nms_kernel_size,
) )
return [ if clips is None:
detection_to_xarray( return detections
dets,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for dets, clip in zip(detections, clips)
]
def get_classification_arrays(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.DataArray]:
"""Extract and remap raw classification tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.class_probs` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware class probability maps, one per clip.
Raises
------
ValueError
If batch sizes of `output.class_probs` and `clips` do not match, or
if number of classes mismatches `self.targets.class_names`.
"""
classifications = output.class_probs
if len(clips) != len(classifications):
raise ValueError(
"Number of clips and batch size of classification array "
"do not match. "
f"(clips: {len(clips)}, classification: {len(classifications)})"
)
return [ return [
classification_to_xarray( map_detection_to_clip(
class_probs, detection,
start_time=clip.start_time, start_time=clip.start_time,
end_time=clip.end_time, end_time=clip.end_time,
class_names=self.targets.class_names, min_freq=self.preprocessor.min_freq,
min_freq=self.min_freq, max_freq=self.preprocessor.max_freq,
max_freq=self.max_freq,
) )
for class_probs, clip in zip(classifications, clips) for detection, clip in zip(detections, clips)
] ]
def get_sizes_arrays(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.DataArray]:
"""Extract and remap raw size prediction tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.size_preds` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware size prediction maps, one per clip.
Raises
------
ValueError
If batch sizes of `output.size_preds` and `clips` do not match.
"""
sizes = output.size_preds
if len(clips) != len(sizes):
raise ValueError(
"Number of clips and batch size of sizes array do not match. "
f"(clips: {len(clips)}, sizes: {len(sizes)})"
)
return [
sizes_to_xarray(
size_preds,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for size_preds, clip in zip(sizes, clips)
]
def get_detection_datasets(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.Dataset]:
"""Perform NMS, remapping, detection, and data extraction for a batch.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[xr.Dataset]
List of xarray Datasets (one per clip). Each Dataset contains
aligned scores, dimensions, class probabilities, and features for
detections found in that clip.
"""
detection_arrays = self.get_detection_arrays(output, clips)
classification_arrays = self.get_classification_arrays(output, clips)
size_arrays = self.get_sizes_arrays(output, clips)
features_arrays = self.get_feature_arrays(output, clips)
datasets = []
for det_array, class_array, sizes_array, feats_array in zip(
detection_arrays,
classification_arrays,
size_arrays,
features_arrays,
):
max_detections = get_max_detections(
det_array,
top_k_per_sec=self.config.top_k_per_sec,
)
positions = extract_detections_from_array(
det_array,
max_detections=max_detections,
threshold=self.config.detection_threshold,
)
datasets.append(
extract_detection_xr_dataset(
positions,
sizes_array,
class_array,
feats_array,
)
)
return datasets
def get_raw_predictions( def get_raw_predictions(
self, output: ModelOutput, clips: List[data.Clip] self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]: ) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch. """Extract intermediate RawPrediction objects for a batch.
@ -531,13 +216,13 @@ class Postprocessor(PostprocessorProtocol):
List of lists (one inner list per input clip). Each inner list List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip. contains `RawPrediction` objects for detections in that clip.
""" """
detection_datasets = self.get_detection_datasets(output, clips) detections = self.get_detections(output, clips)
return [ return [
convert_xr_dataset_to_raw_prediction( convert_detections_to_raw_predictions(
dataset, dataset,
self.targets.decode_roi, targets=self.targets,
) )
for dataset in detection_datasets for dataset in detections
] ]
def get_sound_event_predictions( def get_sound_event_predictions(

View File

@ -1,42 +1,18 @@
"""Decodes extracted detection data into standard soundevent predictions. """Decodes extracted detection data into standard soundevent predictions."""
This module handles the final stages of the BatDetect2 postprocessing pipeline.
It takes the structured detection data extracted by the `extraction` module
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
class probabilities, and features for each detection point) and converts it
into standardized prediction objects based on the `soundevent` data model.
The process involves:
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
objects, using a configured geometry builder to recover bounding boxes from
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
2. Converting each `RawPrediction` into a
`soundevent.data.SoundEventPrediction`, which involves:
- Creating the `soundevent.data.SoundEvent` with geometry and features.
- Decoding the predicted class probabilities into representative tags using
a configured class decoder (`SoundEventDecoder`).
- Applying a classification threshold.
- Optionally selecting only the single highest-scoring class (top-1) or
including tags for all classes above the threshold (multi-label).
- Adding generic class tags as a baseline.
- Associating scores with the final prediction and tags.
(`convert_raw_prediction_to_sound_event_prediction`)
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
a `soundevent.data.ClipPrediction`
(`convert_raw_predictions_to_clip_prediction`).
"""
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.typing.postprocess import GeometryDecoder, RawPrediction from batdetect2.typing.postprocess import (
Detections,
RawPrediction,
)
from batdetect2.typing.targets import TargetProtocol from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"convert_xr_dataset_to_raw_prediction", "convert_detections_to_raw_predictions",
"convert_raw_predictions_to_clip_prediction", "convert_raw_predictions_to_clip_prediction",
"convert_raw_prediction_to_sound_event_prediction", "convert_raw_prediction_to_sound_event_prediction",
"DEFAULT_CLASSIFICATION_THRESHOLD", "DEFAULT_CLASSIFICATION_THRESHOLD",
@ -51,65 +27,29 @@ decoding.
""" """
def convert_xr_dataset_to_raw_prediction( def convert_detections_to_raw_predictions(
detection_dataset: xr.Dataset, detections: Detections,
geometry_decoder: GeometryDecoder, targets: TargetProtocol,
) -> List[RawPrediction]: ) -> List[RawPrediction]:
"""Convert an xarray.Dataset of detections to RawPrediction objects. predictions = []
Takes the output of the extraction step (`extract_detection_xr_dataset`)
and transforms each detection entry into an intermediate `RawPrediction`
object. This involves recovering the geometry (e.g., bounding box) from
the predicted position and scaled size dimensions using the provided
`geometry_builder` function.
Parameters
----------
detection_dataset : xr.Dataset
An xarray Dataset containing aligned detection information, typically
output by `extract_detection_xr_dataset`. Expected variables include
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
Must have a 'detection' dimension.
geometry_decoder : GeometryDecoder
A function that takes a position tuple `(time, freq)` and a NumPy array
of dimensions, and returns the corresponding reconstructed
`soundevent.data.Geometry`.
Returns
-------
List[RawPrediction]
A list of `RawPrediction` objects, each containing the detection score,
recovered bounding box coordinates (start/end time, low/high freq),
the vector of class scores, and the feature vector for one detection.
Raises
------
AttributeError, KeyError, ValueError
If `detection_dataset` is missing expected variables ('scores',
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
associated with 'scores'), or if `geometry_builder` fails.
"""
detections = []
categories = detection_dataset.category.values
for score, class_scores, time, freq, dims, feats in zip( for score, class_scores, time, freq, dims, feats in zip(
detection_dataset["scores"].values, detections.scores,
detection_dataset["classes"].values, detections.class_scores,
detection_dataset["time"].values, detections.times,
detection_dataset["frequency"].values, detections.frequencies,
detection_dataset["dimensions"].values, detections.sizes,
detection_dataset["features"].values, detections.features,
): ):
highest_scoring_class = categories[class_scores.argmax()] highest_scoring_class = targets.class_names[class_scores.argmax()]
geom = geometry_decoder( geom = targets.decode_roi(
(time, freq), (time, freq),
dims, dims,
class_name=highest_scoring_class, class_name=highest_scoring_class,
) )
detections.append( predictions.append(
RawPrediction( RawPrediction(
detection_score=score, detection_score=score,
geometry=geom, geometry=geom,
@ -118,7 +58,7 @@ def convert_xr_dataset_to_raw_prediction(
) )
) )
return detections return predictions
def convert_raw_predictions_to_clip_prediction( def convert_raw_predictions_to_clip_prediction(
@ -128,35 +68,7 @@ def convert_raw_predictions_to_clip_prediction(
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False, top_class_only: bool = False,
) -> data.ClipPrediction: ) -> data.ClipPrediction:
"""Convert a list of RawPredictions into a soundevent ClipPrediction. """Convert a list of RawPredictions into a soundevent ClipPrediction."""
Iterates through `raw_predictions` (assumed to belong to a single clip),
converts each one into a `soundevent.data.SoundEventPrediction` using
`convert_raw_prediction_to_sound_event_prediction`, and packages them
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
Parameters
----------
raw_predictions : List[RawPrediction]
List of raw prediction objects for a single clip.
clip : data.Clip
The original `soundevent.data.Clip` object these predictions belong to.
sound_event_decoder : SoundEventDecoder
Function to decode class names into representative tags.
generic_class_tags : List[data.Tag]
List of tags representing the generic class category.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Threshold applied to class scores during decoding.
top_class_only : bool, default=False
If True, only decode tags for the single highest-scoring class above
the threshold. If False, decode tags for all classes above threshold.
Returns
-------
data.ClipPrediction
A `ClipPrediction` object containing a list of `SoundEventPrediction`
objects corresponding to the input `raw_predictions`.
"""
return data.ClipPrediction( return data.ClipPrediction(
clip=clip, clip=clip,
sound_events=[ sound_events=[
@ -181,68 +93,7 @@ def convert_raw_prediction_to_sound_event_prediction(
] = DEFAULT_CLASSIFICATION_THRESHOLD, ] = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False, top_class_only: bool = False,
): ):
"""Convert a single RawPrediction into a soundevent SoundEventPrediction. """Convert a single RawPrediction into a soundevent SoundEventPrediction."""
This function performs the core decoding steps for a single detected event:
1. Creates a `soundevent.data.SoundEvent` containing the geometry
(BoundingBox derived from `raw_prediction` bounds) and any associated
feature vectors.
2. Initializes a list of predicted tags using the provided
`generic_class_tags`, assigning the overall `detection_score` from the
`raw_prediction` to these generic tags.
3. Processes the `class_scores` from the `raw_prediction`:
a. Optionally filters out scores below `classification_threshold`
(if it's not None).
b. Sorts the remaining scores in descending order.
c. Iterates through the sorted, thresholded class scores.
d. For each class, uses the `sound_event_decoder` to get the
representative base tags for that class name.
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
the specific `score` of that class prediction.
f. Appends these specific predicted tags to the list.
g. If `top_class_only` is True, stops after processing the first
(highest-scoring) class that passed the threshold.
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
associating the `SoundEvent`, the overall `detection_score`, and the
compiled list of `PredictedTag` objects.
Parameters
----------
raw_prediction : RawPrediction
The raw prediction object containing score, bounds, class scores,
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
coordinate.
recording : data.Recording
The recording the sound event belongs to.
sound_event_decoder : SoundEventDecoder
Configured function mapping class names (str) to lists of base
`data.Tag` objects.
generic_class_tags : List[data.Tag]
List of base tags representing the generic category.
classification_threshold : float, optional
The minimum score a class prediction must have to be considered
significant enough to have its tags decoded and added. If None, no
thresholding is applied based on class score (all predicted classes,
or the top one if `top_class_only` is True, will be processed).
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
top_class_only : bool, default=False
If True, only includes tags for the single highest-scoring class that
exceeds the threshold. If False (default), includes tags for all classes
exceeding the threshold.
Returns
-------
data.SoundEventPrediction
The fully formed sound event prediction object.
Raises
------
ValueError
If `raw_prediction.features` has unexpected structure or if
`data.term_from_key` (if used internally) fails.
If `sound_event_decoder` fails for a class name and errors are raised.
"""
sound_event = data.SoundEvent( sound_event = data.SoundEvent(
recording=recording, recording=recording,
geometry=raw_prediction.geometry, geometry=raw_prediction.geometry,
@ -273,25 +124,7 @@ def get_generic_tags(
detection_score: float, detection_score: float,
generic_class_tags: List[data.Tag], generic_class_tags: List[data.Tag],
) -> List[data.PredictedTag]: ) -> List[data.PredictedTag]:
"""Create PredictedTag objects for the generic category. """Create PredictedTag objects for the generic category."""
Takes the base list of generic tags and assigns the overall detection
score to each one, wrapping them in `PredictedTag` objects.
Parameters
----------
detection_score : float
The overall confidence score of the detection event.
generic_class_tags : List[data.Tag]
The list of base `soundevent.data.Tag` objects that define the
generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']).
Returns
-------
List[data.PredictedTag]
A list of `PredictedTag` objects for the generic category, each
assigned the `detection_score`.
"""
return [ return [
data.PredictedTag(tag=tag, score=detection_score) data.PredictedTag(tag=tag, score=detection_score)
for tag in generic_class_tags for tag in generic_class_tags
@ -299,25 +132,7 @@ def get_generic_tags(
def get_prediction_features(features: np.ndarray) -> List[data.Feature]: def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features. """Convert an extracted feature vector DataArray into soundevent Features."""
Parameters
----------
features : xr.DataArray
A 1D xarray DataArray containing feature values, indexed by a coordinate
named 'feature' which holds the feature names (e.g., output of selecting
features for one detection from `extract_detection_xr_dataset`).
Returns
-------
List[data.Feature]
A list of `soundevent.data.Feature` objects.
Notes
-----
- This function creates basic `Term` objects using the feature coordinate
names with a "batdetect2:" prefix.
"""
return [ return [
data.Feature( data.Feature(
term=data.Term( term=data.Term(

View File

@ -1,162 +0,0 @@
"""Extracts candidate detection points from a model output heatmap.
This module implements Step 3 within the BatDetect2 postprocessing
pipeline. Its primary function is to identify potential sound event locations
by finding peaks (local maxima or high-scoring points) in the detection heatmap
produced by the neural network (usually after Non-Maximum Suppression and
coordinate remapping have been applied).
It provides functionality to:
- Identify the locations (time, frequency) of the highest-scoring points.
- Filter these points based on a minimum confidence score threshold.
- Limit the maximum number of detection points returned (top-k).
The main output is an `xarray.DataArray` containing the scores and
corresponding time/frequency coordinates for the extracted detection points.
This output serves as the input for subsequent postprocessing steps, such as
extracting predicted class probabilities and bounding box sizes at these
specific locations.
"""
from typing import Optional
import numpy as np
import xarray as xr
from soundevent.arrays import Dimensions, get_dim_width
__all__ = [
"extract_detections_from_array",
"get_max_detections",
"DEFAULT_DETECTION_THRESHOLD",
"TOP_K_PER_SEC",
]
DEFAULT_DETECTION_THRESHOLD = 0.01
"""Default confidence score threshold used for filtering detections."""
TOP_K_PER_SEC = 200
"""Default desired maximum number of detections per second of audio."""
def extract_detections_from_array(
detection_array: xr.DataArray,
max_detections: Optional[int] = None,
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
) -> xr.DataArray:
"""Extract detection locations (time, freq) and scores from a heatmap.
Identifies the pixels with the highest scores in the input detection
heatmap, filters them based on an optional score `threshold`, limits the
number to an optional `max_detections`, and returns their scores along with
their corresponding time and frequency coordinates.
Parameters
----------
detection_array : xr.DataArray
A 2D xarray DataArray representing the detection heatmap. Must have
dimensions and coordinates named 'time' and 'frequency'. Higher values
are assumed to indicate higher detection confidence.
max_detections : int, optional
The absolute maximum number of detections to return. If specified, only
the top `max_detections` highest-scoring detections (passing the
threshold) are returned. If None (default), all detections passing
the threshold are returned, sorted by score.
threshold : float, optional
The minimum confidence score required for a detection peak to be
kept. Detections with scores below this value are discarded.
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
thresholding is applied.
Returns
-------
xr.DataArray
A 1D xarray DataArray named 'score' with a 'detection' dimension.
- The data values are the scores of the extracted detections, sorted
in descending order.
- It has coordinates 'time' and 'frequency' (also indexed by the
'detection' dimension) indicating the location of each detection
peak in the original coordinate system.
- Returns an empty DataArray if no detections pass the criteria.
Raises
------
ValueError
If `max_detections` is not None and not a positive integer, or if
`detection_array` lacks required dimensions/coordinates.
"""
if max_detections is not None:
if max_detections <= 0:
raise ValueError("Max detections must be positive")
values = detection_array.values.flatten()
if max_detections is not None:
top_indices = np.argpartition(-values, max_detections)[:max_detections]
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
else:
top_sorted_indices = np.argsort(-values)
top_values = values[top_sorted_indices]
if threshold is not None:
mask = top_values > threshold
top_values = top_values[mask]
top_sorted_indices = top_sorted_indices[mask]
freq_indices, time_indices = np.unravel_index(
top_sorted_indices,
detection_array.shape,
)
times = detection_array.coords[Dimensions.time.value].values[time_indices]
freqs = detection_array.coords[Dimensions.frequency.value].values[
freq_indices
]
return xr.DataArray(
data=top_values,
coords={
Dimensions.frequency.value: ("detection", freqs),
Dimensions.time.value: ("detection", times),
},
dims="detection",
name="score",
)
def get_max_detections(
detection_array: xr.DataArray,
top_k_per_sec: int = TOP_K_PER_SEC,
) -> int:
"""Calculate max detections allowed based on duration and rate.
Determines the total maximum number of detections to extract from a
heatmap based on its time duration and a desired rate of detections
per second.
Parameters
----------
detection_array : xr.DataArray
The detection heatmap, requiring 'time' coordinates from which the
total duration can be calculated using
`soundevent.arrays.get_dim_width`.
top_k_per_sec : int, default=TOP_K_PER_SEC
The desired maximum number of detections to allow per second of audio.
Returns
-------
int
The calculated total maximum number of detections allowed for the
entire duration of the `detection_array`.
Raises
------
ValueError
If the duration cannot be calculated from the `detection_array` (e.g.,
missing or invalid 'time' coordinates/dimension).
"""
if top_k_per_sec < 0:
raise ValueError("top_k_per_sec cannot be negative.")
duration = get_dim_width(detection_array, Dimensions.time.value)
return int(duration * top_k_per_sec)

View File

@ -15,108 +15,73 @@ precise time-frequency location of each detection. The final output aggregates
all extracted information into a structured `xarray.Dataset`. all extracted information into a structured `xarray.Dataset`.
""" """
import xarray as xr from typing import List, Optional, Tuple, Union
from soundevent.arrays import Dimensions
import torch
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
from batdetect2.typing.postprocess import Detections, ModelOutput
__all__ = [ __all__ = [
"extract_values_at_positions", "extract_prediction_tensor",
"extract_detection_xr_dataset",
] ]
def extract_values_at_positions( def extract_prediction_tensor(
array: xr.DataArray, output: ModelOutput,
positions: xr.DataArray, max_detections: int = 200,
) -> xr.DataArray: threshold: Optional[float] = None,
"""Extract values from an array at specified time-frequency positions. nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> List[Detections]:
Uses coordinate-based indexing to retrieve values from a source `array` detection_heatmap = non_max_suppression(
(e.g., class probabilities, size predictions, features) at the time and output.detection_probs,
frequency coordinates defined in the `positions` array. kernel_size=nms_kernel_size,
Parameters
----------
array : xr.DataArray
The source DataArray from which to extract values. Must have 'time'
and 'frequency' dimensions and coordinates matching the space of
`positions`.
positions : xr.DataArray
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
locations from which to extract values.
Returns
-------
xr.DataArray
A DataArray containing the values extracted from `array` at the given
positions.
Raises
------
ValueError, IndexError, KeyError
If dimensions or coordinates are missing or incompatible between
`array` and `positions`, or if selection fails.
"""
return array.sel(
**{
Dimensions.frequency.value: positions.coords[
Dimensions.frequency.value
],
Dimensions.time.value: positions.coords[Dimensions.time.value],
}
).T
def extract_detection_xr_dataset(
positions: xr.DataArray,
sizes: xr.DataArray,
classes: xr.DataArray,
features: xr.DataArray,
) -> xr.Dataset:
"""Combine extracted detection information into a structured xr.Dataset.
Takes the detection positions/scores and the full model output heatmaps
(sizes, classes, optional features), extracts the relevant data at the
detection positions, and packages everything into a single `xarray.Dataset`
where all variables are indexed by a common 'detection' dimension.
Parameters
----------
positions : xr.DataArray
Output from `extract_detections_from_array`, containing detection
scores as data and 'time', 'frequency' coordinates along the
'detection' dimension.
sizes : xr.DataArray
The full size prediction heatmap from the model, with dimensions like
('dimension', 'time', 'frequency').
classes : xr.DataArray
The full class probability heatmap from the model, with dimensions like
('category', 'time', 'frequency').
features : xr.DataArray
The full feature map from the model, with
dimensions like ('feature', 'time', 'frequency').
Returns
-------
xr.Dataset
An xarray Dataset containing aligned information for each detection:
- 'scores': DataArray from `positions` (score data, time/freq coords).
- 'dimensions': DataArray with extracted size values
(dims: 'detection', 'dimension').
- 'classes': DataArray with extracted class probabilities
(dims: 'detection', 'category').
- 'features': DataArray with extracted feature vectors
(dims: 'detection', 'feature'), if `features` was provided. All
DataArrays share the 'detection' dimension and associated
time/frequency coordinates.
"""
sizes = extract_values_at_positions(sizes, positions)
classes = extract_values_at_positions(classes, positions)
features = extract_values_at_positions(features, positions)
return xr.Dataset(
{
"scores": positions,
"dimensions": sizes,
"classes": classes,
"features": features,
}
) )
height = detection_heatmap.shape[-2]
width = detection_heatmap.shape[-1]
freqs, times = torch.meshgrid(
torch.arange(height, dtype=torch.int32),
torch.arange(width, dtype=torch.int32),
indexing="ij",
)
freqs = freqs.flatten()
times = times.flatten()
predictions = []
for idx, item in enumerate(detection_heatmap):
item = item.squeeze().flatten() # Remove channel dim
indices = torch.argsort(item, descending=True)[:max_detections]
detection_scores = item.take(indices)
detection_freqs = freqs.take(indices)
detection_times = times.take(indices)
sizes = output.size_preds[idx, :, detection_freqs, detection_times].T
features = output.features[idx, :, detection_freqs, detection_times].T
class_scores = output.class_probs[
idx, :, detection_freqs, detection_times
].T
if threshold is not None:
mask = detection_scores >= threshold
detection_scores = detection_scores[mask]
sizes = sizes[mask]
detection_times = detection_times[mask]
detection_freqs = detection_freqs[mask]
features = features[mask]
class_scores = class_scores[mask]
predictions.append(
Detections(
scores=detection_scores,
sizes=sizes,
features=features,
class_scores=class_scores,
times=detection_times.to(torch.float32) / width,
frequencies=(detection_freqs.to(torch.float32) / height),
)
)
return predictions

View File

@ -20,6 +20,7 @@ import xarray as xr
from soundevent.arrays import Dimensions from soundevent.arrays import Dimensions
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.postprocess import Detections
__all__ = [ __all__ = [
"features_to_xarray", "features_to_xarray",
@ -29,6 +30,26 @@ __all__ = [
] ]
def map_detection_to_clip(
detections: Detections,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
) -> Detections:
duration = end_time - start_time
bandwidth = max_freq - min_freq
print(f"{bandwidth=} {min_freq=} {detections.frequencies=}")
return Detections(
scores=detections.scores,
sizes=detections.sizes,
features=detections.features,
class_scores=detections.class_scores,
times=(detections.times * duration + start_time),
frequencies=(detections.frequencies * bandwidth + min_freq),
)
def features_to_xarray( def features_to_xarray(
features: torch.Tensor, features: torch.Tensor,
start_time: float, start_time: float,

View File

@ -53,6 +53,7 @@ from batdetect2.preprocess.spectrogram import (
SpectrogramConfig, SpectrogramConfig,
SpectrogramPipeline, SpectrogramPipeline,
STFTConfig, STFTConfig,
_spec_params_from_config,
build_spectrogram_builder, build_spectrogram_builder,
build_spectrogram_pipeline, build_spectrogram_pipeline,
) )
@ -109,7 +110,9 @@ def load_preprocessing_config(
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol): class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol.""" """Standard implementation of the `Preprocessor` protocol."""
samplerate: int input_samplerate: int
output_samplerate: float
max_freq: float max_freq: float
min_freq: float min_freq: float
@ -117,22 +120,33 @@ class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
self, self,
audio_pipeline: torch.nn.Module, audio_pipeline: torch.nn.Module,
spectrogram_pipeline: SpectrogramPipeline, spectrogram_pipeline: SpectrogramPipeline,
samplerate: int, input_samplerate: int,
output_samplerate: float,
max_freq: float, max_freq: float,
min_freq: float, min_freq: float,
) -> None: ) -> None:
super().__init__() super().__init__()
self.audio_pipeline = audio_pipeline self.audio_pipeline = audio_pipeline
self.spectrogram_pipeline = spectrogram_pipeline self.spectrogram_pipeline = spectrogram_pipeline
self.samplerate = samplerate
self.max_freq = max_freq self.max_freq = max_freq
self.min_freq = min_freq self.min_freq = min_freq
self.input_samplerate = input_samplerate
self.output_samplerate = output_samplerate
def forward(self, wav: torch.Tensor) -> torch.Tensor: def forward(self, wav: torch.Tensor) -> torch.Tensor:
wav = self.audio_pipeline(wav) wav = self.audio_pipeline(wav)
return self.spectrogram_pipeline(wav) return self.spectrogram_pipeline(wav)
def compute_output_samplerate(config: PreprocessingConfig) -> float:
samplerate = config.audio.samplerate
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
factor = config.spectrogram.size.resize_factor
return samplerate * factor / hop_size
def build_preprocessor( def build_preprocessor(
config: Optional[PreprocessingConfig] = None, config: Optional[PreprocessingConfig] = None,
) -> PreprocessorProtocol: ) -> PreprocessorProtocol:
@ -148,16 +162,15 @@ def build_preprocessor(
min_freq = config.spectrogram.frequencies.min_freq min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq max_freq = config.spectrogram.frequencies.max_freq
output_samplerate = compute_output_samplerate(config)
return StandardPreprocessor( return StandardPreprocessor(
audio_pipeline=build_audio_pipeline(config.audio), audio_pipeline=build_audio_pipeline(config.audio),
spectrogram_pipeline=build_spectrogram_pipeline( spectrogram_pipeline=build_spectrogram_pipeline(
samplerate, config.spectrogram samplerate, config.spectrogram
), ),
samplerate=samplerate, input_samplerate=samplerate,
output_samplerate=output_samplerate,
min_freq=min_freq, min_freq=min_freq,
max_freq=max_freq, max_freq=max_freq,
) )
def get_default_preprocessor():
return build_preprocessor()

View File

@ -148,7 +148,7 @@ def add_echo(
"""Add a synthetic echo to the audio waveform.""" """Add a synthetic echo to the audio waveform."""
audio = example.audio audio = example.audio
delay_steps = int(preprocessor.samplerate * delay) delay_steps = int(preprocessor.input_samplerate * delay)
audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1]) audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1])
audio = audio + weight * audio_delay audio = audio + weight * audio_delay

View File

@ -1,10 +1,12 @@
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
import torch
from loguru import logger from loguru import logger
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.typing import ClipperProtocol from batdetect2.typing import ClipperProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width from batdetect2.utils.arrays import adjust_width
@ -18,24 +20,26 @@ class ClipingConfig(BaseConfig):
max_empty: float = DEFAULT_MAX_EMPTY_CLIP max_empty: float = DEFAULT_MAX_EMPTY_CLIP
class Clipper(ClipperProtocol): class Clipper(torch.nn.Module):
def __init__( def __init__(
self, self,
samplerate: int, preprocessor: PreprocessorProtocol,
duration: float = 0.5, duration: float = 0.5,
max_empty: float = 0.2, max_empty: float = 0.2,
random: bool = True, random: bool = True,
): ):
self.samplerate = samplerate super().__init__()
self.preprocessor = preprocessor
self.duration = duration self.duration = duration
self.random = random self.random = random
self.max_empty = max_empty self.max_empty = max_empty
def extract_clip( def forward(
self, example: PreprocessedExample self,
example: PreprocessedExample,
) -> Tuple[PreprocessedExample, float, float]: ) -> Tuple[PreprocessedExample, float, float]:
start_time = 0 start_time = 0
duration = example.audio.shape[-1] / self.samplerate duration = example.audio.shape[-1] / self.preprocessor.input_samplerate
if self.random: if self.random:
start_time = np.random.uniform( start_time = np.random.uniform(
@ -48,7 +52,8 @@ class Clipper(ClipperProtocol):
example, example,
start=start_time, start=start_time,
duration=self.duration, duration=self.duration,
samplerate=self.samplerate, input_samplerate=self.preprocessor.input_samplerate,
output_samplerate=self.preprocessor.output_samplerate,
), ),
start_time, start_time,
start_time + self.duration, start_time + self.duration,
@ -56,7 +61,7 @@ class Clipper(ClipperProtocol):
def build_clipper( def build_clipper(
samplerate: int, preprocessor: PreprocessorProtocol,
config: Optional[ClipingConfig] = None, config: Optional[ClipingConfig] = None,
random: Optional[bool] = None, random: Optional[bool] = None,
) -> ClipperProtocol: ) -> ClipperProtocol:
@ -66,7 +71,7 @@ def build_clipper(
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
return Clipper( return Clipper(
samplerate=samplerate, preprocessor=preprocessor,
duration=config.duration, duration=config.duration,
max_empty=config.max_empty, max_empty=config.max_empty,
random=config.random if random else False, random=config.random if random else False,
@ -77,11 +82,12 @@ def select_subclip(
example: PreprocessedExample, example: PreprocessedExample,
start: float, start: float,
duration: float, duration: float,
samplerate: float, input_samplerate: float,
output_samplerate: float,
fill_value: float = 0, fill_value: float = 0,
) -> PreprocessedExample: ) -> PreprocessedExample:
audio_width = int(np.floor(duration * samplerate)) audio_width = int(np.floor(duration * input_samplerate))
audio_start = int(np.floor(start * samplerate)) audio_start = int(np.floor(start * input_samplerate))
audio = adjust_width( audio = adjust_width(
example.audio[audio_start : audio_start + audio_width], example.audio[audio_start : audio_start + audio_width],
@ -89,12 +95,8 @@ def select_subclip(
value=fill_value, value=fill_value,
) )
audio_duration = example.audio.shape[-1] / samplerate spec_start = int(np.floor(start * output_samplerate))
spec_sr = example.spectrogram.shape[-1] / audio_duration spec_width = int(np.floor(duration * output_samplerate))
spec_start = int(np.floor(start * spec_sr))
spec_width = int(np.floor(duration * spec_sr))
return PreprocessedExample( return PreprocessedExample(
audio=audio, audio=audio,
spectrogram=adjust_width( spectrogram=adjust_width(

View File

@ -32,7 +32,7 @@ class LabeledDataset(Dataset):
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, idx) -> TrainExample:
example = self.get_example(idx) example = self.get_example(idx)
example, start_time, end_time = self.clipper.extract_clip(example) example, start_time, end_time = self.clipper(example)
if self.augmentation: if self.augmentation:
example = self.augmentation(example) example = self.augmentation(example)
@ -64,9 +64,7 @@ class LabeledDataset(Dataset):
def get_random_example(self) -> Tuple[PreprocessedExample, float, float]: def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
idx = np.random.randint(0, len(self)) idx = np.random.randint(0, len(self))
dataset = self.get_example(idx) dataset = self.get_example(idx)
dataset, start_time, end_time = self.clipper(dataset)
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
return dataset, start_time, end_time return dataset, start_time, end_time
def get_example(self, idx) -> PreprocessedExample: def get_example(self, idx) -> PreprocessedExample:
@ -107,5 +105,5 @@ class RandomExampleSource:
index = int(np.random.randint(len(self.filenames))) index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index] filename = self.filenames[index]
example = load_preprocessed_example(filename) example = load_preprocessed_example(filename)
example, _, _ = self.clipper.extract_clip(example) example, _, _ = self.clipper(example)
return example return example

View File

@ -229,7 +229,7 @@ def build_train_dataset(
config = config or TrainingConfig() config = config or TrainingConfig()
clipper = build_clipper( clipper = build_clipper(
samplerate=preprocessor.samplerate, preprocessor=preprocessor,
config=config.cliping, config=config.cliping,
random=True, random=True,
) )
@ -265,7 +265,7 @@ def build_val_dataset(
logger.info("Building validation dataset...") logger.info("Building validation dataset...")
config = config or TrainingConfig() config = config or TrainingConfig()
clipper = build_clipper( clipper = build_clipper(
samplerate=preprocessor.samplerate, preprocessor=preprocessor,
config=config.cliping, config=config.cliping,
random=train, random=train,
) )

View File

@ -15,7 +15,7 @@ from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol from typing import List, NamedTuple, Optional, Protocol
import numpy as np import numpy as np
import xarray as xr import torch
from soundevent import data from soundevent import data
from batdetect2.typing.models import ModelOutput from batdetect2.typing.models import ModelOutput
@ -77,6 +77,15 @@ class RawPrediction(NamedTuple):
features: np.ndarray features: np.ndarray
class Detections(NamedTuple):
scores: torch.Tensor
sizes: torch.Tensor
class_scores: torch.Tensor
times: torch.Tensor
frequencies: torch.Tensor
features: torch.Tensor
@dataclass @dataclass
class BatDetect2Prediction: class BatDetect2Prediction:
raw: RawPrediction raw: RawPrediction
@ -84,154 +93,13 @@ class BatDetect2Prediction:
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline. """Protocol defining the interface for the full postprocessing pipeline."""
This protocol outlines the standard methods for an object that takes raw def get_detections(
output from a BatDetect2 model and the corresponding input clip metadata,
and processes it through various stages (e.g., coordinate remapping, NMS,
detection extraction, data extraction, decoding) to produce interpretable
results at different levels of completion.
Implementations manage the configured logic for all postprocessing steps.
"""
def get_feature_arrays(
self, self,
output: ModelOutput, output: ModelOutput,
clips: List[data.Clip], clips: Optional[List[data.Clip]] = None,
) -> List[xr.DataArray]: ) -> List[Detections]: ...
"""Remap feature tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch, expected
to contain the necessary feature tensors.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects, one for each item in the
processed batch. This list provides the timing, recording, and
other metadata context needed to calculate real-world coordinates
(seconds, Hz) for the output arrays. The length of this list must
correspond to the batch size of the `output`.
Returns
-------
List[xr.DataArray]
A list of xarray DataArrays, one for each input clip in the batch,
in the same order. Each DataArray contains the feature vectors
with dimensions like ('feature', 'time', 'frequency') and
corresponding real-world coordinates.
"""
...
def get_detection_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap detection tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing detection heatmaps.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 2D xarray DataArrays (one per input clip, in order),
representing the detection heatmap with 'time' and 'frequency'
coordinates. Values typically indicate detection confidence.
"""
...
def get_classification_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap classification tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing class probability tensors.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 3D xarray DataArrays (one per input clip, in order),
representing class probabilities with 'category', 'time', and
'frequency' dimensions and coordinates.
"""
...
def get_sizes_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap size prediction tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing predicted size tensors (e.g., width and height).
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 3D xarray DataArrays (one per input clip, in order),
representing predicted sizes with 'dimension'
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
coordinates. Values represent estimated detection sizes.
"""
...
def get_detection_datasets(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.Dataset]:
"""Perform remapping, NMS, detection, and data extraction for a batch.
Processes the raw model output for a batch to identify detection peaks
and extract all associated information (score, position, size, class
probs, features) at those peak locations, returning a structured
dataset for each input clip in the batch.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[xr.Dataset]
A list of xarray Datasets (one per input clip, in order). Each
Dataset contains multiple DataArrays ('scores', 'dimensions',
'classes', 'features') sharing a common 'detection' dimension,
providing aligned data for each detected event in that clip.
"""
...
def get_raw_predictions( def get_raw_predictions(
self, self,

View File

@ -148,7 +148,9 @@ class PreprocessorProtocol(Protocol):
min_freq: float min_freq: float
samplerate: int input_samplerate: int
output_samplerate: float
audio_pipeline: AudioPipeline audio_pipeline: AudioPipeline

View File

@ -96,6 +96,6 @@ class LossProtocol(Protocol):
class ClipperProtocol(Protocol): class ClipperProtocol(Protocol):
def extract_clip( def __call__(
self, example: PreprocessedExample self, example: PreprocessedExample
) -> Tuple[PreprocessedExample, float, float]: ... ) -> Tuple[PreprocessedExample, float, float]: ...

View File

@ -10,7 +10,6 @@ from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD, DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction, convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction, convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction,
get_class_tags, get_class_tags,
get_generic_tags, get_generic_tags,
get_prediction_features, get_prediction_features,
@ -278,63 +277,6 @@ def sample_raw_predictions() -> List[RawPrediction]:
return [pred1, pred2, pred3] return [pred1, pred2, pred3]
def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
"""Test basic conversion of a dataset to RawPrediction list."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
sample_detection_dataset,
dummy_targets.decode_roi,
)
assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 2
pred1 = raw_predictions[0]
assert isinstance(pred1, RawPrediction)
assert pred1.detection_score == 0.9
assert pred1.geometry.coordinates == [
20 - 7 / 2,
300 - 16 / 2,
20 + 7 / 2,
300 + 16 / 2,
]
np.testing.assert_allclose(
pred1.class_scores,
sample_detection_dataset["classes"].sel(detection=0),
)
np.testing.assert_allclose(
pred1.features, sample_detection_dataset["features"].sel(detection=0)
)
pred2 = raw_predictions[1]
assert isinstance(pred2, RawPrediction)
assert pred2.detection_score == 0.8
assert pred2.geometry.coordinates == [
10 - 3 / 2,
200 - 12 / 2,
10 + 3 / 2,
200 + 12 / 2,
]
np.testing.assert_allclose(
pred2.class_scores,
sample_detection_dataset["classes"].sel(detection=1),
)
np.testing.assert_allclose(
pred2.features, sample_detection_dataset["features"].sel(detection=1)
)
def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
"""Test conversion of an empty dataset."""
raw_predictions = convert_xr_dataset_to_raw_prediction(
empty_detection_dataset,
dummy_targets.decode_roi,
)
assert isinstance(raw_predictions, list)
assert len(raw_predictions) == 0
def test_convert_raw_to_sound_event_basic( def test_convert_raw_to_sound_event_basic(
sample_raw_predictions, sample_raw_predictions,
sample_recording, sample_recording,

View File

@ -1,214 +0,0 @@
import numpy as np
import pytest
import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.postprocess.detection import extract_detections_from_array
@pytest.fixture
def sample_data_array():
"""Provides a basic 3x3 DataArray.
Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30)
"""
array = xr.DataArray(
np.zeros([3, 3]),
coords={
Dimensions.frequency.value: [100, 200, 300],
Dimensions.time.value: [10, 20, 30],
},
dims=[
Dimensions.frequency.value,
Dimensions.time.value,
],
)
array.loc[dict(time=10, frequency=100)] = 0.005
array.loc[dict(time=10, frequency=200)] = 0.5
array.loc[dict(time=10, frequency=300)] = 0.03
array.loc[dict(time=20, frequency=100)] = 0.8
array.loc[dict(time=20, frequency=200)] = 0.02
array.loc[dict(time=20, frequency=300)] = 0.6
array.loc[dict(time=30, frequency=100)] = 0.04
array.loc[dict(time=30, frequency=200)] = 0.9
array.loc[dict(time=30, frequency=300)] = 0.7
return array
@pytest.fixture
def data_array_with_nans(sample_data_array: xr.DataArray):
"""Provides a 2D DataArray containing NaN values."""
array = sample_data_array.copy()
array.loc[dict(time=10, frequency=300)] = np.nan
array.loc[dict(time=30, frequency=100)] = np.nan
return array
def test_basic_extraction(sample_data_array: xr.DataArray):
threshold = 0.1
max_detections = 3
actual_result = extract_detections_from_array(
sample_data_array,
threshold=threshold,
max_detections=max_detections,
)
expected_values = np.array([0.9, 0.8, 0.7])
expected_times = np.array([30, 20, 30])
expected_freqs = np.array([200, 100, 300])
expected_coords = {
Dimensions.frequency.value: ("detection", expected_freqs),
Dimensions.time.value: ("detection", expected_times),
}
expected_result = xr.DataArray(
expected_values,
coords=expected_coords,
dims="detection",
name="score",
)
xr.testing.assert_equal(actual_result, expected_result)
def test_threshold_only(sample_data_array):
input_array = sample_data_array
threshold = 0.5
actual_result = extract_detections_from_array(
input_array, threshold=threshold
)
expected_values = np.array([0.9, 0.8, 0.7, 0.6])
expected_times = np.array([30, 20, 30, 20])
expected_freqs = np.array([200, 100, 300, 300])
expected_coords = {
Dimensions.time.value: ("detection", expected_times),
Dimensions.frequency.value: ("detection", expected_freqs),
}
expected_result = xr.DataArray(
expected_values,
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)
def test_max_detections_only(sample_data_array):
input_array = sample_data_array
max_detections = 4
actual_result = extract_detections_from_array(
input_array, max_detections=max_detections
)
expected_values = np.array([0.9, 0.8, 0.7, 0.6])
expected_times = np.array([30, 20, 30, 20])
expected_freqs = np.array([200, 100, 300, 300])
expected_coords = {
Dimensions.time.value: ("detection", expected_times),
Dimensions.frequency.value: ("detection", expected_freqs),
}
expected_result = xr.DataArray(
expected_values,
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)
def test_no_optional_args(sample_data_array):
input_array = sample_data_array
actual_result = extract_detections_from_array(input_array)
expected_values = np.array([0.9, 0.8, 0.7, 0.6, 0.5, 0.04, 0.03, 0.02])
expected_times = np.array([30, 20, 30, 20, 10, 30, 10, 20])
expected_freqs = np.array([200, 100, 300, 300, 200, 100, 300, 200])
expected_coords = {
Dimensions.time.value: ("detection", expected_times),
Dimensions.frequency.value: ("detection", expected_freqs),
}
expected_result = xr.DataArray(
expected_values,
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)
def test_no_values_above_threshold(sample_data_array):
input_array = sample_data_array
threshold = 1.0
actual_result = extract_detections_from_array(
input_array, threshold=threshold
)
expected_coords = {
Dimensions.time.value: ("detection", np.array([], dtype=np.int64)),
Dimensions.frequency.value: (
"detection",
np.array([], dtype=np.int64),
),
}
expected_result = xr.DataArray(
np.array([], dtype=np.float64),
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)
assert actual_result.sizes["detection"] == 0
def test_max_detections_zero(sample_data_array):
input_array = sample_data_array
max_detections = 0
with pytest.raises(ValueError):
extract_detections_from_array(
input_array,
max_detections=max_detections,
)
def test_empty_input_array():
empty_array = xr.DataArray(
np.empty((0, 0)),
coords={Dimensions.time.value: [], Dimensions.frequency.value: []},
dims=[Dimensions.time.value, Dimensions.frequency.value],
)
actual_result = extract_detections_from_array(empty_array)
expected_coords = {
Dimensions.time.value: ("detection", np.array([], dtype=np.int64)),
Dimensions.frequency.value: (
"detection",
np.array([], dtype=np.int64),
),
}
expected_result = xr.DataArray(
np.array([], dtype=np.float64),
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)
assert actual_result.sizes["detection"] == 0
def test_nan_handling(data_array_with_nans):
input_array = data_array_with_nans
threshold = 0.1
max_detections = 3
actual_result = extract_detections_from_array(
input_array, threshold=threshold, max_detections=max_detections
)
expected_values = np.array([0.9, 0.8, 0.7])
expected_times = np.array([30, 20, 30])
expected_freqs = np.array([200, 100, 300])
expected_coords = {
Dimensions.time.value: ("detection", expected_times),
Dimensions.frequency.value: ("detection", expected_freqs),
}
expected_result = xr.DataArray(
expected_values,
coords=expected_coords,
dims="detection",
name="detection_value",
)
xr.testing.assert_equal(actual_result, expected_result)

View File

@ -1,397 +1,2 @@
import numpy as np import numpy as np
import pytest import pytest
import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.postprocess.detection import extract_detections_from_array
from batdetect2.postprocess.extraction import (
extract_detection_xr_dataset,
extract_values_at_positions,
)
@pytest.fixture
def sample_data_array():
"""Provides a basic 3x3 DataArray.
Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30)
"""
coords = {
Dimensions.frequency.value: [100, 200, 300],
Dimensions.time.value: [10, 20, 30],
}
array = xr.DataArray(
np.zeros([3, 3]),
coords=coords,
dims=[
Dimensions.frequency.value,
Dimensions.time.value,
],
)
array.loc[dict(time=10, frequency=100)] = 0.005
array.loc[dict(time=10, frequency=200)] = 0.5
array.loc[dict(time=10, frequency=300)] = 0.03
array.loc[dict(time=20, frequency=100)] = 0.8
array.loc[dict(time=20, frequency=200)] = 0.02
array.loc[dict(time=20, frequency=300)] = 0.6
array.loc[dict(time=30, frequency=100)] = 0.04
array.loc[dict(time=30, frequency=200)] = 0.9
array.loc[dict(time=30, frequency=300)] = 0.7
return array
@pytest.fixture
def sample_array_for_extraction():
"""Provides a simple array (1-9) for value extraction tests."""
data = np.arange(1, 10).reshape(3, 3)
coords = {
Dimensions.frequency.value: [100, 200, 300],
Dimensions.time.value: [10, 20, 30],
}
return xr.DataArray(
data,
coords=coords,
dims=[
Dimensions.frequency.value,
Dimensions.time.value,
],
name="test_values",
)
@pytest.fixture
def sample_positions_top3(sample_data_array):
"""Get top 3 detection positions from sample_data_array."""
return extract_detections_from_array(
sample_data_array,
max_detections=3,
threshold=None,
)
@pytest.fixture
def sample_positions_top2(sample_data_array):
"""Get top 2 detection positions from sample_data_array."""
return extract_detections_from_array(
sample_data_array,
max_detections=2,
threshold=None,
)
@pytest.fixture
def empty_positions(sample_data_array):
"""Get an empty positions array (high threshold)."""
return extract_detections_from_array(
sample_data_array,
threshold=1.0,
)
@pytest.fixture
def sample_sizes_array(sample_data_array):
"""Provides a sample sizes array matching sample_data_array coords."""
coords = sample_data_array.coords
data = np.array(
[
[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
],
[
[9, 10, 11],
[12, 13, 14],
[15, 16, 17],
],
],
dtype=np.float32,
)
return xr.DataArray(
data,
coords={
"dimension": ["width", "height"],
Dimensions.frequency.value: coords[Dimensions.frequency.value],
Dimensions.time.value: coords[Dimensions.time.value],
},
dims=["dimension", Dimensions.frequency.value, Dimensions.time.value],
name="sizes",
)
@pytest.fixture
def sample_classes_array(sample_data_array):
"""Provides a sample classes array matching sample_data_array coords."""
coords = sample_data_array.coords
data = np.linspace(0.1, 0.9, 18, dtype=np.float32).reshape(2, 3, 3)
return xr.DataArray(
data,
coords={
"category": ["bat", "noise"],
Dimensions.frequency.value: coords[Dimensions.frequency.value],
Dimensions.time.value: coords[Dimensions.time.value],
},
dims=["category", Dimensions.frequency.value, Dimensions.time.value],
name="class_scores",
)
@pytest.fixture
def sample_features_array(sample_data_array):
"""Provides a sample features array matching sample_data_array coords."""
coords = sample_data_array.coords
data = np.arange(0, 36, dtype=np.float32).reshape(4, 3, 3)
return xr.DataArray(
data,
coords={
"feature": ["f0", "f1", "f2", "f3"],
Dimensions.frequency.value: coords[Dimensions.frequency.value],
Dimensions.time.value: coords[Dimensions.time.value],
},
dims=["feature", Dimensions.frequency.value, Dimensions.time.value],
name="features",
)
def test_extract_values_at_positions_correct(
sample_array_for_extraction,
sample_positions_top3,
):
"""Verify correct values are extracted based on positions coords."""
expected_values = np.array(
[
sample_array_for_extraction.sel(time=30, frequency=200).values,
sample_array_for_extraction.sel(time=20, frequency=100).values,
sample_array_for_extraction.sel(time=30, frequency=300).values,
]
)
expected = xr.DataArray(
expected_values,
coords=sample_positions_top3.coords,
dims="detection",
name="test_values",
)
extracted = extract_values_at_positions(
sample_array_for_extraction, sample_positions_top3
)
xr.testing.assert_allclose(extracted, expected)
def test_extract_values_at_positions_extra_dims(
sample_sizes_array,
sample_positions_top2,
):
"""Test extraction preserves other dimensions in the source array."""
times = np.array([30, 20])
freqs = np.array([200, 100])
expected_values = np.array(
[
sample_sizes_array.sel(time=30, frequency=200).values,
sample_sizes_array.sel(time=20, frequency=100).values,
],
dtype=np.float32,
)
expected = xr.DataArray(
expected_values,
coords={
"dimension": ["width", "height"],
Dimensions.frequency.value: ("detection", freqs),
Dimensions.time.value: ("detection", times),
},
dims=["detection", "dimension"],
name="sizes",
)
extracted = extract_values_at_positions(
sample_sizes_array,
sample_positions_top2,
)
xr.testing.assert_equal(extracted, expected)
def test_extract_values_at_positions_empty(
sample_array_for_extraction, empty_positions
):
"""Test extraction with empty positions returns empty array."""
extracted = extract_values_at_positions(
sample_array_for_extraction, empty_positions
)
assert extracted.sizes["detection"] == 0
assert Dimensions.time.value in extracted.coords
assert Dimensions.frequency.value in extracted.coords
assert extracted.coords[Dimensions.time.value].size == 0
assert extracted.coords[Dimensions.frequency.value].size == 0
assert extracted.name == sample_array_for_extraction.name
def test_extract_values_at_positions_missing_coord_in_array(
sample_array_for_extraction, sample_positions_top2
):
"""Test error if source array misses required coordinates."""
array_no_time = sample_array_for_extraction.copy()
del array_no_time.coords[Dimensions.time.value]
with pytest.raises(IndexError):
extract_values_at_positions(array_no_time, sample_positions_top2)
array_no_freq = sample_array_for_extraction.copy()
del array_no_freq.coords[Dimensions.frequency.value]
with pytest.raises(IndexError):
extract_values_at_positions(array_no_freq, sample_positions_top2)
def test_extract_values_at_positions_missing_coord_in_positions(
sample_array_for_extraction, sample_positions_top2
):
"""Test error if positions array misses required coordinates."""
positions_no_time = sample_positions_top2.copy()
del positions_no_time.coords[Dimensions.time.value]
with pytest.raises(KeyError):
extract_values_at_positions(
sample_array_for_extraction, positions_no_time
)
positions_no_freq = sample_positions_top2.copy()
del positions_no_freq.coords[Dimensions.frequency.value]
with pytest.raises(KeyError):
extract_values_at_positions(
sample_array_for_extraction, positions_no_freq
)
def test_extract_values_at_positions_mismatched_coords(
sample_array_for_extraction, sample_positions_top2
):
"""Test error if positions requests coords not in source array."""
bad_positions = sample_positions_top2.copy()
bad_positions.coords[Dimensions.time.value] = (
"detection",
np.array([40, 10]),
)
with pytest.raises(KeyError):
extract_values_at_positions(sample_array_for_extraction, bad_positions)
def test_extract_detection_xr_dataset_correct(
sample_positions_top2,
sample_sizes_array,
sample_classes_array,
sample_features_array,
):
"""Tests extracting and bundling info for top 2 detections."""
actual_dataset = extract_detection_xr_dataset(
sample_positions_top2,
sample_sizes_array,
sample_classes_array,
sample_features_array,
)
expected_times = np.array([30, 20])
expected_freqs = np.array([200, 100])
detection_coords = {
Dimensions.time.value: ("detection", expected_times),
Dimensions.frequency.value: ("detection", expected_freqs),
}
expected_score = sample_positions_top2
expected_dimensions_data = np.array(
[
sample_sizes_array.sel(time=30, frequency=200).values,
sample_sizes_array.sel(time=20, frequency=100).values,
],
dtype=np.float32,
)
expected_dimensions = xr.DataArray(
expected_dimensions_data,
coords={**detection_coords, "dimension": ["width", "height"]},
dims=["detection", "dimension"],
name="dimensions",
)
expected_classes_data = np.array(
[
sample_classes_array.sel(time=30, frequency=200).values,
sample_classes_array.sel(time=20, frequency=100).values,
],
dtype=np.float32,
)
expected_classes = xr.DataArray(
expected_classes_data,
coords={**detection_coords, "category": ["bat", "noise"]},
dims=["detection", "category"],
name="classes",
)
expected_features_data = np.array(
[
sample_features_array.sel(time=30, frequency=200).values,
sample_features_array.sel(time=20, frequency=100).values,
],
dtype=np.float32,
)
expected_features = xr.DataArray(
expected_features_data,
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
dims=["detection", "feature"],
name="features",
)
expected_dataset = xr.Dataset(
{
"scores": expected_score,
"dimensions": expected_dimensions,
"classes": expected_classes,
"features": expected_features,
}
)
expected_dataset = expected_dataset.assign_coords(detection_coords)
xr.testing.assert_allclose(actual_dataset, expected_dataset)
def test_extract_detection_xr_dataset_empty(
empty_positions,
sample_sizes_array,
sample_classes_array,
sample_features_array,
):
"""Test extraction with empty positions yields an empty dataset."""
actual_dataset = extract_detection_xr_dataset(
empty_positions,
sample_sizes_array,
sample_classes_array,
sample_features_array,
)
assert isinstance(actual_dataset, xr.Dataset)
assert "detection" in actual_dataset.dims
assert actual_dataset.sizes["detection"] == 0
assert "scores" in actual_dataset
assert actual_dataset["scores"].dims == ("detection",)
assert actual_dataset["scores"].size == 0
assert "dimensions" in actual_dataset
assert actual_dataset["dimensions"].dims == ("detection", "dimension")
assert actual_dataset["dimensions"].shape == (0, 2)
assert "classes" in actual_dataset
assert actual_dataset["classes"].dims == ("detection", "category")
assert actual_dataset["classes"].shape == (0, 2)
assert "features" in actual_dataset
assert actual_dataset["features"].dims == ("detection", "feature")
assert actual_dataset["features"].shape == (0, 4)
assert Dimensions.time.value in actual_dataset.coords
assert Dimensions.frequency.value in actual_dataset.coords
assert actual_dataset.coords[Dimensions.time.value].size == 0
assert actual_dataset.coords[Dimensions.frequency.value].size == 0

View File

@ -162,7 +162,8 @@ def test_selected_random_subclip_has_the_correct_width(
subclip = select_subclip( subclip = select_subclip(
original, original,
samplerate=256_000, input_samplerate=256_000,
output_samplerate=1000,
start=0, start=0,
duration=0.512, duration=0.512,
) )

View File

@ -39,9 +39,8 @@ def build_from_config(
) )
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
targets, targets,
preprocessor=preprocessor,
config=postprocessing_config, config=postprocessing_config,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
) )
return targets, preprocessor, labeller, postprocessor return targets, preprocessor, labeller, postprocessor
@ -84,7 +83,10 @@ def test_encoding_decoding_roundtrip_recovers_object(
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example( encoded = generate_train_example(
clip_annotation, sample_audio_loader, preprocessor, labeller clip_annotation,
sample_audio_loader,
preprocessor,
labeller,
) )
predictions = postprocessor.get_predictions( predictions = postprocessor.get_predictions(
ModelOutput( ModelOutput(