Added BlockGroups

This commit is contained in:
mbsantiago 2025-04-23 23:14:11 +01:00
parent bfcab0331e
commit 6498b6ca37
3 changed files with 50 additions and 7 deletions

View File

@ -27,7 +27,7 @@ A unified factory function `build_layer_from_config` allows creating instances
of these blocks based on configuration objects. of these blocks based on configuration objects.
""" """
from typing import Annotated, Literal, Tuple, Union from typing import Annotated, List, Literal, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -38,6 +38,7 @@ from batdetect2.configs import BaseConfig
__all__ = [ __all__ = [
"ConvBlock", "ConvBlock",
"BlockGroupConfig",
"VerticalConv", "VerticalConv",
"FreqCoordConvDownBlock", "FreqCoordConvDownBlock",
"StandardConvDownBlock", "StandardConvDownBlock",
@ -653,12 +654,18 @@ LayerConfig = Annotated[
StandardConvDownConfig, StandardConvDownConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
StandardConvUpConfig, StandardConvUpConfig,
"BlockGroupConfig",
], ],
Field(discriminator="block_type"), Field(discriminator="block_type"),
] ]
"""Type alias for the discriminated union of block configuration models.""" """Type alias for the discriminated union of block configuration models."""
class BlockGroupConfig(BaseConfig):
block_type: Literal["group"] = "group"
blocks: List[LayerConfig]
def build_layer_from_config( def build_layer_from_config(
input_height: int, input_height: int,
in_channels: int, in_channels: int,
@ -762,4 +769,20 @@ def build_layer_from_config(
input_height * 2, input_height * 2,
) )
if config.block_type == "group":
current_channels = in_channels
current_height = input_height
blocks = []
for block_config in config.blocks:
block, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,
config=block_config,
)
blocks.append(block)
return nn.Sequential(*blocks), current_channels, current_height
raise NotImplementedError(f"Unknown block type {config.block_type}") raise NotImplementedError(f"Unknown block type {config.block_type}")

View File

@ -26,6 +26,7 @@ from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
BlockGroupConfig,
ConvConfig, ConvConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
StandardConvUpConfig, StandardConvUpConfig,
@ -40,7 +41,12 @@ __all__ = [
] ]
DecoderLayerConfig = Annotated[ DecoderLayerConfig = Annotated[
Union[ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig], Union[
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
BlockGroupConfig,
],
Field(discriminator="block_type"), Field(discriminator="block_type"),
] ]
"""Type alias for the discriminated union of block configs usable in Decoder.""" """Type alias for the discriminated union of block configs usable in Decoder."""
@ -191,8 +197,12 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
layers=[ layers=[
FreqCoordConvUpConfig(out_channels=64), FreqCoordConvUpConfig(out_channels=64),
FreqCoordConvUpConfig(out_channels=32), FreqCoordConvUpConfig(out_channels=32),
FreqCoordConvUpConfig(out_channels=32), BlockGroupConfig(
ConvConfig(out_channels=32), blocks=[
FreqCoordConvUpConfig(out_channels=32),
ConvConfig(out_channels=32),
]
),
], ],
) )
"""A default configuration for the Decoder's *layer sequence*. """A default configuration for the Decoder's *layer sequence*.

View File

@ -28,6 +28,7 @@ from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
BlockGroupConfig,
ConvConfig, ConvConfig,
FreqCoordConvDownConfig, FreqCoordConvDownConfig,
StandardConvDownConfig, StandardConvDownConfig,
@ -42,7 +43,12 @@ __all__ = [
] ]
EncoderLayerConfig = Annotated[ EncoderLayerConfig = Annotated[
Union[ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig], Union[
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
BlockGroupConfig,
],
Field(discriminator="block_type"), Field(discriminator="block_type"),
] ]
"""Type alias for the discriminated union of block configs usable in Encoder.""" """Type alias for the discriminated union of block configs usable in Encoder."""
@ -224,8 +230,12 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
layers=[ layers=[
FreqCoordConvDownConfig(out_channels=32), FreqCoordConvDownConfig(out_channels=32),
FreqCoordConvDownConfig(out_channels=64), FreqCoordConvDownConfig(out_channels=64),
FreqCoordConvDownConfig(out_channels=128), BlockGroupConfig(
ConvConfig(out_channels=256), blocks=[
FreqCoordConvDownConfig(out_channels=128),
ConvConfig(out_channels=256),
]
),
], ],
) )
"""Default configuration for the Encoder. """Default configuration for the Encoder.