mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
35 lines
838 B
Python
35 lines
838 B
Python
from enum import Enum
|
|
from typing import Optional, Tuple
|
|
|
|
from soundevent.data import PathLike
|
|
|
|
from batdetect2.configs import BaseConfig, load_config
|
|
|
|
__all__ = [
|
|
"ModelType",
|
|
"ModelConfig",
|
|
"load_model_config",
|
|
]
|
|
|
|
|
|
class ModelType(str, Enum):
|
|
Net2DFast = "Net2DFast"
|
|
Net2DFastNoAttn = "Net2DFastNoAttn"
|
|
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
|
Net2DPlain = "Net2DPlain"
|
|
|
|
|
|
class ModelConfig(BaseConfig):
|
|
name: ModelType = ModelType.Net2DFast
|
|
input_height: int = 128
|
|
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
|
bottleneck_channels: int = 256
|
|
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
|
out_channels: int = 32
|
|
|
|
|
|
def load_model_config(
|
|
path: PathLike, field: Optional[str] = None
|
|
) -> ModelConfig:
|
|
return load_config(path, schema=ModelConfig, field=field)
|