mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Remove unnecessary config and build modules
This commit is contained in:
parent
6c744eaac5
commit
ffa4c2e5e9
@ -1,17 +1,18 @@
|
|||||||
from typing import Sequence, Tuple
|
from enum import Enum
|
||||||
|
from typing import Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvBlock,
|
ConvBlock,
|
||||||
Decoder,
|
|
||||||
DownscalingLayer,
|
|
||||||
Encoder,
|
|
||||||
SelfAttention,
|
SelfAttention,
|
||||||
UpscalingLayer,
|
|
||||||
VerticalConv,
|
VerticalConv,
|
||||||
)
|
)
|
||||||
|
from batdetect2.models.decoder import Decoder, UpscalingLayer
|
||||||
|
from batdetect2.models.encoder import DownscalingLayer, Encoder
|
||||||
from batdetect2.models.types import BackboneModel
|
from batdetect2.models.types import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -183,3 +184,70 @@ def restore_pad(
|
|||||||
x = x[:, :, :, :-w_pad]
|
x = x[:, :, :, :-w_pad]
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
Net2DFast = "Net2DFast"
|
||||||
|
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||||
|
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||||
|
Net2DPlain = "Net2DPlain"
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneConfig(BaseConfig):
|
||||||
|
backbone_type: ModelType = ModelType.Net2DFast
|
||||||
|
input_height: int = 128
|
||||||
|
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||||
|
bottleneck_channels: int = 256
|
||||||
|
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||||
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
def load_backbone_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> BackboneConfig:
|
||||||
|
return load_config(path, schema=BackboneConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_backbone(
|
||||||
|
config: Optional[BackboneConfig] = None,
|
||||||
|
) -> BackboneModel:
|
||||||
|
config = config or BackboneConfig()
|
||||||
|
|
||||||
|
if config.backbone_type == ModelType.Net2DFast:
|
||||||
|
return Net2DFast(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.backbone_type == ModelType.Net2DFastNoAttn:
|
||||||
|
return Net2DFastNoAttn(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.backbone_type == ModelType.Net2DFastNoCoordConv:
|
||||||
|
return Net2DFastNoCoordConv(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.backbone_type == ModelType.Net2DPlain:
|
||||||
|
return Net2DPlain(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder_channels=config.encoder_channels,
|
||||||
|
bottleneck_channels=config.bottleneck_channels,
|
||||||
|
decoder_channels=config.decoder_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown model type: {config.backbone_type}")
|
||||||
|
@ -1,58 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from batdetect2.models.backbones import (
|
|
||||||
Net2DFast,
|
|
||||||
Net2DFastNoAttn,
|
|
||||||
Net2DFastNoCoordConv,
|
|
||||||
Net2DPlain,
|
|
||||||
)
|
|
||||||
from batdetect2.models.config import ModelConfig, ModelType
|
|
||||||
from batdetect2.models.types import BackboneModel
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"build_architecture",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_architecture(
|
|
||||||
config: Optional[ModelConfig] = None,
|
|
||||||
) -> BackboneModel:
|
|
||||||
config = config or ModelConfig()
|
|
||||||
|
|
||||||
if config.name == ModelType.Net2DFast:
|
|
||||||
return Net2DFast(
|
|
||||||
input_height=config.input_height,
|
|
||||||
encoder_channels=config.encoder_channels,
|
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
|
||||||
decoder_channels=config.decoder_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == ModelType.Net2DFastNoAttn:
|
|
||||||
return Net2DFastNoAttn(
|
|
||||||
input_height=config.input_height,
|
|
||||||
encoder_channels=config.encoder_channels,
|
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
|
||||||
decoder_channels=config.decoder_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == ModelType.Net2DFastNoCoordConv:
|
|
||||||
return Net2DFastNoCoordConv(
|
|
||||||
input_height=config.input_height,
|
|
||||||
encoder_channels=config.encoder_channels,
|
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
|
||||||
decoder_channels=config.decoder_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == ModelType.Net2DPlain:
|
|
||||||
return Net2DPlain(
|
|
||||||
input_height=config.input_height,
|
|
||||||
encoder_channels=config.encoder_channels,
|
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
|
||||||
decoder_channels=config.decoder_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown model type: {config.name}")
|
|
@ -1,34 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from soundevent.data import PathLike
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ModelType",
|
|
||||||
"ModelConfig",
|
|
||||||
"load_model_config",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
|
||||||
Net2DFast = "Net2DFast"
|
|
||||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
|
||||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
|
||||||
Net2DPlain = "Net2DPlain"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseConfig):
|
|
||||||
name: ModelType = ModelType.Net2DFast
|
|
||||||
input_height: int = 128
|
|
||||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
|
||||||
bottleneck_channels: int = 256
|
|
||||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
|
||||||
out_channels: int = 32
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(
|
|
||||||
path: PathLike, field: Optional[str] = None
|
|
||||||
) -> ModelConfig:
|
|
||||||
return load_config(path, schema=ModelConfig, field=field)
|
|
Loading…
Reference in New Issue
Block a user