mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Remove xr from postprocess
This commit is contained in:
parent
cc9e47b022
commit
281c4dcb8a
@ -140,9 +140,8 @@ def build_model(config: Optional[ModelConfig] = None):
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
postprocessor = build_postprocessor(
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
config=config.postprocess,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
|
||||
@ -6,7 +6,7 @@ from matplotlib.axes import Axes
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.plotting.common import plot_spectrogram
|
||||
from batdetect2.preprocess import build_audio_loader, get_default_preprocessor
|
||||
from batdetect2.preprocess import build_audio_loader, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
@ -27,7 +27,7 @@ def plot_clip(
|
||||
_, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = get_default_preprocessor()
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
@ -8,10 +8,7 @@ from soundevent.plot.tags import TagColorMapper
|
||||
|
||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessorProtocol,
|
||||
get_default_preprocessor,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
|
||||
__all__ = [
|
||||
@ -50,7 +47,7 @@ def plot_matches(
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
) -> Axes:
|
||||
if preprocessor is None:
|
||||
preprocessor = get_default_preprocessor()
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
|
||||
@ -1,36 +1,7 @@
|
||||
"""Main entry point for the BatDetect2 Postprocessing pipeline.
|
||||
|
||||
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
|
||||
BatDetect2 neural network model and transforms them into meaningful, structured
|
||||
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
|
||||
containing detected sound events with associated class tags and geometry.
|
||||
|
||||
The pipeline involves several configurable steps, implemented in submodules:
|
||||
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
|
||||
model output arrays.
|
||||
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||
(location and score) based on thresholds and score ranking (top-k).
|
||||
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||
class probabilities, features) at the detected locations.
|
||||
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
|
||||
class predictions into interpretable `soundevent` objects, including
|
||||
recovering geometry (ROIs) and decoding class names back to standard tags.
|
||||
|
||||
This module provides the primary interface:
|
||||
- `PostprocessConfig`: A configuration object for postprocessing parameters
|
||||
(thresholds, NMS kernel size, etc.).
|
||||
- `load_postprocess_config`: Function to load the configuration from a file.
|
||||
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
|
||||
holds the configured pipeline logic.
|
||||
- `build_postprocessor`: A factory function to create a `Postprocessor`
|
||||
instance, linking it to the necessary target definitions (`TargetProtocol`).
|
||||
It also re-exports key components from submodules for convenience.
|
||||
"""
|
||||
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import xarray as xr
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -38,37 +9,24 @@ from soundevent import data
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_detections_to_raw_predictions,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
convert_xr_dataset_to_raw_prediction,
|
||||
)
|
||||
from batdetect2.postprocess.detection import (
|
||||
DEFAULT_DETECTION_THRESHOLD,
|
||||
TOP_K_PER_SEC,
|
||||
extract_detections_from_array,
|
||||
get_max_detections,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import (
|
||||
extract_detection_xr_dataset,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import extract_prediction_tensor
|
||||
from batdetect2.postprocess.nms import (
|
||||
NMS_KERNEL_SIZE,
|
||||
non_max_suppression,
|
||||
)
|
||||
from batdetect2.postprocess.remapping import (
|
||||
classification_to_xarray,
|
||||
detection_to_xarray,
|
||||
features_to_xarray,
|
||||
sizes_to_xarray,
|
||||
)
|
||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
from batdetect2.typing import ModelOutput, PreprocessorProtocol, TargetProtocol
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
Detections,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
@ -81,19 +39,17 @@ __all__ = [
|
||||
"Postprocessor",
|
||||
"TOP_K_PER_SEC",
|
||||
"build_postprocessor",
|
||||
"classification_to_xarray",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
"detection_to_xarray",
|
||||
"extract_detection_xr_dataset",
|
||||
"extract_detections_from_array",
|
||||
"features_to_xarray",
|
||||
"get_max_detections",
|
||||
"convert_detections_to_raw_predictions",
|
||||
"load_postprocess_config",
|
||||
"non_max_suppression",
|
||||
"sizes_to_xarray",
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
|
||||
|
||||
TOP_K_PER_SEC = 200
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration settings for the postprocessing pipeline.
|
||||
@ -173,40 +129,10 @@ def load_postprocess_config(
|
||||
|
||||
def build_postprocessor(
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
max_freq: float = MAX_FREQ,
|
||||
min_freq: float = MIN_FREQ,
|
||||
) -> PostprocessorProtocol:
|
||||
"""Factory function to build the standard postprocessor.
|
||||
|
||||
Creates and initializes the `Postprocessor` instance, providing it with the
|
||||
necessary `targets` object and the `PostprocessConfig`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
An initialized object conforming to the `TargetProtocol`, providing
|
||||
methods like `.decode()` and `.recover_roi()`, and attributes like
|
||||
`.class_names` and `.generic_class_tags`. This links postprocessing
|
||||
to the defined target semantics and geometry mappings.
|
||||
config : PostprocessConfig, optional
|
||||
Configuration object specifying postprocessing parameters (thresholds,
|
||||
NMS kernel size, etc.). If None, default settings defined in
|
||||
`PostprocessConfig` will be used.
|
||||
min_freq : int, default=MIN_FREQ
|
||||
The minimum frequency (Hz) corresponding to the frequency axis of the
|
||||
model outputs. Required for coordinate remapping. Consider setting via
|
||||
`PostprocessConfig` instead for better encapsulation.
|
||||
max_freq : int, default=MAX_FREQ
|
||||
The maximum frequency (Hz) corresponding to the frequency axis of the
|
||||
model outputs. Required for coordinate remapping. Consider setting via
|
||||
`PostprocessConfig`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessorProtocol
|
||||
An initialized `Postprocessor` instance ready to process model outputs.
|
||||
"""
|
||||
"""Factory function to build the standard postprocessor."""
|
||||
config = config or PostprocessConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building postprocessor with config: \n{}",
|
||||
@ -214,303 +140,62 @@ def build_postprocessor(
|
||||
)
|
||||
return Postprocessor(
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
|
||||
class Postprocessor(PostprocessorProtocol):
|
||||
"""Standard implementation of the postprocessing pipeline.
|
||||
|
||||
This class orchestrates the steps required to convert raw model outputs
|
||||
into interpretable `soundevent` predictions. It uses configured parameters
|
||||
and leverages functions from the `batdetect2.postprocess` submodules for
|
||||
each stage (NMS, remapping, detection, extraction, decoding).
|
||||
|
||||
It requires a `TargetProtocol` object during initialization to access
|
||||
necessary decoding information (class name to tag mapping,
|
||||
ROI recovery logic) ensuring consistency with the target definitions used
|
||||
during training or specified for inference.
|
||||
|
||||
Instances are typically created using the `build_postprocessor` factory
|
||||
function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
The configured target definition object providing decoding and ROI
|
||||
recovery.
|
||||
config : PostprocessConfig
|
||||
Configuration object holding parameters for NMS, thresholds, etc.
|
||||
min_freq : float
|
||||
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
||||
max_freq : float
|
||||
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
||||
"""
|
||||
"""Standard implementation of the postprocessing pipeline."""
|
||||
|
||||
targets: TargetProtocol
|
||||
|
||||
preprocessor: PreprocessorProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: PostprocessConfig,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
):
|
||||
"""Initialize the Postprocessor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
Initialized target definition object.
|
||||
config : PostprocessConfig
|
||||
Configuration for postprocessing parameters.
|
||||
min_freq : int, default=MIN_FREQ
|
||||
Minimum frequency (Hz) for coordinate remapping.
|
||||
max_freq : int, default=MAX_FREQ
|
||||
Maximum frequency (Hz) for coordinate remapping.
|
||||
"""
|
||||
"""Initialize the Postprocessor."""
|
||||
self.targets = targets
|
||||
self.preprocessor = preprocessor
|
||||
self.config = config
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
|
||||
def get_feature_arrays(
|
||||
def get_detections(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw feature tensors for a batch.
|
||||
clips: Optional[List[data.Clip]] = None,
|
||||
) -> List[Detections]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.preprocessor.output_samplerate
|
||||
max_detections = int(self.config.top_k_per_sec * duration)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.features` tensor for the batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware feature DataArrays, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.features` and `clips` do not match.
|
||||
"""
|
||||
if len(clips) != len(output.features):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of feature array"
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, features: {len(output.features)})"
|
||||
)
|
||||
|
||||
return [
|
||||
features_to_xarray(
|
||||
feats,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for feats, clip in zip(output.features, clips)
|
||||
]
|
||||
|
||||
def get_detection_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Apply NMS and remap detection heatmaps for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.detection_probs` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of NMS-applied, coordinate-aware detection heatmaps, one per
|
||||
clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.detection_probs` and `clips` do not match.
|
||||
"""
|
||||
detections = output.detection_probs
|
||||
|
||||
if len(clips) != len(output.detection_probs):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of detection array "
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, detection: {len(detections)})"
|
||||
)
|
||||
|
||||
detections = non_max_suppression(
|
||||
detections,
|
||||
kernel_size=self.config.nms_kernel_size,
|
||||
detections = extract_prediction_tensor(
|
||||
output,
|
||||
max_detections=max_detections,
|
||||
threshold=self.config.detection_threshold,
|
||||
)
|
||||
|
||||
return [
|
||||
detection_to_xarray(
|
||||
dets,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for dets, clip in zip(detections, clips)
|
||||
]
|
||||
|
||||
def get_classification_arrays(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw classification tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.class_probs` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware class probability maps, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.class_probs` and `clips` do not match, or
|
||||
if number of classes mismatches `self.targets.class_names`.
|
||||
"""
|
||||
classifications = output.class_probs
|
||||
|
||||
if len(clips) != len(classifications):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of classification array "
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, classification: {len(classifications)})"
|
||||
)
|
||||
if clips is None:
|
||||
return detections
|
||||
|
||||
return [
|
||||
classification_to_xarray(
|
||||
class_probs,
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
class_names=self.targets.class_names,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
min_freq=self.preprocessor.min_freq,
|
||||
max_freq=self.preprocessor.max_freq,
|
||||
)
|
||||
for class_probs, clip in zip(classifications, clips)
|
||||
for detection, clip in zip(detections, clips)
|
||||
]
|
||||
|
||||
def get_sizes_arrays(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw size prediction tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.size_preds` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware size prediction maps, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.size_preds` and `clips` do not match.
|
||||
"""
|
||||
sizes = output.size_preds
|
||||
|
||||
if len(clips) != len(sizes):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of sizes array do not match. "
|
||||
f"(clips: {len(clips)}, sizes: {len(sizes)})"
|
||||
)
|
||||
|
||||
return [
|
||||
sizes_to_xarray(
|
||||
size_preds,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for size_preds, clip in zip(sizes, clips)
|
||||
]
|
||||
|
||||
def get_detection_datasets(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.Dataset]:
|
||||
"""Perform NMS, remapping, detection, and data extraction for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.Dataset]
|
||||
List of xarray Datasets (one per clip). Each Dataset contains
|
||||
aligned scores, dimensions, class probabilities, and features for
|
||||
detections found in that clip.
|
||||
"""
|
||||
detection_arrays = self.get_detection_arrays(output, clips)
|
||||
classification_arrays = self.get_classification_arrays(output, clips)
|
||||
size_arrays = self.get_sizes_arrays(output, clips)
|
||||
features_arrays = self.get_feature_arrays(output, clips)
|
||||
|
||||
datasets = []
|
||||
for det_array, class_array, sizes_array, feats_array in zip(
|
||||
detection_arrays,
|
||||
classification_arrays,
|
||||
size_arrays,
|
||||
features_arrays,
|
||||
):
|
||||
max_detections = get_max_detections(
|
||||
det_array,
|
||||
top_k_per_sec=self.config.top_k_per_sec,
|
||||
)
|
||||
|
||||
positions = extract_detections_from_array(
|
||||
det_array,
|
||||
max_detections=max_detections,
|
||||
threshold=self.config.detection_threshold,
|
||||
)
|
||||
|
||||
datasets.append(
|
||||
extract_detection_xr_dataset(
|
||||
positions,
|
||||
sizes_array,
|
||||
class_array,
|
||||
feats_array,
|
||||
)
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
def get_raw_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch.
|
||||
|
||||
@ -531,13 +216,13 @@ class Postprocessor(PostprocessorProtocol):
|
||||
List of lists (one inner list per input clip). Each inner list
|
||||
contains `RawPrediction` objects for detections in that clip.
|
||||
"""
|
||||
detection_datasets = self.get_detection_datasets(output, clips)
|
||||
detections = self.get_detections(output, clips)
|
||||
return [
|
||||
convert_xr_dataset_to_raw_prediction(
|
||||
convert_detections_to_raw_predictions(
|
||||
dataset,
|
||||
self.targets.decode_roi,
|
||||
targets=self.targets,
|
||||
)
|
||||
for dataset in detection_datasets
|
||||
for dataset in detections
|
||||
]
|
||||
|
||||
def get_sound_event_predictions(
|
||||
|
||||
@ -1,42 +1,18 @@
|
||||
"""Decodes extracted detection data into standard soundevent predictions.
|
||||
|
||||
This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
||||
It takes the structured detection data extracted by the `extraction` module
|
||||
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||
class probabilities, and features for each detection point) and converts it
|
||||
into standardized prediction objects based on the `soundevent` data model.
|
||||
|
||||
The process involves:
|
||||
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||
objects, using a configured geometry builder to recover bounding boxes from
|
||||
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
|
||||
2. Converting each `RawPrediction` into a
|
||||
`soundevent.data.SoundEventPrediction`, which involves:
|
||||
- Creating the `soundevent.data.SoundEvent` with geometry and features.
|
||||
- Decoding the predicted class probabilities into representative tags using
|
||||
a configured class decoder (`SoundEventDecoder`).
|
||||
- Applying a classification threshold.
|
||||
- Optionally selecting only the single highest-scoring class (top-1) or
|
||||
including tags for all classes above the threshold (multi-label).
|
||||
- Adding generic class tags as a baseline.
|
||||
- Associating scores with the final prediction and tags.
|
||||
(`convert_raw_prediction_to_sound_event_prediction`)
|
||||
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
|
||||
a `soundevent.data.ClipPrediction`
|
||||
(`convert_raw_predictions_to_clip_prediction`).
|
||||
"""
|
||||
"""Decodes extracted detection data into standard soundevent predictions."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.postprocess import GeometryDecoder, RawPrediction
|
||||
from batdetect2.typing.postprocess import (
|
||||
Detections,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
"convert_detections_to_raw_predictions",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"convert_raw_prediction_to_sound_event_prediction",
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
@ -51,65 +27,29 @@ decoding.
|
||||
"""
|
||||
|
||||
|
||||
def convert_xr_dataset_to_raw_prediction(
|
||||
detection_dataset: xr.Dataset,
|
||||
geometry_decoder: GeometryDecoder,
|
||||
def convert_detections_to_raw_predictions(
|
||||
detections: Detections,
|
||||
targets: TargetProtocol,
|
||||
) -> List[RawPrediction]:
|
||||
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||
|
||||
Takes the output of the extraction step (`extract_detection_xr_dataset`)
|
||||
and transforms each detection entry into an intermediate `RawPrediction`
|
||||
object. This involves recovering the geometry (e.g., bounding box) from
|
||||
the predicted position and scaled size dimensions using the provided
|
||||
`geometry_builder` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_dataset : xr.Dataset
|
||||
An xarray Dataset containing aligned detection information, typically
|
||||
output by `extract_detection_xr_dataset`. Expected variables include
|
||||
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||
Must have a 'detection' dimension.
|
||||
geometry_decoder : GeometryDecoder
|
||||
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||
of dimensions, and returns the corresponding reconstructed
|
||||
`soundevent.data.Geometry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[RawPrediction]
|
||||
A list of `RawPrediction` objects, each containing the detection score,
|
||||
recovered bounding box coordinates (start/end time, low/high freq),
|
||||
the vector of class scores, and the feature vector for one detection.
|
||||
|
||||
Raises
|
||||
------
|
||||
AttributeError, KeyError, ValueError
|
||||
If `detection_dataset` is missing expected variables ('scores',
|
||||
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
|
||||
associated with 'scores'), or if `geometry_builder` fails.
|
||||
"""
|
||||
detections = []
|
||||
|
||||
categories = detection_dataset.category.values
|
||||
predictions = []
|
||||
|
||||
for score, class_scores, time, freq, dims, feats in zip(
|
||||
detection_dataset["scores"].values,
|
||||
detection_dataset["classes"].values,
|
||||
detection_dataset["time"].values,
|
||||
detection_dataset["frequency"].values,
|
||||
detection_dataset["dimensions"].values,
|
||||
detection_dataset["features"].values,
|
||||
detections.scores,
|
||||
detections.class_scores,
|
||||
detections.times,
|
||||
detections.frequencies,
|
||||
detections.sizes,
|
||||
detections.features,
|
||||
):
|
||||
highest_scoring_class = categories[class_scores.argmax()]
|
||||
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
||||
|
||||
geom = geometry_decoder(
|
||||
geom = targets.decode_roi(
|
||||
(time, freq),
|
||||
dims,
|
||||
class_name=highest_scoring_class,
|
||||
)
|
||||
|
||||
detections.append(
|
||||
predictions.append(
|
||||
RawPrediction(
|
||||
detection_score=score,
|
||||
geometry=geom,
|
||||
@ -118,7 +58,7 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
)
|
||||
)
|
||||
|
||||
return detections
|
||||
return predictions
|
||||
|
||||
|
||||
def convert_raw_predictions_to_clip_prediction(
|
||||
@ -128,35 +68,7 @@ def convert_raw_predictions_to_clip_prediction(
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
) -> data.ClipPrediction:
|
||||
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
|
||||
|
||||
Iterates through `raw_predictions` (assumed to belong to a single clip),
|
||||
converts each one into a `soundevent.data.SoundEventPrediction` using
|
||||
`convert_raw_prediction_to_sound_event_prediction`, and packages them
|
||||
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_predictions : List[RawPrediction]
|
||||
List of raw prediction objects for a single clip.
|
||||
clip : data.Clip
|
||||
The original `soundevent.data.Clip` object these predictions belong to.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Function to decode class names into representative tags.
|
||||
generic_class_tags : List[data.Tag]
|
||||
List of tags representing the generic class category.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Threshold applied to class scores during decoding.
|
||||
top_class_only : bool, default=False
|
||||
If True, only decode tags for the single highest-scoring class above
|
||||
the threshold. If False, decode tags for all classes above threshold.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.ClipPrediction
|
||||
A `ClipPrediction` object containing a list of `SoundEventPrediction`
|
||||
objects corresponding to the input `raw_predictions`.
|
||||
"""
|
||||
"""Convert a list of RawPredictions into a soundevent ClipPrediction."""
|
||||
return data.ClipPrediction(
|
||||
clip=clip,
|
||||
sound_events=[
|
||||
@ -181,68 +93,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
):
|
||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
|
||||
|
||||
This function performs the core decoding steps for a single detected event:
|
||||
1. Creates a `soundevent.data.SoundEvent` containing the geometry
|
||||
(BoundingBox derived from `raw_prediction` bounds) and any associated
|
||||
feature vectors.
|
||||
2. Initializes a list of predicted tags using the provided
|
||||
`generic_class_tags`, assigning the overall `detection_score` from the
|
||||
`raw_prediction` to these generic tags.
|
||||
3. Processes the `class_scores` from the `raw_prediction`:
|
||||
a. Optionally filters out scores below `classification_threshold`
|
||||
(if it's not None).
|
||||
b. Sorts the remaining scores in descending order.
|
||||
c. Iterates through the sorted, thresholded class scores.
|
||||
d. For each class, uses the `sound_event_decoder` to get the
|
||||
representative base tags for that class name.
|
||||
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
|
||||
the specific `score` of that class prediction.
|
||||
f. Appends these specific predicted tags to the list.
|
||||
g. If `top_class_only` is True, stops after processing the first
|
||||
(highest-scoring) class that passed the threshold.
|
||||
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
|
||||
associating the `SoundEvent`, the overall `detection_score`, and the
|
||||
compiled list of `PredictedTag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_prediction : RawPrediction
|
||||
The raw prediction object containing score, bounds, class scores,
|
||||
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
|
||||
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
|
||||
coordinate.
|
||||
recording : data.Recording
|
||||
The recording the sound event belongs to.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Configured function mapping class names (str) to lists of base
|
||||
`data.Tag` objects.
|
||||
generic_class_tags : List[data.Tag]
|
||||
List of base tags representing the generic category.
|
||||
classification_threshold : float, optional
|
||||
The minimum score a class prediction must have to be considered
|
||||
significant enough to have its tags decoded and added. If None, no
|
||||
thresholding is applied based on class score (all predicted classes,
|
||||
or the top one if `top_class_only` is True, will be processed).
|
||||
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
|
||||
top_class_only : bool, default=False
|
||||
If True, only includes tags for the single highest-scoring class that
|
||||
exceeds the threshold. If False (default), includes tags for all classes
|
||||
exceeding the threshold.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventPrediction
|
||||
The fully formed sound event prediction object.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `raw_prediction.features` has unexpected structure or if
|
||||
`data.term_from_key` (if used internally) fails.
|
||||
If `sound_event_decoder` fails for a class name and errors are raised.
|
||||
"""
|
||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
|
||||
sound_event = data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=raw_prediction.geometry,
|
||||
@ -273,25 +124,7 @@ def get_generic_tags(
|
||||
detection_score: float,
|
||||
generic_class_tags: List[data.Tag],
|
||||
) -> List[data.PredictedTag]:
|
||||
"""Create PredictedTag objects for the generic category.
|
||||
|
||||
Takes the base list of generic tags and assigns the overall detection
|
||||
score to each one, wrapping them in `PredictedTag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_score : float
|
||||
The overall confidence score of the detection event.
|
||||
generic_class_tags : List[data.Tag]
|
||||
The list of base `soundevent.data.Tag` objects that define the
|
||||
generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']).
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.PredictedTag]
|
||||
A list of `PredictedTag` objects for the generic category, each
|
||||
assigned the `detection_score`.
|
||||
"""
|
||||
"""Create PredictedTag objects for the generic category."""
|
||||
return [
|
||||
data.PredictedTag(tag=tag, score=detection_score)
|
||||
for tag in generic_class_tags
|
||||
@ -299,25 +132,7 @@ def get_generic_tags(
|
||||
|
||||
|
||||
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
||||
"""Convert an extracted feature vector DataArray into soundevent Features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : xr.DataArray
|
||||
A 1D xarray DataArray containing feature values, indexed by a coordinate
|
||||
named 'feature' which holds the feature names (e.g., output of selecting
|
||||
features for one detection from `extract_detection_xr_dataset`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Feature]
|
||||
A list of `soundevent.data.Feature` objects.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- This function creates basic `Term` objects using the feature coordinate
|
||||
names with a "batdetect2:" prefix.
|
||||
"""
|
||||
"""Convert an extracted feature vector DataArray into soundevent Features."""
|
||||
return [
|
||||
data.Feature(
|
||||
term=data.Term(
|
||||
|
||||
@ -1,162 +0,0 @@
|
||||
"""Extracts candidate detection points from a model output heatmap.
|
||||
|
||||
This module implements Step 3 within the BatDetect2 postprocessing
|
||||
pipeline. Its primary function is to identify potential sound event locations
|
||||
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||
produced by the neural network (usually after Non-Maximum Suppression and
|
||||
coordinate remapping have been applied).
|
||||
|
||||
It provides functionality to:
|
||||
- Identify the locations (time, frequency) of the highest-scoring points.
|
||||
- Filter these points based on a minimum confidence score threshold.
|
||||
- Limit the maximum number of detection points returned (top-k).
|
||||
|
||||
The main output is an `xarray.DataArray` containing the scores and
|
||||
corresponding time/frequency coordinates for the extracted detection points.
|
||||
This output serves as the input for subsequent postprocessing steps, such as
|
||||
extracting predicted class probabilities and bounding box sizes at these
|
||||
specific locations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions, get_dim_width
|
||||
|
||||
__all__ = [
|
||||
"extract_detections_from_array",
|
||||
"get_max_detections",
|
||||
"DEFAULT_DETECTION_THRESHOLD",
|
||||
"TOP_K_PER_SEC",
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
"""Default confidence score threshold used for filtering detections."""
|
||||
|
||||
TOP_K_PER_SEC = 200
|
||||
"""Default desired maximum number of detections per second of audio."""
|
||||
|
||||
|
||||
def extract_detections_from_array(
|
||||
detection_array: xr.DataArray,
|
||||
max_detections: Optional[int] = None,
|
||||
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
|
||||
) -> xr.DataArray:
|
||||
"""Extract detection locations (time, freq) and scores from a heatmap.
|
||||
|
||||
Identifies the pixels with the highest scores in the input detection
|
||||
heatmap, filters them based on an optional score `threshold`, limits the
|
||||
number to an optional `max_detections`, and returns their scores along with
|
||||
their corresponding time and frequency coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_array : xr.DataArray
|
||||
A 2D xarray DataArray representing the detection heatmap. Must have
|
||||
dimensions and coordinates named 'time' and 'frequency'. Higher values
|
||||
are assumed to indicate higher detection confidence.
|
||||
max_detections : int, optional
|
||||
The absolute maximum number of detections to return. If specified, only
|
||||
the top `max_detections` highest-scoring detections (passing the
|
||||
threshold) are returned. If None (default), all detections passing
|
||||
the threshold are returned, sorted by score.
|
||||
threshold : float, optional
|
||||
The minimum confidence score required for a detection peak to be
|
||||
kept. Detections with scores below this value are discarded.
|
||||
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
|
||||
thresholding is applied.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
A 1D xarray DataArray named 'score' with a 'detection' dimension.
|
||||
- The data values are the scores of the extracted detections, sorted
|
||||
in descending order.
|
||||
- It has coordinates 'time' and 'frequency' (also indexed by the
|
||||
'detection' dimension) indicating the location of each detection
|
||||
peak in the original coordinate system.
|
||||
- Returns an empty DataArray if no detections pass the criteria.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `max_detections` is not None and not a positive integer, or if
|
||||
`detection_array` lacks required dimensions/coordinates.
|
||||
"""
|
||||
if max_detections is not None:
|
||||
if max_detections <= 0:
|
||||
raise ValueError("Max detections must be positive")
|
||||
|
||||
values = detection_array.values.flatten()
|
||||
|
||||
if max_detections is not None:
|
||||
top_indices = np.argpartition(-values, max_detections)[:max_detections]
|
||||
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
|
||||
else:
|
||||
top_sorted_indices = np.argsort(-values)
|
||||
|
||||
top_values = values[top_sorted_indices]
|
||||
|
||||
if threshold is not None:
|
||||
mask = top_values > threshold
|
||||
top_values = top_values[mask]
|
||||
top_sorted_indices = top_sorted_indices[mask]
|
||||
|
||||
freq_indices, time_indices = np.unravel_index(
|
||||
top_sorted_indices,
|
||||
detection_array.shape,
|
||||
)
|
||||
|
||||
times = detection_array.coords[Dimensions.time.value].values[time_indices]
|
||||
freqs = detection_array.coords[Dimensions.frequency.value].values[
|
||||
freq_indices
|
||||
]
|
||||
|
||||
return xr.DataArray(
|
||||
data=top_values,
|
||||
coords={
|
||||
Dimensions.frequency.value: ("detection", freqs),
|
||||
Dimensions.time.value: ("detection", times),
|
||||
},
|
||||
dims="detection",
|
||||
name="score",
|
||||
)
|
||||
|
||||
|
||||
def get_max_detections(
|
||||
detection_array: xr.DataArray,
|
||||
top_k_per_sec: int = TOP_K_PER_SEC,
|
||||
) -> int:
|
||||
"""Calculate max detections allowed based on duration and rate.
|
||||
|
||||
Determines the total maximum number of detections to extract from a
|
||||
heatmap based on its time duration and a desired rate of detections
|
||||
per second.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_array : xr.DataArray
|
||||
The detection heatmap, requiring 'time' coordinates from which the
|
||||
total duration can be calculated using
|
||||
`soundevent.arrays.get_dim_width`.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
The desired maximum number of detections to allow per second of audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The calculated total maximum number of detections allowed for the
|
||||
entire duration of the `detection_array`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the duration cannot be calculated from the `detection_array` (e.g.,
|
||||
missing or invalid 'time' coordinates/dimension).
|
||||
"""
|
||||
if top_k_per_sec < 0:
|
||||
raise ValueError("top_k_per_sec cannot be negative.")
|
||||
|
||||
duration = get_dim_width(detection_array, Dimensions.time.value)
|
||||
return int(duration * top_k_per_sec)
|
||||
@ -15,108 +15,73 @@ precise time-frequency location of each detection. The final output aggregates
|
||||
all extracted information into a structured `xarray.Dataset`.
|
||||
"""
|
||||
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||
from batdetect2.typing.postprocess import Detections, ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"extract_values_at_positions",
|
||||
"extract_detection_xr_dataset",
|
||||
"extract_prediction_tensor",
|
||||
]
|
||||
|
||||
|
||||
def extract_values_at_positions(
|
||||
array: xr.DataArray,
|
||||
positions: xr.DataArray,
|
||||
) -> xr.DataArray:
|
||||
"""Extract values from an array at specified time-frequency positions.
|
||||
|
||||
Uses coordinate-based indexing to retrieve values from a source `array`
|
||||
(e.g., class probabilities, size predictions, features) at the time and
|
||||
frequency coordinates defined in the `positions` array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : xr.DataArray
|
||||
The source DataArray from which to extract values. Must have 'time'
|
||||
and 'frequency' dimensions and coordinates matching the space of
|
||||
`positions`.
|
||||
positions : xr.DataArray
|
||||
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
|
||||
locations from which to extract values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
A DataArray containing the values extracted from `array` at the given
|
||||
positions.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError, IndexError, KeyError
|
||||
If dimensions or coordinates are missing or incompatible between
|
||||
`array` and `positions`, or if selection fails.
|
||||
"""
|
||||
return array.sel(
|
||||
**{
|
||||
Dimensions.frequency.value: positions.coords[
|
||||
Dimensions.frequency.value
|
||||
],
|
||||
Dimensions.time.value: positions.coords[Dimensions.time.value],
|
||||
}
|
||||
).T
|
||||
|
||||
|
||||
def extract_detection_xr_dataset(
|
||||
positions: xr.DataArray,
|
||||
sizes: xr.DataArray,
|
||||
classes: xr.DataArray,
|
||||
features: xr.DataArray,
|
||||
) -> xr.Dataset:
|
||||
"""Combine extracted detection information into a structured xr.Dataset.
|
||||
|
||||
Takes the detection positions/scores and the full model output heatmaps
|
||||
(sizes, classes, optional features), extracts the relevant data at the
|
||||
detection positions, and packages everything into a single `xarray.Dataset`
|
||||
where all variables are indexed by a common 'detection' dimension.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : xr.DataArray
|
||||
Output from `extract_detections_from_array`, containing detection
|
||||
scores as data and 'time', 'frequency' coordinates along the
|
||||
'detection' dimension.
|
||||
sizes : xr.DataArray
|
||||
The full size prediction heatmap from the model, with dimensions like
|
||||
('dimension', 'time', 'frequency').
|
||||
classes : xr.DataArray
|
||||
The full class probability heatmap from the model, with dimensions like
|
||||
('category', 'time', 'frequency').
|
||||
features : xr.DataArray
|
||||
The full feature map from the model, with
|
||||
dimensions like ('feature', 'time', 'frequency').
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
An xarray Dataset containing aligned information for each detection:
|
||||
- 'scores': DataArray from `positions` (score data, time/freq coords).
|
||||
- 'dimensions': DataArray with extracted size values
|
||||
(dims: 'detection', 'dimension').
|
||||
- 'classes': DataArray with extracted class probabilities
|
||||
(dims: 'detection', 'category').
|
||||
- 'features': DataArray with extracted feature vectors
|
||||
(dims: 'detection', 'feature'), if `features` was provided. All
|
||||
DataArrays share the 'detection' dimension and associated
|
||||
time/frequency coordinates.
|
||||
"""
|
||||
sizes = extract_values_at_positions(sizes, positions)
|
||||
classes = extract_values_at_positions(classes, positions)
|
||||
features = extract_values_at_positions(features, positions)
|
||||
return xr.Dataset(
|
||||
{
|
||||
"scores": positions,
|
||||
"dimensions": sizes,
|
||||
"classes": classes,
|
||||
"features": features,
|
||||
}
|
||||
def extract_prediction_tensor(
|
||||
output: ModelOutput,
|
||||
max_detections: int = 200,
|
||||
threshold: Optional[float] = None,
|
||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
) -> List[Detections]:
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs,
|
||||
kernel_size=nms_kernel_size,
|
||||
)
|
||||
|
||||
height = detection_heatmap.shape[-2]
|
||||
width = detection_heatmap.shape[-1]
|
||||
|
||||
freqs, times = torch.meshgrid(
|
||||
torch.arange(height, dtype=torch.int32),
|
||||
torch.arange(width, dtype=torch.int32),
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
freqs = freqs.flatten()
|
||||
times = times.flatten()
|
||||
|
||||
predictions = []
|
||||
for idx, item in enumerate(detection_heatmap):
|
||||
item = item.squeeze().flatten() # Remove channel dim
|
||||
indices = torch.argsort(item, descending=True)[:max_detections]
|
||||
|
||||
detection_scores = item.take(indices)
|
||||
detection_freqs = freqs.take(indices)
|
||||
detection_times = times.take(indices)
|
||||
sizes = output.size_preds[idx, :, detection_freqs, detection_times].T
|
||||
features = output.features[idx, :, detection_freqs, detection_times].T
|
||||
class_scores = output.class_probs[
|
||||
idx, :, detection_freqs, detection_times
|
||||
].T
|
||||
|
||||
if threshold is not None:
|
||||
mask = detection_scores >= threshold
|
||||
detection_scores = detection_scores[mask]
|
||||
sizes = sizes[mask]
|
||||
detection_times = detection_times[mask]
|
||||
detection_freqs = detection_freqs[mask]
|
||||
features = features[mask]
|
||||
class_scores = class_scores[mask]
|
||||
|
||||
predictions.append(
|
||||
Detections(
|
||||
scores=detection_scores,
|
||||
sizes=sizes,
|
||||
features=features,
|
||||
class_scores=class_scores,
|
||||
times=detection_times.to(torch.float32) / width,
|
||||
frequencies=(detection_freqs.to(torch.float32) / height),
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
@ -20,6 +20,7 @@ import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.typing.postprocess import Detections
|
||||
|
||||
__all__ = [
|
||||
"features_to_xarray",
|
||||
@ -29,6 +30,26 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def map_detection_to_clip(
|
||||
detections: Detections,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> Detections:
|
||||
duration = end_time - start_time
|
||||
bandwidth = max_freq - min_freq
|
||||
print(f"{bandwidth=} {min_freq=} {detections.frequencies=}")
|
||||
return Detections(
|
||||
scores=detections.scores,
|
||||
sizes=detections.sizes,
|
||||
features=detections.features,
|
||||
class_scores=detections.class_scores,
|
||||
times=(detections.times * duration + start_time),
|
||||
frequencies=(detections.frequencies * bandwidth + min_freq),
|
||||
)
|
||||
|
||||
|
||||
def features_to_xarray(
|
||||
features: torch.Tensor,
|
||||
start_time: float,
|
||||
|
||||
@ -53,6 +53,7 @@ from batdetect2.preprocess.spectrogram import (
|
||||
SpectrogramConfig,
|
||||
SpectrogramPipeline,
|
||||
STFTConfig,
|
||||
_spec_params_from_config,
|
||||
build_spectrogram_builder,
|
||||
build_spectrogram_pipeline,
|
||||
)
|
||||
@ -109,7 +110,9 @@ def load_preprocessing_config(
|
||||
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol."""
|
||||
|
||||
samplerate: int
|
||||
input_samplerate: int
|
||||
output_samplerate: float
|
||||
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
|
||||
@ -117,22 +120,33 @@ class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
self,
|
||||
audio_pipeline: torch.nn.Module,
|
||||
spectrogram_pipeline: SpectrogramPipeline,
|
||||
samplerate: int,
|
||||
input_samplerate: int,
|
||||
output_samplerate: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.audio_pipeline = audio_pipeline
|
||||
self.spectrogram_pipeline = spectrogram_pipeline
|
||||
self.samplerate = samplerate
|
||||
|
||||
self.max_freq = max_freq
|
||||
self.min_freq = min_freq
|
||||
|
||||
self.input_samplerate = input_samplerate
|
||||
self.output_samplerate = output_samplerate
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
wav = self.audio_pipeline(wav)
|
||||
return self.spectrogram_pipeline(wav)
|
||||
|
||||
|
||||
def compute_output_samplerate(config: PreprocessingConfig) -> float:
|
||||
samplerate = config.audio.samplerate
|
||||
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
|
||||
factor = config.spectrogram.size.resize_factor
|
||||
return samplerate * factor / hop_size
|
||||
|
||||
|
||||
def build_preprocessor(
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> PreprocessorProtocol:
|
||||
@ -148,16 +162,15 @@ def build_preprocessor(
|
||||
min_freq = config.spectrogram.frequencies.min_freq
|
||||
max_freq = config.spectrogram.frequencies.max_freq
|
||||
|
||||
output_samplerate = compute_output_samplerate(config)
|
||||
|
||||
return StandardPreprocessor(
|
||||
audio_pipeline=build_audio_pipeline(config.audio),
|
||||
spectrogram_pipeline=build_spectrogram_pipeline(
|
||||
samplerate, config.spectrogram
|
||||
),
|
||||
samplerate=samplerate,
|
||||
input_samplerate=samplerate,
|
||||
output_samplerate=output_samplerate,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
|
||||
def get_default_preprocessor():
|
||||
return build_preprocessor()
|
||||
|
||||
@ -148,7 +148,7 @@ def add_echo(
|
||||
"""Add a synthetic echo to the audio waveform."""
|
||||
|
||||
audio = example.audio
|
||||
delay_steps = int(preprocessor.samplerate * delay)
|
||||
delay_steps = int(preprocessor.input_samplerate * delay)
|
||||
audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1])
|
||||
|
||||
audio = audio + weight * audio_delay
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.typing import ClipperProtocol
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.train import PreprocessedExample
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
|
||||
@ -18,24 +20,26 @@ class ClipingConfig(BaseConfig):
|
||||
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
||||
|
||||
|
||||
class Clipper(ClipperProtocol):
|
||||
class Clipper(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
duration: float = 0.5,
|
||||
max_empty: float = 0.2,
|
||||
random: bool = True,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
super().__init__()
|
||||
self.preprocessor = preprocessor
|
||||
self.duration = duration
|
||||
self.random = random
|
||||
self.max_empty = max_empty
|
||||
|
||||
def extract_clip(
|
||||
self, example: PreprocessedExample
|
||||
def forward(
|
||||
self,
|
||||
example: PreprocessedExample,
|
||||
) -> Tuple[PreprocessedExample, float, float]:
|
||||
start_time = 0
|
||||
duration = example.audio.shape[-1] / self.samplerate
|
||||
duration = example.audio.shape[-1] / self.preprocessor.input_samplerate
|
||||
|
||||
if self.random:
|
||||
start_time = np.random.uniform(
|
||||
@ -48,7 +52,8 @@ class Clipper(ClipperProtocol):
|
||||
example,
|
||||
start=start_time,
|
||||
duration=self.duration,
|
||||
samplerate=self.samplerate,
|
||||
input_samplerate=self.preprocessor.input_samplerate,
|
||||
output_samplerate=self.preprocessor.output_samplerate,
|
||||
),
|
||||
start_time,
|
||||
start_time + self.duration,
|
||||
@ -56,7 +61,7 @@ class Clipper(ClipperProtocol):
|
||||
|
||||
|
||||
def build_clipper(
|
||||
samplerate: int,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[ClipingConfig] = None,
|
||||
random: Optional[bool] = None,
|
||||
) -> ClipperProtocol:
|
||||
@ -66,7 +71,7 @@ def build_clipper(
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return Clipper(
|
||||
samplerate=samplerate,
|
||||
preprocessor=preprocessor,
|
||||
duration=config.duration,
|
||||
max_empty=config.max_empty,
|
||||
random=config.random if random else False,
|
||||
@ -77,11 +82,12 @@ def select_subclip(
|
||||
example: PreprocessedExample,
|
||||
start: float,
|
||||
duration: float,
|
||||
samplerate: float,
|
||||
input_samplerate: float,
|
||||
output_samplerate: float,
|
||||
fill_value: float = 0,
|
||||
) -> PreprocessedExample:
|
||||
audio_width = int(np.floor(duration * samplerate))
|
||||
audio_start = int(np.floor(start * samplerate))
|
||||
audio_width = int(np.floor(duration * input_samplerate))
|
||||
audio_start = int(np.floor(start * input_samplerate))
|
||||
|
||||
audio = adjust_width(
|
||||
example.audio[audio_start : audio_start + audio_width],
|
||||
@ -89,12 +95,8 @@ def select_subclip(
|
||||
value=fill_value,
|
||||
)
|
||||
|
||||
audio_duration = example.audio.shape[-1] / samplerate
|
||||
spec_sr = example.spectrogram.shape[-1] / audio_duration
|
||||
|
||||
spec_start = int(np.floor(start * spec_sr))
|
||||
spec_width = int(np.floor(duration * spec_sr))
|
||||
|
||||
spec_start = int(np.floor(start * output_samplerate))
|
||||
spec_width = int(np.floor(duration * output_samplerate))
|
||||
return PreprocessedExample(
|
||||
audio=audio,
|
||||
spectrogram=adjust_width(
|
||||
|
||||
@ -32,7 +32,7 @@ class LabeledDataset(Dataset):
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
example = self.get_example(idx)
|
||||
|
||||
example, start_time, end_time = self.clipper.extract_clip(example)
|
||||
example, start_time, end_time = self.clipper(example)
|
||||
|
||||
if self.augmentation:
|
||||
example = self.augmentation(example)
|
||||
@ -64,9 +64,7 @@ class LabeledDataset(Dataset):
|
||||
def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
|
||||
idx = np.random.randint(0, len(self))
|
||||
dataset = self.get_example(idx)
|
||||
|
||||
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
||||
|
||||
dataset, start_time, end_time = self.clipper(dataset)
|
||||
return dataset, start_time, end_time
|
||||
|
||||
def get_example(self, idx) -> PreprocessedExample:
|
||||
@ -107,5 +105,5 @@ class RandomExampleSource:
|
||||
index = int(np.random.randint(len(self.filenames)))
|
||||
filename = self.filenames[index]
|
||||
example = load_preprocessed_example(filename)
|
||||
example, _, _ = self.clipper.extract_clip(example)
|
||||
example, _, _ = self.clipper(example)
|
||||
return example
|
||||
|
||||
@ -229,7 +229,7 @@ def build_train_dataset(
|
||||
config = config or TrainingConfig()
|
||||
|
||||
clipper = build_clipper(
|
||||
samplerate=preprocessor.samplerate,
|
||||
preprocessor=preprocessor,
|
||||
config=config.cliping,
|
||||
random=True,
|
||||
)
|
||||
@ -265,7 +265,7 @@ def build_val_dataset(
|
||||
logger.info("Building validation dataset...")
|
||||
config = config or TrainingConfig()
|
||||
clipper = build_clipper(
|
||||
samplerate=preprocessor.samplerate,
|
||||
preprocessor=preprocessor,
|
||||
config=config.cliping,
|
||||
random=train,
|
||||
)
|
||||
|
||||
@ -15,7 +15,7 @@ from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
@ -77,6 +77,15 @@ class RawPrediction(NamedTuple):
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
class Detections(NamedTuple):
|
||||
scores: torch.Tensor
|
||||
sizes: torch.Tensor
|
||||
class_scores: torch.Tensor
|
||||
times: torch.Tensor
|
||||
frequencies: torch.Tensor
|
||||
features: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatDetect2Prediction:
|
||||
raw: RawPrediction
|
||||
@ -84,154 +93,13 @@ class BatDetect2Prediction:
|
||||
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
"""Protocol defining the interface for the full postprocessing pipeline.
|
||||
"""Protocol defining the interface for the full postprocessing pipeline."""
|
||||
|
||||
This protocol outlines the standard methods for an object that takes raw
|
||||
output from a BatDetect2 model and the corresponding input clip metadata,
|
||||
and processes it through various stages (e.g., coordinate remapping, NMS,
|
||||
detection extraction, data extraction, decoding) to produce interpretable
|
||||
results at different levels of completion.
|
||||
|
||||
Implementations manage the configured logic for all postprocessing steps.
|
||||
"""
|
||||
|
||||
def get_feature_arrays(
|
||||
def get_detections(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap feature tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch, expected
|
||||
to contain the necessary feature tensors.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects, one for each item in the
|
||||
processed batch. This list provides the timing, recording, and
|
||||
other metadata context needed to calculate real-world coordinates
|
||||
(seconds, Hz) for the output arrays. The length of this list must
|
||||
correspond to the batch size of the `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of xarray DataArrays, one for each input clip in the batch,
|
||||
in the same order. Each DataArray contains the feature vectors
|
||||
with dimensions like ('feature', 'time', 'frequency') and
|
||||
corresponding real-world coordinates.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_detection_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap detection tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing detection heatmaps.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 2D xarray DataArrays (one per input clip, in order),
|
||||
representing the detection heatmap with 'time' and 'frequency'
|
||||
coordinates. Values typically indicate detection confidence.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_classification_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap classification tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing class probability tensors.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||
representing class probabilities with 'category', 'time', and
|
||||
'frequency' dimensions and coordinates.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_sizes_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap size prediction tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing predicted size tensors (e.g., width and height).
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||
representing predicted sizes with 'dimension'
|
||||
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
|
||||
coordinates. Values represent estimated detection sizes.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_detection_datasets(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.Dataset]:
|
||||
"""Perform remapping, NMS, detection, and data extraction for a batch.
|
||||
|
||||
Processes the raw model output for a batch to identify detection peaks
|
||||
and extract all associated information (score, position, size, class
|
||||
probs, features) at those peak locations, returning a structured
|
||||
dataset for each input clip in the batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing context. Must match the batch size of `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.Dataset]
|
||||
A list of xarray Datasets (one per input clip, in order). Each
|
||||
Dataset contains multiple DataArrays ('scores', 'dimensions',
|
||||
'classes', 'features') sharing a common 'detection' dimension,
|
||||
providing aligned data for each detected event in that clip.
|
||||
"""
|
||||
...
|
||||
clips: Optional[List[data.Clip]] = None,
|
||||
) -> List[Detections]: ...
|
||||
|
||||
def get_raw_predictions(
|
||||
self,
|
||||
|
||||
@ -148,7 +148,9 @@ class PreprocessorProtocol(Protocol):
|
||||
|
||||
min_freq: float
|
||||
|
||||
samplerate: int
|
||||
input_samplerate: int
|
||||
|
||||
output_samplerate: float
|
||||
|
||||
audio_pipeline: AudioPipeline
|
||||
|
||||
|
||||
@ -96,6 +96,6 @@ class LossProtocol(Protocol):
|
||||
|
||||
|
||||
class ClipperProtocol(Protocol):
|
||||
def extract_clip(
|
||||
def __call__(
|
||||
self, example: PreprocessedExample
|
||||
) -> Tuple[PreprocessedExample, float, float]: ...
|
||||
|
||||
@ -10,7 +10,6 @@ from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
convert_xr_dataset_to_raw_prediction,
|
||||
get_class_tags,
|
||||
get_generic_tags,
|
||||
get_prediction_features,
|
||||
@ -278,63 +277,6 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
||||
return [pred1, pred2, pred3]
|
||||
|
||||
|
||||
def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
|
||||
"""Test basic conversion of a dataset to RawPrediction list."""
|
||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
||||
sample_detection_dataset,
|
||||
dummy_targets.decode_roi,
|
||||
)
|
||||
|
||||
assert isinstance(raw_predictions, list)
|
||||
assert len(raw_predictions) == 2
|
||||
|
||||
pred1 = raw_predictions[0]
|
||||
assert isinstance(pred1, RawPrediction)
|
||||
assert pred1.detection_score == 0.9
|
||||
|
||||
assert pred1.geometry.coordinates == [
|
||||
20 - 7 / 2,
|
||||
300 - 16 / 2,
|
||||
20 + 7 / 2,
|
||||
300 + 16 / 2,
|
||||
]
|
||||
np.testing.assert_allclose(
|
||||
pred1.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=0),
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
pred1.features, sample_detection_dataset["features"].sel(detection=0)
|
||||
)
|
||||
|
||||
pred2 = raw_predictions[1]
|
||||
assert isinstance(pred2, RawPrediction)
|
||||
assert pred2.detection_score == 0.8
|
||||
|
||||
assert pred2.geometry.coordinates == [
|
||||
10 - 3 / 2,
|
||||
200 - 12 / 2,
|
||||
10 + 3 / 2,
|
||||
200 + 12 / 2,
|
||||
]
|
||||
np.testing.assert_allclose(
|
||||
pred2.class_scores,
|
||||
sample_detection_dataset["classes"].sel(detection=1),
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
pred2.features, sample_detection_dataset["features"].sel(detection=1)
|
||||
)
|
||||
|
||||
|
||||
def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
|
||||
"""Test conversion of an empty dataset."""
|
||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
||||
empty_detection_dataset,
|
||||
dummy_targets.decode_roi,
|
||||
)
|
||||
assert isinstance(raw_predictions, list)
|
||||
assert len(raw_predictions) == 0
|
||||
|
||||
|
||||
def test_convert_raw_to_sound_event_basic(
|
||||
sample_raw_predictions,
|
||||
sample_recording,
|
||||
|
||||
@ -1,214 +0,0 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.postprocess.detection import extract_detections_from_array
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data_array():
|
||||
"""Provides a basic 3x3 DataArray.
|
||||
Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30)
|
||||
"""
|
||||
array = xr.DataArray(
|
||||
np.zeros([3, 3]),
|
||||
coords={
|
||||
Dimensions.frequency.value: [100, 200, 300],
|
||||
Dimensions.time.value: [10, 20, 30],
|
||||
},
|
||||
dims=[
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
)
|
||||
|
||||
array.loc[dict(time=10, frequency=100)] = 0.005
|
||||
array.loc[dict(time=10, frequency=200)] = 0.5
|
||||
array.loc[dict(time=10, frequency=300)] = 0.03
|
||||
array.loc[dict(time=20, frequency=100)] = 0.8
|
||||
array.loc[dict(time=20, frequency=200)] = 0.02
|
||||
array.loc[dict(time=20, frequency=300)] = 0.6
|
||||
array.loc[dict(time=30, frequency=100)] = 0.04
|
||||
array.loc[dict(time=30, frequency=200)] = 0.9
|
||||
array.loc[dict(time=30, frequency=300)] = 0.7
|
||||
return array
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_array_with_nans(sample_data_array: xr.DataArray):
|
||||
"""Provides a 2D DataArray containing NaN values."""
|
||||
array = sample_data_array.copy()
|
||||
array.loc[dict(time=10, frequency=300)] = np.nan
|
||||
array.loc[dict(time=30, frequency=100)] = np.nan
|
||||
return array
|
||||
|
||||
|
||||
def test_basic_extraction(sample_data_array: xr.DataArray):
|
||||
threshold = 0.1
|
||||
max_detections = 3
|
||||
|
||||
actual_result = extract_detections_from_array(
|
||||
sample_data_array,
|
||||
threshold=threshold,
|
||||
max_detections=max_detections,
|
||||
)
|
||||
|
||||
expected_values = np.array([0.9, 0.8, 0.7])
|
||||
expected_times = np.array([30, 20, 30])
|
||||
expected_freqs = np.array([200, 100, 300])
|
||||
expected_coords = {
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
expected_values,
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="score",
|
||||
)
|
||||
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
|
||||
|
||||
def test_threshold_only(sample_data_array):
|
||||
input_array = sample_data_array
|
||||
threshold = 0.5
|
||||
actual_result = extract_detections_from_array(
|
||||
input_array, threshold=threshold
|
||||
)
|
||||
expected_values = np.array([0.9, 0.8, 0.7, 0.6])
|
||||
expected_times = np.array([30, 20, 30, 20])
|
||||
expected_freqs = np.array([200, 100, 300, 300])
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
expected_values,
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
|
||||
|
||||
def test_max_detections_only(sample_data_array):
|
||||
input_array = sample_data_array
|
||||
max_detections = 4
|
||||
actual_result = extract_detections_from_array(
|
||||
input_array, max_detections=max_detections
|
||||
)
|
||||
expected_values = np.array([0.9, 0.8, 0.7, 0.6])
|
||||
expected_times = np.array([30, 20, 30, 20])
|
||||
expected_freqs = np.array([200, 100, 300, 300])
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
expected_values,
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
|
||||
|
||||
def test_no_optional_args(sample_data_array):
|
||||
input_array = sample_data_array
|
||||
actual_result = extract_detections_from_array(input_array)
|
||||
expected_values = np.array([0.9, 0.8, 0.7, 0.6, 0.5, 0.04, 0.03, 0.02])
|
||||
expected_times = np.array([30, 20, 30, 20, 10, 30, 10, 20])
|
||||
expected_freqs = np.array([200, 100, 300, 300, 200, 100, 300, 200])
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
expected_values,
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
|
||||
|
||||
def test_no_values_above_threshold(sample_data_array):
|
||||
input_array = sample_data_array
|
||||
threshold = 1.0
|
||||
actual_result = extract_detections_from_array(
|
||||
input_array, threshold=threshold
|
||||
)
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", np.array([], dtype=np.int64)),
|
||||
Dimensions.frequency.value: (
|
||||
"detection",
|
||||
np.array([], dtype=np.int64),
|
||||
),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
np.array([], dtype=np.float64),
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
assert actual_result.sizes["detection"] == 0
|
||||
|
||||
|
||||
def test_max_detections_zero(sample_data_array):
|
||||
input_array = sample_data_array
|
||||
max_detections = 0
|
||||
with pytest.raises(ValueError):
|
||||
extract_detections_from_array(
|
||||
input_array,
|
||||
max_detections=max_detections,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_input_array():
|
||||
empty_array = xr.DataArray(
|
||||
np.empty((0, 0)),
|
||||
coords={Dimensions.time.value: [], Dimensions.frequency.value: []},
|
||||
dims=[Dimensions.time.value, Dimensions.frequency.value],
|
||||
)
|
||||
actual_result = extract_detections_from_array(empty_array)
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", np.array([], dtype=np.int64)),
|
||||
Dimensions.frequency.value: (
|
||||
"detection",
|
||||
np.array([], dtype=np.int64),
|
||||
),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
np.array([], dtype=np.float64),
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
assert actual_result.sizes["detection"] == 0
|
||||
|
||||
|
||||
def test_nan_handling(data_array_with_nans):
|
||||
input_array = data_array_with_nans
|
||||
threshold = 0.1
|
||||
max_detections = 3
|
||||
actual_result = extract_detections_from_array(
|
||||
input_array, threshold=threshold, max_detections=max_detections
|
||||
)
|
||||
expected_values = np.array([0.9, 0.8, 0.7])
|
||||
expected_times = np.array([30, 20, 30])
|
||||
expected_freqs = np.array([200, 100, 300])
|
||||
expected_coords = {
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
}
|
||||
expected_result = xr.DataArray(
|
||||
expected_values,
|
||||
coords=expected_coords,
|
||||
dims="detection",
|
||||
name="detection_value",
|
||||
)
|
||||
xr.testing.assert_equal(actual_result, expected_result)
|
||||
@ -1,397 +1,2 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.postprocess.detection import extract_detections_from_array
|
||||
from batdetect2.postprocess.extraction import (
|
||||
extract_detection_xr_dataset,
|
||||
extract_values_at_positions,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data_array():
|
||||
"""Provides a basic 3x3 DataArray.
|
||||
Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30)
|
||||
"""
|
||||
coords = {
|
||||
Dimensions.frequency.value: [100, 200, 300],
|
||||
Dimensions.time.value: [10, 20, 30],
|
||||
}
|
||||
array = xr.DataArray(
|
||||
np.zeros([3, 3]),
|
||||
coords=coords,
|
||||
dims=[
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
)
|
||||
|
||||
array.loc[dict(time=10, frequency=100)] = 0.005
|
||||
array.loc[dict(time=10, frequency=200)] = 0.5
|
||||
array.loc[dict(time=10, frequency=300)] = 0.03
|
||||
array.loc[dict(time=20, frequency=100)] = 0.8
|
||||
array.loc[dict(time=20, frequency=200)] = 0.02
|
||||
array.loc[dict(time=20, frequency=300)] = 0.6
|
||||
array.loc[dict(time=30, frequency=100)] = 0.04
|
||||
array.loc[dict(time=30, frequency=200)] = 0.9
|
||||
array.loc[dict(time=30, frequency=300)] = 0.7
|
||||
return array
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array_for_extraction():
|
||||
"""Provides a simple array (1-9) for value extraction tests."""
|
||||
data = np.arange(1, 10).reshape(3, 3)
|
||||
coords = {
|
||||
Dimensions.frequency.value: [100, 200, 300],
|
||||
Dimensions.time.value: [10, 20, 30],
|
||||
}
|
||||
return xr.DataArray(
|
||||
data,
|
||||
coords=coords,
|
||||
dims=[
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
name="test_values",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_positions_top3(sample_data_array):
|
||||
"""Get top 3 detection positions from sample_data_array."""
|
||||
|
||||
return extract_detections_from_array(
|
||||
sample_data_array,
|
||||
max_detections=3,
|
||||
threshold=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_positions_top2(sample_data_array):
|
||||
"""Get top 2 detection positions from sample_data_array."""
|
||||
return extract_detections_from_array(
|
||||
sample_data_array,
|
||||
max_detections=2,
|
||||
threshold=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_positions(sample_data_array):
|
||||
"""Get an empty positions array (high threshold)."""
|
||||
return extract_detections_from_array(
|
||||
sample_data_array,
|
||||
threshold=1.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sizes_array(sample_data_array):
|
||||
"""Provides a sample sizes array matching sample_data_array coords."""
|
||||
coords = sample_data_array.coords
|
||||
data = np.array(
|
||||
[
|
||||
[
|
||||
[0, 1, 2],
|
||||
[3, 4, 5],
|
||||
[6, 7, 8],
|
||||
],
|
||||
[
|
||||
[9, 10, 11],
|
||||
[12, 13, 14],
|
||||
[15, 16, 17],
|
||||
],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data,
|
||||
coords={
|
||||
"dimension": ["width", "height"],
|
||||
Dimensions.frequency.value: coords[Dimensions.frequency.value],
|
||||
Dimensions.time.value: coords[Dimensions.time.value],
|
||||
},
|
||||
dims=["dimension", Dimensions.frequency.value, Dimensions.time.value],
|
||||
name="sizes",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_classes_array(sample_data_array):
|
||||
"""Provides a sample classes array matching sample_data_array coords."""
|
||||
coords = sample_data_array.coords
|
||||
data = np.linspace(0.1, 0.9, 18, dtype=np.float32).reshape(2, 3, 3)
|
||||
return xr.DataArray(
|
||||
data,
|
||||
coords={
|
||||
"category": ["bat", "noise"],
|
||||
Dimensions.frequency.value: coords[Dimensions.frequency.value],
|
||||
Dimensions.time.value: coords[Dimensions.time.value],
|
||||
},
|
||||
dims=["category", Dimensions.frequency.value, Dimensions.time.value],
|
||||
name="class_scores",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_features_array(sample_data_array):
|
||||
"""Provides a sample features array matching sample_data_array coords."""
|
||||
coords = sample_data_array.coords
|
||||
data = np.arange(0, 36, dtype=np.float32).reshape(4, 3, 3)
|
||||
return xr.DataArray(
|
||||
data,
|
||||
coords={
|
||||
"feature": ["f0", "f1", "f2", "f3"],
|
||||
Dimensions.frequency.value: coords[Dimensions.frequency.value],
|
||||
Dimensions.time.value: coords[Dimensions.time.value],
|
||||
},
|
||||
dims=["feature", Dimensions.frequency.value, Dimensions.time.value],
|
||||
name="features",
|
||||
)
|
||||
|
||||
|
||||
def test_extract_values_at_positions_correct(
|
||||
sample_array_for_extraction,
|
||||
sample_positions_top3,
|
||||
):
|
||||
"""Verify correct values are extracted based on positions coords."""
|
||||
expected_values = np.array(
|
||||
[
|
||||
sample_array_for_extraction.sel(time=30, frequency=200).values,
|
||||
sample_array_for_extraction.sel(time=20, frequency=100).values,
|
||||
sample_array_for_extraction.sel(time=30, frequency=300).values,
|
||||
]
|
||||
)
|
||||
|
||||
expected = xr.DataArray(
|
||||
expected_values,
|
||||
coords=sample_positions_top3.coords,
|
||||
dims="detection",
|
||||
name="test_values",
|
||||
)
|
||||
|
||||
extracted = extract_values_at_positions(
|
||||
sample_array_for_extraction, sample_positions_top3
|
||||
)
|
||||
|
||||
xr.testing.assert_allclose(extracted, expected)
|
||||
|
||||
|
||||
def test_extract_values_at_positions_extra_dims(
|
||||
sample_sizes_array,
|
||||
sample_positions_top2,
|
||||
):
|
||||
"""Test extraction preserves other dimensions in the source array."""
|
||||
times = np.array([30, 20])
|
||||
freqs = np.array([200, 100])
|
||||
|
||||
expected_values = np.array(
|
||||
[
|
||||
sample_sizes_array.sel(time=30, frequency=200).values,
|
||||
sample_sizes_array.sel(time=20, frequency=100).values,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
expected = xr.DataArray(
|
||||
expected_values,
|
||||
coords={
|
||||
"dimension": ["width", "height"],
|
||||
Dimensions.frequency.value: ("detection", freqs),
|
||||
Dimensions.time.value: ("detection", times),
|
||||
},
|
||||
dims=["detection", "dimension"],
|
||||
name="sizes",
|
||||
)
|
||||
|
||||
extracted = extract_values_at_positions(
|
||||
sample_sizes_array,
|
||||
sample_positions_top2,
|
||||
)
|
||||
|
||||
xr.testing.assert_equal(extracted, expected)
|
||||
|
||||
|
||||
def test_extract_values_at_positions_empty(
|
||||
sample_array_for_extraction, empty_positions
|
||||
):
|
||||
"""Test extraction with empty positions returns empty array."""
|
||||
extracted = extract_values_at_positions(
|
||||
sample_array_for_extraction, empty_positions
|
||||
)
|
||||
assert extracted.sizes["detection"] == 0
|
||||
assert Dimensions.time.value in extracted.coords
|
||||
assert Dimensions.frequency.value in extracted.coords
|
||||
assert extracted.coords[Dimensions.time.value].size == 0
|
||||
assert extracted.coords[Dimensions.frequency.value].size == 0
|
||||
assert extracted.name == sample_array_for_extraction.name
|
||||
|
||||
|
||||
def test_extract_values_at_positions_missing_coord_in_array(
|
||||
sample_array_for_extraction, sample_positions_top2
|
||||
):
|
||||
"""Test error if source array misses required coordinates."""
|
||||
array_no_time = sample_array_for_extraction.copy()
|
||||
del array_no_time.coords[Dimensions.time.value]
|
||||
with pytest.raises(IndexError):
|
||||
extract_values_at_positions(array_no_time, sample_positions_top2)
|
||||
|
||||
array_no_freq = sample_array_for_extraction.copy()
|
||||
del array_no_freq.coords[Dimensions.frequency.value]
|
||||
with pytest.raises(IndexError):
|
||||
extract_values_at_positions(array_no_freq, sample_positions_top2)
|
||||
|
||||
|
||||
def test_extract_values_at_positions_missing_coord_in_positions(
|
||||
sample_array_for_extraction, sample_positions_top2
|
||||
):
|
||||
"""Test error if positions array misses required coordinates."""
|
||||
positions_no_time = sample_positions_top2.copy()
|
||||
del positions_no_time.coords[Dimensions.time.value]
|
||||
with pytest.raises(KeyError):
|
||||
extract_values_at_positions(
|
||||
sample_array_for_extraction, positions_no_time
|
||||
)
|
||||
|
||||
positions_no_freq = sample_positions_top2.copy()
|
||||
del positions_no_freq.coords[Dimensions.frequency.value]
|
||||
with pytest.raises(KeyError):
|
||||
extract_values_at_positions(
|
||||
sample_array_for_extraction, positions_no_freq
|
||||
)
|
||||
|
||||
|
||||
def test_extract_values_at_positions_mismatched_coords(
|
||||
sample_array_for_extraction, sample_positions_top2
|
||||
):
|
||||
"""Test error if positions requests coords not in source array."""
|
||||
bad_positions = sample_positions_top2.copy()
|
||||
bad_positions.coords[Dimensions.time.value] = (
|
||||
"detection",
|
||||
np.array([40, 10]),
|
||||
)
|
||||
with pytest.raises(KeyError):
|
||||
extract_values_at_positions(sample_array_for_extraction, bad_positions)
|
||||
|
||||
|
||||
def test_extract_detection_xr_dataset_correct(
|
||||
sample_positions_top2,
|
||||
sample_sizes_array,
|
||||
sample_classes_array,
|
||||
sample_features_array,
|
||||
):
|
||||
"""Tests extracting and bundling info for top 2 detections."""
|
||||
actual_dataset = extract_detection_xr_dataset(
|
||||
sample_positions_top2,
|
||||
sample_sizes_array,
|
||||
sample_classes_array,
|
||||
sample_features_array,
|
||||
)
|
||||
|
||||
expected_times = np.array([30, 20])
|
||||
expected_freqs = np.array([200, 100])
|
||||
detection_coords = {
|
||||
Dimensions.time.value: ("detection", expected_times),
|
||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
||||
}
|
||||
|
||||
expected_score = sample_positions_top2
|
||||
|
||||
expected_dimensions_data = np.array(
|
||||
[
|
||||
sample_sizes_array.sel(time=30, frequency=200).values,
|
||||
sample_sizes_array.sel(time=20, frequency=100).values,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
expected_dimensions = xr.DataArray(
|
||||
expected_dimensions_data,
|
||||
coords={**detection_coords, "dimension": ["width", "height"]},
|
||||
dims=["detection", "dimension"],
|
||||
name="dimensions",
|
||||
)
|
||||
|
||||
expected_classes_data = np.array(
|
||||
[
|
||||
sample_classes_array.sel(time=30, frequency=200).values,
|
||||
sample_classes_array.sel(time=20, frequency=100).values,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
expected_classes = xr.DataArray(
|
||||
expected_classes_data,
|
||||
coords={**detection_coords, "category": ["bat", "noise"]},
|
||||
dims=["detection", "category"],
|
||||
name="classes",
|
||||
)
|
||||
|
||||
expected_features_data = np.array(
|
||||
[
|
||||
sample_features_array.sel(time=30, frequency=200).values,
|
||||
sample_features_array.sel(time=20, frequency=100).values,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
expected_features = xr.DataArray(
|
||||
expected_features_data,
|
||||
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
|
||||
dims=["detection", "feature"],
|
||||
name="features",
|
||||
)
|
||||
|
||||
expected_dataset = xr.Dataset(
|
||||
{
|
||||
"scores": expected_score,
|
||||
"dimensions": expected_dimensions,
|
||||
"classes": expected_classes,
|
||||
"features": expected_features,
|
||||
}
|
||||
)
|
||||
expected_dataset = expected_dataset.assign_coords(detection_coords)
|
||||
|
||||
xr.testing.assert_allclose(actual_dataset, expected_dataset)
|
||||
|
||||
|
||||
def test_extract_detection_xr_dataset_empty(
|
||||
empty_positions,
|
||||
sample_sizes_array,
|
||||
sample_classes_array,
|
||||
sample_features_array,
|
||||
):
|
||||
"""Test extraction with empty positions yields an empty dataset."""
|
||||
actual_dataset = extract_detection_xr_dataset(
|
||||
empty_positions,
|
||||
sample_sizes_array,
|
||||
sample_classes_array,
|
||||
sample_features_array,
|
||||
)
|
||||
|
||||
assert isinstance(actual_dataset, xr.Dataset)
|
||||
assert "detection" in actual_dataset.dims
|
||||
assert actual_dataset.sizes["detection"] == 0
|
||||
|
||||
assert "scores" in actual_dataset
|
||||
assert actual_dataset["scores"].dims == ("detection",)
|
||||
assert actual_dataset["scores"].size == 0
|
||||
|
||||
assert "dimensions" in actual_dataset
|
||||
assert actual_dataset["dimensions"].dims == ("detection", "dimension")
|
||||
assert actual_dataset["dimensions"].shape == (0, 2)
|
||||
|
||||
assert "classes" in actual_dataset
|
||||
assert actual_dataset["classes"].dims == ("detection", "category")
|
||||
assert actual_dataset["classes"].shape == (0, 2)
|
||||
|
||||
assert "features" in actual_dataset
|
||||
assert actual_dataset["features"].dims == ("detection", "feature")
|
||||
assert actual_dataset["features"].shape == (0, 4)
|
||||
|
||||
assert Dimensions.time.value in actual_dataset.coords
|
||||
assert Dimensions.frequency.value in actual_dataset.coords
|
||||
assert actual_dataset.coords[Dimensions.time.value].size == 0
|
||||
assert actual_dataset.coords[Dimensions.frequency.value].size == 0
|
||||
|
||||
@ -162,7 +162,8 @@ def test_selected_random_subclip_has_the_correct_width(
|
||||
|
||||
subclip = select_subclip(
|
||||
original,
|
||||
samplerate=256_000,
|
||||
input_samplerate=256_000,
|
||||
output_samplerate=1000,
|
||||
start=0,
|
||||
duration=0.512,
|
||||
)
|
||||
|
||||
@ -39,9 +39,8 @@ def build_from_config(
|
||||
)
|
||||
postprocessor = build_postprocessor(
|
||||
targets,
|
||||
preprocessor=preprocessor,
|
||||
config=postprocessing_config,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
return targets, preprocessor, labeller, postprocessor
|
||||
@ -84,7 +83,10 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
|
||||
encoded = generate_train_example(
|
||||
clip_annotation, sample_audio_loader, preprocessor, labeller
|
||||
clip_annotation,
|
||||
sample_audio_loader,
|
||||
preprocessor,
|
||||
labeller,
|
||||
)
|
||||
predictions = postprocessor.get_predictions(
|
||||
ModelOutput(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user