diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index cf93b0c..eb7b19d 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -39,6 +39,7 @@ from batdetect2.configs import BaseConfig, load_config from batdetect2.models.types import ModelOutput from batdetect2.postprocess.decoding import ( DEFAULT_CLASSIFICATION_THRESHOLD, + convert_raw_prediction_to_sound_event_prediction, convert_raw_predictions_to_clip_prediction, convert_xr_dataset_to_raw_prediction, ) @@ -61,7 +62,11 @@ from batdetect2.postprocess.remapping import ( features_to_xarray, sizes_to_xarray, ) -from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction +from batdetect2.postprocess.types import ( + BatDetect2Prediction, + PostprocessorProtocol, + RawPrediction, +) from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.targets.types import TargetProtocol @@ -537,6 +542,27 @@ class Postprocessor(PostprocessorProtocol): for dataset in detection_datasets ] + def get_sound_event_predictions( + self, output: ModelOutput, clips: List[data.Clip] + ) -> List[List[BatDetect2Prediction]]: + raw_predictions = self.get_raw_predictions(output, clips) + return [ + [ + BatDetect2Prediction( + raw=raw, + sound_event_prediction=convert_raw_prediction_to_sound_event_prediction( + raw, + recording=clip.recording, + sound_event_decoder=self.targets.decode_class, + generic_class_tags=self.targets.generic_class_tags, + classification_threshold=self.config.classification_threshold, + ), + ) + for raw in predictions + ] + for predictions, clip in zip(raw_predictions, clips) + ] + def get_predictions( self, output: ModelOutput, clips: List[data.Clip] ) -> List[data.ClipPrediction]: diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index 6a5846c..5237d6c 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -95,7 +95,6 @@ def convert_xr_dataset_to_raw_prediction( for det_num in range(detection_dataset.sizes["detection"]): det_info = detection_dataset.sel(detection=det_num) - # TODO: Maybe clean this up highest_scoring_class = det_info.coords["category"][ det_info["classes"].argmax() ].item() diff --git a/src/batdetect2/postprocess/types.py b/src/batdetect2/postprocess/types.py index c34b57f..a15e4d8 100644 --- a/src/batdetect2/postprocess/types.py +++ b/src/batdetect2/postprocess/types.py @@ -11,6 +11,7 @@ modularity and consistent interaction between different parts of the BatDetect2 system that deal with model predictions. """ +from dataclasses import dataclass from typing import List, NamedTuple, Optional, Protocol import xarray as xr @@ -75,6 +76,12 @@ class RawPrediction(NamedTuple): features: xr.DataArray +@dataclass +class BatDetect2Prediction: + raw: RawPrediction + sound_event_prediction: data.SoundEventPrediction + + class PostprocessorProtocol(Protocol): """Protocol defining the interface for the full postprocessing pipeline. @@ -254,6 +261,10 @@ class PostprocessorProtocol(Protocol): """ ... + def get_sound_event_predictions( + self, output: ModelOutput, clips: List[data.Clip] + ) -> List[List[BatDetect2Prediction]]: ... + def get_predictions( self, output: ModelOutput,