mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Formatting
This commit is contained in:
parent
ff00da9a9a
commit
c2c4ac53fd
@ -1,18 +1,20 @@
|
|||||||
"""Module for postprocessing model outputs."""
|
"""Module for postprocessing model outputs."""
|
||||||
|
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.typing import ModelOutput
|
from batdetect2.models.typing import ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"postprocess_model_outputs",
|
|
||||||
"PostprocessConfig",
|
"PostprocessConfig",
|
||||||
|
"load_postprocess_config",
|
||||||
|
"postprocess_model_outputs",
|
||||||
]
|
]
|
||||||
|
|
||||||
NMS_KERNEL_SIZE = 9
|
NMS_KERNEL_SIZE = 9
|
||||||
@ -20,7 +22,7 @@ DETECTION_THRESHOLD = 0.01
|
|||||||
TOP_K_PER_SEC = 200
|
TOP_K_PER_SEC = 200
|
||||||
|
|
||||||
|
|
||||||
class PostprocessConfig(BaseModel):
|
class PostprocessConfig(BaseConfig):
|
||||||
"""Configuration for postprocessing model outputs."""
|
"""Configuration for postprocessing model outputs."""
|
||||||
|
|
||||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||||
@ -30,7 +32,14 @@ class PostprocessConfig(BaseModel):
|
|||||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||||
|
|
||||||
|
|
||||||
class RawPrediction(BaseModel):
|
def load_postprocess_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> PostprocessConfig:
|
||||||
|
return load_config(path, schema=PostprocessConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
class RawPrediction(NamedTuple):
|
||||||
start_time: float
|
start_time: float
|
||||||
end_time: float
|
end_time: float
|
||||||
low_freq: float
|
low_freq: float
|
||||||
@ -211,12 +220,6 @@ def compute_predictions_from_outputs(
|
|||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def convert_raw_prediction_to_soundevent(
|
|
||||||
prediction: RawPrediction,
|
|
||||||
) -> data.SoundEventPrediction:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def compute_sound_events_from_outputs(
|
def compute_sound_events_from_outputs(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
scores: torch.Tensor,
|
scores: torch.Tensor,
|
||||||
|
Loading…
Reference in New Issue
Block a user