mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Improved docstrings for blocks
This commit is contained in:
parent
e00674f628
commit
6c744eaac5
@ -1,42 +1,93 @@
|
||||
"""Module containing custom NN blocks.
|
||||
"""Commonly used neural network building blocks for BatDetect2 models.
|
||||
|
||||
All these classes are subclasses of `torch.nn.Module` and can be used to build
|
||||
complex neural network architectures.
|
||||
This module provides various reusable `torch.nn.Module` subclasses that form
|
||||
the fundamental building blocks for constructing convolutional neural network
|
||||
architectures, particularly encoder-decoder backbones used in BatDetect2.
|
||||
|
||||
It includes standard components like basic convolutional blocks (`ConvBlock`),
|
||||
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
|
||||
upsampling (`StandardConvUpBlock`).
|
||||
|
||||
Additionally, it features specialized layers investigated in BatDetect2
|
||||
research:
|
||||
|
||||
- `SelfAttention`: Applies self-attention along the time dimension, enabling
|
||||
the model to weigh information across the entire temporal context, often
|
||||
used in the bottleneck of an encoder-decoder.
|
||||
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
|
||||
concept by concatenating normalized frequency coordinate information as an
|
||||
extra channel to the input of convolutional layers. This explicitly provides
|
||||
spatial frequency information to filters, potentially enabling them to learn
|
||||
frequency-dependent patterns more effectively.
|
||||
|
||||
These blocks can be utilized directly in custom PyTorch model definitions or
|
||||
assembled into larger architectures.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ConvBlock",
|
||||
"ConvBlockDownCoordF",
|
||||
"ConvBlockDownStandard",
|
||||
"ConvBlockUpF",
|
||||
"ConvBlockUpStandard",
|
||||
"SelfAttention",
|
||||
"VerticalConv",
|
||||
"DownscalingLayer",
|
||||
"UpscalingLayer",
|
||||
"FreqCoordConvDownBlock",
|
||||
"StandardConvDownBlock",
|
||||
"FreqCoordConvUpBlock",
|
||||
"StandardConvUpBlock",
|
||||
"SelfAttention",
|
||||
]
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""Self-Attention module.
|
||||
"""Self-Attention mechanism operating along the time dimension.
|
||||
|
||||
This module implements self-attention mechanism.
|
||||
This module implements a scaled dot-product self-attention mechanism,
|
||||
specifically designed here to operate across the time steps of an input
|
||||
feature map, typically after spatial dimensions (like frequency) have been
|
||||
condensed or squeezed.
|
||||
|
||||
By calculating attention weights between all pairs of time steps, it allows
|
||||
the model to capture long-range temporal dependencies and focus on relevant
|
||||
parts of the sequence. It's often employed in the bottleneck or
|
||||
intermediate layers of an encoder-decoder architecture to integrate global
|
||||
temporal context.
|
||||
|
||||
The implementation uses linear projections to create query, key, and value
|
||||
representations, computes scaled dot-product attention scores, applies
|
||||
softmax, and produces an output by weighting the values according to the
|
||||
attention scores, followed by a final linear projection. Positional encoding
|
||||
is not explicitly included in this block.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels (features per time step after spatial squeeze).
|
||||
attention_channels : int
|
||||
Number of channels for the query, key, and value projections. Also the
|
||||
dimension of the output projection's input.
|
||||
temperature : float, default=1.0
|
||||
Scaling factor applied *before* the final projection layer. Can be used
|
||||
to adjust the sharpness or focus of the attention mechanism, although
|
||||
scaling within the softmax (dividing by sqrt(dim)) is more common for
|
||||
standard transformers. Here it scales the weighted values.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
key_fun : nn.Linear
|
||||
Linear layer for key projection.
|
||||
value_fun : nn.Linear
|
||||
Linear layer for value projection.
|
||||
query_fun : nn.Linear
|
||||
Linear layer for query projection.
|
||||
pro_fun : nn.Linear
|
||||
Final linear projection layer applied after attention weighting.
|
||||
temperature : float
|
||||
Scaling factor applied before final projection.
|
||||
att_dim : int
|
||||
Dimensionality of the attention space (`attention_channels`).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -56,6 +107,27 @@ class SelfAttention(nn.Module):
|
||||
self.pro_fun = nn.Linear(attention_channels, in_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply self-attention along the time dimension.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, expected shape `(B, C, H, W)`, where H is typically
|
||||
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
|
||||
applying attention along the W (time) dimension.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor of the same shape as the input `(B, C, H, W)`, where
|
||||
attention has been applied across the W dimension.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If input tensor dimensions are incompatible with operations.
|
||||
"""
|
||||
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
|
||||
key = torch.matmul(
|
||||
@ -83,6 +155,26 @@ class SelfAttention(nn.Module):
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Basic Convolutional Block.
|
||||
|
||||
A standard building block consisting of a 2D convolution, followed by
|
||||
batch normalization and a ReLU activation function.
|
||||
|
||||
Sequence: Conv2d -> BatchNorm2d -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor.
|
||||
out_channels : int
|
||||
Number of channels produced by the convolution.
|
||||
kernel_size : int, default=3
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
Amount of padding added to preserve spatial dimensions (assuming
|
||||
stride=1 and kernel_size=3).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@ -100,27 +192,41 @@ class ConvBlock(nn.Module):
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Conv -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H, W)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H, W)`.
|
||||
"""
|
||||
return F.relu_(self.conv_bn(self.conv(x)))
|
||||
|
||||
|
||||
class VerticalConv(nn.Module):
|
||||
"""Convolutional layer over full height.
|
||||
"""Convolutional layer that aggregates features across the entire height.
|
||||
|
||||
This layer applies a convolution that captures information across the
|
||||
entire height of the input image. It uses a kernel with the same height as
|
||||
the input, effectively condensing the vertical information into a single
|
||||
output row.
|
||||
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
|
||||
This collapses the height dimension (H) to 1 while preserving the width (W),
|
||||
effectively summarizing features across the full vertical extent (e.g.,
|
||||
frequency axis) at each time step. Followed by BatchNorm and ReLU.
|
||||
|
||||
More specifically:
|
||||
Useful for summarizing frequency information before applying operations
|
||||
along the time axis (like SelfAttention).
|
||||
|
||||
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
|
||||
input channels, H is the image height, and W is the image width.
|
||||
* **Kernel:** (C', H, 1) where C' is the number of output channels.
|
||||
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
|
||||
convolution integrates information from all rows of the input.
|
||||
|
||||
This process effectively extracts features that span the full height of
|
||||
the input image.
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor.
|
||||
out_channels : int
|
||||
Number of channels produced by the convolution.
|
||||
input_height : int
|
||||
The height (H dimension) of the input tensor. The convolutional kernel
|
||||
will be sized `(input_height, 1)`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -139,14 +245,53 @@ class VerticalConv(nn.Module):
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Vertical Conv -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H, W)`, where H must match the
|
||||
`input_height` provided during initialization.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, 1, W)`.
|
||||
"""
|
||||
return F.relu_(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ConvBlockDownCoordF(nn.Module):
|
||||
"""Convolutional Block with Downsampling and Coord Feature.
|
||||
class FreqCoordConvDownBlock(nn.Module):
|
||||
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
||||
|
||||
This block performs convolution followed by downsampling
|
||||
and concatenates with coordinate information.
|
||||
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
|
||||
used in CNN encoders. Before the convolution, it concatenates an extra
|
||||
channel representing the normalized vertical coordinate (frequency) to the
|
||||
input tensor.
|
||||
|
||||
The purpose of adding coordinate features is to potentially help the
|
||||
convolutional filters become spatially aware, allowing them to learn
|
||||
patterns that might depend on the relative frequency position within the
|
||||
spectrogram.
|
||||
|
||||
Sequence: Concat Coords -> Conv -> MaxPool -> BatchNorm -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor.
|
||||
out_channels : int
|
||||
Number of output channels after the convolution.
|
||||
input_height : int
|
||||
Height (H dimension, frequency bins) of the input tensor to this block.
|
||||
Used to generate the coordinate features.
|
||||
kernel_size : int, default=3
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
Padding added before convolution.
|
||||
stride : int, default=1
|
||||
Stride of the convolution. Note: Downsampling is achieved via
|
||||
MaxPool2d(2, 2).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -174,6 +319,19 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H, W)`, where H must match
|
||||
`input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H/2, W/2)` (due to MaxPool).
|
||||
"""
|
||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||
x = torch.cat((x, freq_info), 1)
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
@ -181,10 +339,26 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class ConvBlockDownStandard(nn.Module):
|
||||
"""Convolutional Block with Downsampling.
|
||||
class StandardConvDownBlock(nn.Module):
|
||||
"""Standard Downsampling Convolutional Block.
|
||||
|
||||
This block performs convolution followed by downsampling.
|
||||
A basic downsampling block consisting of a 2D convolution, followed by
|
||||
2x2 max pooling, batch normalization, and ReLU activation.
|
||||
|
||||
Sequence: Conv -> MaxPool -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor.
|
||||
out_channels : int
|
||||
Number of output channels after the convolution.
|
||||
kernel_size : int, default=3
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
Padding added before convolution.
|
||||
stride : int, default=1
|
||||
Stride of the convolution (downsampling is done by MaxPool).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -195,7 +369,7 @@ class ConvBlockDownStandard(nn.Module):
|
||||
pad_size: int = 1,
|
||||
stride: int = 1,
|
||||
):
|
||||
super(ConvBlockDownStandard, self).__init__()
|
||||
super(StandardConvDownBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
@ -206,18 +380,55 @@ class ConvBlockDownStandard(nn.Module):
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply Conv -> MaxPool -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H, W)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H/2, W/2)`.
|
||||
"""
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
return F.relu(self.conv_bn(x), inplace=True)
|
||||
|
||||
|
||||
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
|
||||
class FreqCoordConvUpBlock(nn.Module):
|
||||
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
||||
|
||||
This block implements an upsampling step followed by a convolution,
|
||||
commonly used in CNN decoders. Before the convolution, it concatenates an
|
||||
extra channel representing the normalized vertical coordinate (frequency)
|
||||
of the *upsampled* feature map.
|
||||
|
||||
class ConvBlockUpF(nn.Module):
|
||||
"""Convolutional Block with Upsampling and Coord Feature.
|
||||
The goal is to provide spatial awareness (frequency position) to the
|
||||
filters during the decoding/upsampling process.
|
||||
|
||||
This block performs convolution followed by upsampling
|
||||
and concatenates with coordinate information.
|
||||
Sequence: Interpolate -> Concat Coords -> Conv -> BatchNorm -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor (before upsampling).
|
||||
out_channels : int
|
||||
Number of output channels after the convolution.
|
||||
input_height : int
|
||||
Height (H dimension, frequency bins) of the tensor *before* upsampling.
|
||||
Used to calculate the height for coordinate feature generation after
|
||||
upsampling.
|
||||
kernel_size : int, default=3
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
Padding added before convolution.
|
||||
up_mode : str, default="bilinear"
|
||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
|
||||
"bicubic").
|
||||
up_scale : Tuple[int, int], default=(2, 2)
|
||||
Scaling factor for height and width during upsampling
|
||||
(typically (2, 2)).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -249,6 +460,19 @@ class ConvBlockUpF(nn.Module):
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W_in)`, where H_in should match
|
||||
`input_height` used during initialization.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`.
|
||||
"""
|
||||
op = F.interpolate(
|
||||
x,
|
||||
size=(
|
||||
@ -265,10 +489,29 @@ class ConvBlockUpF(nn.Module):
|
||||
return op
|
||||
|
||||
|
||||
class ConvBlockUpStandard(nn.Module):
|
||||
"""Convolutional Block with Upsampling.
|
||||
class StandardConvUpBlock(nn.Module):
|
||||
"""Standard Upsampling Convolutional Block.
|
||||
|
||||
This block performs convolution followed by upsampling.
|
||||
A basic upsampling block used in CNN decoders. It first upsamples the input
|
||||
feature map using interpolation, then applies a 2D convolution, batch
|
||||
normalization, and ReLU activation. Does not use coordinate features.
|
||||
|
||||
Sequence: Interpolate -> Conv -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor (before upsampling).
|
||||
out_channels : int
|
||||
Number of output channels after the convolution.
|
||||
kernel_size : int, default=3
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
Padding added before convolution.
|
||||
up_mode : str, default="bilinear"
|
||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
|
||||
up_scale : Tuple[int, int], default=(2, 2)
|
||||
Scaling factor for height and width during upsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -280,7 +523,7 @@ class ConvBlockUpStandard(nn.Module):
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
):
|
||||
super(ConvBlockUpStandard, self).__init__()
|
||||
super(StandardConvUpBlock, self).__init__()
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.conv = nn.Conv2d(
|
||||
@ -292,6 +535,18 @@ class ConvBlockUpStandard(nn.Module):
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Interpolate -> Conv -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W_in)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`.
|
||||
"""
|
||||
op = F.interpolate(
|
||||
x,
|
||||
size=(
|
||||
@ -304,143 +559,3 @@ class ConvBlockUpStandard(nn.Module):
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
return op
|
||||
|
||||
|
||||
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
|
||||
|
||||
|
||||
def build_downscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
layer_type: DownscalingLayer,
|
||||
) -> nn.Module:
|
||||
if layer_type == "ConvBlockDownStandard":
|
||||
return ConvBlockDownStandard(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
if layer_type == "ConvBlockDownCoordF":
|
||||
return ConvBlockDownCoordF(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid downscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
|
||||
)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: Sequence[int] = (1, 32, 62, 128),
|
||||
input_height: int = 128,
|
||||
layer_type: Literal[
|
||||
"ConvBlockDownStandard", "ConvBlockDownCoordF"
|
||||
] = "ConvBlockDownStandard",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.input_height = input_height
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
build_downscaling_layer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height // (2**layer_num),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
for layer_num, (in_channels, out_channels) in enumerate(
|
||||
pairwise(channels)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.depth = len(self.layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
outputs = []
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
outputs.append(x)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def build_upscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
layer_type: UpscalingLayer,
|
||||
) -> nn.Module:
|
||||
if layer_type == "ConvBlockUpStandard":
|
||||
return ConvBlockUpStandard(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
if layer_type == "ConvBlockUpF":
|
||||
return ConvBlockUpF(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid upscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
|
||||
)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: Sequence[int] = (256, 62, 32, 32),
|
||||
input_height: int = 128,
|
||||
layer_type: Literal[
|
||||
"ConvBlockUpStandard", "ConvBlockUpF"
|
||||
] = "ConvBlockUpStandard",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.input_height = input_height
|
||||
self.depth = len(self.channels) - 1
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
build_upscaling_layer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height
|
||||
// (2 ** (self.depth - layer_num)),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
for layer_num, (in_channels, out_channels) in enumerate(
|
||||
pairwise(channels)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residuals: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if len(residuals) != len(self.layers):
|
||||
raise ValueError(
|
||||
f"Incorrect number of residuals provided. "
|
||||
f"Expected {len(self.layers)} (matching the number of layers), "
|
||||
f"but got {len(residuals)}."
|
||||
)
|
||||
|
||||
for layer, res in zip(self.layers, residuals[::-1]):
|
||||
x = layer(x + res)
|
||||
|
||||
return x
|
||||
|
Loading…
Reference in New Issue
Block a user