diff --git a/batdetect2/post_process.py b/batdetect2/post_process.py index 3b1df2d..85bb546 100644 --- a/batdetect2/post_process.py +++ b/batdetect2/post_process.py @@ -1,18 +1,20 @@ """Module for postprocessing model outputs.""" -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union import numpy as np import torch -from pydantic import BaseModel, Field +from pydantic import Field from soundevent import data from torch import nn +from batdetect2.configs import BaseConfig, load_config from batdetect2.models.typing import ModelOutput __all__ = [ - "postprocess_model_outputs", "PostprocessConfig", + "load_postprocess_config", + "postprocess_model_outputs", ] NMS_KERNEL_SIZE = 9 @@ -20,7 +22,7 @@ DETECTION_THRESHOLD = 0.01 TOP_K_PER_SEC = 200 -class PostprocessConfig(BaseModel): +class PostprocessConfig(BaseConfig): """Configuration for postprocessing model outputs.""" nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0) @@ -30,7 +32,14 @@ class PostprocessConfig(BaseModel): top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) -class RawPrediction(BaseModel): +def load_postprocess_config( + path: data.PathLike, + field: Optional[str] = None, +) -> PostprocessConfig: + return load_config(path, schema=PostprocessConfig, field=field) + + +class RawPrediction(NamedTuple): start_time: float end_time: float low_freq: float @@ -211,12 +220,6 @@ def compute_predictions_from_outputs( return predictions -def convert_raw_prediction_to_soundevent( - prediction: RawPrediction, -) -> data.SoundEventPrediction: - pass - - def compute_sound_events_from_outputs( clip: data.Clip, scores: torch.Tensor,