Added BlockGroups

This commit is contained in:
mbsantiago 2025-04-23 23:14:11 +01:00
parent bfcab0331e
commit 6498b6ca37
3 changed files with 50 additions and 7 deletions

View File

@ -27,7 +27,7 @@ A unified factory function `build_layer_from_config` allows creating instances
of these blocks based on configuration objects.
"""
from typing import Annotated, Literal, Tuple, Union
from typing import Annotated, List, Literal, Tuple, Union
import torch
import torch.nn.functional as F
@ -38,6 +38,7 @@ from batdetect2.configs import BaseConfig
__all__ = [
"ConvBlock",
"BlockGroupConfig",
"VerticalConv",
"FreqCoordConvDownBlock",
"StandardConvDownBlock",
@ -653,12 +654,18 @@ LayerConfig = Annotated[
StandardConvDownConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
"BlockGroupConfig",
],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configuration models."""
class BlockGroupConfig(BaseConfig):
block_type: Literal["group"] = "group"
blocks: List[LayerConfig]
def build_layer_from_config(
input_height: int,
in_channels: int,
@ -762,4 +769,20 @@ def build_layer_from_config(
input_height * 2,
)
if config.block_type == "group":
current_channels = in_channels
current_height = input_height
blocks = []
for block_config in config.blocks:
block, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,
config=block_config,
)
blocks.append(block)
return nn.Sequential(*blocks), current_channels, current_height
raise NotImplementedError(f"Unknown block type {config.block_type}")

View File

@ -26,6 +26,7 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
BlockGroupConfig,
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
@ -40,7 +41,12 @@ __all__ = [
]
DecoderLayerConfig = Annotated[
Union[ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig],
Union[
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
BlockGroupConfig,
],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
@ -191,8 +197,12 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
layers=[
FreqCoordConvUpConfig(out_channels=64),
FreqCoordConvUpConfig(out_channels=32),
BlockGroupConfig(
blocks=[
FreqCoordConvUpConfig(out_channels=32),
ConvConfig(out_channels=32),
]
),
],
)
"""A default configuration for the Decoder's *layer sequence*.

View File

@ -28,6 +28,7 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
BlockGroupConfig,
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
@ -42,7 +43,12 @@ __all__ = [
]
EncoderLayerConfig = Annotated[
Union[ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig],
Union[
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
BlockGroupConfig,
],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configs usable in Encoder."""
@ -224,8 +230,12 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
layers=[
FreqCoordConvDownConfig(out_channels=32),
FreqCoordConvDownConfig(out_channels=64),
BlockGroupConfig(
blocks=[
FreqCoordConvDownConfig(out_channels=128),
ConvConfig(out_channels=256),
]
),
],
)
"""Default configuration for the Encoder.