Improved docstrings for blocks

This commit is contained in:
mbsantiago 2025-04-21 15:28:26 +01:00
parent e00674f628
commit 6c744eaac5

View File

@ -1,42 +1,93 @@
"""Module containing custom NN blocks.
"""Commonly used neural network building blocks for BatDetect2 models.
All these classes are subclasses of `torch.nn.Module` and can be used to build
complex neural network architectures.
This module provides various reusable `torch.nn.Module` subclasses that form
the fundamental building blocks for constructing convolutional neural network
architectures, particularly encoder-decoder backbones used in BatDetect2.
It includes standard components like basic convolutional blocks (`ConvBlock`),
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
upsampling (`StandardConvUpBlock`).
Additionally, it features specialized layers investigated in BatDetect2
research:
- `SelfAttention`: Applies self-attention along the time dimension, enabling
the model to weigh information across the entire temporal context, often
used in the bottleneck of an encoder-decoder.
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
concept by concatenating normalized frequency coordinate information as an
extra channel to the input of convolutional layers. This explicitly provides
spatial frequency information to filters, potentially enabling them to learn
frequency-dependent patterns more effectively.
These blocks can be utilized directly in custom PyTorch model definitions or
assembled into larger architectures.
"""
import sys
from typing import Iterable, List, Literal, Sequence, Tuple
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import nn
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y
__all__ = [
"ConvBlock",
"ConvBlockDownCoordF",
"ConvBlockDownStandard",
"ConvBlockUpF",
"ConvBlockUpStandard",
"SelfAttention",
"VerticalConv",
"DownscalingLayer",
"UpscalingLayer",
"FreqCoordConvDownBlock",
"StandardConvDownBlock",
"FreqCoordConvUpBlock",
"StandardConvUpBlock",
"SelfAttention",
]
class SelfAttention(nn.Module):
"""Self-Attention module.
"""Self-Attention mechanism operating along the time dimension.
This module implements self-attention mechanism.
This module implements a scaled dot-product self-attention mechanism,
specifically designed here to operate across the time steps of an input
feature map, typically after spatial dimensions (like frequency) have been
condensed or squeezed.
By calculating attention weights between all pairs of time steps, it allows
the model to capture long-range temporal dependencies and focus on relevant
parts of the sequence. It's often employed in the bottleneck or
intermediate layers of an encoder-decoder architecture to integrate global
temporal context.
The implementation uses linear projections to create query, key, and value
representations, computes scaled dot-product attention scores, applies
softmax, and produces an output by weighting the values according to the
attention scores, followed by a final linear projection. Positional encoding
is not explicitly included in this block.
Parameters
----------
in_channels : int
Number of input channels (features per time step after spatial squeeze).
attention_channels : int
Number of channels for the query, key, and value projections. Also the
dimension of the output projection's input.
temperature : float, default=1.0
Scaling factor applied *before* the final projection layer. Can be used
to adjust the sharpness or focus of the attention mechanism, although
scaling within the softmax (dividing by sqrt(dim)) is more common for
standard transformers. Here it scales the weighted values.
Attributes
----------
key_fun : nn.Linear
Linear layer for key projection.
value_fun : nn.Linear
Linear layer for value projection.
query_fun : nn.Linear
Linear layer for query projection.
pro_fun : nn.Linear
Final linear projection layer applied after attention weighting.
temperature : float
Scaling factor applied before final projection.
att_dim : int
Dimensionality of the attention space (`attention_channels`).
"""
def __init__(
@ -56,6 +107,27 @@ class SelfAttention(nn.Module):
self.pro_fun = nn.Linear(attention_channels, in_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply self-attention along the time dimension.
Parameters
----------
x : torch.Tensor
Input tensor, expected shape `(B, C, H, W)`, where H is typically
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
applying attention along the W (time) dimension.
Returns
-------
torch.Tensor
Output tensor of the same shape as the input `(B, C, H, W)`, where
attention has been applied across the W dimension.
Raises
------
RuntimeError
If input tensor dimensions are incompatible with operations.
"""
x = x.squeeze(2).permute(0, 2, 1)
key = torch.matmul(
@ -83,6 +155,26 @@ class SelfAttention(nn.Module):
class ConvBlock(nn.Module):
"""Basic Convolutional Block.
A standard building block consisting of a 2D convolution, followed by
batch normalization and a ReLU activation function.
Sequence: Conv2d -> BatchNorm2d -> ReLU.
Parameters
----------
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of channels produced by the convolution.
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Amount of padding added to preserve spatial dimensions (assuming
stride=1 and kernel_size=3).
"""
def __init__(
self,
in_channels: int,
@ -100,27 +192,41 @@ class ConvBlock(nn.Module):
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Conv -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H, W)`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H, W)`.
"""
return F.relu_(self.conv_bn(self.conv(x)))
class VerticalConv(nn.Module):
"""Convolutional layer over full height.
"""Convolutional layer that aggregates features across the entire height.
This layer applies a convolution that captures information across the
entire height of the input image. It uses a kernel with the same height as
the input, effectively condensing the vertical information into a single
output row.
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
This collapses the height dimension (H) to 1 while preserving the width (W),
effectively summarizing features across the full vertical extent (e.g.,
frequency axis) at each time step. Followed by BatchNorm and ReLU.
More specifically:
Useful for summarizing frequency information before applying operations
along the time axis (like SelfAttention).
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
input channels, H is the image height, and W is the image width.
* **Kernel:** (C', H, 1) where C' is the number of output channels.
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
convolution integrates information from all rows of the input.
This process effectively extracts features that span the full height of
the input image.
Parameters
----------
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of channels produced by the convolution.
input_height : int
The height (H dimension) of the input tensor. The convolutional kernel
will be sized `(input_height, 1)`.
"""
def __init__(
@ -139,14 +245,53 @@ class VerticalConv(nn.Module):
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Vertical Conv -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H, W)`, where H must match the
`input_height` provided during initialization.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, 1, W)`.
"""
return F.relu_(self.bn(self.conv(x)))
class ConvBlockDownCoordF(nn.Module):
"""Convolutional Block with Downsampling and Coord Feature.
class FreqCoordConvDownBlock(nn.Module):
"""Downsampling Conv Block incorporating Frequency Coordinate features.
This block performs convolution followed by downsampling
and concatenates with coordinate information.
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
used in CNN encoders. Before the convolution, it concatenates an extra
channel representing the normalized vertical coordinate (frequency) to the
input tensor.
The purpose of adding coordinate features is to potentially help the
convolutional filters become spatially aware, allowing them to learn
patterns that might depend on the relative frequency position within the
spectrogram.
Sequence: Concat Coords -> Conv -> MaxPool -> BatchNorm -> ReLU.
Parameters
----------
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of output channels after the convolution.
input_height : int
Height (H dimension, frequency bins) of the input tensor to this block.
Used to generate the coordinate features.
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Padding added before convolution.
stride : int, default=1
Stride of the convolution. Note: Downsampling is achieved via
MaxPool2d(2, 2).
"""
def __init__(
@ -174,6 +319,19 @@ class ConvBlockDownCoordF(nn.Module):
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H, W)`, where H must match
`input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H/2, W/2)` (due to MaxPool).
"""
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
x = torch.cat((x, freq_info), 1)
x = F.max_pool2d(self.conv(x), 2, 2)
@ -181,10 +339,26 @@ class ConvBlockDownCoordF(nn.Module):
return x
class ConvBlockDownStandard(nn.Module):
"""Convolutional Block with Downsampling.
class StandardConvDownBlock(nn.Module):
"""Standard Downsampling Convolutional Block.
This block performs convolution followed by downsampling.
A basic downsampling block consisting of a 2D convolution, followed by
2x2 max pooling, batch normalization, and ReLU activation.
Sequence: Conv -> MaxPool -> BN -> ReLU.
Parameters
----------
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of output channels after the convolution.
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Padding added before convolution.
stride : int, default=1
Stride of the convolution (downsampling is done by MaxPool).
"""
def __init__(
@ -195,7 +369,7 @@ class ConvBlockDownStandard(nn.Module):
pad_size: int = 1,
stride: int = 1,
):
super(ConvBlockDownStandard, self).__init__()
super(StandardConvDownBlock, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
@ -206,18 +380,55 @@ class ConvBlockDownStandard(nn.Module):
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
"""Apply Conv -> MaxPool -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H, W)`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H/2, W/2)`.
"""
x = F.max_pool2d(self.conv(x), 2, 2)
return F.relu(self.conv_bn(x), inplace=True)
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
class FreqCoordConvUpBlock(nn.Module):
"""Upsampling Conv Block incorporating Frequency Coordinate features.
This block implements an upsampling step followed by a convolution,
commonly used in CNN decoders. Before the convolution, it concatenates an
extra channel representing the normalized vertical coordinate (frequency)
of the *upsampled* feature map.
class ConvBlockUpF(nn.Module):
"""Convolutional Block with Upsampling and Coord Feature.
The goal is to provide spatial awareness (frequency position) to the
filters during the decoding/upsampling process.
This block performs convolution followed by upsampling
and concatenates with coordinate information.
Sequence: Interpolate -> Concat Coords -> Conv -> BatchNorm -> ReLU.
Parameters
----------
in_channels : int
Number of channels in the input tensor (before upsampling).
out_channels : int
Number of output channels after the convolution.
input_height : int
Height (H dimension, frequency bins) of the tensor *before* upsampling.
Used to calculate the height for coordinate feature generation after
upsampling.
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Padding added before convolution.
up_mode : str, default="bilinear"
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
"bicubic").
up_scale : Tuple[int, int], default=(2, 2)
Scaling factor for height and width during upsampling
(typically (2, 2)).
"""
def __init__(
@ -249,6 +460,19 @@ class ConvBlockUpF(nn.Module):
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W_in)`, where H_in should match
`input_height` used during initialization.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`.
"""
op = F.interpolate(
x,
size=(
@ -265,10 +489,29 @@ class ConvBlockUpF(nn.Module):
return op
class ConvBlockUpStandard(nn.Module):
"""Convolutional Block with Upsampling.
class StandardConvUpBlock(nn.Module):
"""Standard Upsampling Convolutional Block.
This block performs convolution followed by upsampling.
A basic upsampling block used in CNN decoders. It first upsamples the input
feature map using interpolation, then applies a 2D convolution, batch
normalization, and ReLU activation. Does not use coordinate features.
Sequence: Interpolate -> Conv -> BN -> ReLU.
Parameters
----------
in_channels : int
Number of channels in the input tensor (before upsampling).
out_channels : int
Number of output channels after the convolution.
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Padding added before convolution.
up_mode : str, default="bilinear"
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
up_scale : Tuple[int, int], default=(2, 2)
Scaling factor for height and width during upsampling.
"""
def __init__(
@ -280,7 +523,7 @@ class ConvBlockUpStandard(nn.Module):
up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2),
):
super(ConvBlockUpStandard, self).__init__()
super(StandardConvUpBlock, self).__init__()
self.up_scale = up_scale
self.up_mode = up_mode
self.conv = nn.Conv2d(
@ -292,6 +535,18 @@ class ConvBlockUpStandard(nn.Module):
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Interpolate -> Conv -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W_in)`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`.
"""
op = F.interpolate(
x,
size=(
@ -304,143 +559,3 @@ class ConvBlockUpStandard(nn.Module):
op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True)
return op
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
def build_downscaling_layer(
in_channels: int,
out_channels: int,
input_height: int,
layer_type: DownscalingLayer,
) -> nn.Module:
if layer_type == "ConvBlockDownStandard":
return ConvBlockDownStandard(
in_channels=in_channels,
out_channels=out_channels,
)
if layer_type == "ConvBlockDownCoordF":
return ConvBlockDownCoordF(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
raise ValueError(
f"Invalid downscaling layer type {layer_type}. "
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
)
class Encoder(nn.Module):
def __init__(
self,
channels: Sequence[int] = (1, 32, 62, 128),
input_height: int = 128,
layer_type: Literal[
"ConvBlockDownStandard", "ConvBlockDownCoordF"
] = "ConvBlockDownStandard",
):
super().__init__()
self.channels = channels
self.input_height = input_height
self.layers = nn.ModuleList(
[
build_downscaling_layer(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height // (2**layer_num),
layer_type=layer_type,
)
for layer_num, (in_channels, out_channels) in enumerate(
pairwise(channels)
)
]
)
self.depth = len(self.layers)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
outputs = []
for layer in self.layers:
x = layer(x)
outputs.append(x)
return outputs
def build_upscaling_layer(
in_channels: int,
out_channels: int,
input_height: int,
layer_type: UpscalingLayer,
) -> nn.Module:
if layer_type == "ConvBlockUpStandard":
return ConvBlockUpStandard(
in_channels=in_channels,
out_channels=out_channels,
)
if layer_type == "ConvBlockUpF":
return ConvBlockUpF(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
raise ValueError(
f"Invalid upscaling layer type {layer_type}. "
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
)
class Decoder(nn.Module):
def __init__(
self,
channels: Sequence[int] = (256, 62, 32, 32),
input_height: int = 128,
layer_type: Literal[
"ConvBlockUpStandard", "ConvBlockUpF"
] = "ConvBlockUpStandard",
):
super().__init__()
self.channels = channels
self.input_height = input_height
self.depth = len(self.channels) - 1
self.layers = nn.ModuleList(
[
build_upscaling_layer(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height
// (2 ** (self.depth - layer_num)),
layer_type=layer_type,
)
for layer_num, (in_channels, out_channels) in enumerate(
pairwise(channels)
)
]
)
def forward(
self,
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
if len(residuals) != len(self.layers):
raise ValueError(
f"Incorrect number of residuals provided. "
f"Expected {len(self.layers)} (matching the number of layers), "
f"but got {len(residuals)}."
)
for layer, res in zip(self.layers, residuals[::-1]):
x = layer(x + res)
return x