Create an encoder module

This commit is contained in:
mbsantiago 2025-04-21 15:28:56 +01:00
parent ffa4c2e5e9
commit 096d180ea3

View File

@ -0,0 +1,357 @@
"""Constructs the Encoder part of an Encoder-Decoder neural network.
This module defines the configuration structure (`EncoderConfig`) and provides
the `Encoder` class (an `nn.Module`) along with a factory function
(`build_encoder`) to create sequential encoders commonly used as the
downsampling path in architectures like U-Nets for spectrogram analysis.
The encoder is built by stacking configurable downscaling blocks. Two types
of downscaling blocks are supported, selectable via the configuration:
- `StandardConvDownBlock`: A basic Conv2d -> MaxPool2d -> BN -> ReLU block.
- `FreqCoordConvDownBlock`: A similar block that incorporates frequency
coordinate information (CoordF) before the convolution to potentially aid
spatial awareness along the frequency axis.
The `Encoder`'s `forward` method provides access to intermediate feature maps
from each stage, suitable for use as skip connections in a corresponding
Decoder. A separate `encode` method returns only the final output (bottleneck)
features.
"""
from enum import Enum
from typing import List
import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
FreqCoordConvDownBlock,
StandardConvDownBlock,
)
__all__ = [
"DownscalingLayer",
"EncoderLayer",
"EncoderConfig",
"Encoder",
"build_encoder",
]
class DownscalingLayer(str, Enum):
"""Enumeration of available downscaling layer types for the Encoder.
Used in configuration to specify which block implementation to use at each
stage of the encoder.
Attributes
----------
standard : str
Identifier for the `StandardConvDownBlock`.
coord : str
Identifier for the `FreqCoordConvDownBlock` (incorporates frequency
coords).
"""
standard = "ConvBlockDownStandard"
coord = "FreqCoordConvDownBlock"
class EncoderLayer(BaseConfig):
"""Configuration for a single layer within the Encoder sequence.
Attributes
----------
layer_type : DownscalingLayer
Specifies the type of downscaling block to use for this layer
(either 'standard' or 'coord').
channels : int
The number of output channels this layer should produce. Must be > 0.
"""
layer_type: DownscalingLayer
channels: int
class EncoderConfig(BaseConfig):
"""Configuration for building the entire sequential Encoder.
Attributes
----------
input_height : int
The expected height (number of frequency bins) of the input spectrogram
tensor fed into the first layer of the encoder. Required for
calculating intermediate heights, especially for CoordF layers. Must be
> 0.
layers : List[EncoderLayer]
An ordered list defining the sequence of downscaling layers in the
encoder. Each item specifies the layer type and its output channel
count. The number of input channels for each layer is inferred from the
previous layer's output channels (or `input_channels` for the first
layer). Must contain at least one layer definition.
input_channels : int, default=1
The number of channels in the initial input tensor to the encoder
(e.g., 1 for a standard single-channel spectrogram). Must be > 0.
"""
input_height: int = Field(gt=0)
layers: List[EncoderLayer] = Field(min_length=1)
input_channels: int = Field(gt=0)
def build_downscaling_layer(
in_channels: int,
out_channels: int,
input_height: int,
layer_type: DownscalingLayer,
) -> tuple[nn.Module, int, int]:
"""Build a single downscaling layer based on configuration.
Internal factory function used by `build_encoder`. Instantiates the
appropriate downscaling block (`StandardConvDownBlock` or
`FreqCoordConvDownBlock`) and returns it along with its expected output
channel count and output height (assuming 2x spatial downsampling).
Parameters
----------
in_channels : int
Number of input channels to the layer.
out_channels : int
Desired number of output channels from the layer.
input_height : int
Height of the input feature map to this layer.
layer_type : DownscalingLayer
The type of layer to build ('standard' or 'coord').
Returns
-------
Tuple[nn.Module, int, int]
A tuple containing:
- The instantiated `nn.Module` layer.
- The number of output channels (`out_channels`).
- The expected output height (`input_height // 2`).
Raises
------
ValueError
If `layer_type` is invalid.
"""
if layer_type == DownscalingLayer.standard:
return (
StandardConvDownBlock(
in_channels=in_channels,
out_channels=out_channels,
),
out_channels,
input_height // 2,
)
if layer_type == DownscalingLayer.coord:
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
),
out_channels,
input_height // 2,
)
raise ValueError(
f"Invalid downscaling layer type {layer_type}. "
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
)
class Encoder(nn.Module):
"""Sequential Encoder module composed of configurable downscaling layers.
Constructs the downsampling path of an encoder-decoder network by stacking
multiple downscaling blocks.
The `forward` method executes the sequence and returns the output feature
map from *each* downscaling stage, facilitating the implementation of skip
connections in U-Net-like architectures. The `encode` method returns only
the final output tensor (bottleneck features).
Attributes
----------
input_channels : int
Number of channels expected in the input tensor.
input_height : int
Height (frequency bins) expected in the input tensor.
output_channels : int
Number of channels in the final output tensor (bottleneck).
layers : nn.ModuleList
The sequence of instantiated downscaling layer modules.
depth : int
The number of downscaling layers in the encoder.
"""
def __init__(
self,
output_channels: int,
layers: List[nn.Module],
input_height: int = 128,
input_channels: int = 1,
):
"""Initialize the Encoder module.
Note: This constructor is typically called internally by the
`build_encoder` factory function, which prepares the `layers` list.
Parameters
----------
output_channels : int
Number of channels produced by the final layer.
layers : List[nn.Module]
A list of pre-instantiated downscaling layer modules (e.g.,
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
sequence.
input_height : int, default=128
Expected height of the input tensor.
input_channels : int, default=1
Expected number of channels in the input tensor.
"""
super().__init__()
self.input_channels = input_channels
self.input_height = input_height
self.output_channels = output_channels
self.layers = nn.ModuleList(layers)
self.depth = len(self.layers)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Pass input through encoder layers, returns all intermediate outputs.
This method is typically used when the Encoder is part of a U-Net or
similar architecture requiring skip connections.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
`self.input_channels` and `H_in` must match `self.input_height`.
Returns
-------
List[torch.Tensor]
A list containing the output tensors from *each* downscaling layer
in the sequence. `outputs[0]` is the output of the first layer,
`outputs[-1]` is the final output (bottleneck) of the encoder.
Raises
------
ValueError
If input tensor channel count or height does not match expected
values.
"""
if x.shape[1] != self.input_channels:
raise ValueError(
f"Input tensor has {x.shape[1]} channels, "
f"but encoder expects {self.input_channels}."
)
if x.shape[2] != self.input_height:
raise ValueError(
f"Input tensor height {x.shape[2]} does not match "
f"encoder expected input_height {self.input_height}."
)
outputs = []
for layer in self.layers:
x = layer(x)
outputs.append(x)
return outputs
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Pass input through encoder layers, returning only the final output.
This method provides access to the bottleneck features produced after
the last downscaling layer.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
`input_channels` and `input_height`.
Returns
-------
torch.Tensor
The final output tensor (bottleneck features) from the last layer
of the encoder. Shape `(B, C_out, H_out, W_out)`.
Raises
------
ValueError
If input tensor channel count or height does not match expected
values.
"""
if x.shape[1] != self.input_channels:
raise ValueError(
f"Input tensor has {x.shape[1]} channels, "
f"but encoder expects {self.input_channels}."
)
if x.shape[2] != self.input_height:
raise ValueError(
f"Input tensor height {x.shape[2]} does not match "
f"encoder expected input_height {self.input_height}."
)
for layer in self.layers:
x = layer(x)
return x
def build_encoder(config: EncoderConfig) -> Encoder:
"""Factory function to build an Encoder instance from configuration.
Constructs a sequential `Encoder` module based on the specifications in
an `EncoderConfig` object. It iteratively builds the specified sequence
of downscaling layers (`StandardConvDownBlock` or `FreqCoordConvDownBlock`),
tracking the changing number of channels and feature map height.
Parameters
----------
config : EncoderConfig
The configuration object detailing the encoder architecture, including
input dimensions, layer types, and channel counts for each stage.
Returns
-------
Encoder
An initialized `Encoder` module.
Raises
------
ValueError
If the layer configuration is invalid (e.g., unknown layer type).
"""
current_channels = config.input_channels
current_height = config.input_height
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_downscaling_layer(
in_channels=current_channels,
out_channels=layer_config.channels,
input_height=current_height,
layer_type=layer_config.layer_type,
)
layers.append(layer)
return Encoder(
input_height=config.input_height,
layers=layers,
input_channels=config.input_channels,
output_channels=current_channels,
)