Rename BlockGroupConfig to LayerGroupConfig

This commit is contained in:
mbsantiago 2025-06-26 10:04:42 -06:00
parent cbb02cf69e
commit 166dad20bd
5 changed files with 36 additions and 23 deletions

View File

@ -27,9 +27,23 @@ from torch import nn
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.blocks import ConvBlock
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
Decoder,
DecoderConfig,
build_decoder,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
Encoder,
EncoderConfig,
build_encoder,
)
from batdetect2.models.types import BackboneModel
__all__ = [
@ -186,9 +200,9 @@ class BackboneConfig(BaseConfig):
input_height: int = 128
in_channels: int = 1
encoder: Optional[EncoderConfig] = None
bottleneck: Optional[BottleneckConfig] = None
decoder: Optional[DecoderConfig] = None
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32

View File

@ -38,7 +38,7 @@ from batdetect2.configs import BaseConfig
__all__ = [
"ConvBlock",
"BlockGroupConfig",
"LayerGroupConfig",
"VerticalConv",
"FreqCoordConvDownBlock",
"StandardConvDownBlock",
@ -654,16 +654,16 @@ LayerConfig = Annotated[
StandardConvDownConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
"BlockGroupConfig",
"LayerGroupConfig",
],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configuration models."""
class BlockGroupConfig(BaseConfig):
block_type: Literal["group"] = "group"
blocks: List[LayerConfig]
class LayerGroupConfig(BaseConfig):
block_type: Literal["LayerGroup"] = "LayerGroup"
layers: List[LayerConfig]
def build_layer_from_config(
@ -769,13 +769,13 @@ def build_layer_from_config(
input_height * 2,
)
if config.block_type == "group":
if config.block_type == "LayerGroup":
current_channels = in_channels
current_height = input_height
blocks = []
for block_config in config.blocks:
for block_config in config.layers:
block, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,

View File

@ -26,7 +26,7 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
BlockGroupConfig,
LayerGroupConfig,
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
@ -45,7 +45,7 @@ DecoderLayerConfig = Annotated[
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
BlockGroupConfig,
LayerGroupConfig,
],
Field(discriminator="block_type"),
]
@ -197,8 +197,8 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
layers=[
FreqCoordConvUpConfig(out_channels=64),
FreqCoordConvUpConfig(out_channels=32),
BlockGroupConfig(
blocks=[
LayerGroupConfig(
layers=[
FreqCoordConvUpConfig(out_channels=32),
ConvConfig(out_channels=32),
]

View File

@ -28,7 +28,7 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
BlockGroupConfig,
LayerGroupConfig,
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
@ -47,7 +47,7 @@ EncoderLayerConfig = Annotated[
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
BlockGroupConfig,
LayerGroupConfig,
],
Field(discriminator="block_type"),
]
@ -230,8 +230,8 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
layers=[
FreqCoordConvDownConfig(out_channels=32),
FreqCoordConvDownConfig(out_channels=64),
BlockGroupConfig(
blocks=[
LayerGroupConfig(
layers=[
FreqCoordConvDownConfig(out_channels=128),
ConvConfig(out_channels=256),
]

View File

@ -2,9 +2,8 @@ from typing import Annotated, Literal, Optional, Union
from lightning.pytorch.loggers import Logger
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.configs import BaseConfig
DEFAULT_LOGS_DIR: str = "logs"