mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Formatting
This commit is contained in:
parent
ff00da9a9a
commit
c2c4ac53fd
@ -1,18 +1,20 @@
|
||||
"""Module for postprocessing model outputs."""
|
||||
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"postprocess_model_outputs",
|
||||
"PostprocessConfig",
|
||||
"load_postprocess_config",
|
||||
"postprocess_model_outputs",
|
||||
]
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
@ -20,7 +22,7 @@ DETECTION_THRESHOLD = 0.01
|
||||
TOP_K_PER_SEC = 200
|
||||
|
||||
|
||||
class PostprocessConfig(BaseModel):
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration for postprocessing model outputs."""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
@ -30,7 +32,14 @@ class PostprocessConfig(BaseModel):
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
class RawPrediction(BaseModel):
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
@ -211,12 +220,6 @@ def compute_predictions_from_outputs(
|
||||
return predictions
|
||||
|
||||
|
||||
def convert_raw_prediction_to_soundevent(
|
||||
prediction: RawPrediction,
|
||||
) -> data.SoundEventPrediction:
|
||||
pass
|
||||
|
||||
|
||||
def compute_sound_events_from_outputs(
|
||||
clip: data.Clip,
|
||||
scores: torch.Tensor,
|
||||
|
Loading…
Reference in New Issue
Block a user