diff --git a/batdetect2/models/blocks.py b/batdetect2/models/blocks.py index 1c8dda0..ac68888 100644 --- a/batdetect2/models/blocks.py +++ b/batdetect2/models/blocks.py @@ -1,42 +1,93 @@ -"""Module containing custom NN blocks. +"""Commonly used neural network building blocks for BatDetect2 models. -All these classes are subclasses of `torch.nn.Module` and can be used to build -complex neural network architectures. +This module provides various reusable `torch.nn.Module` subclasses that form +the fundamental building blocks for constructing convolutional neural network +architectures, particularly encoder-decoder backbones used in BatDetect2. + +It includes standard components like basic convolutional blocks (`ConvBlock`), +blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with +upsampling (`StandardConvUpBlock`). + +Additionally, it features specialized layers investigated in BatDetect2 +research: + +- `SelfAttention`: Applies self-attention along the time dimension, enabling + the model to weigh information across the entire temporal context, often + used in the bottleneck of an encoder-decoder. +- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv" + concept by concatenating normalized frequency coordinate information as an + extra channel to the input of convolutional layers. This explicitly provides + spatial frequency information to filters, potentially enabling them to learn + frequency-dependent patterns more effectively. + +These blocks can be utilized directly in custom PyTorch model definitions or +assembled into larger architectures. """ -import sys -from typing import Iterable, List, Literal, Sequence, Tuple +from typing import Tuple import torch import torch.nn.functional as F from torch import nn -if sys.version_info >= (3, 10): - from itertools import pairwise -else: - - def pairwise(iterable: Sequence) -> Iterable: - for x, y in zip(iterable[:-1], iterable[1:]): - yield x, y - - __all__ = [ "ConvBlock", - "ConvBlockDownCoordF", - "ConvBlockDownStandard", - "ConvBlockUpF", - "ConvBlockUpStandard", - "SelfAttention", "VerticalConv", - "DownscalingLayer", - "UpscalingLayer", + "FreqCoordConvDownBlock", + "StandardConvDownBlock", + "FreqCoordConvUpBlock", + "StandardConvUpBlock", + "SelfAttention", ] class SelfAttention(nn.Module): - """Self-Attention module. + """Self-Attention mechanism operating along the time dimension. - This module implements self-attention mechanism. + This module implements a scaled dot-product self-attention mechanism, + specifically designed here to operate across the time steps of an input + feature map, typically after spatial dimensions (like frequency) have been + condensed or squeezed. + + By calculating attention weights between all pairs of time steps, it allows + the model to capture long-range temporal dependencies and focus on relevant + parts of the sequence. It's often employed in the bottleneck or + intermediate layers of an encoder-decoder architecture to integrate global + temporal context. + + The implementation uses linear projections to create query, key, and value + representations, computes scaled dot-product attention scores, applies + softmax, and produces an output by weighting the values according to the + attention scores, followed by a final linear projection. Positional encoding + is not explicitly included in this block. + + Parameters + ---------- + in_channels : int + Number of input channels (features per time step after spatial squeeze). + attention_channels : int + Number of channels for the query, key, and value projections. Also the + dimension of the output projection's input. + temperature : float, default=1.0 + Scaling factor applied *before* the final projection layer. Can be used + to adjust the sharpness or focus of the attention mechanism, although + scaling within the softmax (dividing by sqrt(dim)) is more common for + standard transformers. Here it scales the weighted values. + + Attributes + ---------- + key_fun : nn.Linear + Linear layer for key projection. + value_fun : nn.Linear + Linear layer for value projection. + query_fun : nn.Linear + Linear layer for query projection. + pro_fun : nn.Linear + Final linear projection layer applied after attention weighting. + temperature : float + Scaling factor applied before final projection. + att_dim : int + Dimensionality of the attention space (`attention_channels`). """ def __init__( @@ -56,6 +107,27 @@ class SelfAttention(nn.Module): self.pro_fun = nn.Linear(attention_channels, in_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply self-attention along the time dimension. + + Parameters + ---------- + x : torch.Tensor + Input tensor, expected shape `(B, C, H, W)`, where H is typically + squeezed (e.g., H=1 after a `VerticalConv` or pooling) before + applying attention along the W (time) dimension. + + Returns + ------- + torch.Tensor + Output tensor of the same shape as the input `(B, C, H, W)`, where + attention has been applied across the W dimension. + + Raises + ------ + RuntimeError + If input tensor dimensions are incompatible with operations. + """ + x = x.squeeze(2).permute(0, 2, 1) key = torch.matmul( @@ -83,6 +155,26 @@ class SelfAttention(nn.Module): class ConvBlock(nn.Module): + """Basic Convolutional Block. + + A standard building block consisting of a 2D convolution, followed by + batch normalization and a ReLU activation function. + + Sequence: Conv2d -> BatchNorm2d -> ReLU. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + kernel_size : int, default=3 + Size of the square convolutional kernel. + pad_size : int, default=1 + Amount of padding added to preserve spatial dimensions (assuming + stride=1 and kernel_size=3). + """ + def __init__( self, in_channels: int, @@ -100,27 +192,41 @@ class ConvBlock(nn.Module): self.conv_bn = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Conv -> BN -> ReLU. + + Parameters + ---------- + x : torch.Tensor + Input tensor, shape `(B, C_in, H, W)`. + + Returns + ------- + torch.Tensor + Output tensor, shape `(B, C_out, H, W)`. + """ return F.relu_(self.conv_bn(self.conv(x))) class VerticalConv(nn.Module): - """Convolutional layer over full height. + """Convolutional layer that aggregates features across the entire height. - This layer applies a convolution that captures information across the - entire height of the input image. It uses a kernel with the same height as - the input, effectively condensing the vertical information into a single - output row. + Applies a 2D convolution using a kernel with shape `(input_height, 1)`. + This collapses the height dimension (H) to 1 while preserving the width (W), + effectively summarizing features across the full vertical extent (e.g., + frequency axis) at each time step. Followed by BatchNorm and ReLU. - More specifically: + Useful for summarizing frequency information before applying operations + along the time axis (like SelfAttention). - * **Input:** (B, C, H, W) where B is the batch size, C is the number of - input channels, H is the image height, and W is the image width. - * **Kernel:** (C', H, 1) where C' is the number of output channels. - * **Output:** (B, C', 1, W) - The height dimension is 1 because the - convolution integrates information from all rows of the input. - - This process effectively extracts features that span the full height of - the input image. + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the convolution. + input_height : int + The height (H dimension) of the input tensor. The convolutional kernel + will be sized `(input_height, 1)`. """ def __init__( @@ -139,14 +245,53 @@ class VerticalConv(nn.Module): self.bn = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Vertical Conv -> BN -> ReLU. + + Parameters + ---------- + x : torch.Tensor + Input tensor, shape `(B, C_in, H, W)`, where H must match the + `input_height` provided during initialization. + + Returns + ------- + torch.Tensor + Output tensor, shape `(B, C_out, 1, W)`. + """ return F.relu_(self.bn(self.conv(x))) -class ConvBlockDownCoordF(nn.Module): - """Convolutional Block with Downsampling and Coord Feature. +class FreqCoordConvDownBlock(nn.Module): + """Downsampling Conv Block incorporating Frequency Coordinate features. - This block performs convolution followed by downsampling - and concatenates with coordinate information. + This block implements a downsampling step (Conv2d + MaxPool2d) commonly + used in CNN encoders. Before the convolution, it concatenates an extra + channel representing the normalized vertical coordinate (frequency) to the + input tensor. + + The purpose of adding coordinate features is to potentially help the + convolutional filters become spatially aware, allowing them to learn + patterns that might depend on the relative frequency position within the + spectrogram. + + Sequence: Concat Coords -> Conv -> MaxPool -> BatchNorm -> ReLU. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of output channels after the convolution. + input_height : int + Height (H dimension, frequency bins) of the input tensor to this block. + Used to generate the coordinate features. + kernel_size : int, default=3 + Size of the square convolutional kernel. + pad_size : int, default=1 + Padding added before convolution. + stride : int, default=1 + Stride of the convolution. Note: Downsampling is achieved via + MaxPool2d(2, 2). """ def __init__( @@ -174,6 +319,19 @@ class ConvBlockDownCoordF(nn.Module): self.conv_bn = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply CoordF -> Conv -> MaxPool -> BN -> ReLU. + + Parameters + ---------- + x : torch.Tensor + Input tensor, shape `(B, C_in, H, W)`, where H must match + `input_height`. + + Returns + ------- + torch.Tensor + Output tensor, shape `(B, C_out, H/2, W/2)` (due to MaxPool). + """ freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3]) x = torch.cat((x, freq_info), 1) x = F.max_pool2d(self.conv(x), 2, 2) @@ -181,10 +339,26 @@ class ConvBlockDownCoordF(nn.Module): return x -class ConvBlockDownStandard(nn.Module): - """Convolutional Block with Downsampling. +class StandardConvDownBlock(nn.Module): + """Standard Downsampling Convolutional Block. - This block performs convolution followed by downsampling. + A basic downsampling block consisting of a 2D convolution, followed by + 2x2 max pooling, batch normalization, and ReLU activation. + + Sequence: Conv -> MaxPool -> BN -> ReLU. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of output channels after the convolution. + kernel_size : int, default=3 + Size of the square convolutional kernel. + pad_size : int, default=1 + Padding added before convolution. + stride : int, default=1 + Stride of the convolution (downsampling is done by MaxPool). """ def __init__( @@ -195,7 +369,7 @@ class ConvBlockDownStandard(nn.Module): pad_size: int = 1, stride: int = 1, ): - super(ConvBlockDownStandard, self).__init__() + super(StandardConvDownBlock, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, @@ -206,18 +380,55 @@ class ConvBlockDownStandard(nn.Module): self.conv_bn = nn.BatchNorm2d(out_channels) def forward(self, x): + """Apply Conv -> MaxPool -> BN -> ReLU. + + Parameters + ---------- + x : torch.Tensor + Input tensor, shape `(B, C_in, H, W)`. + + Returns + ------- + torch.Tensor + Output tensor, shape `(B, C_out, H/2, W/2)`. + """ x = F.max_pool2d(self.conv(x), 2, 2) return F.relu(self.conv_bn(x), inplace=True) -DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"] +class FreqCoordConvUpBlock(nn.Module): + """Upsampling Conv Block incorporating Frequency Coordinate features. + This block implements an upsampling step followed by a convolution, + commonly used in CNN decoders. Before the convolution, it concatenates an + extra channel representing the normalized vertical coordinate (frequency) + of the *upsampled* feature map. -class ConvBlockUpF(nn.Module): - """Convolutional Block with Upsampling and Coord Feature. + The goal is to provide spatial awareness (frequency position) to the + filters during the decoding/upsampling process. - This block performs convolution followed by upsampling - and concatenates with coordinate information. + Sequence: Interpolate -> Concat Coords -> Conv -> BatchNorm -> ReLU. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor (before upsampling). + out_channels : int + Number of output channels after the convolution. + input_height : int + Height (H dimension, frequency bins) of the tensor *before* upsampling. + Used to calculate the height for coordinate feature generation after + upsampling. + kernel_size : int, default=3 + Size of the square convolutional kernel. + pad_size : int, default=1 + Padding added before convolution. + up_mode : str, default="bilinear" + Interpolation mode for upsampling (e.g., "nearest", "bilinear", + "bicubic"). + up_scale : Tuple[int, int], default=(2, 2) + Scaling factor for height and width during upsampling + (typically (2, 2)). """ def __init__( @@ -249,6 +460,19 @@ class ConvBlockUpF(nn.Module): self.conv_bn = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU. + + Parameters + ---------- + x : torch.Tensor + Input tensor, shape `(B, C_in, H_in, W_in)`, where H_in should match + `input_height` used during initialization. + + Returns + ------- + torch.Tensor + Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`. + """ op = F.interpolate( x, size=( @@ -265,10 +489,29 @@ class ConvBlockUpF(nn.Module): return op -class ConvBlockUpStandard(nn.Module): - """Convolutional Block with Upsampling. +class StandardConvUpBlock(nn.Module): + """Standard Upsampling Convolutional Block. - This block performs convolution followed by upsampling. + A basic upsampling block used in CNN decoders. It first upsamples the input + feature map using interpolation, then applies a 2D convolution, batch + normalization, and ReLU activation. Does not use coordinate features. + + Sequence: Interpolate -> Conv -> BN -> ReLU. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor (before upsampling). + out_channels : int + Number of output channels after the convolution. + kernel_size : int, default=3 + Size of the square convolutional kernel. + pad_size : int, default=1 + Padding added before convolution. + up_mode : str, default="bilinear" + Interpolation mode for upsampling (e.g., "nearest", "bilinear"). + up_scale : Tuple[int, int], default=(2, 2) + Scaling factor for height and width during upsampling. """ def __init__( @@ -280,7 +523,7 @@ class ConvBlockUpStandard(nn.Module): up_mode: str = "bilinear", up_scale: Tuple[int, int] = (2, 2), ): - super(ConvBlockUpStandard, self).__init__() + super(StandardConvUpBlock, self).__init__() self.up_scale = up_scale self.up_mode = up_mode self.conv = nn.Conv2d( @@ -292,6 +535,18 @@ class ConvBlockUpStandard(nn.Module): self.conv_bn = nn.BatchNorm2d(out_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply Interpolate -> Conv -> BN -> ReLU. + + Parameters + ---------- + x : torch.Tensor + Input tensor, shape `(B, C_in, H_in, W_in)`. + + Returns + ------- + torch.Tensor + Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`. + """ op = F.interpolate( x, size=( @@ -304,143 +559,3 @@ class ConvBlockUpStandard(nn.Module): op = self.conv(op) op = F.relu(self.conv_bn(op), inplace=True) return op - - -UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"] - - -def build_downscaling_layer( - in_channels: int, - out_channels: int, - input_height: int, - layer_type: DownscalingLayer, -) -> nn.Module: - if layer_type == "ConvBlockDownStandard": - return ConvBlockDownStandard( - in_channels=in_channels, - out_channels=out_channels, - ) - - if layer_type == "ConvBlockDownCoordF": - return ConvBlockDownCoordF( - in_channels=in_channels, - out_channels=out_channels, - input_height=input_height, - ) - - raise ValueError( - f"Invalid downscaling layer type {layer_type}. " - f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard" - ) - - -class Encoder(nn.Module): - def __init__( - self, - channels: Sequence[int] = (1, 32, 62, 128), - input_height: int = 128, - layer_type: Literal[ - "ConvBlockDownStandard", "ConvBlockDownCoordF" - ] = "ConvBlockDownStandard", - ): - super().__init__() - - self.channels = channels - self.input_height = input_height - - self.layers = nn.ModuleList( - [ - build_downscaling_layer( - in_channels=in_channels, - out_channels=out_channels, - input_height=input_height // (2**layer_num), - layer_type=layer_type, - ) - for layer_num, (in_channels, out_channels) in enumerate( - pairwise(channels) - ) - ] - ) - self.depth = len(self.layers) - - def forward(self, x: torch.Tensor) -> List[torch.Tensor]: - outputs = [] - - for layer in self.layers: - x = layer(x) - outputs.append(x) - - return outputs - - -def build_upscaling_layer( - in_channels: int, - out_channels: int, - input_height: int, - layer_type: UpscalingLayer, -) -> nn.Module: - if layer_type == "ConvBlockUpStandard": - return ConvBlockUpStandard( - in_channels=in_channels, - out_channels=out_channels, - ) - - if layer_type == "ConvBlockUpF": - return ConvBlockUpF( - in_channels=in_channels, - out_channels=out_channels, - input_height=input_height, - ) - - raise ValueError( - f"Invalid upscaling layer type {layer_type}. " - f"Valid values: ConvBlockUpStandard, ConvBlockUpF" - ) - - -class Decoder(nn.Module): - def __init__( - self, - channels: Sequence[int] = (256, 62, 32, 32), - input_height: int = 128, - layer_type: Literal[ - "ConvBlockUpStandard", "ConvBlockUpF" - ] = "ConvBlockUpStandard", - ): - super().__init__() - - self.channels = channels - self.input_height = input_height - self.depth = len(self.channels) - 1 - - self.layers = nn.ModuleList( - [ - build_upscaling_layer( - in_channels=in_channels, - out_channels=out_channels, - input_height=input_height - // (2 ** (self.depth - layer_num)), - layer_type=layer_type, - ) - for layer_num, (in_channels, out_channels) in enumerate( - pairwise(channels) - ) - ] - ) - - def forward( - self, - x: torch.Tensor, - residuals: List[torch.Tensor], - ) -> torch.Tensor: - if len(residuals) != len(self.layers): - raise ValueError( - f"Incorrect number of residuals provided. " - f"Expected {len(self.layers)} (matching the number of layers), " - f"but got {len(residuals)}." - ) - - for layer, res in zip(self.layers, residuals[::-1]): - x = layer(x + res) - - return x