mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Remove xr from postprocess
This commit is contained in:
parent
cc9e47b022
commit
281c4dcb8a
@ -140,9 +140,8 @@ def build_model(config: Optional[ModelConfig] = None):
|
|||||||
preprocessor = build_preprocessor(config=config.preprocess)
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
preprocessor=preprocessor,
|
||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
)
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
num_classes=len(targets.class_names),
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from matplotlib.axes import Axes
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.plotting.common import plot_spectrogram
|
from batdetect2.plotting.common import plot_spectrogram
|
||||||
from batdetect2.preprocess import build_audio_loader, get_default_preprocessor
|
from batdetect2.preprocess import build_audio_loader, build_preprocessor
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -27,7 +27,7 @@ def plot_clip(
|
|||||||
_, ax = plt.subplots(figsize=figsize)
|
_, ax = plt.subplots(figsize=figsize)
|
||||||
|
|
||||||
if preprocessor is None:
|
if preprocessor is None:
|
||||||
preprocessor = get_default_preprocessor()
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
if audio_loader is None:
|
if audio_loader is None:
|
||||||
audio_loader = build_audio_loader()
|
audio_loader = build_audio_loader()
|
||||||
|
|||||||
@ -8,10 +8,7 @@ from soundevent.plot.tags import TagColorMapper
|
|||||||
|
|
||||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.preprocess import (
|
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||||
PreprocessorProtocol,
|
|
||||||
get_default_preprocessor,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.evaluate import MatchEvaluation
|
from batdetect2.typing.evaluate import MatchEvaluation
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -50,7 +47,7 @@ def plot_matches(
|
|||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
if preprocessor is None:
|
if preprocessor is None:
|
||||||
preprocessor = get_default_preprocessor()
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
ax = plot_clip(
|
ax = plot_clip(
|
||||||
clip,
|
clip,
|
||||||
|
|||||||
@ -1,36 +1,7 @@
|
|||||||
"""Main entry point for the BatDetect2 Postprocessing pipeline.
|
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
|
||||||
|
|
||||||
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
|
|
||||||
BatDetect2 neural network model and transforms them into meaningful, structured
|
|
||||||
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
|
|
||||||
containing detected sound events with associated class tags and geometry.
|
|
||||||
|
|
||||||
The pipeline involves several configurable steps, implemented in submodules:
|
|
||||||
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
|
||||||
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
|
|
||||||
model output arrays.
|
|
||||||
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
|
||||||
(location and score) based on thresholds and score ranking (top-k).
|
|
||||||
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
|
||||||
class probabilities, features) at the detected locations.
|
|
||||||
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
|
|
||||||
class predictions into interpretable `soundevent` objects, including
|
|
||||||
recovering geometry (ROIs) and decoding class names back to standard tags.
|
|
||||||
|
|
||||||
This module provides the primary interface:
|
|
||||||
- `PostprocessConfig`: A configuration object for postprocessing parameters
|
|
||||||
(thresholds, NMS kernel size, etc.).
|
|
||||||
- `load_postprocess_config`: Function to load the configuration from a file.
|
|
||||||
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
|
|
||||||
holds the configured pipeline logic.
|
|
||||||
- `build_postprocessor`: A factory function to create a `Postprocessor`
|
|
||||||
instance, linking it to the necessary target definitions (`TargetProtocol`).
|
|
||||||
It also re-exports key components from submodules for convenience.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import xarray as xr
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -38,37 +9,24 @@ from soundevent import data
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.postprocess.decoding import (
|
from batdetect2.postprocess.decoding import (
|
||||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
convert_detections_to_raw_predictions,
|
||||||
convert_raw_prediction_to_sound_event_prediction,
|
convert_raw_prediction_to_sound_event_prediction,
|
||||||
convert_raw_predictions_to_clip_prediction,
|
convert_raw_predictions_to_clip_prediction,
|
||||||
convert_xr_dataset_to_raw_prediction,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess.detection import (
|
|
||||||
DEFAULT_DETECTION_THRESHOLD,
|
|
||||||
TOP_K_PER_SEC,
|
|
||||||
extract_detections_from_array,
|
|
||||||
get_max_detections,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess.extraction import (
|
|
||||||
extract_detection_xr_dataset,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.postprocess.extraction import extract_prediction_tensor
|
||||||
from batdetect2.postprocess.nms import (
|
from batdetect2.postprocess.nms import (
|
||||||
NMS_KERNEL_SIZE,
|
NMS_KERNEL_SIZE,
|
||||||
non_max_suppression,
|
non_max_suppression,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.remapping import (
|
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||||
classification_to_xarray,
|
|
||||||
detection_to_xarray,
|
|
||||||
features_to_xarray,
|
|
||||||
sizes_to_xarray,
|
|
||||||
)
|
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.typing.models import ModelOutput
|
from batdetect2.typing import ModelOutput, PreprocessorProtocol, TargetProtocol
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
BatDetect2Prediction,
|
BatDetect2Prediction,
|
||||||
|
Detections,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
RawPrediction,
|
RawPrediction,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
@ -81,19 +39,17 @@ __all__ = [
|
|||||||
"Postprocessor",
|
"Postprocessor",
|
||||||
"TOP_K_PER_SEC",
|
"TOP_K_PER_SEC",
|
||||||
"build_postprocessor",
|
"build_postprocessor",
|
||||||
"classification_to_xarray",
|
|
||||||
"convert_raw_predictions_to_clip_prediction",
|
"convert_raw_predictions_to_clip_prediction",
|
||||||
"convert_xr_dataset_to_raw_prediction",
|
"convert_detections_to_raw_predictions",
|
||||||
"detection_to_xarray",
|
|
||||||
"extract_detection_xr_dataset",
|
|
||||||
"extract_detections_from_array",
|
|
||||||
"features_to_xarray",
|
|
||||||
"get_max_detections",
|
|
||||||
"load_postprocess_config",
|
"load_postprocess_config",
|
||||||
"non_max_suppression",
|
"non_max_suppression",
|
||||||
"sizes_to_xarray",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
TOP_K_PER_SEC = 200
|
||||||
|
|
||||||
|
|
||||||
class PostprocessConfig(BaseConfig):
|
class PostprocessConfig(BaseConfig):
|
||||||
"""Configuration settings for the postprocessing pipeline.
|
"""Configuration settings for the postprocessing pipeline.
|
||||||
@ -173,40 +129,10 @@ def load_postprocess_config(
|
|||||||
|
|
||||||
def build_postprocessor(
|
def build_postprocessor(
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[PostprocessConfig] = None,
|
config: Optional[PostprocessConfig] = None,
|
||||||
max_freq: float = MAX_FREQ,
|
|
||||||
min_freq: float = MIN_FREQ,
|
|
||||||
) -> PostprocessorProtocol:
|
) -> PostprocessorProtocol:
|
||||||
"""Factory function to build the standard postprocessor.
|
"""Factory function to build the standard postprocessor."""
|
||||||
|
|
||||||
Creates and initializes the `Postprocessor` instance, providing it with the
|
|
||||||
necessary `targets` object and the `PostprocessConfig`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
targets : TargetProtocol
|
|
||||||
An initialized object conforming to the `TargetProtocol`, providing
|
|
||||||
methods like `.decode()` and `.recover_roi()`, and attributes like
|
|
||||||
`.class_names` and `.generic_class_tags`. This links postprocessing
|
|
||||||
to the defined target semantics and geometry mappings.
|
|
||||||
config : PostprocessConfig, optional
|
|
||||||
Configuration object specifying postprocessing parameters (thresholds,
|
|
||||||
NMS kernel size, etc.). If None, default settings defined in
|
|
||||||
`PostprocessConfig` will be used.
|
|
||||||
min_freq : int, default=MIN_FREQ
|
|
||||||
The minimum frequency (Hz) corresponding to the frequency axis of the
|
|
||||||
model outputs. Required for coordinate remapping. Consider setting via
|
|
||||||
`PostprocessConfig` instead for better encapsulation.
|
|
||||||
max_freq : int, default=MAX_FREQ
|
|
||||||
The maximum frequency (Hz) corresponding to the frequency axis of the
|
|
||||||
model outputs. Required for coordinate remapping. Consider setting via
|
|
||||||
`PostprocessConfig`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
PostprocessorProtocol
|
|
||||||
An initialized `Postprocessor` instance ready to process model outputs.
|
|
||||||
"""
|
|
||||||
config = config or PostprocessConfig()
|
config = config or PostprocessConfig()
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building postprocessor with config: \n{}",
|
"Building postprocessor with config: \n{}",
|
||||||
@ -214,303 +140,62 @@ def build_postprocessor(
|
|||||||
)
|
)
|
||||||
return Postprocessor(
|
return Postprocessor(
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
preprocessor=preprocessor,
|
||||||
config=config,
|
config=config,
|
||||||
min_freq=min_freq,
|
|
||||||
max_freq=max_freq,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Postprocessor(PostprocessorProtocol):
|
class Postprocessor(PostprocessorProtocol):
|
||||||
"""Standard implementation of the postprocessing pipeline.
|
"""Standard implementation of the postprocessing pipeline."""
|
||||||
|
|
||||||
This class orchestrates the steps required to convert raw model outputs
|
|
||||||
into interpretable `soundevent` predictions. It uses configured parameters
|
|
||||||
and leverages functions from the `batdetect2.postprocess` submodules for
|
|
||||||
each stage (NMS, remapping, detection, extraction, decoding).
|
|
||||||
|
|
||||||
It requires a `TargetProtocol` object during initialization to access
|
|
||||||
necessary decoding information (class name to tag mapping,
|
|
||||||
ROI recovery logic) ensuring consistency with the target definitions used
|
|
||||||
during training or specified for inference.
|
|
||||||
|
|
||||||
Instances are typically created using the `build_postprocessor` factory
|
|
||||||
function.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
targets : TargetProtocol
|
|
||||||
The configured target definition object providing decoding and ROI
|
|
||||||
recovery.
|
|
||||||
config : PostprocessConfig
|
|
||||||
Configuration object holding parameters for NMS, thresholds, etc.
|
|
||||||
min_freq : float
|
|
||||||
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
|
||||||
max_freq : float
|
|
||||||
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
|
||||||
"""
|
|
||||||
|
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
|
|
||||||
|
preprocessor: PreprocessorProtocol
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
config: PostprocessConfig,
|
config: PostprocessConfig,
|
||||||
min_freq: float = MIN_FREQ,
|
|
||||||
max_freq: float = MAX_FREQ,
|
|
||||||
):
|
):
|
||||||
"""Initialize the Postprocessor.
|
"""Initialize the Postprocessor."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
targets : TargetProtocol
|
|
||||||
Initialized target definition object.
|
|
||||||
config : PostprocessConfig
|
|
||||||
Configuration for postprocessing parameters.
|
|
||||||
min_freq : int, default=MIN_FREQ
|
|
||||||
Minimum frequency (Hz) for coordinate remapping.
|
|
||||||
max_freq : int, default=MAX_FREQ
|
|
||||||
Maximum frequency (Hz) for coordinate remapping.
|
|
||||||
"""
|
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
self.preprocessor = preprocessor
|
||||||
self.config = config
|
self.config = config
|
||||||
self.min_freq = min_freq
|
|
||||||
self.max_freq = max_freq
|
|
||||||
|
|
||||||
def get_feature_arrays(
|
def get_detections(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
clips: List[data.Clip],
|
clips: Optional[List[data.Clip]] = None,
|
||||||
) -> List[xr.DataArray]:
|
) -> List[Detections]:
|
||||||
"""Extract and remap raw feature tensors for a batch.
|
width = output.detection_probs.shape[-1]
|
||||||
|
duration = width / self.preprocessor.output_samplerate
|
||||||
|
max_detections = int(self.config.top_k_per_sec * duration)
|
||||||
|
|
||||||
Parameters
|
detections = extract_prediction_tensor(
|
||||||
----------
|
output,
|
||||||
output : ModelOutput
|
max_detections=max_detections,
|
||||||
Raw model output containing `output.features` tensor for the batch.
|
threshold=self.config.detection_threshold,
|
||||||
clips : List[data.Clip]
|
|
||||||
List of Clip objects corresponding to the batch items.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[xr.DataArray]
|
|
||||||
List of coordinate-aware feature DataArrays, one per clip.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If batch sizes of `output.features` and `clips` do not match.
|
|
||||||
"""
|
|
||||||
if len(clips) != len(output.features):
|
|
||||||
raise ValueError(
|
|
||||||
"Number of clips and batch size of feature array"
|
|
||||||
"do not match. "
|
|
||||||
f"(clips: {len(clips)}, features: {len(output.features)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
features_to_xarray(
|
|
||||||
feats,
|
|
||||||
start_time=clip.start_time,
|
|
||||||
end_time=clip.end_time,
|
|
||||||
min_freq=self.min_freq,
|
|
||||||
max_freq=self.max_freq,
|
|
||||||
)
|
|
||||||
for feats, clip in zip(output.features, clips)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_detection_arrays(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
) -> List[xr.DataArray]:
|
|
||||||
"""Apply NMS and remap detection heatmaps for a batch.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
Raw model output containing `output.detection_probs` tensor for the
|
|
||||||
batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
List of Clip objects corresponding to the batch items.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[xr.DataArray]
|
|
||||||
List of NMS-applied, coordinate-aware detection heatmaps, one per
|
|
||||||
clip.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If batch sizes of `output.detection_probs` and `clips` do not match.
|
|
||||||
"""
|
|
||||||
detections = output.detection_probs
|
|
||||||
|
|
||||||
if len(clips) != len(output.detection_probs):
|
|
||||||
raise ValueError(
|
|
||||||
"Number of clips and batch size of detection array "
|
|
||||||
"do not match. "
|
|
||||||
f"(clips: {len(clips)}, detection: {len(detections)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
detections = non_max_suppression(
|
|
||||||
detections,
|
|
||||||
kernel_size=self.config.nms_kernel_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
if clips is None:
|
||||||
detection_to_xarray(
|
return detections
|
||||||
dets,
|
|
||||||
start_time=clip.start_time,
|
|
||||||
end_time=clip.end_time,
|
|
||||||
min_freq=self.min_freq,
|
|
||||||
max_freq=self.max_freq,
|
|
||||||
)
|
|
||||||
for dets, clip in zip(detections, clips)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_classification_arrays(
|
|
||||||
self, output: ModelOutput, clips: List[data.Clip]
|
|
||||||
) -> List[xr.DataArray]:
|
|
||||||
"""Extract and remap raw classification tensors for a batch.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
Raw model output containing `output.class_probs` tensor for the
|
|
||||||
batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
List of Clip objects corresponding to the batch items.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[xr.DataArray]
|
|
||||||
List of coordinate-aware class probability maps, one per clip.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If batch sizes of `output.class_probs` and `clips` do not match, or
|
|
||||||
if number of classes mismatches `self.targets.class_names`.
|
|
||||||
"""
|
|
||||||
classifications = output.class_probs
|
|
||||||
|
|
||||||
if len(clips) != len(classifications):
|
|
||||||
raise ValueError(
|
|
||||||
"Number of clips and batch size of classification array "
|
|
||||||
"do not match. "
|
|
||||||
f"(clips: {len(clips)}, classification: {len(classifications)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
classification_to_xarray(
|
map_detection_to_clip(
|
||||||
class_probs,
|
detection,
|
||||||
start_time=clip.start_time,
|
start_time=clip.start_time,
|
||||||
end_time=clip.end_time,
|
end_time=clip.end_time,
|
||||||
class_names=self.targets.class_names,
|
min_freq=self.preprocessor.min_freq,
|
||||||
min_freq=self.min_freq,
|
max_freq=self.preprocessor.max_freq,
|
||||||
max_freq=self.max_freq,
|
|
||||||
)
|
)
|
||||||
for class_probs, clip in zip(classifications, clips)
|
for detection, clip in zip(detections, clips)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_sizes_arrays(
|
|
||||||
self, output: ModelOutput, clips: List[data.Clip]
|
|
||||||
) -> List[xr.DataArray]:
|
|
||||||
"""Extract and remap raw size prediction tensors for a batch.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
Raw model output containing `output.size_preds` tensor for the
|
|
||||||
batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
List of Clip objects corresponding to the batch items.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[xr.DataArray]
|
|
||||||
List of coordinate-aware size prediction maps, one per clip.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If batch sizes of `output.size_preds` and `clips` do not match.
|
|
||||||
"""
|
|
||||||
sizes = output.size_preds
|
|
||||||
|
|
||||||
if len(clips) != len(sizes):
|
|
||||||
raise ValueError(
|
|
||||||
"Number of clips and batch size of sizes array do not match. "
|
|
||||||
f"(clips: {len(clips)}, sizes: {len(sizes)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
sizes_to_xarray(
|
|
||||||
size_preds,
|
|
||||||
start_time=clip.start_time,
|
|
||||||
end_time=clip.end_time,
|
|
||||||
min_freq=self.min_freq,
|
|
||||||
max_freq=self.max_freq,
|
|
||||||
)
|
|
||||||
for size_preds, clip in zip(sizes, clips)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_detection_datasets(
|
|
||||||
self, output: ModelOutput, clips: List[data.Clip]
|
|
||||||
) -> List[xr.Dataset]:
|
|
||||||
"""Perform NMS, remapping, detection, and data extraction for a batch.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
Raw output from the neural network model for a batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[xr.Dataset]
|
|
||||||
List of xarray Datasets (one per clip). Each Dataset contains
|
|
||||||
aligned scores, dimensions, class probabilities, and features for
|
|
||||||
detections found in that clip.
|
|
||||||
"""
|
|
||||||
detection_arrays = self.get_detection_arrays(output, clips)
|
|
||||||
classification_arrays = self.get_classification_arrays(output, clips)
|
|
||||||
size_arrays = self.get_sizes_arrays(output, clips)
|
|
||||||
features_arrays = self.get_feature_arrays(output, clips)
|
|
||||||
|
|
||||||
datasets = []
|
|
||||||
for det_array, class_array, sizes_array, feats_array in zip(
|
|
||||||
detection_arrays,
|
|
||||||
classification_arrays,
|
|
||||||
size_arrays,
|
|
||||||
features_arrays,
|
|
||||||
):
|
|
||||||
max_detections = get_max_detections(
|
|
||||||
det_array,
|
|
||||||
top_k_per_sec=self.config.top_k_per_sec,
|
|
||||||
)
|
|
||||||
|
|
||||||
positions = extract_detections_from_array(
|
|
||||||
det_array,
|
|
||||||
max_detections=max_detections,
|
|
||||||
threshold=self.config.detection_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
datasets.append(
|
|
||||||
extract_detection_xr_dataset(
|
|
||||||
positions,
|
|
||||||
sizes_array,
|
|
||||||
class_array,
|
|
||||||
feats_array,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
def get_raw_predictions(
|
def get_raw_predictions(
|
||||||
self, output: ModelOutput, clips: List[data.Clip]
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
) -> List[List[RawPrediction]]:
|
) -> List[List[RawPrediction]]:
|
||||||
"""Extract intermediate RawPrediction objects for a batch.
|
"""Extract intermediate RawPrediction objects for a batch.
|
||||||
|
|
||||||
@ -531,13 +216,13 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
List of lists (one inner list per input clip). Each inner list
|
List of lists (one inner list per input clip). Each inner list
|
||||||
contains `RawPrediction` objects for detections in that clip.
|
contains `RawPrediction` objects for detections in that clip.
|
||||||
"""
|
"""
|
||||||
detection_datasets = self.get_detection_datasets(output, clips)
|
detections = self.get_detections(output, clips)
|
||||||
return [
|
return [
|
||||||
convert_xr_dataset_to_raw_prediction(
|
convert_detections_to_raw_predictions(
|
||||||
dataset,
|
dataset,
|
||||||
self.targets.decode_roi,
|
targets=self.targets,
|
||||||
)
|
)
|
||||||
for dataset in detection_datasets
|
for dataset in detections
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_sound_event_predictions(
|
def get_sound_event_predictions(
|
||||||
|
|||||||
@ -1,42 +1,18 @@
|
|||||||
"""Decodes extracted detection data into standard soundevent predictions.
|
"""Decodes extracted detection data into standard soundevent predictions."""
|
||||||
|
|
||||||
This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
|
||||||
It takes the structured detection data extracted by the `extraction` module
|
|
||||||
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
|
||||||
class probabilities, and features for each detection point) and converts it
|
|
||||||
into standardized prediction objects based on the `soundevent` data model.
|
|
||||||
|
|
||||||
The process involves:
|
|
||||||
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
|
||||||
objects, using a configured geometry builder to recover bounding boxes from
|
|
||||||
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
|
|
||||||
2. Converting each `RawPrediction` into a
|
|
||||||
`soundevent.data.SoundEventPrediction`, which involves:
|
|
||||||
- Creating the `soundevent.data.SoundEvent` with geometry and features.
|
|
||||||
- Decoding the predicted class probabilities into representative tags using
|
|
||||||
a configured class decoder (`SoundEventDecoder`).
|
|
||||||
- Applying a classification threshold.
|
|
||||||
- Optionally selecting only the single highest-scoring class (top-1) or
|
|
||||||
including tags for all classes above the threshold (multi-label).
|
|
||||||
- Adding generic class tags as a baseline.
|
|
||||||
- Associating scores with the final prediction and tags.
|
|
||||||
(`convert_raw_prediction_to_sound_event_prediction`)
|
|
||||||
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
|
|
||||||
a `soundevent.data.ClipPrediction`
|
|
||||||
(`convert_raw_predictions_to_clip_prediction`).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import GeometryDecoder, RawPrediction
|
from batdetect2.typing.postprocess import (
|
||||||
|
Detections,
|
||||||
|
RawPrediction,
|
||||||
|
)
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_xr_dataset_to_raw_prediction",
|
"convert_detections_to_raw_predictions",
|
||||||
"convert_raw_predictions_to_clip_prediction",
|
"convert_raw_predictions_to_clip_prediction",
|
||||||
"convert_raw_prediction_to_sound_event_prediction",
|
"convert_raw_prediction_to_sound_event_prediction",
|
||||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
@ -51,65 +27,29 @@ decoding.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def convert_xr_dataset_to_raw_prediction(
|
def convert_detections_to_raw_predictions(
|
||||||
detection_dataset: xr.Dataset,
|
detections: Detections,
|
||||||
geometry_decoder: GeometryDecoder,
|
targets: TargetProtocol,
|
||||||
) -> List[RawPrediction]:
|
) -> List[RawPrediction]:
|
||||||
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
predictions = []
|
||||||
|
|
||||||
Takes the output of the extraction step (`extract_detection_xr_dataset`)
|
|
||||||
and transforms each detection entry into an intermediate `RawPrediction`
|
|
||||||
object. This involves recovering the geometry (e.g., bounding box) from
|
|
||||||
the predicted position and scaled size dimensions using the provided
|
|
||||||
`geometry_builder` function.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
detection_dataset : xr.Dataset
|
|
||||||
An xarray Dataset containing aligned detection information, typically
|
|
||||||
output by `extract_detection_xr_dataset`. Expected variables include
|
|
||||||
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
|
||||||
Must have a 'detection' dimension.
|
|
||||||
geometry_decoder : GeometryDecoder
|
|
||||||
A function that takes a position tuple `(time, freq)` and a NumPy array
|
|
||||||
of dimensions, and returns the corresponding reconstructed
|
|
||||||
`soundevent.data.Geometry`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[RawPrediction]
|
|
||||||
A list of `RawPrediction` objects, each containing the detection score,
|
|
||||||
recovered bounding box coordinates (start/end time, low/high freq),
|
|
||||||
the vector of class scores, and the feature vector for one detection.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
AttributeError, KeyError, ValueError
|
|
||||||
If `detection_dataset` is missing expected variables ('scores',
|
|
||||||
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
|
|
||||||
associated with 'scores'), or if `geometry_builder` fails.
|
|
||||||
"""
|
|
||||||
detections = []
|
|
||||||
|
|
||||||
categories = detection_dataset.category.values
|
|
||||||
|
|
||||||
for score, class_scores, time, freq, dims, feats in zip(
|
for score, class_scores, time, freq, dims, feats in zip(
|
||||||
detection_dataset["scores"].values,
|
detections.scores,
|
||||||
detection_dataset["classes"].values,
|
detections.class_scores,
|
||||||
detection_dataset["time"].values,
|
detections.times,
|
||||||
detection_dataset["frequency"].values,
|
detections.frequencies,
|
||||||
detection_dataset["dimensions"].values,
|
detections.sizes,
|
||||||
detection_dataset["features"].values,
|
detections.features,
|
||||||
):
|
):
|
||||||
highest_scoring_class = categories[class_scores.argmax()]
|
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
||||||
|
|
||||||
geom = geometry_decoder(
|
geom = targets.decode_roi(
|
||||||
(time, freq),
|
(time, freq),
|
||||||
dims,
|
dims,
|
||||||
class_name=highest_scoring_class,
|
class_name=highest_scoring_class,
|
||||||
)
|
)
|
||||||
|
|
||||||
detections.append(
|
predictions.append(
|
||||||
RawPrediction(
|
RawPrediction(
|
||||||
detection_score=score,
|
detection_score=score,
|
||||||
geometry=geom,
|
geometry=geom,
|
||||||
@ -118,7 +58,7 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return detections
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def convert_raw_predictions_to_clip_prediction(
|
def convert_raw_predictions_to_clip_prediction(
|
||||||
@ -128,35 +68,7 @@ def convert_raw_predictions_to_clip_prediction(
|
|||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
) -> data.ClipPrediction:
|
) -> data.ClipPrediction:
|
||||||
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
|
"""Convert a list of RawPredictions into a soundevent ClipPrediction."""
|
||||||
|
|
||||||
Iterates through `raw_predictions` (assumed to belong to a single clip),
|
|
||||||
converts each one into a `soundevent.data.SoundEventPrediction` using
|
|
||||||
`convert_raw_prediction_to_sound_event_prediction`, and packages them
|
|
||||||
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
raw_predictions : List[RawPrediction]
|
|
||||||
List of raw prediction objects for a single clip.
|
|
||||||
clip : data.Clip
|
|
||||||
The original `soundevent.data.Clip` object these predictions belong to.
|
|
||||||
sound_event_decoder : SoundEventDecoder
|
|
||||||
Function to decode class names into representative tags.
|
|
||||||
generic_class_tags : List[data.Tag]
|
|
||||||
List of tags representing the generic class category.
|
|
||||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
|
||||||
Threshold applied to class scores during decoding.
|
|
||||||
top_class_only : bool, default=False
|
|
||||||
If True, only decode tags for the single highest-scoring class above
|
|
||||||
the threshold. If False, decode tags for all classes above threshold.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
data.ClipPrediction
|
|
||||||
A `ClipPrediction` object containing a list of `SoundEventPrediction`
|
|
||||||
objects corresponding to the input `raw_predictions`.
|
|
||||||
"""
|
|
||||||
return data.ClipPrediction(
|
return data.ClipPrediction(
|
||||||
clip=clip,
|
clip=clip,
|
||||||
sound_events=[
|
sound_events=[
|
||||||
@ -181,68 +93,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
|||||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
):
|
):
|
||||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
|
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
|
||||||
|
|
||||||
This function performs the core decoding steps for a single detected event:
|
|
||||||
1. Creates a `soundevent.data.SoundEvent` containing the geometry
|
|
||||||
(BoundingBox derived from `raw_prediction` bounds) and any associated
|
|
||||||
feature vectors.
|
|
||||||
2. Initializes a list of predicted tags using the provided
|
|
||||||
`generic_class_tags`, assigning the overall `detection_score` from the
|
|
||||||
`raw_prediction` to these generic tags.
|
|
||||||
3. Processes the `class_scores` from the `raw_prediction`:
|
|
||||||
a. Optionally filters out scores below `classification_threshold`
|
|
||||||
(if it's not None).
|
|
||||||
b. Sorts the remaining scores in descending order.
|
|
||||||
c. Iterates through the sorted, thresholded class scores.
|
|
||||||
d. For each class, uses the `sound_event_decoder` to get the
|
|
||||||
representative base tags for that class name.
|
|
||||||
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
|
|
||||||
the specific `score` of that class prediction.
|
|
||||||
f. Appends these specific predicted tags to the list.
|
|
||||||
g. If `top_class_only` is True, stops after processing the first
|
|
||||||
(highest-scoring) class that passed the threshold.
|
|
||||||
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
|
|
||||||
associating the `SoundEvent`, the overall `detection_score`, and the
|
|
||||||
compiled list of `PredictedTag` objects.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
raw_prediction : RawPrediction
|
|
||||||
The raw prediction object containing score, bounds, class scores,
|
|
||||||
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
|
|
||||||
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
|
|
||||||
coordinate.
|
|
||||||
recording : data.Recording
|
|
||||||
The recording the sound event belongs to.
|
|
||||||
sound_event_decoder : SoundEventDecoder
|
|
||||||
Configured function mapping class names (str) to lists of base
|
|
||||||
`data.Tag` objects.
|
|
||||||
generic_class_tags : List[data.Tag]
|
|
||||||
List of base tags representing the generic category.
|
|
||||||
classification_threshold : float, optional
|
|
||||||
The minimum score a class prediction must have to be considered
|
|
||||||
significant enough to have its tags decoded and added. If None, no
|
|
||||||
thresholding is applied based on class score (all predicted classes,
|
|
||||||
or the top one if `top_class_only` is True, will be processed).
|
|
||||||
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
|
|
||||||
top_class_only : bool, default=False
|
|
||||||
If True, only includes tags for the single highest-scoring class that
|
|
||||||
exceeds the threshold. If False (default), includes tags for all classes
|
|
||||||
exceeding the threshold.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
data.SoundEventPrediction
|
|
||||||
The fully formed sound event prediction object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `raw_prediction.features` has unexpected structure or if
|
|
||||||
`data.term_from_key` (if used internally) fails.
|
|
||||||
If `sound_event_decoder` fails for a class name and errors are raised.
|
|
||||||
"""
|
|
||||||
sound_event = data.SoundEvent(
|
sound_event = data.SoundEvent(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
geometry=raw_prediction.geometry,
|
geometry=raw_prediction.geometry,
|
||||||
@ -273,25 +124,7 @@ def get_generic_tags(
|
|||||||
detection_score: float,
|
detection_score: float,
|
||||||
generic_class_tags: List[data.Tag],
|
generic_class_tags: List[data.Tag],
|
||||||
) -> List[data.PredictedTag]:
|
) -> List[data.PredictedTag]:
|
||||||
"""Create PredictedTag objects for the generic category.
|
"""Create PredictedTag objects for the generic category."""
|
||||||
|
|
||||||
Takes the base list of generic tags and assigns the overall detection
|
|
||||||
score to each one, wrapping them in `PredictedTag` objects.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
detection_score : float
|
|
||||||
The overall confidence score of the detection event.
|
|
||||||
generic_class_tags : List[data.Tag]
|
|
||||||
The list of base `soundevent.data.Tag` objects that define the
|
|
||||||
generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[data.PredictedTag]
|
|
||||||
A list of `PredictedTag` objects for the generic category, each
|
|
||||||
assigned the `detection_score`.
|
|
||||||
"""
|
|
||||||
return [
|
return [
|
||||||
data.PredictedTag(tag=tag, score=detection_score)
|
data.PredictedTag(tag=tag, score=detection_score)
|
||||||
for tag in generic_class_tags
|
for tag in generic_class_tags
|
||||||
@ -299,25 +132,7 @@ def get_generic_tags(
|
|||||||
|
|
||||||
|
|
||||||
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
||||||
"""Convert an extracted feature vector DataArray into soundevent Features.
|
"""Convert an extracted feature vector DataArray into soundevent Features."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
features : xr.DataArray
|
|
||||||
A 1D xarray DataArray containing feature values, indexed by a coordinate
|
|
||||||
named 'feature' which holds the feature names (e.g., output of selecting
|
|
||||||
features for one detection from `extract_detection_xr_dataset`).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[data.Feature]
|
|
||||||
A list of `soundevent.data.Feature` objects.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
- This function creates basic `Term` objects using the feature coordinate
|
|
||||||
names with a "batdetect2:" prefix.
|
|
||||||
"""
|
|
||||||
return [
|
return [
|
||||||
data.Feature(
|
data.Feature(
|
||||||
term=data.Term(
|
term=data.Term(
|
||||||
|
|||||||
@ -1,162 +0,0 @@
|
|||||||
"""Extracts candidate detection points from a model output heatmap.
|
|
||||||
|
|
||||||
This module implements Step 3 within the BatDetect2 postprocessing
|
|
||||||
pipeline. Its primary function is to identify potential sound event locations
|
|
||||||
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
|
||||||
produced by the neural network (usually after Non-Maximum Suppression and
|
|
||||||
coordinate remapping have been applied).
|
|
||||||
|
|
||||||
It provides functionality to:
|
|
||||||
- Identify the locations (time, frequency) of the highest-scoring points.
|
|
||||||
- Filter these points based on a minimum confidence score threshold.
|
|
||||||
- Limit the maximum number of detection points returned (top-k).
|
|
||||||
|
|
||||||
The main output is an `xarray.DataArray` containing the scores and
|
|
||||||
corresponding time/frequency coordinates for the extracted detection points.
|
|
||||||
This output serves as the input for subsequent postprocessing steps, such as
|
|
||||||
extracting predicted class probabilities and bounding box sizes at these
|
|
||||||
specific locations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import xarray as xr
|
|
||||||
from soundevent.arrays import Dimensions, get_dim_width
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"extract_detections_from_array",
|
|
||||||
"get_max_detections",
|
|
||||||
"DEFAULT_DETECTION_THRESHOLD",
|
|
||||||
"TOP_K_PER_SEC",
|
|
||||||
]
|
|
||||||
|
|
||||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
|
||||||
"""Default confidence score threshold used for filtering detections."""
|
|
||||||
|
|
||||||
TOP_K_PER_SEC = 200
|
|
||||||
"""Default desired maximum number of detections per second of audio."""
|
|
||||||
|
|
||||||
|
|
||||||
def extract_detections_from_array(
|
|
||||||
detection_array: xr.DataArray,
|
|
||||||
max_detections: Optional[int] = None,
|
|
||||||
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Extract detection locations (time, freq) and scores from a heatmap.
|
|
||||||
|
|
||||||
Identifies the pixels with the highest scores in the input detection
|
|
||||||
heatmap, filters them based on an optional score `threshold`, limits the
|
|
||||||
number to an optional `max_detections`, and returns their scores along with
|
|
||||||
their corresponding time and frequency coordinates.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
detection_array : xr.DataArray
|
|
||||||
A 2D xarray DataArray representing the detection heatmap. Must have
|
|
||||||
dimensions and coordinates named 'time' and 'frequency'. Higher values
|
|
||||||
are assumed to indicate higher detection confidence.
|
|
||||||
max_detections : int, optional
|
|
||||||
The absolute maximum number of detections to return. If specified, only
|
|
||||||
the top `max_detections` highest-scoring detections (passing the
|
|
||||||
threshold) are returned. If None (default), all detections passing
|
|
||||||
the threshold are returned, sorted by score.
|
|
||||||
threshold : float, optional
|
|
||||||
The minimum confidence score required for a detection peak to be
|
|
||||||
kept. Detections with scores below this value are discarded.
|
|
||||||
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
|
|
||||||
thresholding is applied.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
A 1D xarray DataArray named 'score' with a 'detection' dimension.
|
|
||||||
- The data values are the scores of the extracted detections, sorted
|
|
||||||
in descending order.
|
|
||||||
- It has coordinates 'time' and 'frequency' (also indexed by the
|
|
||||||
'detection' dimension) indicating the location of each detection
|
|
||||||
peak in the original coordinate system.
|
|
||||||
- Returns an empty DataArray if no detections pass the criteria.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `max_detections` is not None and not a positive integer, or if
|
|
||||||
`detection_array` lacks required dimensions/coordinates.
|
|
||||||
"""
|
|
||||||
if max_detections is not None:
|
|
||||||
if max_detections <= 0:
|
|
||||||
raise ValueError("Max detections must be positive")
|
|
||||||
|
|
||||||
values = detection_array.values.flatten()
|
|
||||||
|
|
||||||
if max_detections is not None:
|
|
||||||
top_indices = np.argpartition(-values, max_detections)[:max_detections]
|
|
||||||
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
|
|
||||||
else:
|
|
||||||
top_sorted_indices = np.argsort(-values)
|
|
||||||
|
|
||||||
top_values = values[top_sorted_indices]
|
|
||||||
|
|
||||||
if threshold is not None:
|
|
||||||
mask = top_values > threshold
|
|
||||||
top_values = top_values[mask]
|
|
||||||
top_sorted_indices = top_sorted_indices[mask]
|
|
||||||
|
|
||||||
freq_indices, time_indices = np.unravel_index(
|
|
||||||
top_sorted_indices,
|
|
||||||
detection_array.shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
times = detection_array.coords[Dimensions.time.value].values[time_indices]
|
|
||||||
freqs = detection_array.coords[Dimensions.frequency.value].values[
|
|
||||||
freq_indices
|
|
||||||
]
|
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
data=top_values,
|
|
||||||
coords={
|
|
||||||
Dimensions.frequency.value: ("detection", freqs),
|
|
||||||
Dimensions.time.value: ("detection", times),
|
|
||||||
},
|
|
||||||
dims="detection",
|
|
||||||
name="score",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_detections(
|
|
||||||
detection_array: xr.DataArray,
|
|
||||||
top_k_per_sec: int = TOP_K_PER_SEC,
|
|
||||||
) -> int:
|
|
||||||
"""Calculate max detections allowed based on duration and rate.
|
|
||||||
|
|
||||||
Determines the total maximum number of detections to extract from a
|
|
||||||
heatmap based on its time duration and a desired rate of detections
|
|
||||||
per second.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
detection_array : xr.DataArray
|
|
||||||
The detection heatmap, requiring 'time' coordinates from which the
|
|
||||||
total duration can be calculated using
|
|
||||||
`soundevent.arrays.get_dim_width`.
|
|
||||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
|
||||||
The desired maximum number of detections to allow per second of audio.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
int
|
|
||||||
The calculated total maximum number of detections allowed for the
|
|
||||||
entire duration of the `detection_array`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the duration cannot be calculated from the `detection_array` (e.g.,
|
|
||||||
missing or invalid 'time' coordinates/dimension).
|
|
||||||
"""
|
|
||||||
if top_k_per_sec < 0:
|
|
||||||
raise ValueError("top_k_per_sec cannot be negative.")
|
|
||||||
|
|
||||||
duration = get_dim_width(detection_array, Dimensions.time.value)
|
|
||||||
return int(duration * top_k_per_sec)
|
|
||||||
@ -15,108 +15,73 @@ precise time-frequency location of each detection. The final output aggregates
|
|||||||
all extracted information into a structured `xarray.Dataset`.
|
all extracted information into a structured `xarray.Dataset`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import xarray as xr
|
from typing import List, Optional, Tuple, Union
|
||||||
from soundevent.arrays import Dimensions
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||||
|
from batdetect2.typing.postprocess import Detections, ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_values_at_positions",
|
"extract_prediction_tensor",
|
||||||
"extract_detection_xr_dataset",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def extract_values_at_positions(
|
def extract_prediction_tensor(
|
||||||
array: xr.DataArray,
|
output: ModelOutput,
|
||||||
positions: xr.DataArray,
|
max_detections: int = 200,
|
||||||
) -> xr.DataArray:
|
threshold: Optional[float] = None,
|
||||||
"""Extract values from an array at specified time-frequency positions.
|
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||||
|
) -> List[Detections]:
|
||||||
Uses coordinate-based indexing to retrieve values from a source `array`
|
detection_heatmap = non_max_suppression(
|
||||||
(e.g., class probabilities, size predictions, features) at the time and
|
output.detection_probs,
|
||||||
frequency coordinates defined in the `positions` array.
|
kernel_size=nms_kernel_size,
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
array : xr.DataArray
|
|
||||||
The source DataArray from which to extract values. Must have 'time'
|
|
||||||
and 'frequency' dimensions and coordinates matching the space of
|
|
||||||
`positions`.
|
|
||||||
positions : xr.DataArray
|
|
||||||
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
|
|
||||||
locations from which to extract values.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
A DataArray containing the values extracted from `array` at the given
|
|
||||||
positions.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError, IndexError, KeyError
|
|
||||||
If dimensions or coordinates are missing or incompatible between
|
|
||||||
`array` and `positions`, or if selection fails.
|
|
||||||
"""
|
|
||||||
return array.sel(
|
|
||||||
**{
|
|
||||||
Dimensions.frequency.value: positions.coords[
|
|
||||||
Dimensions.frequency.value
|
|
||||||
],
|
|
||||||
Dimensions.time.value: positions.coords[Dimensions.time.value],
|
|
||||||
}
|
|
||||||
).T
|
|
||||||
|
|
||||||
|
|
||||||
def extract_detection_xr_dataset(
|
|
||||||
positions: xr.DataArray,
|
|
||||||
sizes: xr.DataArray,
|
|
||||||
classes: xr.DataArray,
|
|
||||||
features: xr.DataArray,
|
|
||||||
) -> xr.Dataset:
|
|
||||||
"""Combine extracted detection information into a structured xr.Dataset.
|
|
||||||
|
|
||||||
Takes the detection positions/scores and the full model output heatmaps
|
|
||||||
(sizes, classes, optional features), extracts the relevant data at the
|
|
||||||
detection positions, and packages everything into a single `xarray.Dataset`
|
|
||||||
where all variables are indexed by a common 'detection' dimension.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
positions : xr.DataArray
|
|
||||||
Output from `extract_detections_from_array`, containing detection
|
|
||||||
scores as data and 'time', 'frequency' coordinates along the
|
|
||||||
'detection' dimension.
|
|
||||||
sizes : xr.DataArray
|
|
||||||
The full size prediction heatmap from the model, with dimensions like
|
|
||||||
('dimension', 'time', 'frequency').
|
|
||||||
classes : xr.DataArray
|
|
||||||
The full class probability heatmap from the model, with dimensions like
|
|
||||||
('category', 'time', 'frequency').
|
|
||||||
features : xr.DataArray
|
|
||||||
The full feature map from the model, with
|
|
||||||
dimensions like ('feature', 'time', 'frequency').
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.Dataset
|
|
||||||
An xarray Dataset containing aligned information for each detection:
|
|
||||||
- 'scores': DataArray from `positions` (score data, time/freq coords).
|
|
||||||
- 'dimensions': DataArray with extracted size values
|
|
||||||
(dims: 'detection', 'dimension').
|
|
||||||
- 'classes': DataArray with extracted class probabilities
|
|
||||||
(dims: 'detection', 'category').
|
|
||||||
- 'features': DataArray with extracted feature vectors
|
|
||||||
(dims: 'detection', 'feature'), if `features` was provided. All
|
|
||||||
DataArrays share the 'detection' dimension and associated
|
|
||||||
time/frequency coordinates.
|
|
||||||
"""
|
|
||||||
sizes = extract_values_at_positions(sizes, positions)
|
|
||||||
classes = extract_values_at_positions(classes, positions)
|
|
||||||
features = extract_values_at_positions(features, positions)
|
|
||||||
return xr.Dataset(
|
|
||||||
{
|
|
||||||
"scores": positions,
|
|
||||||
"dimensions": sizes,
|
|
||||||
"classes": classes,
|
|
||||||
"features": features,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
height = detection_heatmap.shape[-2]
|
||||||
|
width = detection_heatmap.shape[-1]
|
||||||
|
|
||||||
|
freqs, times = torch.meshgrid(
|
||||||
|
torch.arange(height, dtype=torch.int32),
|
||||||
|
torch.arange(width, dtype=torch.int32),
|
||||||
|
indexing="ij",
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = freqs.flatten()
|
||||||
|
times = times.flatten()
|
||||||
|
|
||||||
|
predictions = []
|
||||||
|
for idx, item in enumerate(detection_heatmap):
|
||||||
|
item = item.squeeze().flatten() # Remove channel dim
|
||||||
|
indices = torch.argsort(item, descending=True)[:max_detections]
|
||||||
|
|
||||||
|
detection_scores = item.take(indices)
|
||||||
|
detection_freqs = freqs.take(indices)
|
||||||
|
detection_times = times.take(indices)
|
||||||
|
sizes = output.size_preds[idx, :, detection_freqs, detection_times].T
|
||||||
|
features = output.features[idx, :, detection_freqs, detection_times].T
|
||||||
|
class_scores = output.class_probs[
|
||||||
|
idx, :, detection_freqs, detection_times
|
||||||
|
].T
|
||||||
|
|
||||||
|
if threshold is not None:
|
||||||
|
mask = detection_scores >= threshold
|
||||||
|
detection_scores = detection_scores[mask]
|
||||||
|
sizes = sizes[mask]
|
||||||
|
detection_times = detection_times[mask]
|
||||||
|
detection_freqs = detection_freqs[mask]
|
||||||
|
features = features[mask]
|
||||||
|
class_scores = class_scores[mask]
|
||||||
|
|
||||||
|
predictions.append(
|
||||||
|
Detections(
|
||||||
|
scores=detection_scores,
|
||||||
|
sizes=sizes,
|
||||||
|
features=features,
|
||||||
|
class_scores=class_scores,
|
||||||
|
times=detection_times.to(torch.float32) / width,
|
||||||
|
frequencies=(detection_freqs.to(torch.float32) / height),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import xarray as xr
|
|||||||
from soundevent.arrays import Dimensions
|
from soundevent.arrays import Dimensions
|
||||||
|
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
|
from batdetect2.typing.postprocess import Detections
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"features_to_xarray",
|
"features_to_xarray",
|
||||||
@ -29,6 +30,26 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def map_detection_to_clip(
|
||||||
|
detections: Detections,
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
min_freq: float,
|
||||||
|
max_freq: float,
|
||||||
|
) -> Detections:
|
||||||
|
duration = end_time - start_time
|
||||||
|
bandwidth = max_freq - min_freq
|
||||||
|
print(f"{bandwidth=} {min_freq=} {detections.frequencies=}")
|
||||||
|
return Detections(
|
||||||
|
scores=detections.scores,
|
||||||
|
sizes=detections.sizes,
|
||||||
|
features=detections.features,
|
||||||
|
class_scores=detections.class_scores,
|
||||||
|
times=(detections.times * duration + start_time),
|
||||||
|
frequencies=(detections.frequencies * bandwidth + min_freq),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def features_to_xarray(
|
def features_to_xarray(
|
||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
start_time: float,
|
start_time: float,
|
||||||
|
|||||||
@ -53,6 +53,7 @@ from batdetect2.preprocess.spectrogram import (
|
|||||||
SpectrogramConfig,
|
SpectrogramConfig,
|
||||||
SpectrogramPipeline,
|
SpectrogramPipeline,
|
||||||
STFTConfig,
|
STFTConfig,
|
||||||
|
_spec_params_from_config,
|
||||||
build_spectrogram_builder,
|
build_spectrogram_builder,
|
||||||
build_spectrogram_pipeline,
|
build_spectrogram_pipeline,
|
||||||
)
|
)
|
||||||
@ -109,7 +110,9 @@ def load_preprocessing_config(
|
|||||||
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||||
"""Standard implementation of the `Preprocessor` protocol."""
|
"""Standard implementation of the `Preprocessor` protocol."""
|
||||||
|
|
||||||
samplerate: int
|
input_samplerate: int
|
||||||
|
output_samplerate: float
|
||||||
|
|
||||||
max_freq: float
|
max_freq: float
|
||||||
min_freq: float
|
min_freq: float
|
||||||
|
|
||||||
@ -117,22 +120,33 @@ class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
|||||||
self,
|
self,
|
||||||
audio_pipeline: torch.nn.Module,
|
audio_pipeline: torch.nn.Module,
|
||||||
spectrogram_pipeline: SpectrogramPipeline,
|
spectrogram_pipeline: SpectrogramPipeline,
|
||||||
samplerate: int,
|
input_samplerate: int,
|
||||||
|
output_samplerate: float,
|
||||||
max_freq: float,
|
max_freq: float,
|
||||||
min_freq: float,
|
min_freq: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.audio_pipeline = audio_pipeline
|
self.audio_pipeline = audio_pipeline
|
||||||
self.spectrogram_pipeline = spectrogram_pipeline
|
self.spectrogram_pipeline = spectrogram_pipeline
|
||||||
self.samplerate = samplerate
|
|
||||||
self.max_freq = max_freq
|
self.max_freq = max_freq
|
||||||
self.min_freq = min_freq
|
self.min_freq = min_freq
|
||||||
|
|
||||||
|
self.input_samplerate = input_samplerate
|
||||||
|
self.output_samplerate = output_samplerate
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
wav = self.audio_pipeline(wav)
|
wav = self.audio_pipeline(wav)
|
||||||
return self.spectrogram_pipeline(wav)
|
return self.spectrogram_pipeline(wav)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_output_samplerate(config: PreprocessingConfig) -> float:
|
||||||
|
samplerate = config.audio.samplerate
|
||||||
|
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
|
||||||
|
factor = config.spectrogram.size.resize_factor
|
||||||
|
return samplerate * factor / hop_size
|
||||||
|
|
||||||
|
|
||||||
def build_preprocessor(
|
def build_preprocessor(
|
||||||
config: Optional[PreprocessingConfig] = None,
|
config: Optional[PreprocessingConfig] = None,
|
||||||
) -> PreprocessorProtocol:
|
) -> PreprocessorProtocol:
|
||||||
@ -148,16 +162,15 @@ def build_preprocessor(
|
|||||||
min_freq = config.spectrogram.frequencies.min_freq
|
min_freq = config.spectrogram.frequencies.min_freq
|
||||||
max_freq = config.spectrogram.frequencies.max_freq
|
max_freq = config.spectrogram.frequencies.max_freq
|
||||||
|
|
||||||
|
output_samplerate = compute_output_samplerate(config)
|
||||||
|
|
||||||
return StandardPreprocessor(
|
return StandardPreprocessor(
|
||||||
audio_pipeline=build_audio_pipeline(config.audio),
|
audio_pipeline=build_audio_pipeline(config.audio),
|
||||||
spectrogram_pipeline=build_spectrogram_pipeline(
|
spectrogram_pipeline=build_spectrogram_pipeline(
|
||||||
samplerate, config.spectrogram
|
samplerate, config.spectrogram
|
||||||
),
|
),
|
||||||
samplerate=samplerate,
|
input_samplerate=samplerate,
|
||||||
|
output_samplerate=output_samplerate,
|
||||||
min_freq=min_freq,
|
min_freq=min_freq,
|
||||||
max_freq=max_freq,
|
max_freq=max_freq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_default_preprocessor():
|
|
||||||
return build_preprocessor()
|
|
||||||
|
|||||||
@ -148,7 +148,7 @@ def add_echo(
|
|||||||
"""Add a synthetic echo to the audio waveform."""
|
"""Add a synthetic echo to the audio waveform."""
|
||||||
|
|
||||||
audio = example.audio
|
audio = example.audio
|
||||||
delay_steps = int(preprocessor.samplerate * delay)
|
delay_steps = int(preprocessor.input_samplerate * delay)
|
||||||
audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1])
|
audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1])
|
||||||
|
|
||||||
audio = audio + weight * audio_delay
|
audio = audio + weight * audio_delay
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.typing import ClipperProtocol
|
from batdetect2.typing import ClipperProtocol
|
||||||
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.typing.train import PreprocessedExample
|
from batdetect2.typing.train import PreprocessedExample
|
||||||
from batdetect2.utils.arrays import adjust_width
|
from batdetect2.utils.arrays import adjust_width
|
||||||
|
|
||||||
@ -18,24 +20,26 @@ class ClipingConfig(BaseConfig):
|
|||||||
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
||||||
|
|
||||||
|
|
||||||
class Clipper(ClipperProtocol):
|
class Clipper(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
samplerate: int,
|
preprocessor: PreprocessorProtocol,
|
||||||
duration: float = 0.5,
|
duration: float = 0.5,
|
||||||
max_empty: float = 0.2,
|
max_empty: float = 0.2,
|
||||||
random: bool = True,
|
random: bool = True,
|
||||||
):
|
):
|
||||||
self.samplerate = samplerate
|
super().__init__()
|
||||||
|
self.preprocessor = preprocessor
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
self.random = random
|
self.random = random
|
||||||
self.max_empty = max_empty
|
self.max_empty = max_empty
|
||||||
|
|
||||||
def extract_clip(
|
def forward(
|
||||||
self, example: PreprocessedExample
|
self,
|
||||||
|
example: PreprocessedExample,
|
||||||
) -> Tuple[PreprocessedExample, float, float]:
|
) -> Tuple[PreprocessedExample, float, float]:
|
||||||
start_time = 0
|
start_time = 0
|
||||||
duration = example.audio.shape[-1] / self.samplerate
|
duration = example.audio.shape[-1] / self.preprocessor.input_samplerate
|
||||||
|
|
||||||
if self.random:
|
if self.random:
|
||||||
start_time = np.random.uniform(
|
start_time = np.random.uniform(
|
||||||
@ -48,7 +52,8 @@ class Clipper(ClipperProtocol):
|
|||||||
example,
|
example,
|
||||||
start=start_time,
|
start=start_time,
|
||||||
duration=self.duration,
|
duration=self.duration,
|
||||||
samplerate=self.samplerate,
|
input_samplerate=self.preprocessor.input_samplerate,
|
||||||
|
output_samplerate=self.preprocessor.output_samplerate,
|
||||||
),
|
),
|
||||||
start_time,
|
start_time,
|
||||||
start_time + self.duration,
|
start_time + self.duration,
|
||||||
@ -56,7 +61,7 @@ class Clipper(ClipperProtocol):
|
|||||||
|
|
||||||
|
|
||||||
def build_clipper(
|
def build_clipper(
|
||||||
samplerate: int,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[ClipingConfig] = None,
|
config: Optional[ClipingConfig] = None,
|
||||||
random: Optional[bool] = None,
|
random: Optional[bool] = None,
|
||||||
) -> ClipperProtocol:
|
) -> ClipperProtocol:
|
||||||
@ -66,7 +71,7 @@ def build_clipper(
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
return Clipper(
|
return Clipper(
|
||||||
samplerate=samplerate,
|
preprocessor=preprocessor,
|
||||||
duration=config.duration,
|
duration=config.duration,
|
||||||
max_empty=config.max_empty,
|
max_empty=config.max_empty,
|
||||||
random=config.random if random else False,
|
random=config.random if random else False,
|
||||||
@ -77,11 +82,12 @@ def select_subclip(
|
|||||||
example: PreprocessedExample,
|
example: PreprocessedExample,
|
||||||
start: float,
|
start: float,
|
||||||
duration: float,
|
duration: float,
|
||||||
samplerate: float,
|
input_samplerate: float,
|
||||||
|
output_samplerate: float,
|
||||||
fill_value: float = 0,
|
fill_value: float = 0,
|
||||||
) -> PreprocessedExample:
|
) -> PreprocessedExample:
|
||||||
audio_width = int(np.floor(duration * samplerate))
|
audio_width = int(np.floor(duration * input_samplerate))
|
||||||
audio_start = int(np.floor(start * samplerate))
|
audio_start = int(np.floor(start * input_samplerate))
|
||||||
|
|
||||||
audio = adjust_width(
|
audio = adjust_width(
|
||||||
example.audio[audio_start : audio_start + audio_width],
|
example.audio[audio_start : audio_start + audio_width],
|
||||||
@ -89,12 +95,8 @@ def select_subclip(
|
|||||||
value=fill_value,
|
value=fill_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_duration = example.audio.shape[-1] / samplerate
|
spec_start = int(np.floor(start * output_samplerate))
|
||||||
spec_sr = example.spectrogram.shape[-1] / audio_duration
|
spec_width = int(np.floor(duration * output_samplerate))
|
||||||
|
|
||||||
spec_start = int(np.floor(start * spec_sr))
|
|
||||||
spec_width = int(np.floor(duration * spec_sr))
|
|
||||||
|
|
||||||
return PreprocessedExample(
|
return PreprocessedExample(
|
||||||
audio=audio,
|
audio=audio,
|
||||||
spectrogram=adjust_width(
|
spectrogram=adjust_width(
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class LabeledDataset(Dataset):
|
|||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
example = self.get_example(idx)
|
example = self.get_example(idx)
|
||||||
|
|
||||||
example, start_time, end_time = self.clipper.extract_clip(example)
|
example, start_time, end_time = self.clipper(example)
|
||||||
|
|
||||||
if self.augmentation:
|
if self.augmentation:
|
||||||
example = self.augmentation(example)
|
example = self.augmentation(example)
|
||||||
@ -64,9 +64,7 @@ class LabeledDataset(Dataset):
|
|||||||
def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
|
def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
|
||||||
idx = np.random.randint(0, len(self))
|
idx = np.random.randint(0, len(self))
|
||||||
dataset = self.get_example(idx)
|
dataset = self.get_example(idx)
|
||||||
|
dataset, start_time, end_time = self.clipper(dataset)
|
||||||
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
|
||||||
|
|
||||||
return dataset, start_time, end_time
|
return dataset, start_time, end_time
|
||||||
|
|
||||||
def get_example(self, idx) -> PreprocessedExample:
|
def get_example(self, idx) -> PreprocessedExample:
|
||||||
@ -107,5 +105,5 @@ class RandomExampleSource:
|
|||||||
index = int(np.random.randint(len(self.filenames)))
|
index = int(np.random.randint(len(self.filenames)))
|
||||||
filename = self.filenames[index]
|
filename = self.filenames[index]
|
||||||
example = load_preprocessed_example(filename)
|
example = load_preprocessed_example(filename)
|
||||||
example, _, _ = self.clipper.extract_clip(example)
|
example, _, _ = self.clipper(example)
|
||||||
return example
|
return example
|
||||||
|
|||||||
@ -229,7 +229,7 @@ def build_train_dataset(
|
|||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
clipper = build_clipper(
|
clipper = build_clipper(
|
||||||
samplerate=preprocessor.samplerate,
|
preprocessor=preprocessor,
|
||||||
config=config.cliping,
|
config=config.cliping,
|
||||||
random=True,
|
random=True,
|
||||||
)
|
)
|
||||||
@ -265,7 +265,7 @@ def build_val_dataset(
|
|||||||
logger.info("Building validation dataset...")
|
logger.info("Building validation dataset...")
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
clipper = build_clipper(
|
clipper = build_clipper(
|
||||||
samplerate=preprocessor.samplerate,
|
preprocessor=preprocessor,
|
||||||
config=config.cliping,
|
config=config.cliping,
|
||||||
random=train,
|
random=train,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from dataclasses import dataclass
|
|||||||
from typing import List, NamedTuple, Optional, Protocol
|
from typing import List, NamedTuple, Optional, Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.models import ModelOutput
|
from batdetect2.typing.models import ModelOutput
|
||||||
@ -77,6 +77,15 @@ class RawPrediction(NamedTuple):
|
|||||||
features: np.ndarray
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
class Detections(NamedTuple):
|
||||||
|
scores: torch.Tensor
|
||||||
|
sizes: torch.Tensor
|
||||||
|
class_scores: torch.Tensor
|
||||||
|
times: torch.Tensor
|
||||||
|
frequencies: torch.Tensor
|
||||||
|
features: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatDetect2Prediction:
|
class BatDetect2Prediction:
|
||||||
raw: RawPrediction
|
raw: RawPrediction
|
||||||
@ -84,154 +93,13 @@ class BatDetect2Prediction:
|
|||||||
|
|
||||||
|
|
||||||
class PostprocessorProtocol(Protocol):
|
class PostprocessorProtocol(Protocol):
|
||||||
"""Protocol defining the interface for the full postprocessing pipeline.
|
"""Protocol defining the interface for the full postprocessing pipeline."""
|
||||||
|
|
||||||
This protocol outlines the standard methods for an object that takes raw
|
def get_detections(
|
||||||
output from a BatDetect2 model and the corresponding input clip metadata,
|
|
||||||
and processes it through various stages (e.g., coordinate remapping, NMS,
|
|
||||||
detection extraction, data extraction, decoding) to produce interpretable
|
|
||||||
results at different levels of completion.
|
|
||||||
|
|
||||||
Implementations manage the configured logic for all postprocessing steps.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_feature_arrays(
|
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
clips: List[data.Clip],
|
clips: Optional[List[data.Clip]] = None,
|
||||||
) -> List[xr.DataArray]:
|
) -> List[Detections]: ...
|
||||||
"""Remap feature tensors to coordinate-aware DataArrays.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
The raw output from the neural network model for a batch, expected
|
|
||||||
to contain the necessary feature tensors.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
A list of `soundevent.data.Clip` objects, one for each item in the
|
|
||||||
processed batch. This list provides the timing, recording, and
|
|
||||||
other metadata context needed to calculate real-world coordinates
|
|
||||||
(seconds, Hz) for the output arrays. The length of this list must
|
|
||||||
correspond to the batch size of the `output`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[xr.DataArray]
|
|
||||||
A list of xarray DataArrays, one for each input clip in the batch,
|
|
||||||
in the same order. Each DataArray contains the feature vectors
|
|
||||||
with dimensions like ('feature', 'time', 'frequency') and
|
|
||||||
corresponding real-world coordinates.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_detection_arrays(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
) -> List[xr.DataArray]:
|
|
||||||
"""Remap detection tensors to coordinate-aware DataArrays.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
The raw output from the neural network model for a batch,
|
|
||||||
containing detection heatmaps.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
|
||||||
items, providing coordinate context. Must match the batch size of
|
|
||||||
`output`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[xr.DataArray]
|
|
||||||
A list of 2D xarray DataArrays (one per input clip, in order),
|
|
||||||
representing the detection heatmap with 'time' and 'frequency'
|
|
||||||
coordinates. Values typically indicate detection confidence.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_classification_arrays(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
) -> List[xr.DataArray]:
|
|
||||||
"""Remap classification tensors to coordinate-aware DataArrays.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
The raw output from the neural network model for a batch,
|
|
||||||
containing class probability tensors.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
|
||||||
items, providing coordinate context. Must match the batch size of
|
|
||||||
`output`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[xr.DataArray]
|
|
||||||
A list of 3D xarray DataArrays (one per input clip, in order),
|
|
||||||
representing class probabilities with 'category', 'time', and
|
|
||||||
'frequency' dimensions and coordinates.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_sizes_arrays(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
) -> List[xr.DataArray]:
|
|
||||||
"""Remap size prediction tensors to coordinate-aware DataArrays.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
The raw output from the neural network model for a batch,
|
|
||||||
containing predicted size tensors (e.g., width and height).
|
|
||||||
clips : List[data.Clip]
|
|
||||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
|
||||||
items, providing coordinate context. Must match the batch size of
|
|
||||||
`output`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[xr.DataArray]
|
|
||||||
A list of 3D xarray DataArrays (one per input clip, in order),
|
|
||||||
representing predicted sizes with 'dimension'
|
|
||||||
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
|
|
||||||
coordinates. Values represent estimated detection sizes.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_detection_datasets(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
) -> List[xr.Dataset]:
|
|
||||||
"""Perform remapping, NMS, detection, and data extraction for a batch.
|
|
||||||
|
|
||||||
Processes the raw model output for a batch to identify detection peaks
|
|
||||||
and extract all associated information (score, position, size, class
|
|
||||||
probs, features) at those peak locations, returning a structured
|
|
||||||
dataset for each input clip in the batch.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
The raw output from the neural network model for a batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
|
||||||
items, providing context. Must match the batch size of `output`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[xr.Dataset]
|
|
||||||
A list of xarray Datasets (one per input clip, in order). Each
|
|
||||||
Dataset contains multiple DataArrays ('scores', 'dimensions',
|
|
||||||
'classes', 'features') sharing a common 'detection' dimension,
|
|
||||||
providing aligned data for each detected event in that clip.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_raw_predictions(
|
def get_raw_predictions(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -148,7 +148,9 @@ class PreprocessorProtocol(Protocol):
|
|||||||
|
|
||||||
min_freq: float
|
min_freq: float
|
||||||
|
|
||||||
samplerate: int
|
input_samplerate: int
|
||||||
|
|
||||||
|
output_samplerate: float
|
||||||
|
|
||||||
audio_pipeline: AudioPipeline
|
audio_pipeline: AudioPipeline
|
||||||
|
|
||||||
|
|||||||
@ -96,6 +96,6 @@ class LossProtocol(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class ClipperProtocol(Protocol):
|
class ClipperProtocol(Protocol):
|
||||||
def extract_clip(
|
def __call__(
|
||||||
self, example: PreprocessedExample
|
self, example: PreprocessedExample
|
||||||
) -> Tuple[PreprocessedExample, float, float]: ...
|
) -> Tuple[PreprocessedExample, float, float]: ...
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from batdetect2.postprocess.decoding import (
|
|||||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
convert_raw_prediction_to_sound_event_prediction,
|
convert_raw_prediction_to_sound_event_prediction,
|
||||||
convert_raw_predictions_to_clip_prediction,
|
convert_raw_predictions_to_clip_prediction,
|
||||||
convert_xr_dataset_to_raw_prediction,
|
|
||||||
get_class_tags,
|
get_class_tags,
|
||||||
get_generic_tags,
|
get_generic_tags,
|
||||||
get_prediction_features,
|
get_prediction_features,
|
||||||
@ -278,63 +277,6 @@ def sample_raw_predictions() -> List[RawPrediction]:
|
|||||||
return [pred1, pred2, pred3]
|
return [pred1, pred2, pred3]
|
||||||
|
|
||||||
|
|
||||||
def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets):
|
|
||||||
"""Test basic conversion of a dataset to RawPrediction list."""
|
|
||||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
|
||||||
sample_detection_dataset,
|
|
||||||
dummy_targets.decode_roi,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(raw_predictions, list)
|
|
||||||
assert len(raw_predictions) == 2
|
|
||||||
|
|
||||||
pred1 = raw_predictions[0]
|
|
||||||
assert isinstance(pred1, RawPrediction)
|
|
||||||
assert pred1.detection_score == 0.9
|
|
||||||
|
|
||||||
assert pred1.geometry.coordinates == [
|
|
||||||
20 - 7 / 2,
|
|
||||||
300 - 16 / 2,
|
|
||||||
20 + 7 / 2,
|
|
||||||
300 + 16 / 2,
|
|
||||||
]
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
pred1.class_scores,
|
|
||||||
sample_detection_dataset["classes"].sel(detection=0),
|
|
||||||
)
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
pred1.features, sample_detection_dataset["features"].sel(detection=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
pred2 = raw_predictions[1]
|
|
||||||
assert isinstance(pred2, RawPrediction)
|
|
||||||
assert pred2.detection_score == 0.8
|
|
||||||
|
|
||||||
assert pred2.geometry.coordinates == [
|
|
||||||
10 - 3 / 2,
|
|
||||||
200 - 12 / 2,
|
|
||||||
10 + 3 / 2,
|
|
||||||
200 + 12 / 2,
|
|
||||||
]
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
pred2.class_scores,
|
|
||||||
sample_detection_dataset["classes"].sel(detection=1),
|
|
||||||
)
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
pred2.features, sample_detection_dataset["features"].sel(detection=1)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets):
|
|
||||||
"""Test conversion of an empty dataset."""
|
|
||||||
raw_predictions = convert_xr_dataset_to_raw_prediction(
|
|
||||||
empty_detection_dataset,
|
|
||||||
dummy_targets.decode_roi,
|
|
||||||
)
|
|
||||||
assert isinstance(raw_predictions, list)
|
|
||||||
assert len(raw_predictions) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_convert_raw_to_sound_event_basic(
|
def test_convert_raw_to_sound_event_basic(
|
||||||
sample_raw_predictions,
|
sample_raw_predictions,
|
||||||
sample_recording,
|
sample_recording,
|
||||||
|
|||||||
@ -1,214 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import xarray as xr
|
|
||||||
from soundevent.arrays import Dimensions
|
|
||||||
|
|
||||||
from batdetect2.postprocess.detection import extract_detections_from_array
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_data_array():
|
|
||||||
"""Provides a basic 3x3 DataArray.
|
|
||||||
Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30)
|
|
||||||
"""
|
|
||||||
array = xr.DataArray(
|
|
||||||
np.zeros([3, 3]),
|
|
||||||
coords={
|
|
||||||
Dimensions.frequency.value: [100, 200, 300],
|
|
||||||
Dimensions.time.value: [10, 20, 30],
|
|
||||||
},
|
|
||||||
dims=[
|
|
||||||
Dimensions.frequency.value,
|
|
||||||
Dimensions.time.value,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
array.loc[dict(time=10, frequency=100)] = 0.005
|
|
||||||
array.loc[dict(time=10, frequency=200)] = 0.5
|
|
||||||
array.loc[dict(time=10, frequency=300)] = 0.03
|
|
||||||
array.loc[dict(time=20, frequency=100)] = 0.8
|
|
||||||
array.loc[dict(time=20, frequency=200)] = 0.02
|
|
||||||
array.loc[dict(time=20, frequency=300)] = 0.6
|
|
||||||
array.loc[dict(time=30, frequency=100)] = 0.04
|
|
||||||
array.loc[dict(time=30, frequency=200)] = 0.9
|
|
||||||
array.loc[dict(time=30, frequency=300)] = 0.7
|
|
||||||
return array
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def data_array_with_nans(sample_data_array: xr.DataArray):
|
|
||||||
"""Provides a 2D DataArray containing NaN values."""
|
|
||||||
array = sample_data_array.copy()
|
|
||||||
array.loc[dict(time=10, frequency=300)] = np.nan
|
|
||||||
array.loc[dict(time=30, frequency=100)] = np.nan
|
|
||||||
return array
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_extraction(sample_data_array: xr.DataArray):
|
|
||||||
threshold = 0.1
|
|
||||||
max_detections = 3
|
|
||||||
|
|
||||||
actual_result = extract_detections_from_array(
|
|
||||||
sample_data_array,
|
|
||||||
threshold=threshold,
|
|
||||||
max_detections=max_detections,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_values = np.array([0.9, 0.8, 0.7])
|
|
||||||
expected_times = np.array([30, 20, 30])
|
|
||||||
expected_freqs = np.array([200, 100, 300])
|
|
||||||
expected_coords = {
|
|
||||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
|
||||||
Dimensions.time.value: ("detection", expected_times),
|
|
||||||
}
|
|
||||||
expected_result = xr.DataArray(
|
|
||||||
expected_values,
|
|
||||||
coords=expected_coords,
|
|
||||||
dims="detection",
|
|
||||||
name="score",
|
|
||||||
)
|
|
||||||
|
|
||||||
xr.testing.assert_equal(actual_result, expected_result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_threshold_only(sample_data_array):
|
|
||||||
input_array = sample_data_array
|
|
||||||
threshold = 0.5
|
|
||||||
actual_result = extract_detections_from_array(
|
|
||||||
input_array, threshold=threshold
|
|
||||||
)
|
|
||||||
expected_values = np.array([0.9, 0.8, 0.7, 0.6])
|
|
||||||
expected_times = np.array([30, 20, 30, 20])
|
|
||||||
expected_freqs = np.array([200, 100, 300, 300])
|
|
||||||
expected_coords = {
|
|
||||||
Dimensions.time.value: ("detection", expected_times),
|
|
||||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
|
||||||
}
|
|
||||||
expected_result = xr.DataArray(
|
|
||||||
expected_values,
|
|
||||||
coords=expected_coords,
|
|
||||||
dims="detection",
|
|
||||||
name="detection_value",
|
|
||||||
)
|
|
||||||
xr.testing.assert_equal(actual_result, expected_result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_max_detections_only(sample_data_array):
|
|
||||||
input_array = sample_data_array
|
|
||||||
max_detections = 4
|
|
||||||
actual_result = extract_detections_from_array(
|
|
||||||
input_array, max_detections=max_detections
|
|
||||||
)
|
|
||||||
expected_values = np.array([0.9, 0.8, 0.7, 0.6])
|
|
||||||
expected_times = np.array([30, 20, 30, 20])
|
|
||||||
expected_freqs = np.array([200, 100, 300, 300])
|
|
||||||
expected_coords = {
|
|
||||||
Dimensions.time.value: ("detection", expected_times),
|
|
||||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
|
||||||
}
|
|
||||||
expected_result = xr.DataArray(
|
|
||||||
expected_values,
|
|
||||||
coords=expected_coords,
|
|
||||||
dims="detection",
|
|
||||||
name="detection_value",
|
|
||||||
)
|
|
||||||
xr.testing.assert_equal(actual_result, expected_result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_optional_args(sample_data_array):
|
|
||||||
input_array = sample_data_array
|
|
||||||
actual_result = extract_detections_from_array(input_array)
|
|
||||||
expected_values = np.array([0.9, 0.8, 0.7, 0.6, 0.5, 0.04, 0.03, 0.02])
|
|
||||||
expected_times = np.array([30, 20, 30, 20, 10, 30, 10, 20])
|
|
||||||
expected_freqs = np.array([200, 100, 300, 300, 200, 100, 300, 200])
|
|
||||||
expected_coords = {
|
|
||||||
Dimensions.time.value: ("detection", expected_times),
|
|
||||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
|
||||||
}
|
|
||||||
expected_result = xr.DataArray(
|
|
||||||
expected_values,
|
|
||||||
coords=expected_coords,
|
|
||||||
dims="detection",
|
|
||||||
name="detection_value",
|
|
||||||
)
|
|
||||||
xr.testing.assert_equal(actual_result, expected_result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_values_above_threshold(sample_data_array):
|
|
||||||
input_array = sample_data_array
|
|
||||||
threshold = 1.0
|
|
||||||
actual_result = extract_detections_from_array(
|
|
||||||
input_array, threshold=threshold
|
|
||||||
)
|
|
||||||
expected_coords = {
|
|
||||||
Dimensions.time.value: ("detection", np.array([], dtype=np.int64)),
|
|
||||||
Dimensions.frequency.value: (
|
|
||||||
"detection",
|
|
||||||
np.array([], dtype=np.int64),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
expected_result = xr.DataArray(
|
|
||||||
np.array([], dtype=np.float64),
|
|
||||||
coords=expected_coords,
|
|
||||||
dims="detection",
|
|
||||||
name="detection_value",
|
|
||||||
)
|
|
||||||
xr.testing.assert_equal(actual_result, expected_result)
|
|
||||||
assert actual_result.sizes["detection"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_max_detections_zero(sample_data_array):
|
|
||||||
input_array = sample_data_array
|
|
||||||
max_detections = 0
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
extract_detections_from_array(
|
|
||||||
input_array,
|
|
||||||
max_detections=max_detections,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_input_array():
|
|
||||||
empty_array = xr.DataArray(
|
|
||||||
np.empty((0, 0)),
|
|
||||||
coords={Dimensions.time.value: [], Dimensions.frequency.value: []},
|
|
||||||
dims=[Dimensions.time.value, Dimensions.frequency.value],
|
|
||||||
)
|
|
||||||
actual_result = extract_detections_from_array(empty_array)
|
|
||||||
expected_coords = {
|
|
||||||
Dimensions.time.value: ("detection", np.array([], dtype=np.int64)),
|
|
||||||
Dimensions.frequency.value: (
|
|
||||||
"detection",
|
|
||||||
np.array([], dtype=np.int64),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
expected_result = xr.DataArray(
|
|
||||||
np.array([], dtype=np.float64),
|
|
||||||
coords=expected_coords,
|
|
||||||
dims="detection",
|
|
||||||
name="detection_value",
|
|
||||||
)
|
|
||||||
xr.testing.assert_equal(actual_result, expected_result)
|
|
||||||
assert actual_result.sizes["detection"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_nan_handling(data_array_with_nans):
|
|
||||||
input_array = data_array_with_nans
|
|
||||||
threshold = 0.1
|
|
||||||
max_detections = 3
|
|
||||||
actual_result = extract_detections_from_array(
|
|
||||||
input_array, threshold=threshold, max_detections=max_detections
|
|
||||||
)
|
|
||||||
expected_values = np.array([0.9, 0.8, 0.7])
|
|
||||||
expected_times = np.array([30, 20, 30])
|
|
||||||
expected_freqs = np.array([200, 100, 300])
|
|
||||||
expected_coords = {
|
|
||||||
Dimensions.time.value: ("detection", expected_times),
|
|
||||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
|
||||||
}
|
|
||||||
expected_result = xr.DataArray(
|
|
||||||
expected_values,
|
|
||||||
coords=expected_coords,
|
|
||||||
dims="detection",
|
|
||||||
name="detection_value",
|
|
||||||
)
|
|
||||||
xr.testing.assert_equal(actual_result, expected_result)
|
|
||||||
@ -1,397 +1,2 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import xarray as xr
|
|
||||||
from soundevent.arrays import Dimensions
|
|
||||||
|
|
||||||
from batdetect2.postprocess.detection import extract_detections_from_array
|
|
||||||
from batdetect2.postprocess.extraction import (
|
|
||||||
extract_detection_xr_dataset,
|
|
||||||
extract_values_at_positions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_data_array():
|
|
||||||
"""Provides a basic 3x3 DataArray.
|
|
||||||
Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30)
|
|
||||||
"""
|
|
||||||
coords = {
|
|
||||||
Dimensions.frequency.value: [100, 200, 300],
|
|
||||||
Dimensions.time.value: [10, 20, 30],
|
|
||||||
}
|
|
||||||
array = xr.DataArray(
|
|
||||||
np.zeros([3, 3]),
|
|
||||||
coords=coords,
|
|
||||||
dims=[
|
|
||||||
Dimensions.frequency.value,
|
|
||||||
Dimensions.time.value,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
array.loc[dict(time=10, frequency=100)] = 0.005
|
|
||||||
array.loc[dict(time=10, frequency=200)] = 0.5
|
|
||||||
array.loc[dict(time=10, frequency=300)] = 0.03
|
|
||||||
array.loc[dict(time=20, frequency=100)] = 0.8
|
|
||||||
array.loc[dict(time=20, frequency=200)] = 0.02
|
|
||||||
array.loc[dict(time=20, frequency=300)] = 0.6
|
|
||||||
array.loc[dict(time=30, frequency=100)] = 0.04
|
|
||||||
array.loc[dict(time=30, frequency=200)] = 0.9
|
|
||||||
array.loc[dict(time=30, frequency=300)] = 0.7
|
|
||||||
return array
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_array_for_extraction():
|
|
||||||
"""Provides a simple array (1-9) for value extraction tests."""
|
|
||||||
data = np.arange(1, 10).reshape(3, 3)
|
|
||||||
coords = {
|
|
||||||
Dimensions.frequency.value: [100, 200, 300],
|
|
||||||
Dimensions.time.value: [10, 20, 30],
|
|
||||||
}
|
|
||||||
return xr.DataArray(
|
|
||||||
data,
|
|
||||||
coords=coords,
|
|
||||||
dims=[
|
|
||||||
Dimensions.frequency.value,
|
|
||||||
Dimensions.time.value,
|
|
||||||
],
|
|
||||||
name="test_values",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_positions_top3(sample_data_array):
|
|
||||||
"""Get top 3 detection positions from sample_data_array."""
|
|
||||||
|
|
||||||
return extract_detections_from_array(
|
|
||||||
sample_data_array,
|
|
||||||
max_detections=3,
|
|
||||||
threshold=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_positions_top2(sample_data_array):
|
|
||||||
"""Get top 2 detection positions from sample_data_array."""
|
|
||||||
return extract_detections_from_array(
|
|
||||||
sample_data_array,
|
|
||||||
max_detections=2,
|
|
||||||
threshold=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def empty_positions(sample_data_array):
|
|
||||||
"""Get an empty positions array (high threshold)."""
|
|
||||||
return extract_detections_from_array(
|
|
||||||
sample_data_array,
|
|
||||||
threshold=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_sizes_array(sample_data_array):
|
|
||||||
"""Provides a sample sizes array matching sample_data_array coords."""
|
|
||||||
coords = sample_data_array.coords
|
|
||||||
data = np.array(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[0, 1, 2],
|
|
||||||
[3, 4, 5],
|
|
||||||
[6, 7, 8],
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[9, 10, 11],
|
|
||||||
[12, 13, 14],
|
|
||||||
[15, 16, 17],
|
|
||||||
],
|
|
||||||
],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
data,
|
|
||||||
coords={
|
|
||||||
"dimension": ["width", "height"],
|
|
||||||
Dimensions.frequency.value: coords[Dimensions.frequency.value],
|
|
||||||
Dimensions.time.value: coords[Dimensions.time.value],
|
|
||||||
},
|
|
||||||
dims=["dimension", Dimensions.frequency.value, Dimensions.time.value],
|
|
||||||
name="sizes",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_classes_array(sample_data_array):
|
|
||||||
"""Provides a sample classes array matching sample_data_array coords."""
|
|
||||||
coords = sample_data_array.coords
|
|
||||||
data = np.linspace(0.1, 0.9, 18, dtype=np.float32).reshape(2, 3, 3)
|
|
||||||
return xr.DataArray(
|
|
||||||
data,
|
|
||||||
coords={
|
|
||||||
"category": ["bat", "noise"],
|
|
||||||
Dimensions.frequency.value: coords[Dimensions.frequency.value],
|
|
||||||
Dimensions.time.value: coords[Dimensions.time.value],
|
|
||||||
},
|
|
||||||
dims=["category", Dimensions.frequency.value, Dimensions.time.value],
|
|
||||||
name="class_scores",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_features_array(sample_data_array):
|
|
||||||
"""Provides a sample features array matching sample_data_array coords."""
|
|
||||||
coords = sample_data_array.coords
|
|
||||||
data = np.arange(0, 36, dtype=np.float32).reshape(4, 3, 3)
|
|
||||||
return xr.DataArray(
|
|
||||||
data,
|
|
||||||
coords={
|
|
||||||
"feature": ["f0", "f1", "f2", "f3"],
|
|
||||||
Dimensions.frequency.value: coords[Dimensions.frequency.value],
|
|
||||||
Dimensions.time.value: coords[Dimensions.time.value],
|
|
||||||
},
|
|
||||||
dims=["feature", Dimensions.frequency.value, Dimensions.time.value],
|
|
||||||
name="features",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_values_at_positions_correct(
|
|
||||||
sample_array_for_extraction,
|
|
||||||
sample_positions_top3,
|
|
||||||
):
|
|
||||||
"""Verify correct values are extracted based on positions coords."""
|
|
||||||
expected_values = np.array(
|
|
||||||
[
|
|
||||||
sample_array_for_extraction.sel(time=30, frequency=200).values,
|
|
||||||
sample_array_for_extraction.sel(time=20, frequency=100).values,
|
|
||||||
sample_array_for_extraction.sel(time=30, frequency=300).values,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
expected = xr.DataArray(
|
|
||||||
expected_values,
|
|
||||||
coords=sample_positions_top3.coords,
|
|
||||||
dims="detection",
|
|
||||||
name="test_values",
|
|
||||||
)
|
|
||||||
|
|
||||||
extracted = extract_values_at_positions(
|
|
||||||
sample_array_for_extraction, sample_positions_top3
|
|
||||||
)
|
|
||||||
|
|
||||||
xr.testing.assert_allclose(extracted, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_values_at_positions_extra_dims(
|
|
||||||
sample_sizes_array,
|
|
||||||
sample_positions_top2,
|
|
||||||
):
|
|
||||||
"""Test extraction preserves other dimensions in the source array."""
|
|
||||||
times = np.array([30, 20])
|
|
||||||
freqs = np.array([200, 100])
|
|
||||||
|
|
||||||
expected_values = np.array(
|
|
||||||
[
|
|
||||||
sample_sizes_array.sel(time=30, frequency=200).values,
|
|
||||||
sample_sizes_array.sel(time=20, frequency=100).values,
|
|
||||||
],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected = xr.DataArray(
|
|
||||||
expected_values,
|
|
||||||
coords={
|
|
||||||
"dimension": ["width", "height"],
|
|
||||||
Dimensions.frequency.value: ("detection", freqs),
|
|
||||||
Dimensions.time.value: ("detection", times),
|
|
||||||
},
|
|
||||||
dims=["detection", "dimension"],
|
|
||||||
name="sizes",
|
|
||||||
)
|
|
||||||
|
|
||||||
extracted = extract_values_at_positions(
|
|
||||||
sample_sizes_array,
|
|
||||||
sample_positions_top2,
|
|
||||||
)
|
|
||||||
|
|
||||||
xr.testing.assert_equal(extracted, expected)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_values_at_positions_empty(
|
|
||||||
sample_array_for_extraction, empty_positions
|
|
||||||
):
|
|
||||||
"""Test extraction with empty positions returns empty array."""
|
|
||||||
extracted = extract_values_at_positions(
|
|
||||||
sample_array_for_extraction, empty_positions
|
|
||||||
)
|
|
||||||
assert extracted.sizes["detection"] == 0
|
|
||||||
assert Dimensions.time.value in extracted.coords
|
|
||||||
assert Dimensions.frequency.value in extracted.coords
|
|
||||||
assert extracted.coords[Dimensions.time.value].size == 0
|
|
||||||
assert extracted.coords[Dimensions.frequency.value].size == 0
|
|
||||||
assert extracted.name == sample_array_for_extraction.name
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_values_at_positions_missing_coord_in_array(
|
|
||||||
sample_array_for_extraction, sample_positions_top2
|
|
||||||
):
|
|
||||||
"""Test error if source array misses required coordinates."""
|
|
||||||
array_no_time = sample_array_for_extraction.copy()
|
|
||||||
del array_no_time.coords[Dimensions.time.value]
|
|
||||||
with pytest.raises(IndexError):
|
|
||||||
extract_values_at_positions(array_no_time, sample_positions_top2)
|
|
||||||
|
|
||||||
array_no_freq = sample_array_for_extraction.copy()
|
|
||||||
del array_no_freq.coords[Dimensions.frequency.value]
|
|
||||||
with pytest.raises(IndexError):
|
|
||||||
extract_values_at_positions(array_no_freq, sample_positions_top2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_values_at_positions_missing_coord_in_positions(
|
|
||||||
sample_array_for_extraction, sample_positions_top2
|
|
||||||
):
|
|
||||||
"""Test error if positions array misses required coordinates."""
|
|
||||||
positions_no_time = sample_positions_top2.copy()
|
|
||||||
del positions_no_time.coords[Dimensions.time.value]
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
extract_values_at_positions(
|
|
||||||
sample_array_for_extraction, positions_no_time
|
|
||||||
)
|
|
||||||
|
|
||||||
positions_no_freq = sample_positions_top2.copy()
|
|
||||||
del positions_no_freq.coords[Dimensions.frequency.value]
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
extract_values_at_positions(
|
|
||||||
sample_array_for_extraction, positions_no_freq
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_values_at_positions_mismatched_coords(
|
|
||||||
sample_array_for_extraction, sample_positions_top2
|
|
||||||
):
|
|
||||||
"""Test error if positions requests coords not in source array."""
|
|
||||||
bad_positions = sample_positions_top2.copy()
|
|
||||||
bad_positions.coords[Dimensions.time.value] = (
|
|
||||||
"detection",
|
|
||||||
np.array([40, 10]),
|
|
||||||
)
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
extract_values_at_positions(sample_array_for_extraction, bad_positions)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_detection_xr_dataset_correct(
|
|
||||||
sample_positions_top2,
|
|
||||||
sample_sizes_array,
|
|
||||||
sample_classes_array,
|
|
||||||
sample_features_array,
|
|
||||||
):
|
|
||||||
"""Tests extracting and bundling info for top 2 detections."""
|
|
||||||
actual_dataset = extract_detection_xr_dataset(
|
|
||||||
sample_positions_top2,
|
|
||||||
sample_sizes_array,
|
|
||||||
sample_classes_array,
|
|
||||||
sample_features_array,
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_times = np.array([30, 20])
|
|
||||||
expected_freqs = np.array([200, 100])
|
|
||||||
detection_coords = {
|
|
||||||
Dimensions.time.value: ("detection", expected_times),
|
|
||||||
Dimensions.frequency.value: ("detection", expected_freqs),
|
|
||||||
}
|
|
||||||
|
|
||||||
expected_score = sample_positions_top2
|
|
||||||
|
|
||||||
expected_dimensions_data = np.array(
|
|
||||||
[
|
|
||||||
sample_sizes_array.sel(time=30, frequency=200).values,
|
|
||||||
sample_sizes_array.sel(time=20, frequency=100).values,
|
|
||||||
],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
expected_dimensions = xr.DataArray(
|
|
||||||
expected_dimensions_data,
|
|
||||||
coords={**detection_coords, "dimension": ["width", "height"]},
|
|
||||||
dims=["detection", "dimension"],
|
|
||||||
name="dimensions",
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_classes_data = np.array(
|
|
||||||
[
|
|
||||||
sample_classes_array.sel(time=30, frequency=200).values,
|
|
||||||
sample_classes_array.sel(time=20, frequency=100).values,
|
|
||||||
],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
expected_classes = xr.DataArray(
|
|
||||||
expected_classes_data,
|
|
||||||
coords={**detection_coords, "category": ["bat", "noise"]},
|
|
||||||
dims=["detection", "category"],
|
|
||||||
name="classes",
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_features_data = np.array(
|
|
||||||
[
|
|
||||||
sample_features_array.sel(time=30, frequency=200).values,
|
|
||||||
sample_features_array.sel(time=20, frequency=100).values,
|
|
||||||
],
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
expected_features = xr.DataArray(
|
|
||||||
expected_features_data,
|
|
||||||
coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]},
|
|
||||||
dims=["detection", "feature"],
|
|
||||||
name="features",
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_dataset = xr.Dataset(
|
|
||||||
{
|
|
||||||
"scores": expected_score,
|
|
||||||
"dimensions": expected_dimensions,
|
|
||||||
"classes": expected_classes,
|
|
||||||
"features": expected_features,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
expected_dataset = expected_dataset.assign_coords(detection_coords)
|
|
||||||
|
|
||||||
xr.testing.assert_allclose(actual_dataset, expected_dataset)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_detection_xr_dataset_empty(
|
|
||||||
empty_positions,
|
|
||||||
sample_sizes_array,
|
|
||||||
sample_classes_array,
|
|
||||||
sample_features_array,
|
|
||||||
):
|
|
||||||
"""Test extraction with empty positions yields an empty dataset."""
|
|
||||||
actual_dataset = extract_detection_xr_dataset(
|
|
||||||
empty_positions,
|
|
||||||
sample_sizes_array,
|
|
||||||
sample_classes_array,
|
|
||||||
sample_features_array,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(actual_dataset, xr.Dataset)
|
|
||||||
assert "detection" in actual_dataset.dims
|
|
||||||
assert actual_dataset.sizes["detection"] == 0
|
|
||||||
|
|
||||||
assert "scores" in actual_dataset
|
|
||||||
assert actual_dataset["scores"].dims == ("detection",)
|
|
||||||
assert actual_dataset["scores"].size == 0
|
|
||||||
|
|
||||||
assert "dimensions" in actual_dataset
|
|
||||||
assert actual_dataset["dimensions"].dims == ("detection", "dimension")
|
|
||||||
assert actual_dataset["dimensions"].shape == (0, 2)
|
|
||||||
|
|
||||||
assert "classes" in actual_dataset
|
|
||||||
assert actual_dataset["classes"].dims == ("detection", "category")
|
|
||||||
assert actual_dataset["classes"].shape == (0, 2)
|
|
||||||
|
|
||||||
assert "features" in actual_dataset
|
|
||||||
assert actual_dataset["features"].dims == ("detection", "feature")
|
|
||||||
assert actual_dataset["features"].shape == (0, 4)
|
|
||||||
|
|
||||||
assert Dimensions.time.value in actual_dataset.coords
|
|
||||||
assert Dimensions.frequency.value in actual_dataset.coords
|
|
||||||
assert actual_dataset.coords[Dimensions.time.value].size == 0
|
|
||||||
assert actual_dataset.coords[Dimensions.frequency.value].size == 0
|
|
||||||
|
|||||||
@ -162,7 +162,8 @@ def test_selected_random_subclip_has_the_correct_width(
|
|||||||
|
|
||||||
subclip = select_subclip(
|
subclip = select_subclip(
|
||||||
original,
|
original,
|
||||||
samplerate=256_000,
|
input_samplerate=256_000,
|
||||||
|
output_samplerate=1000,
|
||||||
start=0,
|
start=0,
|
||||||
duration=0.512,
|
duration=0.512,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -39,9 +39,8 @@ def build_from_config(
|
|||||||
)
|
)
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
targets,
|
targets,
|
||||||
|
preprocessor=preprocessor,
|
||||||
config=postprocessing_config,
|
config=postprocessing_config,
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return targets, preprocessor, labeller, postprocessor
|
return targets, preprocessor, labeller, postprocessor
|
||||||
@ -84,7 +83,10 @@ def test_encoding_decoding_roundtrip_recovers_object(
|
|||||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||||
|
|
||||||
encoded = generate_train_example(
|
encoded = generate_train_example(
|
||||||
clip_annotation, sample_audio_loader, preprocessor, labeller
|
clip_annotation,
|
||||||
|
sample_audio_loader,
|
||||||
|
preprocessor,
|
||||||
|
labeller,
|
||||||
)
|
)
|
||||||
predictions = postprocessor.get_predictions(
|
predictions = postprocessor.get_predictions(
|
||||||
ModelOutput(
|
ModelOutput(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user