diff --git a/batdetect2/models/backbones.py b/batdetect2/models/backbones.py index 5bdf5e4..d5dadee 100644 --- a/batdetect2/models/backbones.py +++ b/batdetect2/models/backbones.py @@ -1,17 +1,18 @@ -from typing import Sequence, Tuple +from enum import Enum +from typing import Optional, Sequence, Tuple import torch import torch.nn.functional as F +from soundevent import data +from batdetect2.configs import BaseConfig, load_config from batdetect2.models.blocks import ( ConvBlock, - Decoder, - DownscalingLayer, - Encoder, SelfAttention, - UpscalingLayer, VerticalConv, ) +from batdetect2.models.decoder import Decoder, UpscalingLayer +from batdetect2.models.encoder import DownscalingLayer, Encoder from batdetect2.models.types import BackboneModel __all__ = [ @@ -183,3 +184,70 @@ def restore_pad( x = x[:, :, :, :-w_pad] return x + + +class ModelType(str, Enum): + Net2DFast = "Net2DFast" + Net2DFastNoAttn = "Net2DFastNoAttn" + Net2DFastNoCoordConv = "Net2DFastNoCoordConv" + Net2DPlain = "Net2DPlain" + + +class BackboneConfig(BaseConfig): + backbone_type: ModelType = ModelType.Net2DFast + input_height: int = 128 + encoder_channels: Tuple[int, ...] = (1, 32, 64, 128) + bottleneck_channels: int = 256 + decoder_channels: Tuple[int, ...] = (256, 64, 32, 32) + out_channels: int = 32 + + +def load_backbone_config( + path: data.PathLike, + field: Optional[str] = None, +) -> BackboneConfig: + return load_config(path, schema=BackboneConfig, field=field) + + +def build_model_backbone( + config: Optional[BackboneConfig] = None, +) -> BackboneModel: + config = config or BackboneConfig() + + if config.backbone_type == ModelType.Net2DFast: + return Net2DFast( + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, + ) + + if config.backbone_type == ModelType.Net2DFastNoAttn: + return Net2DFastNoAttn( + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, + ) + + if config.backbone_type == ModelType.Net2DFastNoCoordConv: + return Net2DFastNoCoordConv( + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, + ) + + if config.backbone_type == ModelType.Net2DPlain: + return Net2DPlain( + input_height=config.input_height, + encoder_channels=config.encoder_channels, + bottleneck_channels=config.bottleneck_channels, + decoder_channels=config.decoder_channels, + out_channels=config.out_channels, + ) + + raise ValueError(f"Unknown model type: {config.backbone_type}") diff --git a/batdetect2/models/build.py b/batdetect2/models/build.py deleted file mode 100644 index 474002c..0000000 --- a/batdetect2/models/build.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Optional - -from batdetect2.models.backbones import ( - Net2DFast, - Net2DFastNoAttn, - Net2DFastNoCoordConv, - Net2DPlain, -) -from batdetect2.models.config import ModelConfig, ModelType -from batdetect2.models.types import BackboneModel - -__all__ = [ - "build_architecture", -] - - -def build_architecture( - config: Optional[ModelConfig] = None, -) -> BackboneModel: - config = config or ModelConfig() - - if config.name == ModelType.Net2DFast: - return Net2DFast( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) - - if config.name == ModelType.Net2DFastNoAttn: - return Net2DFastNoAttn( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) - - if config.name == ModelType.Net2DFastNoCoordConv: - return Net2DFastNoCoordConv( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) - - if config.name == ModelType.Net2DPlain: - return Net2DPlain( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) - - raise ValueError(f"Unknown model type: {config.name}") diff --git a/batdetect2/models/config.py b/batdetect2/models/config.py deleted file mode 100644 index 1feef1d..0000000 --- a/batdetect2/models/config.py +++ /dev/null @@ -1,34 +0,0 @@ -from enum import Enum -from typing import Optional, Tuple - -from soundevent.data import PathLike - -from batdetect2.configs import BaseConfig, load_config - -__all__ = [ - "ModelType", - "ModelConfig", - "load_model_config", -] - - -class ModelType(str, Enum): - Net2DFast = "Net2DFast" - Net2DFastNoAttn = "Net2DFastNoAttn" - Net2DFastNoCoordConv = "Net2DFastNoCoordConv" - Net2DPlain = "Net2DPlain" - - -class ModelConfig(BaseConfig): - name: ModelType = ModelType.Net2DFast - input_height: int = 128 - encoder_channels: Tuple[int, ...] = (1, 32, 64, 128) - bottleneck_channels: int = 256 - decoder_channels: Tuple[int, ...] = (256, 64, 32, 32) - out_channels: int = 32 - - -def load_model_config( - path: PathLike, field: Optional[str] = None -) -> ModelConfig: - return load_config(path, schema=ModelConfig, field=field)