mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Added BatDetect2Prediction wrapper and method
This commit is contained in:
parent
6213238585
commit
aaec66c15e
@ -39,6 +39,7 @@ from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
convert_xr_dataset_to_raw_prediction,
|
||||
)
|
||||
@ -61,7 +62,11 @@ from batdetect2.postprocess.remapping import (
|
||||
features_to_xarray,
|
||||
sizes_to_xarray,
|
||||
)
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
|
||||
from batdetect2.postprocess.types import (
|
||||
BatDetect2Prediction,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
@ -537,6 +542,27 @@ class Postprocessor(PostprocessorProtocol):
|
||||
for dataset in detection_datasets
|
||||
]
|
||||
|
||||
def get_sound_event_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = self.get_raw_predictions(output, clips)
|
||||
return [
|
||||
[
|
||||
BatDetect2Prediction(
|
||||
raw=raw,
|
||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||
raw,
|
||||
recording=clip.recording,
|
||||
sound_event_decoder=self.targets.decode_class,
|
||||
generic_class_tags=self.targets.generic_class_tags,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
),
|
||||
)
|
||||
for raw in predictions
|
||||
]
|
||||
for predictions, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
def get_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[data.ClipPrediction]:
|
||||
|
||||
@ -95,7 +95,6 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
for det_num in range(detection_dataset.sizes["detection"]):
|
||||
det_info = detection_dataset.sel(detection=det_num)
|
||||
|
||||
# TODO: Maybe clean this up
|
||||
highest_scoring_class = det_info.coords["category"][
|
||||
det_info["classes"].argmax()
|
||||
].item()
|
||||
|
||||
@ -11,6 +11,7 @@ modularity and consistent interaction between different parts of the BatDetect2
|
||||
system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Protocol
|
||||
|
||||
import xarray as xr
|
||||
@ -75,6 +76,12 @@ class RawPrediction(NamedTuple):
|
||||
features: xr.DataArray
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatDetect2Prediction:
|
||||
raw: RawPrediction
|
||||
sound_event_prediction: data.SoundEventPrediction
|
||||
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
"""Protocol defining the interface for the full postprocessing pipeline.
|
||||
|
||||
@ -254,6 +261,10 @@ class PostprocessorProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_sound_event_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[List[BatDetect2Prediction]]: ...
|
||||
|
||||
def get_predictions(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user