From 281c4dcb8aa17a48ba9f4bd6ac8d5339ffa4c35f Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 25 Aug 2025 22:46:21 +0100 Subject: [PATCH] Remove xr from postprocess --- src/batdetect2/models/__init__.py | 3 +- src/batdetect2/plotting/clips.py | 4 +- src/batdetect2/plotting/matches.py | 7 +- src/batdetect2/postprocess/__init__.py | 405 +++---------------- src/batdetect2/postprocess/decoding.py | 233 ++--------- src/batdetect2/postprocess/detection.py | 162 -------- src/batdetect2/postprocess/extraction.py | 163 +++----- src/batdetect2/postprocess/remapping.py | 21 + src/batdetect2/preprocess/__init__.py | 29 +- src/batdetect2/train/augmentations.py | 2 +- src/batdetect2/train/clips.py | 38 +- src/batdetect2/train/dataset.py | 8 +- src/batdetect2/train/train.py | 4 +- src/batdetect2/typing/postprocess.py | 160 +------- src/batdetect2/typing/preprocess.py | 4 +- src/batdetect2/typing/train.py | 2 +- tests/test_postprocessing/test_decoding.py | 58 --- tests/test_postprocessing/test_detection.py | 214 ---------- tests/test_postprocessing/test_extraction.py | 395 ------------------ tests/test_train/test_augmentations.py | 3 +- tests/test_train/test_preprocessing.py | 8 +- 21 files changed, 231 insertions(+), 1692 deletions(-) delete mode 100644 src/batdetect2/postprocess/detection.py delete mode 100644 tests/test_postprocessing/test_detection.py diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 471fd0b..47d6328 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -140,9 +140,8 @@ def build_model(config: Optional[ModelConfig] = None): preprocessor = build_preprocessor(config=config.preprocess) postprocessor = build_postprocessor( targets=targets, + preprocessor=preprocessor, config=config.postprocess, - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, ) detector = build_detector( num_classes=len(targets.class_names), diff --git a/src/batdetect2/plotting/clips.py b/src/batdetect2/plotting/clips.py index 28a806f..11cc670 100644 --- a/src/batdetect2/plotting/clips.py +++ b/src/batdetect2/plotting/clips.py @@ -6,7 +6,7 @@ from matplotlib.axes import Axes from soundevent import data from batdetect2.plotting.common import plot_spectrogram -from batdetect2.preprocess import build_audio_loader, get_default_preprocessor +from batdetect2.preprocess import build_audio_loader, build_preprocessor from batdetect2.typing import AudioLoader, PreprocessorProtocol __all__ = [ @@ -27,7 +27,7 @@ def plot_clip( _, ax = plt.subplots(figsize=figsize) if preprocessor is None: - preprocessor = get_default_preprocessor() + preprocessor = build_preprocessor() if audio_loader is None: audio_loader = build_audio_loader() diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index 63561ce..58216fd 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -8,10 +8,7 @@ from soundevent.plot.tags import TagColorMapper from batdetect2.plotting.clip_predictions import plot_prediction from batdetect2.plotting.clips import plot_clip -from batdetect2.preprocess import ( - PreprocessorProtocol, - get_default_preprocessor, -) +from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor from batdetect2.typing.evaluate import MatchEvaluation __all__ = [ @@ -50,7 +47,7 @@ def plot_matches( prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, ) -> Axes: if preprocessor is None: - preprocessor = get_default_preprocessor() + preprocessor = build_preprocessor() ax = plot_clip( clip, diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index e724e2b..07f0551 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -1,36 +1,7 @@ -"""Main entry point for the BatDetect2 Postprocessing pipeline. - -This package (`batdetect2.postprocess`) takes the raw outputs from a trained -BatDetect2 neural network model and transforms them into meaningful, structured -predictions, typically in the form of `soundevent.data.ClipPrediction` objects -containing detected sound events with associated class tags and geometry. - -The pipeline involves several configurable steps, implemented in submodules: -1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks. -2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw - model output arrays. -3. Detection Extraction (`.detection`): Identifies candidate detection points - (location and score) based on thresholds and score ranking (top-k). -4. Data Extraction (`.extraction`): Gathers associated model outputs (size, - class probabilities, features) at the detected locations. -5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and - class predictions into interpretable `soundevent` objects, including - recovering geometry (ROIs) and decoding class names back to standard tags. - -This module provides the primary interface: -- `PostprocessConfig`: A configuration object for postprocessing parameters - (thresholds, NMS kernel size, etc.). -- `load_postprocess_config`: Function to load the configuration from a file. -- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that - holds the configured pipeline logic. -- `build_postprocessor`: A factory function to create a `Postprocessor` - instance, linking it to the necessary target definitions (`TargetProtocol`). -It also re-exports key components from submodules for convenience. -""" +"""Main entry point for the BatDetect2 Postprocessing pipeline.""" from typing import List, Optional -import xarray as xr from loguru import logger from pydantic import Field from soundevent import data @@ -38,37 +9,24 @@ from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.postprocess.decoding import ( DEFAULT_CLASSIFICATION_THRESHOLD, + convert_detections_to_raw_predictions, convert_raw_prediction_to_sound_event_prediction, convert_raw_predictions_to_clip_prediction, - convert_xr_dataset_to_raw_prediction, -) -from batdetect2.postprocess.detection import ( - DEFAULT_DETECTION_THRESHOLD, - TOP_K_PER_SEC, - extract_detections_from_array, - get_max_detections, -) -from batdetect2.postprocess.extraction import ( - extract_detection_xr_dataset, ) +from batdetect2.postprocess.extraction import extract_prediction_tensor from batdetect2.postprocess.nms import ( NMS_KERNEL_SIZE, non_max_suppression, ) -from batdetect2.postprocess.remapping import ( - classification_to_xarray, - detection_to_xarray, - features_to_xarray, - sizes_to_xarray, -) +from batdetect2.postprocess.remapping import map_detection_to_clip from batdetect2.preprocess import MAX_FREQ, MIN_FREQ -from batdetect2.typing.models import ModelOutput +from batdetect2.typing import ModelOutput, PreprocessorProtocol, TargetProtocol from batdetect2.typing.postprocess import ( BatDetect2Prediction, + Detections, PostprocessorProtocol, RawPrediction, ) -from batdetect2.typing.targets import TargetProtocol __all__ = [ "DEFAULT_CLASSIFICATION_THRESHOLD", @@ -81,19 +39,17 @@ __all__ = [ "Postprocessor", "TOP_K_PER_SEC", "build_postprocessor", - "classification_to_xarray", "convert_raw_predictions_to_clip_prediction", - "convert_xr_dataset_to_raw_prediction", - "detection_to_xarray", - "extract_detection_xr_dataset", - "extract_detections_from_array", - "features_to_xarray", - "get_max_detections", + "convert_detections_to_raw_predictions", "load_postprocess_config", "non_max_suppression", - "sizes_to_xarray", ] +DEFAULT_DETECTION_THRESHOLD = 0.01 + + +TOP_K_PER_SEC = 200 + class PostprocessConfig(BaseConfig): """Configuration settings for the postprocessing pipeline. @@ -173,40 +129,10 @@ def load_postprocess_config( def build_postprocessor( targets: TargetProtocol, + preprocessor: PreprocessorProtocol, config: Optional[PostprocessConfig] = None, - max_freq: float = MAX_FREQ, - min_freq: float = MIN_FREQ, ) -> PostprocessorProtocol: - """Factory function to build the standard postprocessor. - - Creates and initializes the `Postprocessor` instance, providing it with the - necessary `targets` object and the `PostprocessConfig`. - - Parameters - ---------- - targets : TargetProtocol - An initialized object conforming to the `TargetProtocol`, providing - methods like `.decode()` and `.recover_roi()`, and attributes like - `.class_names` and `.generic_class_tags`. This links postprocessing - to the defined target semantics and geometry mappings. - config : PostprocessConfig, optional - Configuration object specifying postprocessing parameters (thresholds, - NMS kernel size, etc.). If None, default settings defined in - `PostprocessConfig` will be used. - min_freq : int, default=MIN_FREQ - The minimum frequency (Hz) corresponding to the frequency axis of the - model outputs. Required for coordinate remapping. Consider setting via - `PostprocessConfig` instead for better encapsulation. - max_freq : int, default=MAX_FREQ - The maximum frequency (Hz) corresponding to the frequency axis of the - model outputs. Required for coordinate remapping. Consider setting via - `PostprocessConfig`. - - Returns - ------- - PostprocessorProtocol - An initialized `Postprocessor` instance ready to process model outputs. - """ + """Factory function to build the standard postprocessor.""" config = config or PostprocessConfig() logger.opt(lazy=True).debug( "Building postprocessor with config: \n{}", @@ -214,303 +140,62 @@ def build_postprocessor( ) return Postprocessor( targets=targets, + preprocessor=preprocessor, config=config, - min_freq=min_freq, - max_freq=max_freq, ) class Postprocessor(PostprocessorProtocol): - """Standard implementation of the postprocessing pipeline. - - This class orchestrates the steps required to convert raw model outputs - into interpretable `soundevent` predictions. It uses configured parameters - and leverages functions from the `batdetect2.postprocess` submodules for - each stage (NMS, remapping, detection, extraction, decoding). - - It requires a `TargetProtocol` object during initialization to access - necessary decoding information (class name to tag mapping, - ROI recovery logic) ensuring consistency with the target definitions used - during training or specified for inference. - - Instances are typically created using the `build_postprocessor` factory - function. - - Attributes - ---------- - targets : TargetProtocol - The configured target definition object providing decoding and ROI - recovery. - config : PostprocessConfig - Configuration object holding parameters for NMS, thresholds, etc. - min_freq : float - Minimum frequency (Hz) assumed for the model output's frequency axis. - max_freq : float - Maximum frequency (Hz) assumed for the model output's frequency axis. - """ + """Standard implementation of the postprocessing pipeline.""" targets: TargetProtocol + preprocessor: PreprocessorProtocol + def __init__( self, targets: TargetProtocol, + preprocessor: PreprocessorProtocol, config: PostprocessConfig, - min_freq: float = MIN_FREQ, - max_freq: float = MAX_FREQ, ): - """Initialize the Postprocessor. - - Parameters - ---------- - targets : TargetProtocol - Initialized target definition object. - config : PostprocessConfig - Configuration for postprocessing parameters. - min_freq : int, default=MIN_FREQ - Minimum frequency (Hz) for coordinate remapping. - max_freq : int, default=MAX_FREQ - Maximum frequency (Hz) for coordinate remapping. - """ + """Initialize the Postprocessor.""" self.targets = targets + self.preprocessor = preprocessor self.config = config - self.min_freq = min_freq - self.max_freq = max_freq - def get_feature_arrays( + def get_detections( self, output: ModelOutput, - clips: List[data.Clip], - ) -> List[xr.DataArray]: - """Extract and remap raw feature tensors for a batch. + clips: Optional[List[data.Clip]] = None, + ) -> List[Detections]: + width = output.detection_probs.shape[-1] + duration = width / self.preprocessor.output_samplerate + max_detections = int(self.config.top_k_per_sec * duration) - Parameters - ---------- - output : ModelOutput - Raw model output containing `output.features` tensor for the batch. - clips : List[data.Clip] - List of Clip objects corresponding to the batch items. - - Returns - ------- - List[xr.DataArray] - List of coordinate-aware feature DataArrays, one per clip. - - Raises - ------ - ValueError - If batch sizes of `output.features` and `clips` do not match. - """ - if len(clips) != len(output.features): - raise ValueError( - "Number of clips and batch size of feature array" - "do not match. " - f"(clips: {len(clips)}, features: {len(output.features)})" - ) - - return [ - features_to_xarray( - feats, - start_time=clip.start_time, - end_time=clip.end_time, - min_freq=self.min_freq, - max_freq=self.max_freq, - ) - for feats, clip in zip(output.features, clips) - ] - - def get_detection_arrays( - self, - output: ModelOutput, - clips: List[data.Clip], - ) -> List[xr.DataArray]: - """Apply NMS and remap detection heatmaps for a batch. - - Parameters - ---------- - output : ModelOutput - Raw model output containing `output.detection_probs` tensor for the - batch. - clips : List[data.Clip] - List of Clip objects corresponding to the batch items. - - Returns - ------- - List[xr.DataArray] - List of NMS-applied, coordinate-aware detection heatmaps, one per - clip. - - Raises - ------ - ValueError - If batch sizes of `output.detection_probs` and `clips` do not match. - """ - detections = output.detection_probs - - if len(clips) != len(output.detection_probs): - raise ValueError( - "Number of clips and batch size of detection array " - "do not match. " - f"(clips: {len(clips)}, detection: {len(detections)})" - ) - - detections = non_max_suppression( - detections, - kernel_size=self.config.nms_kernel_size, + detections = extract_prediction_tensor( + output, + max_detections=max_detections, + threshold=self.config.detection_threshold, ) - return [ - detection_to_xarray( - dets, - start_time=clip.start_time, - end_time=clip.end_time, - min_freq=self.min_freq, - max_freq=self.max_freq, - ) - for dets, clip in zip(detections, clips) - ] - - def get_classification_arrays( - self, output: ModelOutput, clips: List[data.Clip] - ) -> List[xr.DataArray]: - """Extract and remap raw classification tensors for a batch. - - Parameters - ---------- - output : ModelOutput - Raw model output containing `output.class_probs` tensor for the - batch. - clips : List[data.Clip] - List of Clip objects corresponding to the batch items. - - Returns - ------- - List[xr.DataArray] - List of coordinate-aware class probability maps, one per clip. - - Raises - ------ - ValueError - If batch sizes of `output.class_probs` and `clips` do not match, or - if number of classes mismatches `self.targets.class_names`. - """ - classifications = output.class_probs - - if len(clips) != len(classifications): - raise ValueError( - "Number of clips and batch size of classification array " - "do not match. " - f"(clips: {len(clips)}, classification: {len(classifications)})" - ) + if clips is None: + return detections return [ - classification_to_xarray( - class_probs, + map_detection_to_clip( + detection, start_time=clip.start_time, end_time=clip.end_time, - class_names=self.targets.class_names, - min_freq=self.min_freq, - max_freq=self.max_freq, + min_freq=self.preprocessor.min_freq, + max_freq=self.preprocessor.max_freq, ) - for class_probs, clip in zip(classifications, clips) + for detection, clip in zip(detections, clips) ] - def get_sizes_arrays( - self, output: ModelOutput, clips: List[data.Clip] - ) -> List[xr.DataArray]: - """Extract and remap raw size prediction tensors for a batch. - - Parameters - ---------- - output : ModelOutput - Raw model output containing `output.size_preds` tensor for the - batch. - clips : List[data.Clip] - List of Clip objects corresponding to the batch items. - - Returns - ------- - List[xr.DataArray] - List of coordinate-aware size prediction maps, one per clip. - - Raises - ------ - ValueError - If batch sizes of `output.size_preds` and `clips` do not match. - """ - sizes = output.size_preds - - if len(clips) != len(sizes): - raise ValueError( - "Number of clips and batch size of sizes array do not match. " - f"(clips: {len(clips)}, sizes: {len(sizes)})" - ) - - return [ - sizes_to_xarray( - size_preds, - start_time=clip.start_time, - end_time=clip.end_time, - min_freq=self.min_freq, - max_freq=self.max_freq, - ) - for size_preds, clip in zip(sizes, clips) - ] - - def get_detection_datasets( - self, output: ModelOutput, clips: List[data.Clip] - ) -> List[xr.Dataset]: - """Perform NMS, remapping, detection, and data extraction for a batch. - - Parameters - ---------- - output : ModelOutput - Raw output from the neural network model for a batch. - clips : List[data.Clip] - List of `soundevent.data.Clip` objects corresponding to the batch. - - Returns - ------- - List[xr.Dataset] - List of xarray Datasets (one per clip). Each Dataset contains - aligned scores, dimensions, class probabilities, and features for - detections found in that clip. - """ - detection_arrays = self.get_detection_arrays(output, clips) - classification_arrays = self.get_classification_arrays(output, clips) - size_arrays = self.get_sizes_arrays(output, clips) - features_arrays = self.get_feature_arrays(output, clips) - - datasets = [] - for det_array, class_array, sizes_array, feats_array in zip( - detection_arrays, - classification_arrays, - size_arrays, - features_arrays, - ): - max_detections = get_max_detections( - det_array, - top_k_per_sec=self.config.top_k_per_sec, - ) - - positions = extract_detections_from_array( - det_array, - max_detections=max_detections, - threshold=self.config.detection_threshold, - ) - - datasets.append( - extract_detection_xr_dataset( - positions, - sizes_array, - class_array, - feats_array, - ) - ) - - return datasets - def get_raw_predictions( - self, output: ModelOutput, clips: List[data.Clip] + self, + output: ModelOutput, + clips: List[data.Clip], ) -> List[List[RawPrediction]]: """Extract intermediate RawPrediction objects for a batch. @@ -531,13 +216,13 @@ class Postprocessor(PostprocessorProtocol): List of lists (one inner list per input clip). Each inner list contains `RawPrediction` objects for detections in that clip. """ - detection_datasets = self.get_detection_datasets(output, clips) + detections = self.get_detections(output, clips) return [ - convert_xr_dataset_to_raw_prediction( + convert_detections_to_raw_predictions( dataset, - self.targets.decode_roi, + targets=self.targets, ) - for dataset in detection_datasets + for dataset in detections ] def get_sound_event_predictions( diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index d105080..499b2de 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -1,42 +1,18 @@ -"""Decodes extracted detection data into standard soundevent predictions. - -This module handles the final stages of the BatDetect2 postprocessing pipeline. -It takes the structured detection data extracted by the `extraction` module -(typically an `xarray.Dataset` containing scores, positions, predicted sizes, - class probabilities, and features for each detection point) and converts it -into standardized prediction objects based on the `soundevent` data model. - -The process involves: -1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction` - objects, using a configured geometry builder to recover bounding boxes from - predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`). -2. Converting each `RawPrediction` into a - `soundevent.data.SoundEventPrediction`, which involves: - - Creating the `soundevent.data.SoundEvent` with geometry and features. - - Decoding the predicted class probabilities into representative tags using - a configured class decoder (`SoundEventDecoder`). - - Applying a classification threshold. - - Optionally selecting only the single highest-scoring class (top-1) or - including tags for all classes above the threshold (multi-label). - - Adding generic class tags as a baseline. - - Associating scores with the final prediction and tags. - (`convert_raw_prediction_to_sound_event_prediction`) -3. Grouping the `SoundEventPrediction` objects for a given audio segment into - a `soundevent.data.ClipPrediction` - (`convert_raw_predictions_to_clip_prediction`). -""" +"""Decodes extracted detection data into standard soundevent predictions.""" from typing import List, Optional import numpy as np -import xarray as xr from soundevent import data -from batdetect2.typing.postprocess import GeometryDecoder, RawPrediction +from batdetect2.typing.postprocess import ( + Detections, + RawPrediction, +) from batdetect2.typing.targets import TargetProtocol __all__ = [ - "convert_xr_dataset_to_raw_prediction", + "convert_detections_to_raw_predictions", "convert_raw_predictions_to_clip_prediction", "convert_raw_prediction_to_sound_event_prediction", "DEFAULT_CLASSIFICATION_THRESHOLD", @@ -51,65 +27,29 @@ decoding. """ -def convert_xr_dataset_to_raw_prediction( - detection_dataset: xr.Dataset, - geometry_decoder: GeometryDecoder, +def convert_detections_to_raw_predictions( + detections: Detections, + targets: TargetProtocol, ) -> List[RawPrediction]: - """Convert an xarray.Dataset of detections to RawPrediction objects. - - Takes the output of the extraction step (`extract_detection_xr_dataset`) - and transforms each detection entry into an intermediate `RawPrediction` - object. This involves recovering the geometry (e.g., bounding box) from - the predicted position and scaled size dimensions using the provided - `geometry_builder` function. - - Parameters - ---------- - detection_dataset : xr.Dataset - An xarray Dataset containing aligned detection information, typically - output by `extract_detection_xr_dataset`. Expected variables include - 'scores' (with time/freq coords), 'dimensions', 'classes', 'features'. - Must have a 'detection' dimension. - geometry_decoder : GeometryDecoder - A function that takes a position tuple `(time, freq)` and a NumPy array - of dimensions, and returns the corresponding reconstructed - `soundevent.data.Geometry`. - - Returns - ------- - List[RawPrediction] - A list of `RawPrediction` objects, each containing the detection score, - recovered bounding box coordinates (start/end time, low/high freq), - the vector of class scores, and the feature vector for one detection. - - Raises - ------ - AttributeError, KeyError, ValueError - If `detection_dataset` is missing expected variables ('scores', - 'dimensions', 'classes', 'features') or coordinates ('time', 'freq' - associated with 'scores'), or if `geometry_builder` fails. - """ - detections = [] - - categories = detection_dataset.category.values + predictions = [] for score, class_scores, time, freq, dims, feats in zip( - detection_dataset["scores"].values, - detection_dataset["classes"].values, - detection_dataset["time"].values, - detection_dataset["frequency"].values, - detection_dataset["dimensions"].values, - detection_dataset["features"].values, + detections.scores, + detections.class_scores, + detections.times, + detections.frequencies, + detections.sizes, + detections.features, ): - highest_scoring_class = categories[class_scores.argmax()] + highest_scoring_class = targets.class_names[class_scores.argmax()] - geom = geometry_decoder( + geom = targets.decode_roi( (time, freq), dims, class_name=highest_scoring_class, ) - detections.append( + predictions.append( RawPrediction( detection_score=score, geometry=geom, @@ -118,7 +58,7 @@ def convert_xr_dataset_to_raw_prediction( ) ) - return detections + return predictions def convert_raw_predictions_to_clip_prediction( @@ -128,35 +68,7 @@ def convert_raw_predictions_to_clip_prediction( classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only: bool = False, ) -> data.ClipPrediction: - """Convert a list of RawPredictions into a soundevent ClipPrediction. - - Iterates through `raw_predictions` (assumed to belong to a single clip), - converts each one into a `soundevent.data.SoundEventPrediction` using - `convert_raw_prediction_to_sound_event_prediction`, and packages them - into a `soundevent.data.ClipPrediction` associated with the original `clip`. - - Parameters - ---------- - raw_predictions : List[RawPrediction] - List of raw prediction objects for a single clip. - clip : data.Clip - The original `soundevent.data.Clip` object these predictions belong to. - sound_event_decoder : SoundEventDecoder - Function to decode class names into representative tags. - generic_class_tags : List[data.Tag] - List of tags representing the generic class category. - classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD - Threshold applied to class scores during decoding. - top_class_only : bool, default=False - If True, only decode tags for the single highest-scoring class above - the threshold. If False, decode tags for all classes above threshold. - - Returns - ------- - data.ClipPrediction - A `ClipPrediction` object containing a list of `SoundEventPrediction` - objects corresponding to the input `raw_predictions`. - """ + """Convert a list of RawPredictions into a soundevent ClipPrediction.""" return data.ClipPrediction( clip=clip, sound_events=[ @@ -181,68 +93,7 @@ def convert_raw_prediction_to_sound_event_prediction( ] = DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only: bool = False, ): - """Convert a single RawPrediction into a soundevent SoundEventPrediction. - - This function performs the core decoding steps for a single detected event: - 1. Creates a `soundevent.data.SoundEvent` containing the geometry - (BoundingBox derived from `raw_prediction` bounds) and any associated - feature vectors. - 2. Initializes a list of predicted tags using the provided - `generic_class_tags`, assigning the overall `detection_score` from the - `raw_prediction` to these generic tags. - 3. Processes the `class_scores` from the `raw_prediction`: - a. Optionally filters out scores below `classification_threshold` - (if it's not None). - b. Sorts the remaining scores in descending order. - c. Iterates through the sorted, thresholded class scores. - d. For each class, uses the `sound_event_decoder` to get the - representative base tags for that class name. - e. Wraps these base tags in `soundevent.data.PredictedTag`, associating - the specific `score` of that class prediction. - f. Appends these specific predicted tags to the list. - g. If `top_class_only` is True, stops after processing the first - (highest-scoring) class that passed the threshold. - 4. Creates and returns the final `soundevent.data.SoundEventPrediction`, - associating the `SoundEvent`, the overall `detection_score`, and the - compiled list of `PredictedTag` objects. - - Parameters - ---------- - raw_prediction : RawPrediction - The raw prediction object containing score, bounds, class scores, - features. Assumes `class_scores` is an `xr.DataArray` with a 'category' - coordinate. Assumes `features` is an `xr.DataArray` with a 'feature' - coordinate. - recording : data.Recording - The recording the sound event belongs to. - sound_event_decoder : SoundEventDecoder - Configured function mapping class names (str) to lists of base - `data.Tag` objects. - generic_class_tags : List[data.Tag] - List of base tags representing the generic category. - classification_threshold : float, optional - The minimum score a class prediction must have to be considered - significant enough to have its tags decoded and added. If None, no - thresholding is applied based on class score (all predicted classes, - or the top one if `top_class_only` is True, will be processed). - Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`. - top_class_only : bool, default=False - If True, only includes tags for the single highest-scoring class that - exceeds the threshold. If False (default), includes tags for all classes - exceeding the threshold. - - Returns - ------- - data.SoundEventPrediction - The fully formed sound event prediction object. - - Raises - ------ - ValueError - If `raw_prediction.features` has unexpected structure or if - `data.term_from_key` (if used internally) fails. - If `sound_event_decoder` fails for a class name and errors are raised. - """ + """Convert a single RawPrediction into a soundevent SoundEventPrediction.""" sound_event = data.SoundEvent( recording=recording, geometry=raw_prediction.geometry, @@ -273,25 +124,7 @@ def get_generic_tags( detection_score: float, generic_class_tags: List[data.Tag], ) -> List[data.PredictedTag]: - """Create PredictedTag objects for the generic category. - - Takes the base list of generic tags and assigns the overall detection - score to each one, wrapping them in `PredictedTag` objects. - - Parameters - ---------- - detection_score : float - The overall confidence score of the detection event. - generic_class_tags : List[data.Tag] - The list of base `soundevent.data.Tag` objects that define the - generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']). - - Returns - ------- - List[data.PredictedTag] - A list of `PredictedTag` objects for the generic category, each - assigned the `detection_score`. - """ + """Create PredictedTag objects for the generic category.""" return [ data.PredictedTag(tag=tag, score=detection_score) for tag in generic_class_tags @@ -299,25 +132,7 @@ def get_generic_tags( def get_prediction_features(features: np.ndarray) -> List[data.Feature]: - """Convert an extracted feature vector DataArray into soundevent Features. - - Parameters - ---------- - features : xr.DataArray - A 1D xarray DataArray containing feature values, indexed by a coordinate - named 'feature' which holds the feature names (e.g., output of selecting - features for one detection from `extract_detection_xr_dataset`). - - Returns - ------- - List[data.Feature] - A list of `soundevent.data.Feature` objects. - - Notes - ----- - - This function creates basic `Term` objects using the feature coordinate - names with a "batdetect2:" prefix. - """ + """Convert an extracted feature vector DataArray into soundevent Features.""" return [ data.Feature( term=data.Term( diff --git a/src/batdetect2/postprocess/detection.py b/src/batdetect2/postprocess/detection.py deleted file mode 100644 index 9b2a185..0000000 --- a/src/batdetect2/postprocess/detection.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Extracts candidate detection points from a model output heatmap. - -This module implements Step 3 within the BatDetect2 postprocessing -pipeline. Its primary function is to identify potential sound event locations -by finding peaks (local maxima or high-scoring points) in the detection heatmap -produced by the neural network (usually after Non-Maximum Suppression and -coordinate remapping have been applied). - -It provides functionality to: -- Identify the locations (time, frequency) of the highest-scoring points. -- Filter these points based on a minimum confidence score threshold. -- Limit the maximum number of detection points returned (top-k). - -The main output is an `xarray.DataArray` containing the scores and -corresponding time/frequency coordinates for the extracted detection points. -This output serves as the input for subsequent postprocessing steps, such as -extracting predicted class probabilities and bounding box sizes at these -specific locations. -""" - -from typing import Optional - -import numpy as np -import xarray as xr -from soundevent.arrays import Dimensions, get_dim_width - -__all__ = [ - "extract_detections_from_array", - "get_max_detections", - "DEFAULT_DETECTION_THRESHOLD", - "TOP_K_PER_SEC", -] - -DEFAULT_DETECTION_THRESHOLD = 0.01 -"""Default confidence score threshold used for filtering detections.""" - -TOP_K_PER_SEC = 200 -"""Default desired maximum number of detections per second of audio.""" - - -def extract_detections_from_array( - detection_array: xr.DataArray, - max_detections: Optional[int] = None, - threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD, -) -> xr.DataArray: - """Extract detection locations (time, freq) and scores from a heatmap. - - Identifies the pixels with the highest scores in the input detection - heatmap, filters them based on an optional score `threshold`, limits the - number to an optional `max_detections`, and returns their scores along with - their corresponding time and frequency coordinates. - - Parameters - ---------- - detection_array : xr.DataArray - A 2D xarray DataArray representing the detection heatmap. Must have - dimensions and coordinates named 'time' and 'frequency'. Higher values - are assumed to indicate higher detection confidence. - max_detections : int, optional - The absolute maximum number of detections to return. If specified, only - the top `max_detections` highest-scoring detections (passing the - threshold) are returned. If None (default), all detections passing - the threshold are returned, sorted by score. - threshold : float, optional - The minimum confidence score required for a detection peak to be - kept. Detections with scores below this value are discarded. - Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no - thresholding is applied. - - Returns - ------- - xr.DataArray - A 1D xarray DataArray named 'score' with a 'detection' dimension. - - The data values are the scores of the extracted detections, sorted - in descending order. - - It has coordinates 'time' and 'frequency' (also indexed by the - 'detection' dimension) indicating the location of each detection - peak in the original coordinate system. - - Returns an empty DataArray if no detections pass the criteria. - - Raises - ------ - ValueError - If `max_detections` is not None and not a positive integer, or if - `detection_array` lacks required dimensions/coordinates. - """ - if max_detections is not None: - if max_detections <= 0: - raise ValueError("Max detections must be positive") - - values = detection_array.values.flatten() - - if max_detections is not None: - top_indices = np.argpartition(-values, max_detections)[:max_detections] - top_sorted_indices = top_indices[np.argsort(-values[top_indices])] - else: - top_sorted_indices = np.argsort(-values) - - top_values = values[top_sorted_indices] - - if threshold is not None: - mask = top_values > threshold - top_values = top_values[mask] - top_sorted_indices = top_sorted_indices[mask] - - freq_indices, time_indices = np.unravel_index( - top_sorted_indices, - detection_array.shape, - ) - - times = detection_array.coords[Dimensions.time.value].values[time_indices] - freqs = detection_array.coords[Dimensions.frequency.value].values[ - freq_indices - ] - - return xr.DataArray( - data=top_values, - coords={ - Dimensions.frequency.value: ("detection", freqs), - Dimensions.time.value: ("detection", times), - }, - dims="detection", - name="score", - ) - - -def get_max_detections( - detection_array: xr.DataArray, - top_k_per_sec: int = TOP_K_PER_SEC, -) -> int: - """Calculate max detections allowed based on duration and rate. - - Determines the total maximum number of detections to extract from a - heatmap based on its time duration and a desired rate of detections - per second. - - Parameters - ---------- - detection_array : xr.DataArray - The detection heatmap, requiring 'time' coordinates from which the - total duration can be calculated using - `soundevent.arrays.get_dim_width`. - top_k_per_sec : int, default=TOP_K_PER_SEC - The desired maximum number of detections to allow per second of audio. - - Returns - ------- - int - The calculated total maximum number of detections allowed for the - entire duration of the `detection_array`. - - Raises - ------ - ValueError - If the duration cannot be calculated from the `detection_array` (e.g., - missing or invalid 'time' coordinates/dimension). - """ - if top_k_per_sec < 0: - raise ValueError("top_k_per_sec cannot be negative.") - - duration = get_dim_width(detection_array, Dimensions.time.value) - return int(duration * top_k_per_sec) diff --git a/src/batdetect2/postprocess/extraction.py b/src/batdetect2/postprocess/extraction.py index 2809ab7..361d936 100644 --- a/src/batdetect2/postprocess/extraction.py +++ b/src/batdetect2/postprocess/extraction.py @@ -15,108 +15,73 @@ precise time-frequency location of each detection. The final output aggregates all extracted information into a structured `xarray.Dataset`. """ -import xarray as xr -from soundevent.arrays import Dimensions +from typing import List, Optional, Tuple, Union + +import torch + +from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression +from batdetect2.typing.postprocess import Detections, ModelOutput __all__ = [ - "extract_values_at_positions", - "extract_detection_xr_dataset", + "extract_prediction_tensor", ] -def extract_values_at_positions( - array: xr.DataArray, - positions: xr.DataArray, -) -> xr.DataArray: - """Extract values from an array at specified time-frequency positions. - - Uses coordinate-based indexing to retrieve values from a source `array` - (e.g., class probabilities, size predictions, features) at the time and - frequency coordinates defined in the `positions` array. - - Parameters - ---------- - array : xr.DataArray - The source DataArray from which to extract values. Must have 'time' - and 'frequency' dimensions and coordinates matching the space of - `positions`. - positions : xr.DataArray - A 1D DataArray whose 'time' and 'frequency' coordinates specify the - locations from which to extract values. - - Returns - ------- - xr.DataArray - A DataArray containing the values extracted from `array` at the given - positions. - - Raises - ------ - ValueError, IndexError, KeyError - If dimensions or coordinates are missing or incompatible between - `array` and `positions`, or if selection fails. - """ - return array.sel( - **{ - Dimensions.frequency.value: positions.coords[ - Dimensions.frequency.value - ], - Dimensions.time.value: positions.coords[Dimensions.time.value], - } - ).T - - -def extract_detection_xr_dataset( - positions: xr.DataArray, - sizes: xr.DataArray, - classes: xr.DataArray, - features: xr.DataArray, -) -> xr.Dataset: - """Combine extracted detection information into a structured xr.Dataset. - - Takes the detection positions/scores and the full model output heatmaps - (sizes, classes, optional features), extracts the relevant data at the - detection positions, and packages everything into a single `xarray.Dataset` - where all variables are indexed by a common 'detection' dimension. - - Parameters - ---------- - positions : xr.DataArray - Output from `extract_detections_from_array`, containing detection - scores as data and 'time', 'frequency' coordinates along the - 'detection' dimension. - sizes : xr.DataArray - The full size prediction heatmap from the model, with dimensions like - ('dimension', 'time', 'frequency'). - classes : xr.DataArray - The full class probability heatmap from the model, with dimensions like - ('category', 'time', 'frequency'). - features : xr.DataArray - The full feature map from the model, with - dimensions like ('feature', 'time', 'frequency'). - - Returns - ------- - xr.Dataset - An xarray Dataset containing aligned information for each detection: - - 'scores': DataArray from `positions` (score data, time/freq coords). - - 'dimensions': DataArray with extracted size values - (dims: 'detection', 'dimension'). - - 'classes': DataArray with extracted class probabilities - (dims: 'detection', 'category'). - - 'features': DataArray with extracted feature vectors - (dims: 'detection', 'feature'), if `features` was provided. All - DataArrays share the 'detection' dimension and associated - time/frequency coordinates. - """ - sizes = extract_values_at_positions(sizes, positions) - classes = extract_values_at_positions(classes, positions) - features = extract_values_at_positions(features, positions) - return xr.Dataset( - { - "scores": positions, - "dimensions": sizes, - "classes": classes, - "features": features, - } +def extract_prediction_tensor( + output: ModelOutput, + max_detections: int = 200, + threshold: Optional[float] = None, + nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, +) -> List[Detections]: + detection_heatmap = non_max_suppression( + output.detection_probs, + kernel_size=nms_kernel_size, ) + + height = detection_heatmap.shape[-2] + width = detection_heatmap.shape[-1] + + freqs, times = torch.meshgrid( + torch.arange(height, dtype=torch.int32), + torch.arange(width, dtype=torch.int32), + indexing="ij", + ) + + freqs = freqs.flatten() + times = times.flatten() + + predictions = [] + for idx, item in enumerate(detection_heatmap): + item = item.squeeze().flatten() # Remove channel dim + indices = torch.argsort(item, descending=True)[:max_detections] + + detection_scores = item.take(indices) + detection_freqs = freqs.take(indices) + detection_times = times.take(indices) + sizes = output.size_preds[idx, :, detection_freqs, detection_times].T + features = output.features[idx, :, detection_freqs, detection_times].T + class_scores = output.class_probs[ + idx, :, detection_freqs, detection_times + ].T + + if threshold is not None: + mask = detection_scores >= threshold + detection_scores = detection_scores[mask] + sizes = sizes[mask] + detection_times = detection_times[mask] + detection_freqs = detection_freqs[mask] + features = features[mask] + class_scores = class_scores[mask] + + predictions.append( + Detections( + scores=detection_scores, + sizes=sizes, + features=features, + class_scores=class_scores, + times=detection_times.to(torch.float32) / width, + frequencies=(detection_freqs.to(torch.float32) / height), + ) + ) + + return predictions diff --git a/src/batdetect2/postprocess/remapping.py b/src/batdetect2/postprocess/remapping.py index 7112046..6e1a02e 100644 --- a/src/batdetect2/postprocess/remapping.py +++ b/src/batdetect2/postprocess/remapping.py @@ -20,6 +20,7 @@ import xarray as xr from soundevent.arrays import Dimensions from batdetect2.preprocess import MAX_FREQ, MIN_FREQ +from batdetect2.typing.postprocess import Detections __all__ = [ "features_to_xarray", @@ -29,6 +30,26 @@ __all__ = [ ] +def map_detection_to_clip( + detections: Detections, + start_time: float, + end_time: float, + min_freq: float, + max_freq: float, +) -> Detections: + duration = end_time - start_time + bandwidth = max_freq - min_freq + print(f"{bandwidth=} {min_freq=} {detections.frequencies=}") + return Detections( + scores=detections.scores, + sizes=detections.sizes, + features=detections.features, + class_scores=detections.class_scores, + times=(detections.times * duration + start_time), + frequencies=(detections.frequencies * bandwidth + min_freq), + ) + + def features_to_xarray( features: torch.Tensor, start_time: float, diff --git a/src/batdetect2/preprocess/__init__.py b/src/batdetect2/preprocess/__init__.py index f8df745..7da0725 100644 --- a/src/batdetect2/preprocess/__init__.py +++ b/src/batdetect2/preprocess/__init__.py @@ -53,6 +53,7 @@ from batdetect2.preprocess.spectrogram import ( SpectrogramConfig, SpectrogramPipeline, STFTConfig, + _spec_params_from_config, build_spectrogram_builder, build_spectrogram_pipeline, ) @@ -109,7 +110,9 @@ def load_preprocessing_config( class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol): """Standard implementation of the `Preprocessor` protocol.""" - samplerate: int + input_samplerate: int + output_samplerate: float + max_freq: float min_freq: float @@ -117,22 +120,33 @@ class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol): self, audio_pipeline: torch.nn.Module, spectrogram_pipeline: SpectrogramPipeline, - samplerate: int, + input_samplerate: int, + output_samplerate: float, max_freq: float, min_freq: float, ) -> None: super().__init__() self.audio_pipeline = audio_pipeline self.spectrogram_pipeline = spectrogram_pipeline - self.samplerate = samplerate + self.max_freq = max_freq self.min_freq = min_freq + self.input_samplerate = input_samplerate + self.output_samplerate = output_samplerate + def forward(self, wav: torch.Tensor) -> torch.Tensor: wav = self.audio_pipeline(wav) return self.spectrogram_pipeline(wav) +def compute_output_samplerate(config: PreprocessingConfig) -> float: + samplerate = config.audio.samplerate + _, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft) + factor = config.spectrogram.size.resize_factor + return samplerate * factor / hop_size + + def build_preprocessor( config: Optional[PreprocessingConfig] = None, ) -> PreprocessorProtocol: @@ -148,16 +162,15 @@ def build_preprocessor( min_freq = config.spectrogram.frequencies.min_freq max_freq = config.spectrogram.frequencies.max_freq + output_samplerate = compute_output_samplerate(config) + return StandardPreprocessor( audio_pipeline=build_audio_pipeline(config.audio), spectrogram_pipeline=build_spectrogram_pipeline( samplerate, config.spectrogram ), - samplerate=samplerate, + input_samplerate=samplerate, + output_samplerate=output_samplerate, min_freq=min_freq, max_freq=max_freq, ) - - -def get_default_preprocessor(): - return build_preprocessor() diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 20c1de6..b89ae97 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -148,7 +148,7 @@ def add_echo( """Add a synthetic echo to the audio waveform.""" audio = example.audio - delay_steps = int(preprocessor.samplerate * delay) + delay_steps = int(preprocessor.input_samplerate * delay) audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1]) audio = audio + weight * audio_delay diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index befbf06..67b0ad5 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -1,10 +1,12 @@ from typing import Optional, Tuple import numpy as np +import torch from loguru import logger from batdetect2.configs import BaseConfig from batdetect2.typing import ClipperProtocol +from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.train import PreprocessedExample from batdetect2.utils.arrays import adjust_width @@ -18,24 +20,26 @@ class ClipingConfig(BaseConfig): max_empty: float = DEFAULT_MAX_EMPTY_CLIP -class Clipper(ClipperProtocol): +class Clipper(torch.nn.Module): def __init__( self, - samplerate: int, + preprocessor: PreprocessorProtocol, duration: float = 0.5, max_empty: float = 0.2, random: bool = True, ): - self.samplerate = samplerate + super().__init__() + self.preprocessor = preprocessor self.duration = duration self.random = random self.max_empty = max_empty - def extract_clip( - self, example: PreprocessedExample + def forward( + self, + example: PreprocessedExample, ) -> Tuple[PreprocessedExample, float, float]: start_time = 0 - duration = example.audio.shape[-1] / self.samplerate + duration = example.audio.shape[-1] / self.preprocessor.input_samplerate if self.random: start_time = np.random.uniform( @@ -48,7 +52,8 @@ class Clipper(ClipperProtocol): example, start=start_time, duration=self.duration, - samplerate=self.samplerate, + input_samplerate=self.preprocessor.input_samplerate, + output_samplerate=self.preprocessor.output_samplerate, ), start_time, start_time + self.duration, @@ -56,7 +61,7 @@ class Clipper(ClipperProtocol): def build_clipper( - samplerate: int, + preprocessor: PreprocessorProtocol, config: Optional[ClipingConfig] = None, random: Optional[bool] = None, ) -> ClipperProtocol: @@ -66,7 +71,7 @@ def build_clipper( lambda: config.to_yaml_string(), ) return Clipper( - samplerate=samplerate, + preprocessor=preprocessor, duration=config.duration, max_empty=config.max_empty, random=config.random if random else False, @@ -77,11 +82,12 @@ def select_subclip( example: PreprocessedExample, start: float, duration: float, - samplerate: float, + input_samplerate: float, + output_samplerate: float, fill_value: float = 0, ) -> PreprocessedExample: - audio_width = int(np.floor(duration * samplerate)) - audio_start = int(np.floor(start * samplerate)) + audio_width = int(np.floor(duration * input_samplerate)) + audio_start = int(np.floor(start * input_samplerate)) audio = adjust_width( example.audio[audio_start : audio_start + audio_width], @@ -89,12 +95,8 @@ def select_subclip( value=fill_value, ) - audio_duration = example.audio.shape[-1] / samplerate - spec_sr = example.spectrogram.shape[-1] / audio_duration - - spec_start = int(np.floor(start * spec_sr)) - spec_width = int(np.floor(duration * spec_sr)) - + spec_start = int(np.floor(start * output_samplerate)) + spec_width = int(np.floor(duration * output_samplerate)) return PreprocessedExample( audio=audio, spectrogram=adjust_width( diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 58ecfa5..ae41bd0 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -32,7 +32,7 @@ class LabeledDataset(Dataset): def __getitem__(self, idx) -> TrainExample: example = self.get_example(idx) - example, start_time, end_time = self.clipper.extract_clip(example) + example, start_time, end_time = self.clipper(example) if self.augmentation: example = self.augmentation(example) @@ -64,9 +64,7 @@ class LabeledDataset(Dataset): def get_random_example(self) -> Tuple[PreprocessedExample, float, float]: idx = np.random.randint(0, len(self)) dataset = self.get_example(idx) - - dataset, start_time, end_time = self.clipper.extract_clip(dataset) - + dataset, start_time, end_time = self.clipper(dataset) return dataset, start_time, end_time def get_example(self, idx) -> PreprocessedExample: @@ -107,5 +105,5 @@ class RandomExampleSource: index = int(np.random.randint(len(self.filenames))) filename = self.filenames[index] example = load_preprocessed_example(filename) - example, _, _ = self.clipper.extract_clip(example) + example, _, _ = self.clipper(example) return example diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 20f37b1..9da168c 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -229,7 +229,7 @@ def build_train_dataset( config = config or TrainingConfig() clipper = build_clipper( - samplerate=preprocessor.samplerate, + preprocessor=preprocessor, config=config.cliping, random=True, ) @@ -265,7 +265,7 @@ def build_val_dataset( logger.info("Building validation dataset...") config = config or TrainingConfig() clipper = build_clipper( - samplerate=preprocessor.samplerate, + preprocessor=preprocessor, config=config.cliping, random=train, ) diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py index 9aeca94..dbf23da 100644 --- a/src/batdetect2/typing/postprocess.py +++ b/src/batdetect2/typing/postprocess.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from typing import List, NamedTuple, Optional, Protocol import numpy as np -import xarray as xr +import torch from soundevent import data from batdetect2.typing.models import ModelOutput @@ -77,6 +77,15 @@ class RawPrediction(NamedTuple): features: np.ndarray +class Detections(NamedTuple): + scores: torch.Tensor + sizes: torch.Tensor + class_scores: torch.Tensor + times: torch.Tensor + frequencies: torch.Tensor + features: torch.Tensor + + @dataclass class BatDetect2Prediction: raw: RawPrediction @@ -84,154 +93,13 @@ class BatDetect2Prediction: class PostprocessorProtocol(Protocol): - """Protocol defining the interface for the full postprocessing pipeline. + """Protocol defining the interface for the full postprocessing pipeline.""" - This protocol outlines the standard methods for an object that takes raw - output from a BatDetect2 model and the corresponding input clip metadata, - and processes it through various stages (e.g., coordinate remapping, NMS, - detection extraction, data extraction, decoding) to produce interpretable - results at different levels of completion. - - Implementations manage the configured logic for all postprocessing steps. - """ - - def get_feature_arrays( + def get_detections( self, output: ModelOutput, - clips: List[data.Clip], - ) -> List[xr.DataArray]: - """Remap feature tensors to coordinate-aware DataArrays. - - Parameters - ---------- - output : ModelOutput - The raw output from the neural network model for a batch, expected - to contain the necessary feature tensors. - clips : List[data.Clip] - A list of `soundevent.data.Clip` objects, one for each item in the - processed batch. This list provides the timing, recording, and - other metadata context needed to calculate real-world coordinates - (seconds, Hz) for the output arrays. The length of this list must - correspond to the batch size of the `output`. - - Returns - ------- - List[xr.DataArray] - A list of xarray DataArrays, one for each input clip in the batch, - in the same order. Each DataArray contains the feature vectors - with dimensions like ('feature', 'time', 'frequency') and - corresponding real-world coordinates. - """ - ... - - def get_detection_arrays( - self, - output: ModelOutput, - clips: List[data.Clip], - ) -> List[xr.DataArray]: - """Remap detection tensors to coordinate-aware DataArrays. - - Parameters - ---------- - output : ModelOutput - The raw output from the neural network model for a batch, - containing detection heatmaps. - clips : List[data.Clip] - A list of `soundevent.data.Clip` objects corresponding to the batch - items, providing coordinate context. Must match the batch size of - `output`. - - Returns - ------- - List[xr.DataArray] - A list of 2D xarray DataArrays (one per input clip, in order), - representing the detection heatmap with 'time' and 'frequency' - coordinates. Values typically indicate detection confidence. - """ - ... - - def get_classification_arrays( - self, - output: ModelOutput, - clips: List[data.Clip], - ) -> List[xr.DataArray]: - """Remap classification tensors to coordinate-aware DataArrays. - - Parameters - ---------- - output : ModelOutput - The raw output from the neural network model for a batch, - containing class probability tensors. - clips : List[data.Clip] - A list of `soundevent.data.Clip` objects corresponding to the batch - items, providing coordinate context. Must match the batch size of - `output`. - - Returns - ------- - List[xr.DataArray] - A list of 3D xarray DataArrays (one per input clip, in order), - representing class probabilities with 'category', 'time', and - 'frequency' dimensions and coordinates. - """ - ... - - def get_sizes_arrays( - self, - output: ModelOutput, - clips: List[data.Clip], - ) -> List[xr.DataArray]: - """Remap size prediction tensors to coordinate-aware DataArrays. - - Parameters - ---------- - output : ModelOutput - The raw output from the neural network model for a batch, - containing predicted size tensors (e.g., width and height). - clips : List[data.Clip] - A list of `soundevent.data.Clip` objects corresponding to the batch - items, providing coordinate context. Must match the batch size of - `output`. - - Returns - ------- - List[xr.DataArray] - A list of 3D xarray DataArrays (one per input clip, in order), - representing predicted sizes with 'dimension' - (e.g., ['width', 'height']), 'time', and 'frequency' dimensions and - coordinates. Values represent estimated detection sizes. - """ - ... - - def get_detection_datasets( - self, - output: ModelOutput, - clips: List[data.Clip], - ) -> List[xr.Dataset]: - """Perform remapping, NMS, detection, and data extraction for a batch. - - Processes the raw model output for a batch to identify detection peaks - and extract all associated information (score, position, size, class - probs, features) at those peak locations, returning a structured - dataset for each input clip in the batch. - - Parameters - ---------- - output : ModelOutput - The raw output from the neural network model for a batch. - clips : List[data.Clip] - A list of `soundevent.data.Clip` objects corresponding to the batch - items, providing context. Must match the batch size of `output`. - - Returns - ------- - List[xr.Dataset] - A list of xarray Datasets (one per input clip, in order). Each - Dataset contains multiple DataArrays ('scores', 'dimensions', - 'classes', 'features') sharing a common 'detection' dimension, - providing aligned data for each detected event in that clip. - """ - ... + clips: Optional[List[data.Clip]] = None, + ) -> List[Detections]: ... def get_raw_predictions( self, diff --git a/src/batdetect2/typing/preprocess.py b/src/batdetect2/typing/preprocess.py index 584f739..31e7603 100644 --- a/src/batdetect2/typing/preprocess.py +++ b/src/batdetect2/typing/preprocess.py @@ -148,7 +148,9 @@ class PreprocessorProtocol(Protocol): min_freq: float - samplerate: int + input_samplerate: int + + output_samplerate: float audio_pipeline: AudioPipeline diff --git a/src/batdetect2/typing/train.py b/src/batdetect2/typing/train.py index 7720a27..646f5d0 100644 --- a/src/batdetect2/typing/train.py +++ b/src/batdetect2/typing/train.py @@ -96,6 +96,6 @@ class LossProtocol(Protocol): class ClipperProtocol(Protocol): - def extract_clip( + def __call__( self, example: PreprocessedExample ) -> Tuple[PreprocessedExample, float, float]: ... diff --git a/tests/test_postprocessing/test_decoding.py b/tests/test_postprocessing/test_decoding.py index 771aa69..4bced06 100644 --- a/tests/test_postprocessing/test_decoding.py +++ b/tests/test_postprocessing/test_decoding.py @@ -10,7 +10,6 @@ from batdetect2.postprocess.decoding import ( DEFAULT_CLASSIFICATION_THRESHOLD, convert_raw_prediction_to_sound_event_prediction, convert_raw_predictions_to_clip_prediction, - convert_xr_dataset_to_raw_prediction, get_class_tags, get_generic_tags, get_prediction_features, @@ -278,63 +277,6 @@ def sample_raw_predictions() -> List[RawPrediction]: return [pred1, pred2, pred3] -def test_convert_xr_dataset_basic(sample_detection_dataset, dummy_targets): - """Test basic conversion of a dataset to RawPrediction list.""" - raw_predictions = convert_xr_dataset_to_raw_prediction( - sample_detection_dataset, - dummy_targets.decode_roi, - ) - - assert isinstance(raw_predictions, list) - assert len(raw_predictions) == 2 - - pred1 = raw_predictions[0] - assert isinstance(pred1, RawPrediction) - assert pred1.detection_score == 0.9 - - assert pred1.geometry.coordinates == [ - 20 - 7 / 2, - 300 - 16 / 2, - 20 + 7 / 2, - 300 + 16 / 2, - ] - np.testing.assert_allclose( - pred1.class_scores, - sample_detection_dataset["classes"].sel(detection=0), - ) - np.testing.assert_allclose( - pred1.features, sample_detection_dataset["features"].sel(detection=0) - ) - - pred2 = raw_predictions[1] - assert isinstance(pred2, RawPrediction) - assert pred2.detection_score == 0.8 - - assert pred2.geometry.coordinates == [ - 10 - 3 / 2, - 200 - 12 / 2, - 10 + 3 / 2, - 200 + 12 / 2, - ] - np.testing.assert_allclose( - pred2.class_scores, - sample_detection_dataset["classes"].sel(detection=1), - ) - np.testing.assert_allclose( - pred2.features, sample_detection_dataset["features"].sel(detection=1) - ) - - -def test_convert_xr_dataset_empty(empty_detection_dataset, dummy_targets): - """Test conversion of an empty dataset.""" - raw_predictions = convert_xr_dataset_to_raw_prediction( - empty_detection_dataset, - dummy_targets.decode_roi, - ) - assert isinstance(raw_predictions, list) - assert len(raw_predictions) == 0 - - def test_convert_raw_to_sound_event_basic( sample_raw_predictions, sample_recording, diff --git a/tests/test_postprocessing/test_detection.py b/tests/test_postprocessing/test_detection.py deleted file mode 100644 index 65aaad1..0000000 --- a/tests/test_postprocessing/test_detection.py +++ /dev/null @@ -1,214 +0,0 @@ -import numpy as np -import pytest -import xarray as xr -from soundevent.arrays import Dimensions - -from batdetect2.postprocess.detection import extract_detections_from_array - - -@pytest.fixture -def sample_data_array(): - """Provides a basic 3x3 DataArray. - Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30) - """ - array = xr.DataArray( - np.zeros([3, 3]), - coords={ - Dimensions.frequency.value: [100, 200, 300], - Dimensions.time.value: [10, 20, 30], - }, - dims=[ - Dimensions.frequency.value, - Dimensions.time.value, - ], - ) - - array.loc[dict(time=10, frequency=100)] = 0.005 - array.loc[dict(time=10, frequency=200)] = 0.5 - array.loc[dict(time=10, frequency=300)] = 0.03 - array.loc[dict(time=20, frequency=100)] = 0.8 - array.loc[dict(time=20, frequency=200)] = 0.02 - array.loc[dict(time=20, frequency=300)] = 0.6 - array.loc[dict(time=30, frequency=100)] = 0.04 - array.loc[dict(time=30, frequency=200)] = 0.9 - array.loc[dict(time=30, frequency=300)] = 0.7 - return array - - -@pytest.fixture -def data_array_with_nans(sample_data_array: xr.DataArray): - """Provides a 2D DataArray containing NaN values.""" - array = sample_data_array.copy() - array.loc[dict(time=10, frequency=300)] = np.nan - array.loc[dict(time=30, frequency=100)] = np.nan - return array - - -def test_basic_extraction(sample_data_array: xr.DataArray): - threshold = 0.1 - max_detections = 3 - - actual_result = extract_detections_from_array( - sample_data_array, - threshold=threshold, - max_detections=max_detections, - ) - - expected_values = np.array([0.9, 0.8, 0.7]) - expected_times = np.array([30, 20, 30]) - expected_freqs = np.array([200, 100, 300]) - expected_coords = { - Dimensions.frequency.value: ("detection", expected_freqs), - Dimensions.time.value: ("detection", expected_times), - } - expected_result = xr.DataArray( - expected_values, - coords=expected_coords, - dims="detection", - name="score", - ) - - xr.testing.assert_equal(actual_result, expected_result) - - -def test_threshold_only(sample_data_array): - input_array = sample_data_array - threshold = 0.5 - actual_result = extract_detections_from_array( - input_array, threshold=threshold - ) - expected_values = np.array([0.9, 0.8, 0.7, 0.6]) - expected_times = np.array([30, 20, 30, 20]) - expected_freqs = np.array([200, 100, 300, 300]) - expected_coords = { - Dimensions.time.value: ("detection", expected_times), - Dimensions.frequency.value: ("detection", expected_freqs), - } - expected_result = xr.DataArray( - expected_values, - coords=expected_coords, - dims="detection", - name="detection_value", - ) - xr.testing.assert_equal(actual_result, expected_result) - - -def test_max_detections_only(sample_data_array): - input_array = sample_data_array - max_detections = 4 - actual_result = extract_detections_from_array( - input_array, max_detections=max_detections - ) - expected_values = np.array([0.9, 0.8, 0.7, 0.6]) - expected_times = np.array([30, 20, 30, 20]) - expected_freqs = np.array([200, 100, 300, 300]) - expected_coords = { - Dimensions.time.value: ("detection", expected_times), - Dimensions.frequency.value: ("detection", expected_freqs), - } - expected_result = xr.DataArray( - expected_values, - coords=expected_coords, - dims="detection", - name="detection_value", - ) - xr.testing.assert_equal(actual_result, expected_result) - - -def test_no_optional_args(sample_data_array): - input_array = sample_data_array - actual_result = extract_detections_from_array(input_array) - expected_values = np.array([0.9, 0.8, 0.7, 0.6, 0.5, 0.04, 0.03, 0.02]) - expected_times = np.array([30, 20, 30, 20, 10, 30, 10, 20]) - expected_freqs = np.array([200, 100, 300, 300, 200, 100, 300, 200]) - expected_coords = { - Dimensions.time.value: ("detection", expected_times), - Dimensions.frequency.value: ("detection", expected_freqs), - } - expected_result = xr.DataArray( - expected_values, - coords=expected_coords, - dims="detection", - name="detection_value", - ) - xr.testing.assert_equal(actual_result, expected_result) - - -def test_no_values_above_threshold(sample_data_array): - input_array = sample_data_array - threshold = 1.0 - actual_result = extract_detections_from_array( - input_array, threshold=threshold - ) - expected_coords = { - Dimensions.time.value: ("detection", np.array([], dtype=np.int64)), - Dimensions.frequency.value: ( - "detection", - np.array([], dtype=np.int64), - ), - } - expected_result = xr.DataArray( - np.array([], dtype=np.float64), - coords=expected_coords, - dims="detection", - name="detection_value", - ) - xr.testing.assert_equal(actual_result, expected_result) - assert actual_result.sizes["detection"] == 0 - - -def test_max_detections_zero(sample_data_array): - input_array = sample_data_array - max_detections = 0 - with pytest.raises(ValueError): - extract_detections_from_array( - input_array, - max_detections=max_detections, - ) - - -def test_empty_input_array(): - empty_array = xr.DataArray( - np.empty((0, 0)), - coords={Dimensions.time.value: [], Dimensions.frequency.value: []}, - dims=[Dimensions.time.value, Dimensions.frequency.value], - ) - actual_result = extract_detections_from_array(empty_array) - expected_coords = { - Dimensions.time.value: ("detection", np.array([], dtype=np.int64)), - Dimensions.frequency.value: ( - "detection", - np.array([], dtype=np.int64), - ), - } - expected_result = xr.DataArray( - np.array([], dtype=np.float64), - coords=expected_coords, - dims="detection", - name="detection_value", - ) - xr.testing.assert_equal(actual_result, expected_result) - assert actual_result.sizes["detection"] == 0 - - -def test_nan_handling(data_array_with_nans): - input_array = data_array_with_nans - threshold = 0.1 - max_detections = 3 - actual_result = extract_detections_from_array( - input_array, threshold=threshold, max_detections=max_detections - ) - expected_values = np.array([0.9, 0.8, 0.7]) - expected_times = np.array([30, 20, 30]) - expected_freqs = np.array([200, 100, 300]) - expected_coords = { - Dimensions.time.value: ("detection", expected_times), - Dimensions.frequency.value: ("detection", expected_freqs), - } - expected_result = xr.DataArray( - expected_values, - coords=expected_coords, - dims="detection", - name="detection_value", - ) - xr.testing.assert_equal(actual_result, expected_result) diff --git a/tests/test_postprocessing/test_extraction.py b/tests/test_postprocessing/test_extraction.py index 125b597..b9de9cf 100644 --- a/tests/test_postprocessing/test_extraction.py +++ b/tests/test_postprocessing/test_extraction.py @@ -1,397 +1,2 @@ import numpy as np import pytest -import xarray as xr -from soundevent.arrays import Dimensions - -from batdetect2.postprocess.detection import extract_detections_from_array -from batdetect2.postprocess.extraction import ( - extract_detection_xr_dataset, - extract_values_at_positions, -) - - -@pytest.fixture -def sample_data_array(): - """Provides a basic 3x3 DataArray. - Top values: 0.9 (f=300, t=20), 0.8 (f=200, t=10), 0.7 (f=300, t=30) - """ - coords = { - Dimensions.frequency.value: [100, 200, 300], - Dimensions.time.value: [10, 20, 30], - } - array = xr.DataArray( - np.zeros([3, 3]), - coords=coords, - dims=[ - Dimensions.frequency.value, - Dimensions.time.value, - ], - ) - - array.loc[dict(time=10, frequency=100)] = 0.005 - array.loc[dict(time=10, frequency=200)] = 0.5 - array.loc[dict(time=10, frequency=300)] = 0.03 - array.loc[dict(time=20, frequency=100)] = 0.8 - array.loc[dict(time=20, frequency=200)] = 0.02 - array.loc[dict(time=20, frequency=300)] = 0.6 - array.loc[dict(time=30, frequency=100)] = 0.04 - array.loc[dict(time=30, frequency=200)] = 0.9 - array.loc[dict(time=30, frequency=300)] = 0.7 - return array - - -@pytest.fixture -def sample_array_for_extraction(): - """Provides a simple array (1-9) for value extraction tests.""" - data = np.arange(1, 10).reshape(3, 3) - coords = { - Dimensions.frequency.value: [100, 200, 300], - Dimensions.time.value: [10, 20, 30], - } - return xr.DataArray( - data, - coords=coords, - dims=[ - Dimensions.frequency.value, - Dimensions.time.value, - ], - name="test_values", - ) - - -@pytest.fixture -def sample_positions_top3(sample_data_array): - """Get top 3 detection positions from sample_data_array.""" - - return extract_detections_from_array( - sample_data_array, - max_detections=3, - threshold=None, - ) - - -@pytest.fixture -def sample_positions_top2(sample_data_array): - """Get top 2 detection positions from sample_data_array.""" - return extract_detections_from_array( - sample_data_array, - max_detections=2, - threshold=None, - ) - - -@pytest.fixture -def empty_positions(sample_data_array): - """Get an empty positions array (high threshold).""" - return extract_detections_from_array( - sample_data_array, - threshold=1.0, - ) - - -@pytest.fixture -def sample_sizes_array(sample_data_array): - """Provides a sample sizes array matching sample_data_array coords.""" - coords = sample_data_array.coords - data = np.array( - [ - [ - [0, 1, 2], - [3, 4, 5], - [6, 7, 8], - ], - [ - [9, 10, 11], - [12, 13, 14], - [15, 16, 17], - ], - ], - dtype=np.float32, - ) - - return xr.DataArray( - data, - coords={ - "dimension": ["width", "height"], - Dimensions.frequency.value: coords[Dimensions.frequency.value], - Dimensions.time.value: coords[Dimensions.time.value], - }, - dims=["dimension", Dimensions.frequency.value, Dimensions.time.value], - name="sizes", - ) - - -@pytest.fixture -def sample_classes_array(sample_data_array): - """Provides a sample classes array matching sample_data_array coords.""" - coords = sample_data_array.coords - data = np.linspace(0.1, 0.9, 18, dtype=np.float32).reshape(2, 3, 3) - return xr.DataArray( - data, - coords={ - "category": ["bat", "noise"], - Dimensions.frequency.value: coords[Dimensions.frequency.value], - Dimensions.time.value: coords[Dimensions.time.value], - }, - dims=["category", Dimensions.frequency.value, Dimensions.time.value], - name="class_scores", - ) - - -@pytest.fixture -def sample_features_array(sample_data_array): - """Provides a sample features array matching sample_data_array coords.""" - coords = sample_data_array.coords - data = np.arange(0, 36, dtype=np.float32).reshape(4, 3, 3) - return xr.DataArray( - data, - coords={ - "feature": ["f0", "f1", "f2", "f3"], - Dimensions.frequency.value: coords[Dimensions.frequency.value], - Dimensions.time.value: coords[Dimensions.time.value], - }, - dims=["feature", Dimensions.frequency.value, Dimensions.time.value], - name="features", - ) - - -def test_extract_values_at_positions_correct( - sample_array_for_extraction, - sample_positions_top3, -): - """Verify correct values are extracted based on positions coords.""" - expected_values = np.array( - [ - sample_array_for_extraction.sel(time=30, frequency=200).values, - sample_array_for_extraction.sel(time=20, frequency=100).values, - sample_array_for_extraction.sel(time=30, frequency=300).values, - ] - ) - - expected = xr.DataArray( - expected_values, - coords=sample_positions_top3.coords, - dims="detection", - name="test_values", - ) - - extracted = extract_values_at_positions( - sample_array_for_extraction, sample_positions_top3 - ) - - xr.testing.assert_allclose(extracted, expected) - - -def test_extract_values_at_positions_extra_dims( - sample_sizes_array, - sample_positions_top2, -): - """Test extraction preserves other dimensions in the source array.""" - times = np.array([30, 20]) - freqs = np.array([200, 100]) - - expected_values = np.array( - [ - sample_sizes_array.sel(time=30, frequency=200).values, - sample_sizes_array.sel(time=20, frequency=100).values, - ], - dtype=np.float32, - ) - - expected = xr.DataArray( - expected_values, - coords={ - "dimension": ["width", "height"], - Dimensions.frequency.value: ("detection", freqs), - Dimensions.time.value: ("detection", times), - }, - dims=["detection", "dimension"], - name="sizes", - ) - - extracted = extract_values_at_positions( - sample_sizes_array, - sample_positions_top2, - ) - - xr.testing.assert_equal(extracted, expected) - - -def test_extract_values_at_positions_empty( - sample_array_for_extraction, empty_positions -): - """Test extraction with empty positions returns empty array.""" - extracted = extract_values_at_positions( - sample_array_for_extraction, empty_positions - ) - assert extracted.sizes["detection"] == 0 - assert Dimensions.time.value in extracted.coords - assert Dimensions.frequency.value in extracted.coords - assert extracted.coords[Dimensions.time.value].size == 0 - assert extracted.coords[Dimensions.frequency.value].size == 0 - assert extracted.name == sample_array_for_extraction.name - - -def test_extract_values_at_positions_missing_coord_in_array( - sample_array_for_extraction, sample_positions_top2 -): - """Test error if source array misses required coordinates.""" - array_no_time = sample_array_for_extraction.copy() - del array_no_time.coords[Dimensions.time.value] - with pytest.raises(IndexError): - extract_values_at_positions(array_no_time, sample_positions_top2) - - array_no_freq = sample_array_for_extraction.copy() - del array_no_freq.coords[Dimensions.frequency.value] - with pytest.raises(IndexError): - extract_values_at_positions(array_no_freq, sample_positions_top2) - - -def test_extract_values_at_positions_missing_coord_in_positions( - sample_array_for_extraction, sample_positions_top2 -): - """Test error if positions array misses required coordinates.""" - positions_no_time = sample_positions_top2.copy() - del positions_no_time.coords[Dimensions.time.value] - with pytest.raises(KeyError): - extract_values_at_positions( - sample_array_for_extraction, positions_no_time - ) - - positions_no_freq = sample_positions_top2.copy() - del positions_no_freq.coords[Dimensions.frequency.value] - with pytest.raises(KeyError): - extract_values_at_positions( - sample_array_for_extraction, positions_no_freq - ) - - -def test_extract_values_at_positions_mismatched_coords( - sample_array_for_extraction, sample_positions_top2 -): - """Test error if positions requests coords not in source array.""" - bad_positions = sample_positions_top2.copy() - bad_positions.coords[Dimensions.time.value] = ( - "detection", - np.array([40, 10]), - ) - with pytest.raises(KeyError): - extract_values_at_positions(sample_array_for_extraction, bad_positions) - - -def test_extract_detection_xr_dataset_correct( - sample_positions_top2, - sample_sizes_array, - sample_classes_array, - sample_features_array, -): - """Tests extracting and bundling info for top 2 detections.""" - actual_dataset = extract_detection_xr_dataset( - sample_positions_top2, - sample_sizes_array, - sample_classes_array, - sample_features_array, - ) - - expected_times = np.array([30, 20]) - expected_freqs = np.array([200, 100]) - detection_coords = { - Dimensions.time.value: ("detection", expected_times), - Dimensions.frequency.value: ("detection", expected_freqs), - } - - expected_score = sample_positions_top2 - - expected_dimensions_data = np.array( - [ - sample_sizes_array.sel(time=30, frequency=200).values, - sample_sizes_array.sel(time=20, frequency=100).values, - ], - dtype=np.float32, - ) - expected_dimensions = xr.DataArray( - expected_dimensions_data, - coords={**detection_coords, "dimension": ["width", "height"]}, - dims=["detection", "dimension"], - name="dimensions", - ) - - expected_classes_data = np.array( - [ - sample_classes_array.sel(time=30, frequency=200).values, - sample_classes_array.sel(time=20, frequency=100).values, - ], - dtype=np.float32, - ) - expected_classes = xr.DataArray( - expected_classes_data, - coords={**detection_coords, "category": ["bat", "noise"]}, - dims=["detection", "category"], - name="classes", - ) - - expected_features_data = np.array( - [ - sample_features_array.sel(time=30, frequency=200).values, - sample_features_array.sel(time=20, frequency=100).values, - ], - dtype=np.float32, - ) - expected_features = xr.DataArray( - expected_features_data, - coords={**detection_coords, "feature": ["f0", "f1", "f2", "f3"]}, - dims=["detection", "feature"], - name="features", - ) - - expected_dataset = xr.Dataset( - { - "scores": expected_score, - "dimensions": expected_dimensions, - "classes": expected_classes, - "features": expected_features, - } - ) - expected_dataset = expected_dataset.assign_coords(detection_coords) - - xr.testing.assert_allclose(actual_dataset, expected_dataset) - - -def test_extract_detection_xr_dataset_empty( - empty_positions, - sample_sizes_array, - sample_classes_array, - sample_features_array, -): - """Test extraction with empty positions yields an empty dataset.""" - actual_dataset = extract_detection_xr_dataset( - empty_positions, - sample_sizes_array, - sample_classes_array, - sample_features_array, - ) - - assert isinstance(actual_dataset, xr.Dataset) - assert "detection" in actual_dataset.dims - assert actual_dataset.sizes["detection"] == 0 - - assert "scores" in actual_dataset - assert actual_dataset["scores"].dims == ("detection",) - assert actual_dataset["scores"].size == 0 - - assert "dimensions" in actual_dataset - assert actual_dataset["dimensions"].dims == ("detection", "dimension") - assert actual_dataset["dimensions"].shape == (0, 2) - - assert "classes" in actual_dataset - assert actual_dataset["classes"].dims == ("detection", "category") - assert actual_dataset["classes"].shape == (0, 2) - - assert "features" in actual_dataset - assert actual_dataset["features"].dims == ("detection", "feature") - assert actual_dataset["features"].shape == (0, 4) - - assert Dimensions.time.value in actual_dataset.coords - assert Dimensions.frequency.value in actual_dataset.coords - assert actual_dataset.coords[Dimensions.time.value].size == 0 - assert actual_dataset.coords[Dimensions.frequency.value].size == 0 diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index c363d0e..b5348bb 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -162,7 +162,8 @@ def test_selected_random_subclip_has_the_correct_width( subclip = select_subclip( original, - samplerate=256_000, + input_samplerate=256_000, + output_samplerate=1000, start=0, duration=0.512, ) diff --git a/tests/test_train/test_preprocessing.py b/tests/test_train/test_preprocessing.py index 6e74348..fc30501 100644 --- a/tests/test_train/test_preprocessing.py +++ b/tests/test_train/test_preprocessing.py @@ -39,9 +39,8 @@ def build_from_config( ) postprocessor = build_postprocessor( targets, + preprocessor=preprocessor, config=postprocessing_config, - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, ) return targets, preprocessor, labeller, postprocessor @@ -84,7 +83,10 @@ def test_encoding_decoding_roundtrip_recovers_object( clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) encoded = generate_train_example( - clip_annotation, sample_audio_loader, preprocessor, labeller + clip_annotation, + sample_audio_loader, + preprocessor, + labeller, ) predictions = postprocessor.get_predictions( ModelOutput(