mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Remove unnecessary config and build modules
This commit is contained in:
parent
6c744eaac5
commit
ffa4c2e5e9
@ -1,17 +1,18 @@
|
||||
from typing import Sequence, Tuple
|
||||
from enum import Enum
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.blocks import (
|
||||
ConvBlock,
|
||||
Decoder,
|
||||
DownscalingLayer,
|
||||
Encoder,
|
||||
SelfAttention,
|
||||
UpscalingLayer,
|
||||
VerticalConv,
|
||||
)
|
||||
from batdetect2.models.decoder import Decoder, UpscalingLayer
|
||||
from batdetect2.models.encoder import DownscalingLayer, Encoder
|
||||
from batdetect2.models.types import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
@ -183,3 +184,70 @@ def restore_pad(
|
||||
x = x[:, :, :, :-w_pad]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Net2DFast = "Net2DFast"
|
||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||
Net2DPlain = "Net2DPlain"
|
||||
|
||||
|
||||
class BackboneConfig(BaseConfig):
|
||||
backbone_type: ModelType = ModelType.Net2DFast
|
||||
input_height: int = 128
|
||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||
bottleneck_channels: int = 256
|
||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BackboneConfig:
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
|
||||
|
||||
def build_model_backbone(
|
||||
config: Optional[BackboneConfig] = None,
|
||||
) -> BackboneModel:
|
||||
config = config or BackboneConfig()
|
||||
|
||||
if config.backbone_type == ModelType.Net2DFast:
|
||||
return Net2DFast(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.backbone_type == ModelType.Net2DFastNoAttn:
|
||||
return Net2DFastNoAttn(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.backbone_type == ModelType.Net2DFastNoCoordConv:
|
||||
return Net2DFastNoCoordConv(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.backbone_type == ModelType.Net2DPlain:
|
||||
return Net2DPlain(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown model type: {config.backbone_type}")
|
||||
|
@ -1,58 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from batdetect2.models.backbones import (
|
||||
Net2DFast,
|
||||
Net2DFastNoAttn,
|
||||
Net2DFastNoCoordConv,
|
||||
Net2DPlain,
|
||||
)
|
||||
from batdetect2.models.config import ModelConfig, ModelType
|
||||
from batdetect2.models.types import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"build_architecture",
|
||||
]
|
||||
|
||||
|
||||
def build_architecture(
|
||||
config: Optional[ModelConfig] = None,
|
||||
) -> BackboneModel:
|
||||
config = config or ModelConfig()
|
||||
|
||||
if config.name == ModelType.Net2DFast:
|
||||
return Net2DFast(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoAttn:
|
||||
return Net2DFastNoAttn(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoCoordConv:
|
||||
return Net2DFastNoCoordConv(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DPlain:
|
||||
return Net2DPlain(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown model type: {config.name}")
|
@ -1,34 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
|
||||
__all__ = [
|
||||
"ModelType",
|
||||
"ModelConfig",
|
||||
"load_model_config",
|
||||
]
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Net2DFast = "Net2DFast"
|
||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||
Net2DPlain = "Net2DPlain"
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
name: ModelType = ModelType.Net2DFast
|
||||
input_height: int = 128
|
||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||
bottleneck_channels: int = 256
|
||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_model_config(
|
||||
path: PathLike, field: Optional[str] = None
|
||||
) -> ModelConfig:
|
||||
return load_config(path, schema=ModelConfig, field=field)
|
Loading…
Reference in New Issue
Block a user