Rename BlockGroupConfig to LayerGroupConfig

This commit is contained in:
mbsantiago 2025-06-26 10:04:42 -06:00
parent cbb02cf69e
commit 166dad20bd
5 changed files with 36 additions and 23 deletions

View File

@ -27,9 +27,23 @@ from torch import nn
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.blocks import ConvBlock from batdetect2.models.blocks import ConvBlock
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck from batdetect2.models.bottleneck import (
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder DEFAULT_BOTTLENECK_CONFIG,
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
Decoder,
DecoderConfig,
build_decoder,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
Encoder,
EncoderConfig,
build_encoder,
)
from batdetect2.models.types import BackboneModel from batdetect2.models.types import BackboneModel
__all__ = [ __all__ = [
@ -186,9 +200,9 @@ class BackboneConfig(BaseConfig):
input_height: int = 128 input_height: int = 128
in_channels: int = 1 in_channels: int = 1
encoder: Optional[EncoderConfig] = None encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: Optional[BottleneckConfig] = None bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: Optional[DecoderConfig] = None decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32 out_channels: int = 32

View File

@ -38,7 +38,7 @@ from batdetect2.configs import BaseConfig
__all__ = [ __all__ = [
"ConvBlock", "ConvBlock",
"BlockGroupConfig", "LayerGroupConfig",
"VerticalConv", "VerticalConv",
"FreqCoordConvDownBlock", "FreqCoordConvDownBlock",
"StandardConvDownBlock", "StandardConvDownBlock",
@ -654,16 +654,16 @@ LayerConfig = Annotated[
StandardConvDownConfig, StandardConvDownConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
StandardConvUpConfig, StandardConvUpConfig,
"BlockGroupConfig", "LayerGroupConfig",
], ],
Field(discriminator="block_type"), Field(discriminator="block_type"),
] ]
"""Type alias for the discriminated union of block configuration models.""" """Type alias for the discriminated union of block configuration models."""
class BlockGroupConfig(BaseConfig): class LayerGroupConfig(BaseConfig):
block_type: Literal["group"] = "group" block_type: Literal["LayerGroup"] = "LayerGroup"
blocks: List[LayerConfig] layers: List[LayerConfig]
def build_layer_from_config( def build_layer_from_config(
@ -769,13 +769,13 @@ def build_layer_from_config(
input_height * 2, input_height * 2,
) )
if config.block_type == "group": if config.block_type == "LayerGroup":
current_channels = in_channels current_channels = in_channels
current_height = input_height current_height = input_height
blocks = [] blocks = []
for block_config in config.blocks: for block_config in config.layers:
block, current_channels, current_height = build_layer_from_config( block, current_channels, current_height = build_layer_from_config(
input_height=current_height, input_height=current_height,
in_channels=current_channels, in_channels=current_channels,

View File

@ -26,7 +26,7 @@ from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
BlockGroupConfig, LayerGroupConfig,
ConvConfig, ConvConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
StandardConvUpConfig, StandardConvUpConfig,
@ -45,7 +45,7 @@ DecoderLayerConfig = Annotated[
ConvConfig, ConvConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
StandardConvUpConfig, StandardConvUpConfig,
BlockGroupConfig, LayerGroupConfig,
], ],
Field(discriminator="block_type"), Field(discriminator="block_type"),
] ]
@ -197,8 +197,8 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
layers=[ layers=[
FreqCoordConvUpConfig(out_channels=64), FreqCoordConvUpConfig(out_channels=64),
FreqCoordConvUpConfig(out_channels=32), FreqCoordConvUpConfig(out_channels=32),
BlockGroupConfig( LayerGroupConfig(
blocks=[ layers=[
FreqCoordConvUpConfig(out_channels=32), FreqCoordConvUpConfig(out_channels=32),
ConvConfig(out_channels=32), ConvConfig(out_channels=32),
] ]

View File

@ -28,7 +28,7 @@ from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
BlockGroupConfig, LayerGroupConfig,
ConvConfig, ConvConfig,
FreqCoordConvDownConfig, FreqCoordConvDownConfig,
StandardConvDownConfig, StandardConvDownConfig,
@ -47,7 +47,7 @@ EncoderLayerConfig = Annotated[
ConvConfig, ConvConfig,
FreqCoordConvDownConfig, FreqCoordConvDownConfig,
StandardConvDownConfig, StandardConvDownConfig,
BlockGroupConfig, LayerGroupConfig,
], ],
Field(discriminator="block_type"), Field(discriminator="block_type"),
] ]
@ -230,8 +230,8 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
layers=[ layers=[
FreqCoordConvDownConfig(out_channels=32), FreqCoordConvDownConfig(out_channels=32),
FreqCoordConvDownConfig(out_channels=64), FreqCoordConvDownConfig(out_channels=64),
BlockGroupConfig( LayerGroupConfig(
blocks=[ layers=[
FreqCoordConvDownConfig(out_channels=128), FreqCoordConvDownConfig(out_channels=128),
ConvConfig(out_channels=256), ConvConfig(out_channels=256),
] ]

View File

@ -2,9 +2,8 @@ from typing import Annotated, Literal, Optional, Union
from lightning.pytorch.loggers import Logger from lightning.pytorch.loggers import Logger
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig
DEFAULT_LOGS_DIR: str = "logs" DEFAULT_LOGS_DIR: str = "logs"