batdetect2/batdetect2/models/config.py
2025-04-12 16:48:40 +01:00

35 lines
838 B
Python

from enum import Enum
from typing import Optional, Tuple
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
__all__ = [
"ModelType",
"ModelConfig",
"load_model_config",
]
class ModelType(str, Enum):
Net2DFast = "Net2DFast"
Net2DFastNoAttn = "Net2DFastNoAttn"
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
Net2DPlain = "Net2DPlain"
class ModelConfig(BaseConfig):
name: ModelType = ModelType.Net2DFast
input_height: int = 128
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
bottleneck_channels: int = 256
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
out_channels: int = 32
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)