mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Added BatDetect2Prediction wrapper and method
This commit is contained in:
parent
6213238585
commit
aaec66c15e
@ -39,6 +39,7 @@ from batdetect2.configs import BaseConfig, load_config
|
|||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.postprocess.decoding import (
|
from batdetect2.postprocess.decoding import (
|
||||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
convert_raw_prediction_to_sound_event_prediction,
|
||||||
convert_raw_predictions_to_clip_prediction,
|
convert_raw_predictions_to_clip_prediction,
|
||||||
convert_xr_dataset_to_raw_prediction,
|
convert_xr_dataset_to_raw_prediction,
|
||||||
)
|
)
|
||||||
@ -61,7 +62,11 @@ from batdetect2.postprocess.remapping import (
|
|||||||
features_to_xarray,
|
features_to_xarray,
|
||||||
sizes_to_xarray,
|
sizes_to_xarray,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
|
from batdetect2.postprocess.types import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
)
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
@ -537,6 +542,27 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
for dataset in detection_datasets
|
for dataset in detection_datasets
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_sound_event_predictions(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[List[BatDetect2Prediction]]:
|
||||||
|
raw_predictions = self.get_raw_predictions(output, clips)
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
BatDetect2Prediction(
|
||||||
|
raw=raw,
|
||||||
|
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||||
|
raw,
|
||||||
|
recording=clip.recording,
|
||||||
|
sound_event_decoder=self.targets.decode_class,
|
||||||
|
generic_class_tags=self.targets.generic_class_tags,
|
||||||
|
classification_threshold=self.config.classification_threshold,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for raw in predictions
|
||||||
|
]
|
||||||
|
for predictions, clip in zip(raw_predictions, clips)
|
||||||
|
]
|
||||||
|
|
||||||
def get_predictions(
|
def get_predictions(
|
||||||
self, output: ModelOutput, clips: List[data.Clip]
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
) -> List[data.ClipPrediction]:
|
) -> List[data.ClipPrediction]:
|
||||||
|
|||||||
@ -95,7 +95,6 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
for det_num in range(detection_dataset.sizes["detection"]):
|
for det_num in range(detection_dataset.sizes["detection"]):
|
||||||
det_info = detection_dataset.sel(detection=det_num)
|
det_info = detection_dataset.sel(detection=det_num)
|
||||||
|
|
||||||
# TODO: Maybe clean this up
|
|
||||||
highest_scoring_class = det_info.coords["category"][
|
highest_scoring_class = det_info.coords["category"][
|
||||||
det_info["classes"].argmax()
|
det_info["classes"].argmax()
|
||||||
].item()
|
].item()
|
||||||
|
|||||||
@ -11,6 +11,7 @@ modularity and consistent interaction between different parts of the BatDetect2
|
|||||||
system that deal with model predictions.
|
system that deal with model predictions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import List, NamedTuple, Optional, Protocol
|
from typing import List, NamedTuple, Optional, Protocol
|
||||||
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
@ -75,6 +76,12 @@ class RawPrediction(NamedTuple):
|
|||||||
features: xr.DataArray
|
features: xr.DataArray
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatDetect2Prediction:
|
||||||
|
raw: RawPrediction
|
||||||
|
sound_event_prediction: data.SoundEventPrediction
|
||||||
|
|
||||||
|
|
||||||
class PostprocessorProtocol(Protocol):
|
class PostprocessorProtocol(Protocol):
|
||||||
"""Protocol defining the interface for the full postprocessing pipeline.
|
"""Protocol defining the interface for the full postprocessing pipeline.
|
||||||
|
|
||||||
@ -254,6 +261,10 @@ class PostprocessorProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def get_sound_event_predictions(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[List[BatDetect2Prediction]]: ...
|
||||||
|
|
||||||
def get_predictions(
|
def get_predictions(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user