Remove unnecessary config and build modules

This commit is contained in:
mbsantiago 2025-04-21 15:28:47 +01:00
parent 6c744eaac5
commit ffa4c2e5e9
3 changed files with 73 additions and 97 deletions

View File

@ -1,17 +1,18 @@
from typing import Sequence, Tuple
from enum import Enum
from typing import Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.blocks import (
ConvBlock,
Decoder,
DownscalingLayer,
Encoder,
SelfAttention,
UpscalingLayer,
VerticalConv,
)
from batdetect2.models.decoder import Decoder, UpscalingLayer
from batdetect2.models.encoder import DownscalingLayer, Encoder
from batdetect2.models.types import BackboneModel
__all__ = [
@ -183,3 +184,70 @@ def restore_pad(
x = x[:, :, :, :-w_pad]
return x
class ModelType(str, Enum):
Net2DFast = "Net2DFast"
Net2DFastNoAttn = "Net2DFastNoAttn"
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
Net2DPlain = "Net2DPlain"
class BackboneConfig(BaseConfig):
backbone_type: ModelType = ModelType.Net2DFast
input_height: int = 128
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
bottleneck_channels: int = 256
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
return load_config(path, schema=BackboneConfig, field=field)
def build_model_backbone(
config: Optional[BackboneConfig] = None,
) -> BackboneModel:
config = config or BackboneConfig()
if config.backbone_type == ModelType.Net2DFast:
return Net2DFast(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.backbone_type == ModelType.Net2DFastNoAttn:
return Net2DFastNoAttn(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.backbone_type == ModelType.Net2DFastNoCoordConv:
return Net2DFastNoCoordConv(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.backbone_type == ModelType.Net2DPlain:
return Net2DPlain(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
raise ValueError(f"Unknown model type: {config.backbone_type}")

View File

@ -1,58 +0,0 @@
from typing import Optional
from batdetect2.models.backbones import (
Net2DFast,
Net2DFastNoAttn,
Net2DFastNoCoordConv,
Net2DPlain,
)
from batdetect2.models.config import ModelConfig, ModelType
from batdetect2.models.types import BackboneModel
__all__ = [
"build_architecture",
]
def build_architecture(
config: Optional[ModelConfig] = None,
) -> BackboneModel:
config = config or ModelConfig()
if config.name == ModelType.Net2DFast:
return Net2DFast(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DFastNoAttn:
return Net2DFastNoAttn(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DFastNoCoordConv:
return Net2DFastNoCoordConv(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DPlain:
return Net2DPlain(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
raise ValueError(f"Unknown model type: {config.name}")

View File

@ -1,34 +0,0 @@
from enum import Enum
from typing import Optional, Tuple
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
__all__ = [
"ModelType",
"ModelConfig",
"load_model_config",
]
class ModelType(str, Enum):
Net2DFast = "Net2DFast"
Net2DFastNoAttn = "Net2DFastNoAttn"
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
Net2DPlain = "Net2DPlain"
class ModelConfig(BaseConfig):
name: ModelType = ModelType.Net2DFast
input_height: int = 128
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
bottleneck_channels: int = 256
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
out_channels: int = 32
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)