Formatting

This commit is contained in:
mbsantiago 2025-04-03 16:49:11 +01:00
parent ff00da9a9a
commit c2c4ac53fd

View File

@ -1,18 +1,20 @@
"""Module for postprocessing model outputs."""
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
from pydantic import BaseModel, Field
from pydantic import Field
from soundevent import data
from torch import nn
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.typing import ModelOutput
__all__ = [
"postprocess_model_outputs",
"PostprocessConfig",
"load_postprocess_config",
"postprocess_model_outputs",
]
NMS_KERNEL_SIZE = 9
@ -20,7 +22,7 @@ DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 200
class PostprocessConfig(BaseModel):
class PostprocessConfig(BaseConfig):
"""Configuration for postprocessing model outputs."""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
@ -30,7 +32,14 @@ class PostprocessConfig(BaseModel):
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
class RawPrediction(BaseModel):
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
return load_config(path, schema=PostprocessConfig, field=field)
class RawPrediction(NamedTuple):
start_time: float
end_time: float
low_freq: float
@ -211,12 +220,6 @@ def compute_predictions_from_outputs(
return predictions
def convert_raw_prediction_to_soundevent(
prediction: RawPrediction,
) -> data.SoundEventPrediction:
pass
def compute_sound_events_from_outputs(
clip: data.Clip,
scores: torch.Tensor,