mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Rename BlockGroupConfig to LayerGroupConfig
This commit is contained in:
parent
cbb02cf69e
commit
166dad20bd
@ -27,9 +27,23 @@ from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.blocks import ConvBlock
|
||||
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck
|
||||
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder
|
||||
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
Decoder,
|
||||
DecoderConfig,
|
||||
build_decoder,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
Encoder,
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.types import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
@ -186,9 +200,9 @@ class BackboneConfig(BaseConfig):
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: Optional[EncoderConfig] = None
|
||||
bottleneck: Optional[BottleneckConfig] = None
|
||||
decoder: Optional[DecoderConfig] = None
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"ConvBlock",
|
||||
"BlockGroupConfig",
|
||||
"LayerGroupConfig",
|
||||
"VerticalConv",
|
||||
"FreqCoordConvDownBlock",
|
||||
"StandardConvDownBlock",
|
||||
@ -654,16 +654,16 @@ LayerConfig = Annotated[
|
||||
StandardConvDownConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
"BlockGroupConfig",
|
||||
"LayerGroupConfig",
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configuration models."""
|
||||
|
||||
|
||||
class BlockGroupConfig(BaseConfig):
|
||||
block_type: Literal["group"] = "group"
|
||||
blocks: List[LayerConfig]
|
||||
class LayerGroupConfig(BaseConfig):
|
||||
block_type: Literal["LayerGroup"] = "LayerGroup"
|
||||
layers: List[LayerConfig]
|
||||
|
||||
|
||||
def build_layer_from_config(
|
||||
@ -769,13 +769,13 @@ def build_layer_from_config(
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.block_type == "group":
|
||||
if config.block_type == "LayerGroup":
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
blocks = []
|
||||
|
||||
for block_config in config.blocks:
|
||||
for block_config in config.layers:
|
||||
block, current_channels, current_height = build_layer_from_config(
|
||||
input_height=current_height,
|
||||
in_channels=current_channels,
|
||||
|
@ -26,7 +26,7 @@ from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
BlockGroupConfig,
|
||||
LayerGroupConfig,
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
@ -45,7 +45,7 @@ DecoderLayerConfig = Annotated[
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
BlockGroupConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
@ -197,8 +197,8 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
||||
layers=[
|
||||
FreqCoordConvUpConfig(out_channels=64),
|
||||
FreqCoordConvUpConfig(out_channels=32),
|
||||
BlockGroupConfig(
|
||||
blocks=[
|
||||
LayerGroupConfig(
|
||||
layers=[
|
||||
FreqCoordConvUpConfig(out_channels=32),
|
||||
ConvConfig(out_channels=32),
|
||||
]
|
||||
|
@ -28,7 +28,7 @@ from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
BlockGroupConfig,
|
||||
LayerGroupConfig,
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
@ -47,7 +47,7 @@ EncoderLayerConfig = Annotated[
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
BlockGroupConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
@ -230,8 +230,8 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
||||
layers=[
|
||||
FreqCoordConvDownConfig(out_channels=32),
|
||||
FreqCoordConvDownConfig(out_channels=64),
|
||||
BlockGroupConfig(
|
||||
blocks=[
|
||||
LayerGroupConfig(
|
||||
layers=[
|
||||
FreqCoordConvDownConfig(out_channels=128),
|
||||
ConvConfig(out_channels=256),
|
||||
]
|
||||
|
@ -2,9 +2,8 @@ from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
DEFAULT_LOGS_DIR: str = "logs"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user