From e3937092583018a29533ee2788dbbc588253d6e6 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 8 Mar 2026 14:43:16 +0000 Subject: [PATCH] Add interfaces for encoder/decoder/bottleneck --- src/batdetect2/models/backbones.py | 31 ++++++++------ src/batdetect2/models/blocks.py | 53 ++++++++++------------- src/batdetect2/models/bottleneck.py | 9 ++-- src/batdetect2/models/decoder.py | 2 +- src/batdetect2/models/encoder.py | 2 +- src/batdetect2/typing/models.py | 65 ++++++++++++++++++++++++++++- tests/test_models/test_blocks.py | 10 ++--- 7 files changed, 115 insertions(+), 57 deletions(-) diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index 56e83eb..98548f8 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -26,9 +26,14 @@ from torch import nn from batdetect2.models.bottleneck import build_bottleneck from batdetect2.models.config import BackboneConfig -from batdetect2.models.decoder import Decoder, build_decoder -from batdetect2.models.encoder import Encoder, build_encoder -from batdetect2.typing.models import BackboneModel +from batdetect2.models.decoder import build_decoder +from batdetect2.models.encoder import build_encoder +from batdetect2.typing.models import ( + BackboneModel, + BottleneckProtocol, + DecoderProtocol, + EncoderProtocol, +) __all__ = [ "Backbone", @@ -55,11 +60,11 @@ class Backbone(BackboneModel): Expected height of the input spectrogram. out_channels : int Number of channels in the final output feature map. - encoder : Encoder + encoder : EncoderProtocol The instantiated encoder module. - decoder : Decoder + decoder : DecoderProtocol The instantiated decoder module. - bottleneck : nn.Module + bottleneck : BottleneckProtocol The instantiated bottleneck module. final_conv : ConvBlock Final convolutional block applied after the decoder. @@ -71,9 +76,9 @@ class Backbone(BackboneModel): def __init__( self, input_height: int, - encoder: Encoder, - decoder: Decoder, - bottleneck: nn.Module, + encoder: EncoderProtocol, + decoder: DecoderProtocol, + bottleneck: BottleneckProtocol, ): """Initialize the Backbone network. @@ -83,11 +88,11 @@ class Backbone(BackboneModel): Expected height of the input spectrogram. out_channels : int Desired number of output channels for the backbone's feature map. - encoder : Encoder + encoder : EncoderProtocol An initialized Encoder module. - decoder : Decoder + decoder : DecoderProtocol An initialized Decoder module. - bottleneck : nn.Module + bottleneck : BottleneckProtocol An initialized Bottleneck module. Raises @@ -185,7 +190,7 @@ def build_backbone(config: BackboneConfig) -> BackboneModel: ) decoder = build_decoder( - in_channels=bottleneck.get_output_channels(), + in_channels=bottleneck.out_channels, input_height=encoder.output_height, config=config.decoder, ) diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index 6890dd8..fa42bdf 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -56,17 +56,14 @@ __all__ = [ ] -class BlockProtocol(Protocol): - def get_output_channels(self) -> int: - raise NotImplementedError +class Block(nn.Module): + in_channels: int + out_channels: int def get_output_height(self, input_height: int) -> int: return input_height -class Block(nn.Module, BlockProtocol): ... - - block_registry: Registry[Block, [int, int]] = Registry("block") @@ -132,6 +129,8 @@ class SelfAttention(Block): temperature: float = 1.0, ): super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels # Note, does not encode position information (absolute or relative) self.temperature = temperature @@ -206,9 +205,6 @@ class SelfAttention(Block): att_weights = F.softmax(kk_qq, 1) return att_weights - def get_output_channels(self) -> int: - return self.output_channels - @block_registry.register(SelfAttentionConfig) @staticmethod def from_config( @@ -267,6 +263,8 @@ class ConvBlock(Block): pad_size: int = 1, ): super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels self.conv = nn.Conv2d( in_channels, out_channels, @@ -290,9 +288,6 @@ class ConvBlock(Block): """ return F.relu_(self.batch_norm(self.conv(x))) - def get_output_channels(self) -> int: - return self.conv.out_channels - @block_registry.register(ConvConfig) @staticmethod def from_config( @@ -342,6 +337,8 @@ class VerticalConv(Block): input_height: int, ): super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -366,9 +363,6 @@ class VerticalConv(Block): """ return F.relu_(self.bn(self.conv(x))) - def get_output_channels(self) -> int: - return self.conv.out_channels - @block_registry.register(VerticalConvConfig) @staticmethod def from_config( @@ -438,6 +432,8 @@ class FreqCoordConvDownBlock(Block): pad_size: int = 1, ): super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels self.coords = nn.Parameter( torch.linspace(-1, 1, input_height)[None, None, ..., None], @@ -472,9 +468,6 @@ class FreqCoordConvDownBlock(Block): x = F.relu(self.batch_norm(x), inplace=True) return x - def get_output_channels(self) -> int: - return self.conv.out_channels - def get_output_height(self, input_height: int) -> int: return input_height // 2 @@ -538,6 +531,8 @@ class StandardConvDownBlock(Block): pad_size: int = 1, ): super(StandardConvDownBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels self.conv = nn.Conv2d( in_channels, out_channels, @@ -563,9 +558,6 @@ class StandardConvDownBlock(Block): x = F.max_pool2d(self.conv(x), 2, 2) return F.relu(self.batch_norm(x), inplace=True) - def get_output_channels(self) -> int: - return self.conv.out_channels - def get_output_height(self, input_height: int) -> int: return input_height // 2 @@ -652,6 +644,8 @@ class FreqCoordConvUpBlock(Block): up_scale: Tuple[int, int] = (2, 2), ): super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels self.up_scale = up_scale self.up_mode = up_mode @@ -698,9 +692,6 @@ class FreqCoordConvUpBlock(Block): op = F.relu(self.batch_norm(op), inplace=True) return op - def get_output_channels(self) -> int: - return self.conv.out_channels - def get_output_height(self, input_height: int) -> int: return input_height * 2 @@ -779,6 +770,8 @@ class StandardConvUpBlock(Block): up_scale: Tuple[int, int] = (2, 2), ): super(StandardConvUpBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels self.up_scale = up_scale self.up_mode = up_mode self.conv = nn.Conv2d( @@ -815,9 +808,6 @@ class StandardConvUpBlock(Block): op = F.relu(self.batch_norm(op), inplace=True) return op - def get_output_channels(self) -> int: - return self.conv.out_channels - def get_output_height(self, input_height: int) -> int: return input_height * 2 @@ -868,14 +858,15 @@ class LayerGroup(nn.Module): input_channels: int, ): super().__init__() + self.in_channels = input_channels + self.out_channels = ( + layers[-1].out_channels if layers else input_channels + ) self.layers = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.layers(x) - def get_output_channels(self) -> int: - return self.layers[-1].get_output_channels() # type: ignore - def get_output_height(self, input_height: int) -> int: for block in self.layers: input_height = block.get_output_height(input_height) # type: ignore @@ -898,7 +889,7 @@ class LayerGroup(nn.Module): ) layers.append(layer) input_height = layer.get_output_height(input_height) - input_channels = layer.get_output_channels() + input_channels = layer.out_channels return LayerGroup( layers=layers, diff --git a/src/batdetect2/models/bottleneck.py b/src/batdetect2/models/bottleneck.py index 859788e..290fb64 100644 --- a/src/batdetect2/models/bottleneck.py +++ b/src/batdetect2/models/bottleneck.py @@ -28,6 +28,8 @@ from batdetect2.models.blocks import ( build_layer, ) +from batdetect2.typing.models import BottleneckProtocol + __all__ = [ "BottleneckConfig", "Bottleneck", @@ -127,9 +129,6 @@ class Bottleneck(Block): return x.repeat([1, 1, self.input_height, 1]) - def get_output_channels(self) -> int: - return self.layers[-1].get_output_channels() # type: ignore - BottleneckLayerConfig = Annotated[ SelfAttentionConfig, @@ -177,7 +176,7 @@ def build_bottleneck( input_height: int, in_channels: int, config: BottleneckConfig | None = None, -) -> Block: +) -> BottleneckProtocol: """Factory function to build the Bottleneck module from configuration. Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based @@ -217,7 +216,7 @@ def build_bottleneck( config=layer_config, ) current_height = layer.get_output_height(current_height) - current_channels = layer.get_output_channels() + current_channels = layer.out_channels assert current_height == input_height, ( "Bottleneck layers should not change the spectrogram height" ) diff --git a/src/batdetect2/models/decoder.py b/src/batdetect2/models/decoder.py index 8275cb5..1673653 100644 --- a/src/batdetect2/models/decoder.py +++ b/src/batdetect2/models/decoder.py @@ -265,7 +265,7 @@ def build_decoder( config=layer_config, ) current_height = layer.get_output_height(current_height) - current_channels = layer.get_output_channels() + current_channels = layer.out_channels layers.append(layer) return Decoder( diff --git a/src/batdetect2/models/encoder.py b/src/batdetect2/models/encoder.py index a9a0173..142e9ab 100644 --- a/src/batdetect2/models/encoder.py +++ b/src/batdetect2/models/encoder.py @@ -307,7 +307,7 @@ def build_encoder( ) layers.append(layer) current_height = layer.get_output_height(current_height) - current_channels = layer.get_output_channels() + current_channels = layer.out_channels return Encoder( input_height=input_height, diff --git a/src/batdetect2/typing/models.py b/src/batdetect2/typing/models.py index d71193f..2bb1881 100644 --- a/src/batdetect2/typing/models.py +++ b/src/batdetect2/typing/models.py @@ -16,17 +16,80 @@ Key components: """ from abc import ABC, abstractmethod -from typing import NamedTuple +from typing import List, NamedTuple, Protocol import torch __all__ = [ "ModelOutput", "BackboneModel", + "EncoderDecoderModel", "DetectionModel", + "BlockProtocol", + "EncoderProtocol", + "BottleneckProtocol", + "DecoderProtocol", ] +class BlockProtocol(Protocol): + """Interface for blocks of network layers.""" + + in_channels: int + out_channels: int + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the block.""" + ... + + def get_output_height(self, input_height: int) -> int: + """Calculate the output height based on input height.""" + ... + + +class EncoderProtocol(Protocol): + """Interface for the downsampling path of a network.""" + + in_channels: int + out_channels: int + input_height: int + output_height: int + + def __call__(self, x: torch.Tensor) -> List[torch.Tensor]: + """Forward pass must return intermediate tensors for skip connections.""" + ... + + +class BottleneckProtocol(Protocol): + """Interface for the middle part of a U-Net-like network.""" + + in_channels: int + out_channels: int + input_height: int + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Processes the features from the encoder.""" + ... + + +class DecoderProtocol(Protocol): + """Interface for the upsampling reconstruction path.""" + + in_channels: int + out_channels: int + input_height: int + output_height: int + depth: int + + def __call__( + self, + x: torch.Tensor, + residuals: List[torch.Tensor], + ) -> torch.Tensor: + """Upsamples features while integrating skip connections.""" + ... + + class ModelOutput(NamedTuple): """Standard container for the outputs of a BatDetect2 detection model. diff --git a/tests/test_models/test_blocks.py b/tests/test_models/test_blocks.py index 0795b40..57315d5 100644 --- a/tests/test_models/test_blocks.py +++ b/tests/test_models/test_blocks.py @@ -53,7 +53,7 @@ def test_standard_block_protocol_methods( block = block_class(in_channels=in_channels, out_channels=out_channels) - assert block.get_output_channels() == out_channels + assert block.out_channels == out_channels assert block.get_output_height(input_height) == int( input_height * expected_h_scale ) @@ -80,7 +80,7 @@ def test_coord_block_protocol_methods( input_height=input_height, ) - assert block.get_output_channels() == out_channels + assert block.out_channels == out_channels assert block.get_output_height(input_height) == int( input_height * expected_h_scale ) @@ -96,7 +96,7 @@ def test_vertical_conv_forward_shape(dummy_input): output = block(dummy_input) assert output.shape == (2, out_channels, 1, 32) - assert block.get_output_channels() == out_channels + assert block.out_channels == out_channels def test_self_attention_forward_shape(dummy_bottleneck_input): @@ -110,7 +110,7 @@ def test_self_attention_forward_shape(dummy_bottleneck_input): output = block(dummy_bottleneck_input) assert output.shape == dummy_bottleneck_input.shape - assert block.get_output_channels() == in_channels + assert block.out_channels == in_channels def test_self_attention_weights(dummy_bottleneck_input): @@ -179,7 +179,7 @@ def test_layer_group_from_config_and_forward(dummy_input): assert len(layer_group.layers) == 2 # The group should report the output channels of the LAST block - assert layer_group.get_output_channels() == 64 + assert layer_group.out_channels == 64 # The group should report the accumulated height changes assert layer_group.get_output_height(input_height) == input_height // 2