mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Improved docstrings for blocks
This commit is contained in:
parent
e00674f628
commit
6c744eaac5
@ -1,42 +1,93 @@
|
|||||||
"""Module containing custom NN blocks.
|
"""Commonly used neural network building blocks for BatDetect2 models.
|
||||||
|
|
||||||
All these classes are subclasses of `torch.nn.Module` and can be used to build
|
This module provides various reusable `torch.nn.Module` subclasses that form
|
||||||
complex neural network architectures.
|
the fundamental building blocks for constructing convolutional neural network
|
||||||
|
architectures, particularly encoder-decoder backbones used in BatDetect2.
|
||||||
|
|
||||||
|
It includes standard components like basic convolutional blocks (`ConvBlock`),
|
||||||
|
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
|
||||||
|
upsampling (`StandardConvUpBlock`).
|
||||||
|
|
||||||
|
Additionally, it features specialized layers investigated in BatDetect2
|
||||||
|
research:
|
||||||
|
|
||||||
|
- `SelfAttention`: Applies self-attention along the time dimension, enabling
|
||||||
|
the model to weigh information across the entire temporal context, often
|
||||||
|
used in the bottleneck of an encoder-decoder.
|
||||||
|
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
|
||||||
|
concept by concatenating normalized frequency coordinate information as an
|
||||||
|
extra channel to the input of convolutional layers. This explicitly provides
|
||||||
|
spatial frequency information to filters, potentially enabling them to learn
|
||||||
|
frequency-dependent patterns more effectively.
|
||||||
|
|
||||||
|
These blocks can be utilized directly in custom PyTorch model definitions or
|
||||||
|
assembled into larger architectures.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
from typing import Tuple
|
||||||
from typing import Iterable, List, Literal, Sequence, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
if sys.version_info >= (3, 10):
|
|
||||||
from itertools import pairwise
|
|
||||||
else:
|
|
||||||
|
|
||||||
def pairwise(iterable: Sequence) -> Iterable:
|
|
||||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
|
||||||
yield x, y
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConvBlock",
|
"ConvBlock",
|
||||||
"ConvBlockDownCoordF",
|
|
||||||
"ConvBlockDownStandard",
|
|
||||||
"ConvBlockUpF",
|
|
||||||
"ConvBlockUpStandard",
|
|
||||||
"SelfAttention",
|
|
||||||
"VerticalConv",
|
"VerticalConv",
|
||||||
"DownscalingLayer",
|
"FreqCoordConvDownBlock",
|
||||||
"UpscalingLayer",
|
"StandardConvDownBlock",
|
||||||
|
"FreqCoordConvUpBlock",
|
||||||
|
"StandardConvUpBlock",
|
||||||
|
"SelfAttention",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
"""Self-Attention module.
|
"""Self-Attention mechanism operating along the time dimension.
|
||||||
|
|
||||||
This module implements self-attention mechanism.
|
This module implements a scaled dot-product self-attention mechanism,
|
||||||
|
specifically designed here to operate across the time steps of an input
|
||||||
|
feature map, typically after spatial dimensions (like frequency) have been
|
||||||
|
condensed or squeezed.
|
||||||
|
|
||||||
|
By calculating attention weights between all pairs of time steps, it allows
|
||||||
|
the model to capture long-range temporal dependencies and focus on relevant
|
||||||
|
parts of the sequence. It's often employed in the bottleneck or
|
||||||
|
intermediate layers of an encoder-decoder architecture to integrate global
|
||||||
|
temporal context.
|
||||||
|
|
||||||
|
The implementation uses linear projections to create query, key, and value
|
||||||
|
representations, computes scaled dot-product attention scores, applies
|
||||||
|
softmax, and produces an output by weighting the values according to the
|
||||||
|
attention scores, followed by a final linear projection. Positional encoding
|
||||||
|
is not explicitly included in this block.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of input channels (features per time step after spatial squeeze).
|
||||||
|
attention_channels : int
|
||||||
|
Number of channels for the query, key, and value projections. Also the
|
||||||
|
dimension of the output projection's input.
|
||||||
|
temperature : float, default=1.0
|
||||||
|
Scaling factor applied *before* the final projection layer. Can be used
|
||||||
|
to adjust the sharpness or focus of the attention mechanism, although
|
||||||
|
scaling within the softmax (dividing by sqrt(dim)) is more common for
|
||||||
|
standard transformers. Here it scales the weighted values.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
key_fun : nn.Linear
|
||||||
|
Linear layer for key projection.
|
||||||
|
value_fun : nn.Linear
|
||||||
|
Linear layer for value projection.
|
||||||
|
query_fun : nn.Linear
|
||||||
|
Linear layer for query projection.
|
||||||
|
pro_fun : nn.Linear
|
||||||
|
Final linear projection layer applied after attention weighting.
|
||||||
|
temperature : float
|
||||||
|
Scaling factor applied before final projection.
|
||||||
|
att_dim : int
|
||||||
|
Dimensionality of the attention space (`attention_channels`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -56,6 +107,27 @@ class SelfAttention(nn.Module):
|
|||||||
self.pro_fun = nn.Linear(attention_channels, in_channels)
|
self.pro_fun = nn.Linear(attention_channels, in_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply self-attention along the time dimension.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor, expected shape `(B, C, H, W)`, where H is typically
|
||||||
|
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
|
||||||
|
applying attention along the W (time) dimension.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output tensor of the same shape as the input `(B, C, H, W)`, where
|
||||||
|
attention has been applied across the W dimension.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
RuntimeError
|
||||||
|
If input tensor dimensions are incompatible with operations.
|
||||||
|
"""
|
||||||
|
|
||||||
x = x.squeeze(2).permute(0, 2, 1)
|
x = x.squeeze(2).permute(0, 2, 1)
|
||||||
|
|
||||||
key = torch.matmul(
|
key = torch.matmul(
|
||||||
@ -83,6 +155,26 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
|
"""Basic Convolutional Block.
|
||||||
|
|
||||||
|
A standard building block consisting of a 2D convolution, followed by
|
||||||
|
batch normalization and a ReLU activation function.
|
||||||
|
|
||||||
|
Sequence: Conv2d -> BatchNorm2d -> ReLU.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor.
|
||||||
|
out_channels : int
|
||||||
|
Number of channels produced by the convolution.
|
||||||
|
kernel_size : int, default=3
|
||||||
|
Size of the square convolutional kernel.
|
||||||
|
pad_size : int, default=1
|
||||||
|
Amount of padding added to preserve spatial dimensions (assuming
|
||||||
|
stride=1 and kernel_size=3).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
@ -100,27 +192,41 @@ class ConvBlock(nn.Module):
|
|||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply Conv -> BN -> ReLU.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor, shape `(B, C_in, H, W)`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output tensor, shape `(B, C_out, H, W)`.
|
||||||
|
"""
|
||||||
return F.relu_(self.conv_bn(self.conv(x)))
|
return F.relu_(self.conv_bn(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
class VerticalConv(nn.Module):
|
class VerticalConv(nn.Module):
|
||||||
"""Convolutional layer over full height.
|
"""Convolutional layer that aggregates features across the entire height.
|
||||||
|
|
||||||
This layer applies a convolution that captures information across the
|
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
|
||||||
entire height of the input image. It uses a kernel with the same height as
|
This collapses the height dimension (H) to 1 while preserving the width (W),
|
||||||
the input, effectively condensing the vertical information into a single
|
effectively summarizing features across the full vertical extent (e.g.,
|
||||||
output row.
|
frequency axis) at each time step. Followed by BatchNorm and ReLU.
|
||||||
|
|
||||||
More specifically:
|
Useful for summarizing frequency information before applying operations
|
||||||
|
along the time axis (like SelfAttention).
|
||||||
|
|
||||||
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
|
Parameters
|
||||||
input channels, H is the image height, and W is the image width.
|
----------
|
||||||
* **Kernel:** (C', H, 1) where C' is the number of output channels.
|
in_channels : int
|
||||||
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
|
Number of channels in the input tensor.
|
||||||
convolution integrates information from all rows of the input.
|
out_channels : int
|
||||||
|
Number of channels produced by the convolution.
|
||||||
This process effectively extracts features that span the full height of
|
input_height : int
|
||||||
the input image.
|
The height (H dimension) of the input tensor. The convolutional kernel
|
||||||
|
will be sized `(input_height, 1)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -139,14 +245,53 @@ class VerticalConv(nn.Module):
|
|||||||
self.bn = nn.BatchNorm2d(out_channels)
|
self.bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply Vertical Conv -> BN -> ReLU.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor, shape `(B, C_in, H, W)`, where H must match the
|
||||||
|
`input_height` provided during initialization.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output tensor, shape `(B, C_out, 1, W)`.
|
||||||
|
"""
|
||||||
return F.relu_(self.bn(self.conv(x)))
|
return F.relu_(self.bn(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockDownCoordF(nn.Module):
|
class FreqCoordConvDownBlock(nn.Module):
|
||||||
"""Convolutional Block with Downsampling and Coord Feature.
|
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
||||||
|
|
||||||
This block performs convolution followed by downsampling
|
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
|
||||||
and concatenates with coordinate information.
|
used in CNN encoders. Before the convolution, it concatenates an extra
|
||||||
|
channel representing the normalized vertical coordinate (frequency) to the
|
||||||
|
input tensor.
|
||||||
|
|
||||||
|
The purpose of adding coordinate features is to potentially help the
|
||||||
|
convolutional filters become spatially aware, allowing them to learn
|
||||||
|
patterns that might depend on the relative frequency position within the
|
||||||
|
spectrogram.
|
||||||
|
|
||||||
|
Sequence: Concat Coords -> Conv -> MaxPool -> BatchNorm -> ReLU.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor.
|
||||||
|
out_channels : int
|
||||||
|
Number of output channels after the convolution.
|
||||||
|
input_height : int
|
||||||
|
Height (H dimension, frequency bins) of the input tensor to this block.
|
||||||
|
Used to generate the coordinate features.
|
||||||
|
kernel_size : int, default=3
|
||||||
|
Size of the square convolutional kernel.
|
||||||
|
pad_size : int, default=1
|
||||||
|
Padding added before convolution.
|
||||||
|
stride : int, default=1
|
||||||
|
Stride of the convolution. Note: Downsampling is achieved via
|
||||||
|
MaxPool2d(2, 2).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -174,6 +319,19 @@ class ConvBlockDownCoordF(nn.Module):
|
|||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor, shape `(B, C_in, H, W)`, where H must match
|
||||||
|
`input_height`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output tensor, shape `(B, C_out, H/2, W/2)` (due to MaxPool).
|
||||||
|
"""
|
||||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||||
x = torch.cat((x, freq_info), 1)
|
x = torch.cat((x, freq_info), 1)
|
||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
@ -181,10 +339,26 @@ class ConvBlockDownCoordF(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockDownStandard(nn.Module):
|
class StandardConvDownBlock(nn.Module):
|
||||||
"""Convolutional Block with Downsampling.
|
"""Standard Downsampling Convolutional Block.
|
||||||
|
|
||||||
This block performs convolution followed by downsampling.
|
A basic downsampling block consisting of a 2D convolution, followed by
|
||||||
|
2x2 max pooling, batch normalization, and ReLU activation.
|
||||||
|
|
||||||
|
Sequence: Conv -> MaxPool -> BN -> ReLU.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor.
|
||||||
|
out_channels : int
|
||||||
|
Number of output channels after the convolution.
|
||||||
|
kernel_size : int, default=3
|
||||||
|
Size of the square convolutional kernel.
|
||||||
|
pad_size : int, default=1
|
||||||
|
Padding added before convolution.
|
||||||
|
stride : int, default=1
|
||||||
|
Stride of the convolution (downsampling is done by MaxPool).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -195,7 +369,7 @@ class ConvBlockDownStandard(nn.Module):
|
|||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
stride: int = 1,
|
stride: int = 1,
|
||||||
):
|
):
|
||||||
super(ConvBlockDownStandard, self).__init__()
|
super(StandardConvDownBlock, self).__init__()
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -206,18 +380,55 @@ class ConvBlockDownStandard(nn.Module):
|
|||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
"""Apply Conv -> MaxPool -> BN -> ReLU.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor, shape `(B, C_in, H, W)`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output tensor, shape `(B, C_out, H/2, W/2)`.
|
||||||
|
"""
|
||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
return F.relu(self.conv_bn(x), inplace=True)
|
return F.relu(self.conv_bn(x), inplace=True)
|
||||||
|
|
||||||
|
|
||||||
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
|
class FreqCoordConvUpBlock(nn.Module):
|
||||||
|
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
||||||
|
|
||||||
|
This block implements an upsampling step followed by a convolution,
|
||||||
|
commonly used in CNN decoders. Before the convolution, it concatenates an
|
||||||
|
extra channel representing the normalized vertical coordinate (frequency)
|
||||||
|
of the *upsampled* feature map.
|
||||||
|
|
||||||
class ConvBlockUpF(nn.Module):
|
The goal is to provide spatial awareness (frequency position) to the
|
||||||
"""Convolutional Block with Upsampling and Coord Feature.
|
filters during the decoding/upsampling process.
|
||||||
|
|
||||||
This block performs convolution followed by upsampling
|
Sequence: Interpolate -> Concat Coords -> Conv -> BatchNorm -> ReLU.
|
||||||
and concatenates with coordinate information.
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor (before upsampling).
|
||||||
|
out_channels : int
|
||||||
|
Number of output channels after the convolution.
|
||||||
|
input_height : int
|
||||||
|
Height (H dimension, frequency bins) of the tensor *before* upsampling.
|
||||||
|
Used to calculate the height for coordinate feature generation after
|
||||||
|
upsampling.
|
||||||
|
kernel_size : int, default=3
|
||||||
|
Size of the square convolutional kernel.
|
||||||
|
pad_size : int, default=1
|
||||||
|
Padding added before convolution.
|
||||||
|
up_mode : str, default="bilinear"
|
||||||
|
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
|
||||||
|
"bicubic").
|
||||||
|
up_scale : Tuple[int, int], default=(2, 2)
|
||||||
|
Scaling factor for height and width during upsampling
|
||||||
|
(typically (2, 2)).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -249,6 +460,19 @@ class ConvBlockUpF(nn.Module):
|
|||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor, shape `(B, C_in, H_in, W_in)`, where H_in should match
|
||||||
|
`input_height` used during initialization.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`.
|
||||||
|
"""
|
||||||
op = F.interpolate(
|
op = F.interpolate(
|
||||||
x,
|
x,
|
||||||
size=(
|
size=(
|
||||||
@ -265,10 +489,29 @@ class ConvBlockUpF(nn.Module):
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
class ConvBlockUpStandard(nn.Module):
|
class StandardConvUpBlock(nn.Module):
|
||||||
"""Convolutional Block with Upsampling.
|
"""Standard Upsampling Convolutional Block.
|
||||||
|
|
||||||
This block performs convolution followed by upsampling.
|
A basic upsampling block used in CNN decoders. It first upsamples the input
|
||||||
|
feature map using interpolation, then applies a 2D convolution, batch
|
||||||
|
normalization, and ReLU activation. Does not use coordinate features.
|
||||||
|
|
||||||
|
Sequence: Interpolate -> Conv -> BN -> ReLU.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor (before upsampling).
|
||||||
|
out_channels : int
|
||||||
|
Number of output channels after the convolution.
|
||||||
|
kernel_size : int, default=3
|
||||||
|
Size of the square convolutional kernel.
|
||||||
|
pad_size : int, default=1
|
||||||
|
Padding added before convolution.
|
||||||
|
up_mode : str, default="bilinear"
|
||||||
|
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
|
||||||
|
up_scale : Tuple[int, int], default=(2, 2)
|
||||||
|
Scaling factor for height and width during upsampling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -280,7 +523,7 @@ class ConvBlockUpStandard(nn.Module):
|
|||||||
up_mode: str = "bilinear",
|
up_mode: str = "bilinear",
|
||||||
up_scale: Tuple[int, int] = (2, 2),
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
):
|
):
|
||||||
super(ConvBlockUpStandard, self).__init__()
|
super(StandardConvUpBlock, self).__init__()
|
||||||
self.up_scale = up_scale
|
self.up_scale = up_scale
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
@ -292,6 +535,18 @@ class ConvBlockUpStandard(nn.Module):
|
|||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply Interpolate -> Conv -> BN -> ReLU.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor, shape `(B, C_in, H_in, W_in)`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`.
|
||||||
|
"""
|
||||||
op = F.interpolate(
|
op = F.interpolate(
|
||||||
x,
|
x,
|
||||||
size=(
|
size=(
|
||||||
@ -304,143 +559,3 @@ class ConvBlockUpStandard(nn.Module):
|
|||||||
op = self.conv(op)
|
op = self.conv(op)
|
||||||
op = F.relu(self.conv_bn(op), inplace=True)
|
op = F.relu(self.conv_bn(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
|
|
||||||
|
|
||||||
|
|
||||||
def build_downscaling_layer(
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
layer_type: DownscalingLayer,
|
|
||||||
) -> nn.Module:
|
|
||||||
if layer_type == "ConvBlockDownStandard":
|
|
||||||
return ConvBlockDownStandard(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
if layer_type == "ConvBlockDownCoordF":
|
|
||||||
return ConvBlockDownCoordF(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
input_height=input_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid downscaling layer type {layer_type}. "
|
|
||||||
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channels: Sequence[int] = (1, 32, 62, 128),
|
|
||||||
input_height: int = 128,
|
|
||||||
layer_type: Literal[
|
|
||||||
"ConvBlockDownStandard", "ConvBlockDownCoordF"
|
|
||||||
] = "ConvBlockDownStandard",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.channels = channels
|
|
||||||
self.input_height = input_height
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
build_downscaling_layer(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
input_height=input_height // (2**layer_num),
|
|
||||||
layer_type=layer_type,
|
|
||||||
)
|
|
||||||
for layer_num, (in_channels, out_channels) in enumerate(
|
|
||||||
pairwise(channels)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.depth = len(self.layers)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x)
|
|
||||||
outputs.append(x)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
def build_upscaling_layer(
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
layer_type: UpscalingLayer,
|
|
||||||
) -> nn.Module:
|
|
||||||
if layer_type == "ConvBlockUpStandard":
|
|
||||||
return ConvBlockUpStandard(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
if layer_type == "ConvBlockUpF":
|
|
||||||
return ConvBlockUpF(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
input_height=input_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid upscaling layer type {layer_type}. "
|
|
||||||
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channels: Sequence[int] = (256, 62, 32, 32),
|
|
||||||
input_height: int = 128,
|
|
||||||
layer_type: Literal[
|
|
||||||
"ConvBlockUpStandard", "ConvBlockUpF"
|
|
||||||
] = "ConvBlockUpStandard",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.channels = channels
|
|
||||||
self.input_height = input_height
|
|
||||||
self.depth = len(self.channels) - 1
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
build_upscaling_layer(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
input_height=input_height
|
|
||||||
// (2 ** (self.depth - layer_num)),
|
|
||||||
layer_type=layer_type,
|
|
||||||
)
|
|
||||||
for layer_num, (in_channels, out_channels) in enumerate(
|
|
||||||
pairwise(channels)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
residuals: List[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if len(residuals) != len(self.layers):
|
|
||||||
raise ValueError(
|
|
||||||
f"Incorrect number of residuals provided. "
|
|
||||||
f"Expected {len(self.layers)} (matching the number of layers), "
|
|
||||||
f"but got {len(residuals)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer, res in zip(self.layers, residuals[::-1]):
|
|
||||||
x = layer(x + res)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
Loading…
Reference in New Issue
Block a user