Formatting

This commit is contained in:
mbsantiago 2025-04-03 16:49:11 +01:00
parent ff00da9a9a
commit c2c4ac53fd

View File

@ -1,18 +1,20 @@
"""Module for postprocessing model outputs.""" """Module for postprocessing model outputs."""
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from pydantic import BaseModel, Field from pydantic import Field
from soundevent import data from soundevent import data
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.typing import ModelOutput from batdetect2.models.typing import ModelOutput
__all__ = [ __all__ = [
"postprocess_model_outputs",
"PostprocessConfig", "PostprocessConfig",
"load_postprocess_config",
"postprocess_model_outputs",
] ]
NMS_KERNEL_SIZE = 9 NMS_KERNEL_SIZE = 9
@ -20,7 +22,7 @@ DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 200 TOP_K_PER_SEC = 200
class PostprocessConfig(BaseModel): class PostprocessConfig(BaseConfig):
"""Configuration for postprocessing model outputs.""" """Configuration for postprocessing model outputs."""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0) nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
@ -30,7 +32,14 @@ class PostprocessConfig(BaseModel):
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
class RawPrediction(BaseModel): def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
return load_config(path, schema=PostprocessConfig, field=field)
class RawPrediction(NamedTuple):
start_time: float start_time: float
end_time: float end_time: float
low_freq: float low_freq: float
@ -211,12 +220,6 @@ def compute_predictions_from_outputs(
return predictions return predictions
def convert_raw_prediction_to_soundevent(
prediction: RawPrediction,
) -> data.SoundEventPrediction:
pass
def compute_sound_events_from_outputs( def compute_sound_events_from_outputs(
clip: data.Clip, clip: data.Clip,
scores: torch.Tensor, scores: torch.Tensor,