Add interfaces for encoder/decoder/bottleneck

This commit is contained in:
mbsantiago 2026-03-08 14:43:16 +00:00
parent 54605ef269
commit e393709258
7 changed files with 115 additions and 57 deletions

View File

@ -26,9 +26,14 @@ from torch import nn
from batdetect2.models.bottleneck import build_bottleneck
from batdetect2.models.config import BackboneConfig
from batdetect2.models.decoder import Decoder, build_decoder
from batdetect2.models.encoder import Encoder, build_encoder
from batdetect2.typing.models import BackboneModel
from batdetect2.models.decoder import build_decoder
from batdetect2.models.encoder import build_encoder
from batdetect2.typing.models import (
BackboneModel,
BottleneckProtocol,
DecoderProtocol,
EncoderProtocol,
)
__all__ = [
"Backbone",
@ -55,11 +60,11 @@ class Backbone(BackboneModel):
Expected height of the input spectrogram.
out_channels : int
Number of channels in the final output feature map.
encoder : Encoder
encoder : EncoderProtocol
The instantiated encoder module.
decoder : Decoder
decoder : DecoderProtocol
The instantiated decoder module.
bottleneck : nn.Module
bottleneck : BottleneckProtocol
The instantiated bottleneck module.
final_conv : ConvBlock
Final convolutional block applied after the decoder.
@ -71,9 +76,9 @@ class Backbone(BackboneModel):
def __init__(
self,
input_height: int,
encoder: Encoder,
decoder: Decoder,
bottleneck: nn.Module,
encoder: EncoderProtocol,
decoder: DecoderProtocol,
bottleneck: BottleneckProtocol,
):
"""Initialize the Backbone network.
@ -83,11 +88,11 @@ class Backbone(BackboneModel):
Expected height of the input spectrogram.
out_channels : int
Desired number of output channels for the backbone's feature map.
encoder : Encoder
encoder : EncoderProtocol
An initialized Encoder module.
decoder : Decoder
decoder : DecoderProtocol
An initialized Decoder module.
bottleneck : nn.Module
bottleneck : BottleneckProtocol
An initialized Bottleneck module.
Raises
@ -185,7 +190,7 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
)
decoder = build_decoder(
in_channels=bottleneck.get_output_channels(),
in_channels=bottleneck.out_channels,
input_height=encoder.output_height,
config=config.decoder,
)

View File

@ -56,17 +56,14 @@ __all__ = [
]
class BlockProtocol(Protocol):
def get_output_channels(self) -> int:
raise NotImplementedError
class Block(nn.Module):
in_channels: int
out_channels: int
def get_output_height(self, input_height: int) -> int:
return input_height
class Block(nn.Module, BlockProtocol): ...
block_registry: Registry[Block, [int, int]] = Registry("block")
@ -132,6 +129,8 @@ class SelfAttention(Block):
temperature: float = 1.0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
# Note, does not encode position information (absolute or relative)
self.temperature = temperature
@ -206,9 +205,6 @@ class SelfAttention(Block):
att_weights = F.softmax(kk_qq, 1)
return att_weights
def get_output_channels(self) -> int:
return self.output_channels
@block_registry.register(SelfAttentionConfig)
@staticmethod
def from_config(
@ -267,6 +263,8 @@ class ConvBlock(Block):
pad_size: int = 1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(
in_channels,
out_channels,
@ -290,9 +288,6 @@ class ConvBlock(Block):
"""
return F.relu_(self.batch_norm(self.conv(x)))
def get_output_channels(self) -> int:
return self.conv.out_channels
@block_registry.register(ConvConfig)
@staticmethod
def from_config(
@ -342,6 +337,8 @@ class VerticalConv(Block):
input_height: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
@ -366,9 +363,6 @@ class VerticalConv(Block):
"""
return F.relu_(self.bn(self.conv(x)))
def get_output_channels(self) -> int:
return self.conv.out_channels
@block_registry.register(VerticalConvConfig)
@staticmethod
def from_config(
@ -438,6 +432,8 @@ class FreqCoordConvDownBlock(Block):
pad_size: int = 1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.coords = nn.Parameter(
torch.linspace(-1, 1, input_height)[None, None, ..., None],
@ -472,9 +468,6 @@ class FreqCoordConvDownBlock(Block):
x = F.relu(self.batch_norm(x), inplace=True)
return x
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height // 2
@ -538,6 +531,8 @@ class StandardConvDownBlock(Block):
pad_size: int = 1,
):
super(StandardConvDownBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(
in_channels,
out_channels,
@ -563,9 +558,6 @@ class StandardConvDownBlock(Block):
x = F.max_pool2d(self.conv(x), 2, 2)
return F.relu(self.batch_norm(x), inplace=True)
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height // 2
@ -652,6 +644,8 @@ class FreqCoordConvUpBlock(Block):
up_scale: Tuple[int, int] = (2, 2),
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up_scale = up_scale
self.up_mode = up_mode
@ -698,9 +692,6 @@ class FreqCoordConvUpBlock(Block):
op = F.relu(self.batch_norm(op), inplace=True)
return op
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height * 2
@ -779,6 +770,8 @@ class StandardConvUpBlock(Block):
up_scale: Tuple[int, int] = (2, 2),
):
super(StandardConvUpBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up_scale = up_scale
self.up_mode = up_mode
self.conv = nn.Conv2d(
@ -815,9 +808,6 @@ class StandardConvUpBlock(Block):
op = F.relu(self.batch_norm(op), inplace=True)
return op
def get_output_channels(self) -> int:
return self.conv.out_channels
def get_output_height(self, input_height: int) -> int:
return input_height * 2
@ -868,14 +858,15 @@ class LayerGroup(nn.Module):
input_channels: int,
):
super().__init__()
self.in_channels = input_channels
self.out_channels = (
layers[-1].out_channels if layers else input_channels
)
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
def get_output_channels(self) -> int:
return self.layers[-1].get_output_channels() # type: ignore
def get_output_height(self, input_height: int) -> int:
for block in self.layers:
input_height = block.get_output_height(input_height) # type: ignore
@ -898,7 +889,7 @@ class LayerGroup(nn.Module):
)
layers.append(layer)
input_height = layer.get_output_height(input_height)
input_channels = layer.get_output_channels()
input_channels = layer.out_channels
return LayerGroup(
layers=layers,

View File

@ -28,6 +28,8 @@ from batdetect2.models.blocks import (
build_layer,
)
from batdetect2.typing.models import BottleneckProtocol
__all__ = [
"BottleneckConfig",
"Bottleneck",
@ -127,9 +129,6 @@ class Bottleneck(Block):
return x.repeat([1, 1, self.input_height, 1])
def get_output_channels(self) -> int:
return self.layers[-1].get_output_channels() # type: ignore
BottleneckLayerConfig = Annotated[
SelfAttentionConfig,
@ -177,7 +176,7 @@ def build_bottleneck(
input_height: int,
in_channels: int,
config: BottleneckConfig | None = None,
) -> Block:
) -> BottleneckProtocol:
"""Factory function to build the Bottleneck module from configuration.
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
@ -217,7 +216,7 @@ def build_bottleneck(
config=layer_config,
)
current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels()
current_channels = layer.out_channels
assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height"
)

View File

@ -265,7 +265,7 @@ def build_decoder(
config=layer_config,
)
current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels()
current_channels = layer.out_channels
layers.append(layer)
return Decoder(

View File

@ -307,7 +307,7 @@ def build_encoder(
)
layers.append(layer)
current_height = layer.get_output_height(current_height)
current_channels = layer.get_output_channels()
current_channels = layer.out_channels
return Encoder(
input_height=input_height,

View File

@ -16,17 +16,80 @@ Key components:
"""
from abc import ABC, abstractmethod
from typing import NamedTuple
from typing import List, NamedTuple, Protocol
import torch
__all__ = [
"ModelOutput",
"BackboneModel",
"EncoderDecoderModel",
"DetectionModel",
"BlockProtocol",
"EncoderProtocol",
"BottleneckProtocol",
"DecoderProtocol",
]
class BlockProtocol(Protocol):
"""Interface for blocks of network layers."""
in_channels: int
out_channels: int
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the block."""
...
def get_output_height(self, input_height: int) -> int:
"""Calculate the output height based on input height."""
...
class EncoderProtocol(Protocol):
"""Interface for the downsampling path of a network."""
in_channels: int
out_channels: int
input_height: int
output_height: int
def __call__(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Forward pass must return intermediate tensors for skip connections."""
...
class BottleneckProtocol(Protocol):
"""Interface for the middle part of a U-Net-like network."""
in_channels: int
out_channels: int
input_height: int
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Processes the features from the encoder."""
...
class DecoderProtocol(Protocol):
"""Interface for the upsampling reconstruction path."""
in_channels: int
out_channels: int
input_height: int
output_height: int
depth: int
def __call__(
self,
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
"""Upsamples features while integrating skip connections."""
...
class ModelOutput(NamedTuple):
"""Standard container for the outputs of a BatDetect2 detection model.

View File

@ -53,7 +53,7 @@ def test_standard_block_protocol_methods(
block = block_class(in_channels=in_channels, out_channels=out_channels)
assert block.get_output_channels() == out_channels
assert block.out_channels == out_channels
assert block.get_output_height(input_height) == int(
input_height * expected_h_scale
)
@ -80,7 +80,7 @@ def test_coord_block_protocol_methods(
input_height=input_height,
)
assert block.get_output_channels() == out_channels
assert block.out_channels == out_channels
assert block.get_output_height(input_height) == int(
input_height * expected_h_scale
)
@ -96,7 +96,7 @@ def test_vertical_conv_forward_shape(dummy_input):
output = block(dummy_input)
assert output.shape == (2, out_channels, 1, 32)
assert block.get_output_channels() == out_channels
assert block.out_channels == out_channels
def test_self_attention_forward_shape(dummy_bottleneck_input):
@ -110,7 +110,7 @@ def test_self_attention_forward_shape(dummy_bottleneck_input):
output = block(dummy_bottleneck_input)
assert output.shape == dummy_bottleneck_input.shape
assert block.get_output_channels() == in_channels
assert block.out_channels == in_channels
def test_self_attention_weights(dummy_bottleneck_input):
@ -179,7 +179,7 @@ def test_layer_group_from_config_and_forward(dummy_input):
assert len(layer_group.layers) == 2
# The group should report the output channels of the LAST block
assert layer_group.get_output_channels() == 64
assert layer_group.out_channels == 64
# The group should report the accumulated height changes
assert layer_group.get_output_height(input_height) == input_height // 2