From 6498b6ca3743da66f66c143d003c35e362efbde2 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 23 Apr 2025 23:14:11 +0100 Subject: [PATCH] Added BlockGroups --- batdetect2/models/blocks.py | 25 ++++++++++++++++++++++++- batdetect2/models/decoder.py | 16 +++++++++++++--- batdetect2/models/encoder.py | 16 +++++++++++++--- 3 files changed, 50 insertions(+), 7 deletions(-) diff --git a/batdetect2/models/blocks.py b/batdetect2/models/blocks.py index 59fa7c0..4df17f1 100644 --- a/batdetect2/models/blocks.py +++ b/batdetect2/models/blocks.py @@ -27,7 +27,7 @@ A unified factory function `build_layer_from_config` allows creating instances of these blocks based on configuration objects. """ -from typing import Annotated, Literal, Tuple, Union +from typing import Annotated, List, Literal, Tuple, Union import torch import torch.nn.functional as F @@ -38,6 +38,7 @@ from batdetect2.configs import BaseConfig __all__ = [ "ConvBlock", + "BlockGroupConfig", "VerticalConv", "FreqCoordConvDownBlock", "StandardConvDownBlock", @@ -653,12 +654,18 @@ LayerConfig = Annotated[ StandardConvDownConfig, FreqCoordConvUpConfig, StandardConvUpConfig, + "BlockGroupConfig", ], Field(discriminator="block_type"), ] """Type alias for the discriminated union of block configuration models.""" +class BlockGroupConfig(BaseConfig): + block_type: Literal["group"] = "group" + blocks: List[LayerConfig] + + def build_layer_from_config( input_height: int, in_channels: int, @@ -762,4 +769,20 @@ def build_layer_from_config( input_height * 2, ) + if config.block_type == "group": + current_channels = in_channels + current_height = input_height + + blocks = [] + + for block_config in config.blocks: + block, current_channels, current_height = build_layer_from_config( + input_height=current_height, + in_channels=current_channels, + config=block_config, + ) + blocks.append(block) + + return nn.Sequential(*blocks), current_channels, current_height + raise NotImplementedError(f"Unknown block type {config.block_type}") diff --git a/batdetect2/models/decoder.py b/batdetect2/models/decoder.py index 9760ede..91c1089 100644 --- a/batdetect2/models/decoder.py +++ b/batdetect2/models/decoder.py @@ -26,6 +26,7 @@ from torch import nn from batdetect2.configs import BaseConfig from batdetect2.models.blocks import ( + BlockGroupConfig, ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig, @@ -40,7 +41,12 @@ __all__ = [ ] DecoderLayerConfig = Annotated[ - Union[ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig], + Union[ + ConvConfig, + FreqCoordConvUpConfig, + StandardConvUpConfig, + BlockGroupConfig, + ], Field(discriminator="block_type"), ] """Type alias for the discriminated union of block configs usable in Decoder.""" @@ -191,8 +197,12 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig( layers=[ FreqCoordConvUpConfig(out_channels=64), FreqCoordConvUpConfig(out_channels=32), - FreqCoordConvUpConfig(out_channels=32), - ConvConfig(out_channels=32), + BlockGroupConfig( + blocks=[ + FreqCoordConvUpConfig(out_channels=32), + ConvConfig(out_channels=32), + ] + ), ], ) """A default configuration for the Decoder's *layer sequence*. diff --git a/batdetect2/models/encoder.py b/batdetect2/models/encoder.py index 7f2600e..3a0c609 100644 --- a/batdetect2/models/encoder.py +++ b/batdetect2/models/encoder.py @@ -28,6 +28,7 @@ from torch import nn from batdetect2.configs import BaseConfig from batdetect2.models.blocks import ( + BlockGroupConfig, ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig, @@ -42,7 +43,12 @@ __all__ = [ ] EncoderLayerConfig = Annotated[ - Union[ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig], + Union[ + ConvConfig, + FreqCoordConvDownConfig, + StandardConvDownConfig, + BlockGroupConfig, + ], Field(discriminator="block_type"), ] """Type alias for the discriminated union of block configs usable in Encoder.""" @@ -224,8 +230,12 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig( layers=[ FreqCoordConvDownConfig(out_channels=32), FreqCoordConvDownConfig(out_channels=64), - FreqCoordConvDownConfig(out_channels=128), - ConvConfig(out_channels=256), + BlockGroupConfig( + blocks=[ + FreqCoordConvDownConfig(out_channels=128), + ConvConfig(out_channels=256), + ] + ), ], ) """Default configuration for the Encoder.