Added BatDetect2Prediction wrapper and method

This commit is contained in:
mbsantiago 2025-08-08 13:06:10 +01:00
parent 6213238585
commit aaec66c15e
3 changed files with 38 additions and 2 deletions

View File

@ -39,6 +39,7 @@ from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction,
)
@ -61,7 +62,11 @@ from batdetect2.postprocess.remapping import (
features_to_xarray,
sizes_to_xarray,
)
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
from batdetect2.postprocess.types import (
BatDetect2Prediction,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets.types import TargetProtocol
@ -537,6 +542,27 @@ class Postprocessor(PostprocessorProtocol):
for dataset in detection_datasets
]
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[data.ClipPrediction]:

View File

@ -95,7 +95,6 @@ def convert_xr_dataset_to_raw_prediction(
for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num)
# TODO: Maybe clean this up
highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax()
].item()

View File

@ -11,6 +11,7 @@ modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions.
"""
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol
import xarray as xr
@ -75,6 +76,12 @@ class RawPrediction(NamedTuple):
features: xr.DataArray
@dataclass
class BatDetect2Prediction:
raw: RawPrediction
sound_event_prediction: data.SoundEventPrediction
class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline.
@ -254,6 +261,10 @@ class PostprocessorProtocol(Protocol):
"""
...
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]: ...
def get_predictions(
self,
output: ModelOutput,