Add interfaces for encoder/decoder/bottleneck

This commit is contained in:
mbsantiago 2026-03-08 14:43:16 +00:00
parent 54605ef269
commit e393709258
7 changed files with 115 additions and 57 deletions

View File

@ -26,9 +26,14 @@ from torch import nn
from batdetect2.models.bottleneck import build_bottleneck from batdetect2.models.bottleneck import build_bottleneck
from batdetect2.models.config import BackboneConfig from batdetect2.models.config import BackboneConfig
from batdetect2.models.decoder import Decoder, build_decoder from batdetect2.models.decoder import build_decoder
from batdetect2.models.encoder import Encoder, build_encoder from batdetect2.models.encoder import build_encoder
from batdetect2.typing.models import BackboneModel from batdetect2.typing.models import (
BackboneModel,
BottleneckProtocol,
DecoderProtocol,
EncoderProtocol,
)
__all__ = [ __all__ = [
"Backbone", "Backbone",
@ -55,11 +60,11 @@ class Backbone(BackboneModel):
Expected height of the input spectrogram. Expected height of the input spectrogram.
out_channels : int out_channels : int
Number of channels in the final output feature map. Number of channels in the final output feature map.
encoder : Encoder encoder : EncoderProtocol
The instantiated encoder module. The instantiated encoder module.
decoder : Decoder decoder : DecoderProtocol
The instantiated decoder module. The instantiated decoder module.
bottleneck : nn.Module bottleneck : BottleneckProtocol
The instantiated bottleneck module. The instantiated bottleneck module.
final_conv : ConvBlock final_conv : ConvBlock
Final convolutional block applied after the decoder. Final convolutional block applied after the decoder.
@ -71,9 +76,9 @@ class Backbone(BackboneModel):
def __init__( def __init__(
self, self,
input_height: int, input_height: int,
encoder: Encoder, encoder: EncoderProtocol,
decoder: Decoder, decoder: DecoderProtocol,
bottleneck: nn.Module, bottleneck: BottleneckProtocol,
): ):
"""Initialize the Backbone network. """Initialize the Backbone network.
@ -83,11 +88,11 @@ class Backbone(BackboneModel):
Expected height of the input spectrogram. Expected height of the input spectrogram.
out_channels : int out_channels : int
Desired number of output channels for the backbone's feature map. Desired number of output channels for the backbone's feature map.
encoder : Encoder encoder : EncoderProtocol
An initialized Encoder module. An initialized Encoder module.
decoder : Decoder decoder : DecoderProtocol
An initialized Decoder module. An initialized Decoder module.
bottleneck : nn.Module bottleneck : BottleneckProtocol
An initialized Bottleneck module. An initialized Bottleneck module.
Raises Raises
@ -185,7 +190,7 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
) )
decoder = build_decoder( decoder = build_decoder(
in_channels=bottleneck.get_output_channels(), in_channels=bottleneck.out_channels,
input_height=encoder.output_height, input_height=encoder.output_height,
config=config.decoder, config=config.decoder,
) )

View File

@ -56,17 +56,14 @@ __all__ = [
] ]
class BlockProtocol(Protocol): class Block(nn.Module):
def get_output_channels(self) -> int: in_channels: int
raise NotImplementedError out_channels: int
def get_output_height(self, input_height: int) -> int: def get_output_height(self, input_height: int) -> int:
return input_height return input_height
class Block(nn.Module, BlockProtocol): ...
block_registry: Registry[Block, [int, int]] = Registry("block") block_registry: Registry[Block, [int, int]] = Registry("block")
@ -132,6 +129,8 @@ class SelfAttention(Block):
temperature: float = 1.0, temperature: float = 1.0,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
# Note, does not encode position information (absolute or relative) # Note, does not encode position information (absolute or relative)
self.temperature = temperature self.temperature = temperature
@ -206,9 +205,6 @@ class SelfAttention(Block):
att_weights = F.softmax(kk_qq, 1) att_weights = F.softmax(kk_qq, 1)
return att_weights return att_weights
def get_output_channels(self) -> int:
return self.output_channels
@block_registry.register(SelfAttentionConfig) @block_registry.register(SelfAttentionConfig)
@staticmethod @staticmethod
def from_config( def from_config(
@ -267,6 +263,8 @@ class ConvBlock(Block):
pad_size: int = 1, pad_size: int = 1,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels, in_channels,
out_channels, out_channels,
@ -290,9 +288,6 @@ class ConvBlock(Block):
""" """
return F.relu_(self.batch_norm(self.conv(x))) return F.relu_(self.batch_norm(self.conv(x)))
def get_output_channels(self) -> int:
return self.conv.out_channels
@block_registry.register(ConvConfig) @block_registry.register(ConvConfig)
@staticmethod @staticmethod
def from_config( def from_config(
@ -342,6 +337,8 @@ class VerticalConv(Block):
input_height: int, input_height: int,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
@ -366,9 +363,6 @@ class VerticalConv(Block):
""" """
return F.relu_(self.bn(self.conv(x))) return F.relu_(self.bn(self.conv(x)))
def get_output_channels(self) -> int:
return self.conv.out_channels
@block_registry.register(VerticalConvConfig) @block_registry.register(VerticalConvConfig)
@staticmethod @staticmethod
def from_config( def from_config(
@ -438,6 +432,8 @@ class FreqCoordConvDownBlock(Block):
pad_size: int = 1, pad_size: int = 1,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.coords = nn.Parameter( self.coords = nn.Parameter(
torch.linspace(-1, 1, input_height)[None, None, ..., None], torch.linspace(-1, 1, input_height)[None, None, ..., None],
@ -472,9 +468,6 @@ class FreqCoordConvDownBlock(Block):
x = F.relu(self.batch_norm(x), inplace=True) x = F.relu(self.batch_norm(x), inplace=True)
return x return x
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int: def get_output_height(self, input_height: int) -> int:
return input_height // 2 return input_height // 2
@ -538,6 +531,8 @@ class StandardConvDownBlock(Block):
pad_size: int = 1, pad_size: int = 1,
): ):
super(StandardConvDownBlock, self).__init__() super(StandardConvDownBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels, in_channels,
out_channels, out_channels,
@ -563,9 +558,6 @@ class StandardConvDownBlock(Block):
x = F.max_pool2d(self.conv(x), 2, 2) x = F.max_pool2d(self.conv(x), 2, 2)
return F.relu(self.batch_norm(x), inplace=True) return F.relu(self.batch_norm(x), inplace=True)
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int: def get_output_height(self, input_height: int) -> int:
return input_height // 2 return input_height // 2
@ -652,6 +644,8 @@ class FreqCoordConvUpBlock(Block):
up_scale: Tuple[int, int] = (2, 2), up_scale: Tuple[int, int] = (2, 2),
): ):
super().__init__() super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up_scale = up_scale self.up_scale = up_scale
self.up_mode = up_mode self.up_mode = up_mode
@ -698,9 +692,6 @@ class FreqCoordConvUpBlock(Block):
op = F.relu(self.batch_norm(op), inplace=True) op = F.relu(self.batch_norm(op), inplace=True)
return op return op
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int: def get_output_height(self, input_height: int) -> int:
return input_height * 2 return input_height * 2
@ -779,6 +770,8 @@ class StandardConvUpBlock(Block):
up_scale: Tuple[int, int] = (2, 2), up_scale: Tuple[int, int] = (2, 2),
): ):
super(StandardConvUpBlock, self).__init__() super(StandardConvUpBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up_scale = up_scale self.up_scale = up_scale
self.up_mode = up_mode self.up_mode = up_mode
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
@ -815,9 +808,6 @@ class StandardConvUpBlock(Block):
op = F.relu(self.batch_norm(op), inplace=True) op = F.relu(self.batch_norm(op), inplace=True)
return op return op
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int: def get_output_height(self, input_height: int) -> int:
return input_height * 2 return input_height * 2
@ -868,14 +858,15 @@ class LayerGroup(nn.Module):
input_channels: int, input_channels: int,
): ):
super().__init__() super().__init__()
self.in_channels = input_channels
self.out_channels = (
layers[-1].out_channels if layers else input_channels
)
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x) return self.layers(x)
def get_output_channels(self) -> int:
return self.layers[-1].get_output_channels() # type: ignore
def get_output_height(self, input_height: int) -> int: def get_output_height(self, input_height: int) -> int:
for block in self.layers: for block in self.layers:
input_height = block.get_output_height(input_height) # type: ignore input_height = block.get_output_height(input_height) # type: ignore
@ -898,7 +889,7 @@ class LayerGroup(nn.Module):
) )
layers.append(layer) layers.append(layer)
input_height = layer.get_output_height(input_height) input_height = layer.get_output_height(input_height)
input_channels = layer.get_output_channels() input_channels = layer.out_channels
return LayerGroup( return LayerGroup(
layers=layers, layers=layers,

View File

@ -28,6 +28,8 @@ from batdetect2.models.blocks import (
build_layer, build_layer,
) )
from batdetect2.typing.models import BottleneckProtocol
__all__ = [ __all__ = [
"BottleneckConfig", "BottleneckConfig",
"Bottleneck", "Bottleneck",
@ -127,9 +129,6 @@ class Bottleneck(Block):
return x.repeat([1, 1, self.input_height, 1]) return x.repeat([1, 1, self.input_height, 1])
def get_output_channels(self) -> int:
return self.layers[-1].get_output_channels() # type: ignore
BottleneckLayerConfig = Annotated[ BottleneckLayerConfig = Annotated[
SelfAttentionConfig, SelfAttentionConfig,
@ -177,7 +176,7 @@ def build_bottleneck(
input_height: int, input_height: int,
in_channels: int, in_channels: int,
config: BottleneckConfig | None = None, config: BottleneckConfig | None = None,
) -> Block: ) -> BottleneckProtocol:
"""Factory function to build the Bottleneck module from configuration. """Factory function to build the Bottleneck module from configuration.
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
@ -217,7 +216,7 @@ def build_bottleneck(
config=layer_config, config=layer_config,
) )
current_height = layer.get_output_height(current_height) current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels() current_channels = layer.out_channels
assert current_height == input_height, ( assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height" "Bottleneck layers should not change the spectrogram height"
) )

View File

@ -265,7 +265,7 @@ def build_decoder(
config=layer_config, config=layer_config,
) )
current_height = layer.get_output_height(current_height) current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels() current_channels = layer.out_channels
layers.append(layer) layers.append(layer)
return Decoder( return Decoder(

View File

@ -307,7 +307,7 @@ def build_encoder(
) )
layers.append(layer) layers.append(layer)
current_height = layer.get_output_height(current_height) current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels() current_channels = layer.out_channels
return Encoder( return Encoder(
input_height=input_height, input_height=input_height,

View File

@ -16,17 +16,80 @@ Key components:
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import NamedTuple from typing import List, NamedTuple, Protocol
import torch import torch
__all__ = [ __all__ = [
"ModelOutput", "ModelOutput",
"BackboneModel", "BackboneModel",
"EncoderDecoderModel",
"DetectionModel", "DetectionModel",
"BlockProtocol",
"EncoderProtocol",
"BottleneckProtocol",
"DecoderProtocol",
] ]
class BlockProtocol(Protocol):
"""Interface for blocks of network layers."""
in_channels: int
out_channels: int
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the block."""
...
def get_output_height(self, input_height: int) -> int:
"""Calculate the output height based on input height."""
...
class EncoderProtocol(Protocol):
"""Interface for the downsampling path of a network."""
in_channels: int
out_channels: int
input_height: int
output_height: int
def __call__(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Forward pass must return intermediate tensors for skip connections."""
...
class BottleneckProtocol(Protocol):
"""Interface for the middle part of a U-Net-like network."""
in_channels: int
out_channels: int
input_height: int
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Processes the features from the encoder."""
...
class DecoderProtocol(Protocol):
"""Interface for the upsampling reconstruction path."""
in_channels: int
out_channels: int
input_height: int
output_height: int
depth: int
def __call__(
self,
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
"""Upsamples features while integrating skip connections."""
...
class ModelOutput(NamedTuple): class ModelOutput(NamedTuple):
"""Standard container for the outputs of a BatDetect2 detection model. """Standard container for the outputs of a BatDetect2 detection model.

View File

@ -53,7 +53,7 @@ def test_standard_block_protocol_methods(
block = block_class(in_channels=in_channels, out_channels=out_channels) block = block_class(in_channels=in_channels, out_channels=out_channels)
assert block.get_output_channels() == out_channels assert block.out_channels == out_channels
assert block.get_output_height(input_height) == int( assert block.get_output_height(input_height) == int(
input_height * expected_h_scale input_height * expected_h_scale
) )
@ -80,7 +80,7 @@ def test_coord_block_protocol_methods(
input_height=input_height, input_height=input_height,
) )
assert block.get_output_channels() == out_channels assert block.out_channels == out_channels
assert block.get_output_height(input_height) == int( assert block.get_output_height(input_height) == int(
input_height * expected_h_scale input_height * expected_h_scale
) )
@ -96,7 +96,7 @@ def test_vertical_conv_forward_shape(dummy_input):
output = block(dummy_input) output = block(dummy_input)
assert output.shape == (2, out_channels, 1, 32) assert output.shape == (2, out_channels, 1, 32)
assert block.get_output_channels() == out_channels assert block.out_channels == out_channels
def test_self_attention_forward_shape(dummy_bottleneck_input): def test_self_attention_forward_shape(dummy_bottleneck_input):
@ -110,7 +110,7 @@ def test_self_attention_forward_shape(dummy_bottleneck_input):
output = block(dummy_bottleneck_input) output = block(dummy_bottleneck_input)
assert output.shape == dummy_bottleneck_input.shape assert output.shape == dummy_bottleneck_input.shape
assert block.get_output_channels() == in_channels assert block.out_channels == in_channels
def test_self_attention_weights(dummy_bottleneck_input): def test_self_attention_weights(dummy_bottleneck_input):
@ -179,7 +179,7 @@ def test_layer_group_from_config_and_forward(dummy_input):
assert len(layer_group.layers) == 2 assert len(layer_group.layers) == 2
# The group should report the output channels of the LAST block # The group should report the output channels of the LAST block
assert layer_group.get_output_channels() == 64 assert layer_group.out_channels == 64
# The group should report the accumulated height changes # The group should report the accumulated height changes
assert layer_group.get_output_height(input_height) == input_height // 2 assert layer_group.get_output_height(input_height) == input_height // 2