Added BatDetect2Prediction wrapper and method

This commit is contained in:
mbsantiago 2025-08-08 13:06:10 +01:00
parent 6213238585
commit aaec66c15e
3 changed files with 38 additions and 2 deletions

View File

@ -39,6 +39,7 @@ from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.postprocess.decoding import ( from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD, DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction, convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction, convert_xr_dataset_to_raw_prediction,
) )
@ -61,7 +62,11 @@ from batdetect2.postprocess.remapping import (
features_to_xarray, features_to_xarray,
sizes_to_xarray, sizes_to_xarray,
) )
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction from batdetect2.postprocess.types import (
BatDetect2Prediction,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
@ -537,6 +542,27 @@ class Postprocessor(PostprocessorProtocol):
for dataset in detection_datasets for dataset in detection_datasets
] ]
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions( def get_predictions(
self, output: ModelOutput, clips: List[data.Clip] self, output: ModelOutput, clips: List[data.Clip]
) -> List[data.ClipPrediction]: ) -> List[data.ClipPrediction]:

View File

@ -95,7 +95,6 @@ def convert_xr_dataset_to_raw_prediction(
for det_num in range(detection_dataset.sizes["detection"]): for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num) det_info = detection_dataset.sel(detection=det_num)
# TODO: Maybe clean this up
highest_scoring_class = det_info.coords["category"][ highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax() det_info["classes"].argmax()
].item() ].item()

View File

@ -11,6 +11,7 @@ modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions. system that deal with model predictions.
""" """
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol from typing import List, NamedTuple, Optional, Protocol
import xarray as xr import xarray as xr
@ -75,6 +76,12 @@ class RawPrediction(NamedTuple):
features: xr.DataArray features: xr.DataArray
@dataclass
class BatDetect2Prediction:
raw: RawPrediction
sound_event_prediction: data.SoundEventPrediction
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline. """Protocol defining the interface for the full postprocessing pipeline.
@ -254,6 +261,10 @@ class PostprocessorProtocol(Protocol):
""" """
... ...
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]: ...
def get_predictions( def get_predictions(
self, self,
output: ModelOutput, output: ModelOutput,