mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add interfaces for encoder/decoder/bottleneck
This commit is contained in:
parent
54605ef269
commit
e393709258
@ -26,9 +26,14 @@ from torch import nn
|
||||
|
||||
from batdetect2.models.bottleneck import build_bottleneck
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.models.decoder import Decoder, build_decoder
|
||||
from batdetect2.models.encoder import Encoder, build_encoder
|
||||
from batdetect2.typing.models import BackboneModel
|
||||
from batdetect2.models.decoder import build_decoder
|
||||
from batdetect2.models.encoder import build_encoder
|
||||
from batdetect2.typing.models import (
|
||||
BackboneModel,
|
||||
BottleneckProtocol,
|
||||
DecoderProtocol,
|
||||
EncoderProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Backbone",
|
||||
@ -55,11 +60,11 @@ class Backbone(BackboneModel):
|
||||
Expected height of the input spectrogram.
|
||||
out_channels : int
|
||||
Number of channels in the final output feature map.
|
||||
encoder : Encoder
|
||||
encoder : EncoderProtocol
|
||||
The instantiated encoder module.
|
||||
decoder : Decoder
|
||||
decoder : DecoderProtocol
|
||||
The instantiated decoder module.
|
||||
bottleneck : nn.Module
|
||||
bottleneck : BottleneckProtocol
|
||||
The instantiated bottleneck module.
|
||||
final_conv : ConvBlock
|
||||
Final convolutional block applied after the decoder.
|
||||
@ -71,9 +76,9 @@ class Backbone(BackboneModel):
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int,
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
bottleneck: nn.Module,
|
||||
encoder: EncoderProtocol,
|
||||
decoder: DecoderProtocol,
|
||||
bottleneck: BottleneckProtocol,
|
||||
):
|
||||
"""Initialize the Backbone network.
|
||||
|
||||
@ -83,11 +88,11 @@ class Backbone(BackboneModel):
|
||||
Expected height of the input spectrogram.
|
||||
out_channels : int
|
||||
Desired number of output channels for the backbone's feature map.
|
||||
encoder : Encoder
|
||||
encoder : EncoderProtocol
|
||||
An initialized Encoder module.
|
||||
decoder : Decoder
|
||||
decoder : DecoderProtocol
|
||||
An initialized Decoder module.
|
||||
bottleneck : nn.Module
|
||||
bottleneck : BottleneckProtocol
|
||||
An initialized Bottleneck module.
|
||||
|
||||
Raises
|
||||
@ -185,7 +190,7 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||
)
|
||||
|
||||
decoder = build_decoder(
|
||||
in_channels=bottleneck.get_output_channels(),
|
||||
in_channels=bottleneck.out_channels,
|
||||
input_height=encoder.output_height,
|
||||
config=config.decoder,
|
||||
)
|
||||
|
||||
@ -56,17 +56,14 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class BlockProtocol(Protocol):
|
||||
def get_output_channels(self) -> int:
|
||||
raise NotImplementedError
|
||||
class Block(nn.Module):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height
|
||||
|
||||
|
||||
class Block(nn.Module, BlockProtocol): ...
|
||||
|
||||
|
||||
block_registry: Registry[Block, [int, int]] = Registry("block")
|
||||
|
||||
|
||||
@ -132,6 +129,8 @@ class SelfAttention(Block):
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
|
||||
# Note, does not encode position information (absolute or relative)
|
||||
self.temperature = temperature
|
||||
@ -206,9 +205,6 @@ class SelfAttention(Block):
|
||||
att_weights = F.softmax(kk_qq, 1)
|
||||
return att_weights
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.output_channels
|
||||
|
||||
@block_registry.register(SelfAttentionConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
@ -267,6 +263,8 @@ class ConvBlock(Block):
|
||||
pad_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
@ -290,9 +288,6 @@ class ConvBlock(Block):
|
||||
"""
|
||||
return F.relu_(self.batch_norm(self.conv(x)))
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
@block_registry.register(ConvConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
@ -342,6 +337,8 @@ class VerticalConv(Block):
|
||||
input_height: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@ -366,9 +363,6 @@ class VerticalConv(Block):
|
||||
"""
|
||||
return F.relu_(self.bn(self.conv(x)))
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
@block_registry.register(VerticalConvConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
@ -438,6 +432,8 @@ class FreqCoordConvDownBlock(Block):
|
||||
pad_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.coords = nn.Parameter(
|
||||
torch.linspace(-1, 1, input_height)[None, None, ..., None],
|
||||
@ -472,9 +468,6 @@ class FreqCoordConvDownBlock(Block):
|
||||
x = F.relu(self.batch_norm(x), inplace=True)
|
||||
return x
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height // 2
|
||||
|
||||
@ -538,6 +531,8 @@ class StandardConvDownBlock(Block):
|
||||
pad_size: int = 1,
|
||||
):
|
||||
super(StandardConvDownBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
@ -563,9 +558,6 @@ class StandardConvDownBlock(Block):
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
return F.relu(self.batch_norm(x), inplace=True)
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height // 2
|
||||
|
||||
@ -652,6 +644,8 @@ class FreqCoordConvUpBlock(Block):
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
@ -698,9 +692,6 @@ class FreqCoordConvUpBlock(Block):
|
||||
op = F.relu(self.batch_norm(op), inplace=True)
|
||||
return op
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height * 2
|
||||
|
||||
@ -779,6 +770,8 @@ class StandardConvUpBlock(Block):
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
):
|
||||
super(StandardConvUpBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.conv = nn.Conv2d(
|
||||
@ -815,9 +808,6 @@ class StandardConvUpBlock(Block):
|
||||
op = F.relu(self.batch_norm(op), inplace=True)
|
||||
return op
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.conv.out_channels
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height * 2
|
||||
|
||||
@ -868,14 +858,15 @@ class LayerGroup(nn.Module):
|
||||
input_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = input_channels
|
||||
self.out_channels = (
|
||||
layers[-1].out_channels if layers else input_channels
|
||||
)
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layers(x)
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.layers[-1].get_output_channels() # type: ignore
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
for block in self.layers:
|
||||
input_height = block.get_output_height(input_height) # type: ignore
|
||||
@ -898,7 +889,7 @@ class LayerGroup(nn.Module):
|
||||
)
|
||||
layers.append(layer)
|
||||
input_height = layer.get_output_height(input_height)
|
||||
input_channels = layer.get_output_channels()
|
||||
input_channels = layer.out_channels
|
||||
|
||||
return LayerGroup(
|
||||
layers=layers,
|
||||
|
||||
@ -28,6 +28,8 @@ from batdetect2.models.blocks import (
|
||||
build_layer,
|
||||
)
|
||||
|
||||
from batdetect2.typing.models import BottleneckProtocol
|
||||
|
||||
__all__ = [
|
||||
"BottleneckConfig",
|
||||
"Bottleneck",
|
||||
@ -127,9 +129,6 @@ class Bottleneck(Block):
|
||||
|
||||
return x.repeat([1, 1, self.input_height, 1])
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.layers[-1].get_output_channels() # type: ignore
|
||||
|
||||
|
||||
BottleneckLayerConfig = Annotated[
|
||||
SelfAttentionConfig,
|
||||
@ -177,7 +176,7 @@ def build_bottleneck(
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
config: BottleneckConfig | None = None,
|
||||
) -> Block:
|
||||
) -> BottleneckProtocol:
|
||||
"""Factory function to build the Bottleneck module from configuration.
|
||||
|
||||
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
||||
@ -217,7 +216,7 @@ def build_bottleneck(
|
||||
config=layer_config,
|
||||
)
|
||||
current_height = layer.get_output_height(current_height)
|
||||
current_channels = layer.get_output_channels()
|
||||
current_channels = layer.out_channels
|
||||
assert current_height == input_height, (
|
||||
"Bottleneck layers should not change the spectrogram height"
|
||||
)
|
||||
|
||||
@ -265,7 +265,7 @@ def build_decoder(
|
||||
config=layer_config,
|
||||
)
|
||||
current_height = layer.get_output_height(current_height)
|
||||
current_channels = layer.get_output_channels()
|
||||
current_channels = layer.out_channels
|
||||
layers.append(layer)
|
||||
|
||||
return Decoder(
|
||||
|
||||
@ -307,7 +307,7 @@ def build_encoder(
|
||||
)
|
||||
layers.append(layer)
|
||||
current_height = layer.get_output_height(current_height)
|
||||
current_channels = layer.get_output_channels()
|
||||
current_channels = layer.out_channels
|
||||
|
||||
return Encoder(
|
||||
input_height=input_height,
|
||||
|
||||
@ -16,17 +16,80 @@ Key components:
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple
|
||||
from typing import List, NamedTuple, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"ModelOutput",
|
||||
"BackboneModel",
|
||||
"EncoderDecoderModel",
|
||||
"DetectionModel",
|
||||
"BlockProtocol",
|
||||
"EncoderProtocol",
|
||||
"BottleneckProtocol",
|
||||
"DecoderProtocol",
|
||||
]
|
||||
|
||||
|
||||
class BlockProtocol(Protocol):
|
||||
"""Interface for blocks of network layers."""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the block."""
|
||||
...
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
"""Calculate the output height based on input height."""
|
||||
...
|
||||
|
||||
|
||||
class EncoderProtocol(Protocol):
|
||||
"""Interface for the downsampling path of a network."""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
output_height: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Forward pass must return intermediate tensors for skip connections."""
|
||||
...
|
||||
|
||||
|
||||
class BottleneckProtocol(Protocol):
|
||||
"""Interface for the middle part of a U-Net-like network."""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes the features from the encoder."""
|
||||
...
|
||||
|
||||
|
||||
class DecoderProtocol(Protocol):
|
||||
"""Interface for the upsampling reconstruction path."""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
output_height: int
|
||||
depth: int
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residuals: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Upsamples features while integrating skip connections."""
|
||||
...
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Standard container for the outputs of a BatDetect2 detection model.
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ def test_standard_block_protocol_methods(
|
||||
|
||||
block = block_class(in_channels=in_channels, out_channels=out_channels)
|
||||
|
||||
assert block.get_output_channels() == out_channels
|
||||
assert block.out_channels == out_channels
|
||||
assert block.get_output_height(input_height) == int(
|
||||
input_height * expected_h_scale
|
||||
)
|
||||
@ -80,7 +80,7 @@ def test_coord_block_protocol_methods(
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
assert block.get_output_channels() == out_channels
|
||||
assert block.out_channels == out_channels
|
||||
assert block.get_output_height(input_height) == int(
|
||||
input_height * expected_h_scale
|
||||
)
|
||||
@ -96,7 +96,7 @@ def test_vertical_conv_forward_shape(dummy_input):
|
||||
output = block(dummy_input)
|
||||
|
||||
assert output.shape == (2, out_channels, 1, 32)
|
||||
assert block.get_output_channels() == out_channels
|
||||
assert block.out_channels == out_channels
|
||||
|
||||
|
||||
def test_self_attention_forward_shape(dummy_bottleneck_input):
|
||||
@ -110,7 +110,7 @@ def test_self_attention_forward_shape(dummy_bottleneck_input):
|
||||
output = block(dummy_bottleneck_input)
|
||||
|
||||
assert output.shape == dummy_bottleneck_input.shape
|
||||
assert block.get_output_channels() == in_channels
|
||||
assert block.out_channels == in_channels
|
||||
|
||||
|
||||
def test_self_attention_weights(dummy_bottleneck_input):
|
||||
@ -179,7 +179,7 @@ def test_layer_group_from_config_and_forward(dummy_input):
|
||||
assert len(layer_group.layers) == 2
|
||||
|
||||
# The group should report the output channels of the LAST block
|
||||
assert layer_group.get_output_channels() == 64
|
||||
assert layer_group.out_channels == 64
|
||||
|
||||
# The group should report the accumulated height changes
|
||||
assert layer_group.get_output_height(input_height) == input_height // 2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user