mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Create an encoder module
This commit is contained in:
parent
ffa4c2e5e9
commit
096d180ea3
357
batdetect2/models/encoder.py
Normal file
357
batdetect2/models/encoder.py
Normal file
@ -0,0 +1,357 @@
|
||||
"""Constructs the Encoder part of an Encoder-Decoder neural network.
|
||||
|
||||
This module defines the configuration structure (`EncoderConfig`) and provides
|
||||
the `Encoder` class (an `nn.Module`) along with a factory function
|
||||
(`build_encoder`) to create sequential encoders commonly used as the
|
||||
downsampling path in architectures like U-Nets for spectrogram analysis.
|
||||
|
||||
The encoder is built by stacking configurable downscaling blocks. Two types
|
||||
of downscaling blocks are supported, selectable via the configuration:
|
||||
- `StandardConvDownBlock`: A basic Conv2d -> MaxPool2d -> BN -> ReLU block.
|
||||
- `FreqCoordConvDownBlock`: A similar block that incorporates frequency
|
||||
coordinate information (CoordF) before the convolution to potentially aid
|
||||
spatial awareness along the frequency axis.
|
||||
|
||||
The `Encoder`'s `forward` method provides access to intermediate feature maps
|
||||
from each stage, suitable for use as skip connections in a corresponding
|
||||
Decoder. A separate `encode` method returns only the final output (bottleneck)
|
||||
features.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
FreqCoordConvDownBlock,
|
||||
StandardConvDownBlock,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DownscalingLayer",
|
||||
"EncoderLayer",
|
||||
"EncoderConfig",
|
||||
"Encoder",
|
||||
"build_encoder",
|
||||
]
|
||||
|
||||
|
||||
class DownscalingLayer(str, Enum):
|
||||
"""Enumeration of available downscaling layer types for the Encoder.
|
||||
|
||||
Used in configuration to specify which block implementation to use at each
|
||||
stage of the encoder.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
standard : str
|
||||
Identifier for the `StandardConvDownBlock`.
|
||||
coord : str
|
||||
Identifier for the `FreqCoordConvDownBlock` (incorporates frequency
|
||||
coords).
|
||||
"""
|
||||
|
||||
standard = "ConvBlockDownStandard"
|
||||
coord = "FreqCoordConvDownBlock"
|
||||
|
||||
|
||||
class EncoderLayer(BaseConfig):
|
||||
"""Configuration for a single layer within the Encoder sequence.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
layer_type : DownscalingLayer
|
||||
Specifies the type of downscaling block to use for this layer
|
||||
(either 'standard' or 'coord').
|
||||
channels : int
|
||||
The number of output channels this layer should produce. Must be > 0.
|
||||
"""
|
||||
|
||||
layer_type: DownscalingLayer
|
||||
channels: int
|
||||
|
||||
|
||||
class EncoderConfig(BaseConfig):
|
||||
"""Configuration for building the entire sequential Encoder.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
The expected height (number of frequency bins) of the input spectrogram
|
||||
tensor fed into the first layer of the encoder. Required for
|
||||
calculating intermediate heights, especially for CoordF layers. Must be
|
||||
> 0.
|
||||
layers : List[EncoderLayer]
|
||||
An ordered list defining the sequence of downscaling layers in the
|
||||
encoder. Each item specifies the layer type and its output channel
|
||||
count. The number of input channels for each layer is inferred from the
|
||||
previous layer's output channels (or `input_channels` for the first
|
||||
layer). Must contain at least one layer definition.
|
||||
input_channels : int, default=1
|
||||
The number of channels in the initial input tensor to the encoder
|
||||
(e.g., 1 for a standard single-channel spectrogram). Must be > 0.
|
||||
"""
|
||||
|
||||
input_height: int = Field(gt=0)
|
||||
layers: List[EncoderLayer] = Field(min_length=1)
|
||||
input_channels: int = Field(gt=0)
|
||||
|
||||
|
||||
def build_downscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
layer_type: DownscalingLayer,
|
||||
) -> tuple[nn.Module, int, int]:
|
||||
"""Build a single downscaling layer based on configuration.
|
||||
|
||||
Internal factory function used by `build_encoder`. Instantiates the
|
||||
appropriate downscaling block (`StandardConvDownBlock` or
|
||||
`FreqCoordConvDownBlock`) and returns it along with its expected output
|
||||
channel count and output height (assuming 2x spatial downsampling).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels to the layer.
|
||||
out_channels : int
|
||||
Desired number of output channels from the layer.
|
||||
input_height : int
|
||||
Height of the input feature map to this layer.
|
||||
layer_type : DownscalingLayer
|
||||
The type of layer to build ('standard' or 'coord').
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[nn.Module, int, int]
|
||||
A tuple containing:
|
||||
- The instantiated `nn.Module` layer.
|
||||
- The number of output channels (`out_channels`).
|
||||
- The expected output height (`input_height // 2`).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `layer_type` is invalid.
|
||||
"""
|
||||
if layer_type == DownscalingLayer.standard:
|
||||
return (
|
||||
StandardConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
),
|
||||
out_channels,
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
if layer_type == DownscalingLayer.coord:
|
||||
return (
|
||||
FreqCoordConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
),
|
||||
out_channels,
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid downscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
|
||||
)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Sequential Encoder module composed of configurable downscaling layers.
|
||||
|
||||
Constructs the downsampling path of an encoder-decoder network by stacking
|
||||
multiple downscaling blocks.
|
||||
|
||||
The `forward` method executes the sequence and returns the output feature
|
||||
map from *each* downscaling stage, facilitating the implementation of skip
|
||||
connections in U-Net-like architectures. The `encode` method returns only
|
||||
the final output tensor (bottleneck features).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_channels : int
|
||||
Number of channels expected in the input tensor.
|
||||
input_height : int
|
||||
Height (frequency bins) expected in the input tensor.
|
||||
output_channels : int
|
||||
Number of channels in the final output tensor (bottleneck).
|
||||
layers : nn.ModuleList
|
||||
The sequence of instantiated downscaling layer modules.
|
||||
depth : int
|
||||
The number of downscaling layers in the encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_channels: int,
|
||||
layers: List[nn.Module],
|
||||
input_height: int = 128,
|
||||
input_channels: int = 1,
|
||||
):
|
||||
"""Initialize the Encoder module.
|
||||
|
||||
Note: This constructor is typically called internally by the
|
||||
`build_encoder` factory function, which prepares the `layers` list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output_channels : int
|
||||
Number of channels produced by the final layer.
|
||||
layers : List[nn.Module]
|
||||
A list of pre-instantiated downscaling layer modules (e.g.,
|
||||
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
|
||||
sequence.
|
||||
input_height : int, default=128
|
||||
Expected height of the input tensor.
|
||||
input_channels : int, default=1
|
||||
Expected number of channels in the input tensor.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_channels = input_channels
|
||||
self.input_height = input_height
|
||||
self.output_channels = output_channels
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.depth = len(self.layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Pass input through encoder layers, returns all intermediate outputs.
|
||||
|
||||
This method is typically used when the Encoder is part of a U-Net or
|
||||
similar architecture requiring skip connections.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
|
||||
`self.input_channels` and `H_in` must match `self.input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[torch.Tensor]
|
||||
A list containing the output tensors from *each* downscaling layer
|
||||
in the sequence. `outputs[0]` is the output of the first layer,
|
||||
`outputs[-1]` is the final output (bottleneck) of the encoder.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If input tensor channel count or height does not match expected
|
||||
values.
|
||||
"""
|
||||
if x.shape[1] != self.input_channels:
|
||||
raise ValueError(
|
||||
f"Input tensor has {x.shape[1]} channels, "
|
||||
f"but encoder expects {self.input_channels}."
|
||||
)
|
||||
|
||||
if x.shape[2] != self.input_height:
|
||||
raise ValueError(
|
||||
f"Input tensor height {x.shape[2]} does not match "
|
||||
f"encoder expected input_height {self.input_height}."
|
||||
)
|
||||
|
||||
outputs = []
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
outputs.append(x)
|
||||
|
||||
return outputs
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pass input through encoder layers, returning only the final output.
|
||||
|
||||
This method provides access to the bottleneck features produced after
|
||||
the last downscaling layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
|
||||
`input_channels` and `input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The final output tensor (bottleneck features) from the last layer
|
||||
of the encoder. Shape `(B, C_out, H_out, W_out)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If input tensor channel count or height does not match expected
|
||||
values.
|
||||
"""
|
||||
if x.shape[1] != self.input_channels:
|
||||
raise ValueError(
|
||||
f"Input tensor has {x.shape[1]} channels, "
|
||||
f"but encoder expects {self.input_channels}."
|
||||
)
|
||||
|
||||
if x.shape[2] != self.input_height:
|
||||
raise ValueError(
|
||||
f"Input tensor height {x.shape[2]} does not match "
|
||||
f"encoder expected input_height {self.input_height}."
|
||||
)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def build_encoder(config: EncoderConfig) -> Encoder:
|
||||
"""Factory function to build an Encoder instance from configuration.
|
||||
|
||||
Constructs a sequential `Encoder` module based on the specifications in
|
||||
an `EncoderConfig` object. It iteratively builds the specified sequence
|
||||
of downscaling layers (`StandardConvDownBlock` or `FreqCoordConvDownBlock`),
|
||||
tracking the changing number of channels and feature map height.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : EncoderConfig
|
||||
The configuration object detailing the encoder architecture, including
|
||||
input dimensions, layer types, and channel counts for each stage.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Encoder
|
||||
An initialized `Encoder` module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the layer configuration is invalid (e.g., unknown layer type).
|
||||
"""
|
||||
current_channels = config.input_channels
|
||||
current_height = config.input_height
|
||||
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer, current_channels, current_height = build_downscaling_layer(
|
||||
in_channels=current_channels,
|
||||
out_channels=layer_config.channels,
|
||||
input_height=current_height,
|
||||
layer_type=layer_config.layer_type,
|
||||
)
|
||||
layers.append(layer)
|
||||
|
||||
return Encoder(
|
||||
input_height=config.input_height,
|
||||
layers=layers,
|
||||
input_channels=config.input_channels,
|
||||
output_channels=current_channels,
|
||||
)
|
Loading…
Reference in New Issue
Block a user