Restructured models module

This commit is contained in:
mbsantiago 2025-04-21 21:25:50 +01:00
parent 096d180ea3
commit ce15afc231
8 changed files with 1635 additions and 401 deletions

View File

@ -1,26 +1,135 @@
"""Defines and builds the neural network models used in BatDetect2.
This package (`batdetect2.models`) contains the PyTorch implementations of the
deep neural network architectures used for detecting and classifying bat calls
from spectrograms. It provides modular components and configuration-driven
assembly, allowing for experimentation and use of different architectural
variants.
Key Submodules:
- `.types`: Defines core data structures (`ModelOutput`) and abstract base
classes (`BackboneModel`, `DetectionModel`) establishing interfaces.
- `.blocks`: Provides reusable neural network building blocks.
- `.encoder`: Defines and builds the downsampling path (encoder) of the network.
- `.bottleneck`: Defines and builds the central bottleneck component.
- `.decoder`: Defines and builds the upsampling path (decoder) of the network.
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete
feature extraction backbone (e.g., a U-Net like structure).
- `.heads`: Defines simple prediction heads (detection, classification, size)
that attach to the backbone features.
- `.detectors`: Assembles the backbone and prediction heads into the final,
end-to-end `Detector` model.
This module re-exports the most important classes, configurations, and builder
functions from these submodules for convenient access. The primary entry point
for creating a standard BatDetect2 model instance is the `build_model` function
provided here.
"""
from typing import Optional
from batdetect2.models.backbones import (
Net2DFast,
Net2DFastNoAttn,
Net2DFastNoCoordConv,
Net2DPlain,
Backbone,
BackboneConfig,
build_backbone,
load_backbone_config,
)
from batdetect2.models.build import build_architecture
from batdetect2.models.config import ModelConfig, ModelType, load_model_config
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.types import BackboneModel, ModelOutput
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvDownConfig,
FreqCoordConvUpConfig,
StandardConvDownConfig,
StandardConvUpConfig,
)
from batdetect2.models.bottleneck import (
Bottleneck,
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
build_decoder,
)
from batdetect2.models.detectors import (
Detector,
build_detector,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
__all__ = [
"BBoxHead",
"Backbone",
"BackboneConfig",
"BackboneModel",
"BackboneModel",
"Bottleneck",
"BottleneckConfig",
"ClassifierHead",
"ModelConfig",
"ConvConfig",
"DEFAULT_DECODER_CONFIG",
"DEFAULT_ENCODER_CONFIG",
"DecoderConfig",
"DetectionModel",
"Detector",
"DetectorHead",
"EncoderConfig",
"FreqCoordConvDownConfig",
"FreqCoordConvUpConfig",
"ModelOutput",
"ModelType",
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
"Net2DPlain",
"build_architecture",
"build_architecture",
"load_model_config",
"StandardConvDownConfig",
"StandardConvUpConfig",
"build_backbone",
"build_bottleneck",
"build_decoder",
"build_detector",
"build_encoder",
"build_model",
"load_backbone_config",
]
def build_model(
num_classes: int,
config: Optional[BackboneConfig] = None,
) -> DetectionModel:
"""Build the complete BatDetect2 detection model.
This high-level factory function constructs the standard BatDetect2 model
architecture. It first builds the feature extraction backbone (typically an
encoder-bottleneck-decoder structure) based on the provided
`BackboneConfig` (or defaults if None), and then attaches the standard
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the
`build_detector` function.
Parameters
----------
num_classes : int
The number of specific target classes the model should predict
(required for the `ClassifierHead`). Must be positive.
config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone
(encoder, bottleneck, decoder). If None, default configurations defined
within the respective builder functions (`build_encoder`, etc.) will be
used to construct a default backbone architecture.
Returns
-------
DetectionModel
An initialized `Detector` model instance.
Raises
------
ValueError
If `num_classes` is not positive, or if errors occur during the
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
"""
backbone = build_backbone(config or BackboneConfig())
return build_detector(num_classes, backbone)

View File

@ -1,204 +1,194 @@
from enum import Enum
from typing import Optional, Sequence, Tuple
"""Assembles a complete Encoder-Decoder Backbone network.
This module defines the configuration (`BackboneConfig`) and implementation
(`Backbone`) for a standard encoder-decoder style neural network backbone.
It orchestrates the connection between three main components, built using their
respective configurations and factory functions from sibling modules:
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
at multiple resolutions and provides skip connections.
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
lowest resolution, optionally applying self-attention.
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
resolution features using bottleneck features and skip connections.
The resulting `Backbone` module takes a spectrogram as input and outputs a
final feature map, typically used by subsequent prediction heads. It includes
automatic padding to handle input sizes not perfectly divisible by the
network's total downsampling factor.
"""
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from soundevent import data
from torch import nn
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.blocks import (
ConvBlock,
SelfAttention,
VerticalConv,
)
from batdetect2.models.decoder import Decoder, UpscalingLayer
from batdetect2.models.encoder import DownscalingLayer, Encoder
from batdetect2.models.blocks import ConvBlock
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder
from batdetect2.models.types import BackboneModel
__all__ = [
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
"Net2DPlain",
"Backbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone",
]
class Net2DPlain(BackboneModel):
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
class Backbone(BackboneModel):
"""Encoder-Decoder Backbone Network Implementation.
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
skip connections between the Encoder and Decoder. Implements the standard
U-Net style forward pass. Includes automatic input padding to handle
various input sizes and a final convolutional block to adjust the output
channels.
This class inherits from `BackboneModel` and implements its `forward`
method. Instances are typically created using the `build_backbone` factory
function.
Attributes
----------
input_height : int
Expected height of the input spectrogram.
out_channels : int
Number of channels in the final output feature map.
encoder : Encoder
The instantiated encoder module.
decoder : Decoder
The instantiated decoder module.
bottleneck : nn.Module
The instantiated bottleneck module.
final_conv : ConvBlock
Final convolutional block applied after the decoder.
divide_factor : int
The total downsampling factor (2^depth) applied by the encoder,
used for automatic input padding.
"""
def __init__(
self,
input_height: int = 128,
encoder_channels: Sequence[int] = (1, 32, 64, 128),
bottleneck_channels: int = 256,
decoder_channels: Sequence[int] = (256, 64, 32, 32),
out_channels: int = 32,
input_height: int,
out_channels: int,
encoder: Encoder,
decoder: Decoder,
bottleneck: nn.Module,
):
super().__init__()
"""Initialize the Backbone network.
Parameters
----------
input_height : int
Expected height of the input spectrogram.
out_channels : int
Desired number of output channels for the backbone's feature map.
encoder : Encoder
An initialized Encoder module.
decoder : Decoder
An initialized Decoder module.
bottleneck : nn.Module
An initialized Bottleneck module.
Raises
------
ValueError
If component output/input channels or heights are incompatible.
"""
super().__init__()
self.input_height = input_height
self.encoder_channels = tuple(encoder_channels)
self.decoder_channels = tuple(decoder_channels)
self.out_channels = out_channels
if len(encoder_channels) != len(decoder_channels):
raise ValueError(
f"Mismatched encoder and decoder channel lists. "
f"The encoder has {len(encoder_channels)} channels "
f"(implying {len(encoder_channels) - 1} layers), "
f"while the decoder has {len(decoder_channels)} channels "
f"(implying {len(decoder_channels) - 1} layers). "
f"These lengths must be equal."
)
self.encoder = encoder
self.decoder = decoder
self.bottleneck = bottleneck
self.divide_factor = 2 ** (len(encoder_channels) - 1)
if self.input_height % self.divide_factor != 0:
raise ValueError(
f"Input height ({self.input_height}) must be divisible by "
f"the divide factor ({self.divide_factor}). "
f"This ensures proper upscaling after downscaling to recover "
f"the original input height."
)
self.encoder = Encoder(
channels=encoder_channels,
input_height=self.input_height,
layer_type=self.downscaling_layer_type,
)
self.conv_same_1 = ConvBlock(
in_channels=encoder_channels[-1],
out_channels=bottleneck_channels,
)
# bottleneck
self.conv_vert = VerticalConv(
in_channels=bottleneck_channels,
out_channels=bottleneck_channels,
input_height=self.input_height // (2**self.encoder.depth),
)
self.decoder = Decoder(
channels=decoder_channels,
input_height=self.input_height,
layer_type=self.upscaling_layer_type,
)
self.conv_same_2 = ConvBlock(
in_channels=decoder_channels[-1],
self.final_conv = ConvBlock(
in_channels=decoder.out_channels,
out_channels=out_channels,
)
# Down/Up scaling factor. Need to ensure inputs are divisible by
# this factor in order to be processed by the down/up scaling layers
# and recover the correct shape
self.divide_factor = input_height // self.encoder.output_height
def forward(self, spec: torch.Tensor) -> torch.Tensor:
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
"""Perform the forward pass through the encoder-decoder backbone.
Applies padding, runs encoder, bottleneck, decoder (with skip
connections), removes padding, and applies a final convolution.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match
`self.encoder.input_channels` and `self.input_height`.
Returns
-------
torch.Tensor
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where
`C_out` is `self.out_channels`.
"""
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
# encoder
residuals = self.encoder(spec)
residuals[-1] = self.conv_same_1(residuals[-1])
# bottleneck
x = self.conv_vert(residuals[-1])
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
x = self.bottleneck(residuals[-1])
# decoder
x = self.decoder(x, residuals=residuals)
# Restore original size
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
x = _restore_pad(x, h_pad=h_pad, w_pad=w_pad)
return self.conv_same_2(x)
class Net2DFast(Net2DPlain):
downscaling_layer_type = "ConvBlockDownCoordF"
upscaling_layer_type = "ConvBlockUpF"
def __init__(
self,
input_height: int = 128,
encoder_channels: Sequence[int] = (1, 32, 64, 128),
bottleneck_channels: int = 256,
decoder_channels: Sequence[int] = (256, 64, 32, 32),
out_channels: int = 32,
):
super().__init__(
input_height=input_height,
encoder_channels=encoder_channels,
bottleneck_channels=bottleneck_channels,
decoder_channels=decoder_channels,
out_channels=out_channels,
)
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
# encoder
residuals = self.encoder(spec)
residuals[-1] = self.conv_same_1(residuals[-1])
# bottleneck
x = self.conv_vert(residuals[-1])
x = self.att(x)
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
# decoder
x = self.decoder(x, residuals=residuals)
# Restore original size
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
return self.conv_same_2(x)
class Net2DFastNoAttn(Net2DPlain):
downscaling_layer_type = "ConvBlockDownCoordF"
upscaling_layer_type = "ConvBlockUpF"
class Net2DFastNoCoordConv(Net2DFast):
downscaling_layer_type = "ConvBlockDownStandard"
upscaling_layer_type = "ConvBlockUpStandard"
def pad_adjust(
spec: torch.Tensor,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
h, w = spec.shape[2:]
h_pad = -h % factor
w_pad = -w % factor
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
def restore_pad(
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
) -> torch.Tensor:
# Restore original size
if h_pad > 0:
x = x[:, :, :-h_pad, :]
if w_pad > 0:
x = x[:, :, :, :-w_pad]
return x
class ModelType(str, Enum):
Net2DFast = "Net2DFast"
Net2DFastNoAttn = "Net2DFastNoAttn"
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
Net2DPlain = "Net2DPlain"
return self.final_conv(x)
class BackboneConfig(BaseConfig):
backbone_type: ModelType = ModelType.Net2DFast
"""Configuration for the Encoder-Decoder Backbone network.
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
bottleneck_channels: int = 256
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
in_channels: int = 1
encoder: Optional[EncoderConfig] = None
bottleneck: Optional[BottleneckConfig] = None
decoder: Optional[DecoderConfig] = None
out_channels: int = 32
@ -206,48 +196,162 @@ def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)
def build_model_backbone(
config: Optional[BackboneConfig] = None,
) -> BackboneModel:
config = config or BackboneConfig()
def build_backbone(config: BackboneConfig) -> BackboneModel:
"""Factory function to build a Backbone from configuration.
if config.backbone_type == ModelType.Net2DFast:
return Net2DFast(
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on
the provided `BackboneConfig`, validates their compatibility, and assembles
them into a `Backbone` instance.
Parameters
----------
config : BackboneConfig
The configuration object detailing the backbone architecture, including
input dimensions and configurations for encoder, bottleneck, and
decoder.
Returns
-------
BackboneModel
An initialized `Backbone` module ready for use.
Raises
------
ValueError
If sub-component configurations are incompatible
(e.g., channel mismatches, decoder output height doesn't match backbone
input height).
NotImplementedError
If an unknown block type is specified in sub-configs.
"""
encoder = build_encoder(
in_channels=config.in_channels,
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
config=config.encoder,
)
if config.backbone_type == ModelType.Net2DFastNoAttn:
return Net2DFastNoAttn(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
bottleneck = build_bottleneck(
input_height=encoder.output_height,
in_channels=encoder.out_channels,
config=config.bottleneck,
)
if config.backbone_type == ModelType.Net2DFastNoCoordConv:
return Net2DFastNoCoordConv(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
decoder = build_decoder(
in_channels=bottleneck.out_channels,
input_height=encoder.output_height,
config=config.decoder,
)
if config.backbone_type == ModelType.Net2DPlain:
return Net2DPlain(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
if decoder.output_height != config.input_height:
raise ValueError(
"Invalid configuration: Decoder output height "
f"({decoder.output_height}) must match the Backbone input height "
f"({config.input_height}). Check encoder/decoder layer "
"configurations and input/bottleneck heights."
)
raise ValueError(f"Unknown model type: {config.backbone_type}")
return Backbone(
input_height=config.input_height,
out_channels=config.out_channels,
encoder=encoder,
decoder=decoder,
bottleneck=bottleneck,
)
def _pad_adjust(
spec: torch.Tensor,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
"""Pad tensor height and width to be divisible by a factor.
Calculates the required padding for the last two dimensions (H, W) to make
them divisible by `factor` and applies right/bottom padding using
`torch.nn.functional.pad`.
Parameters
----------
spec : torch.Tensor
Input tensor, typically shape `(B, C, H, W)`.
factor : int, default=32
The factor to make height and width divisible by.
Returns
-------
Tuple[torch.Tensor, int, int]
A tuple containing:
- The padded tensor.
- The amount of padding added to height (`h_pad`).
- The amount of padding added to width (`w_pad`).
"""
h, w = spec.shape[2:]
h_pad = -h % factor
w_pad = -w % factor
if h_pad == 0 and w_pad == 0:
return spec, 0, 0
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
def _restore_pad(
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
) -> torch.Tensor:
"""Remove padding added by _pad_adjust.
Removes padding from the bottom and right edges of the tensor.
Parameters
----------
x : torch.Tensor
Padded tensor, typically shape `(B, C, H_padded, W_padded)`.
h_pad : int, default=0
Amount of padding previously added to the height (bottom).
w_pad : int, default=0
Amount of padding previously added to the width (right).
Returns
-------
torch.Tensor
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
"""
if h_pad > 0:
x = x[:, :, :-h_pad, :]
if w_pad > 0:
x = x[:, :, :, :-w_pad]
return x

View File

@ -22,14 +22,20 @@ research:
These blocks can be utilized directly in custom PyTorch model definitions or
assembled into larger architectures.
A unified factory function `build_layer_from_config` allows creating instances
of these blocks based on configuration objects.
"""
from typing import Tuple
from typing import Annotated, Literal, Tuple, Union
import torch
import torch.nn.functional as F
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
__all__ = [
"ConvBlock",
"VerticalConv",
@ -38,6 +44,13 @@ __all__ = [
"FreqCoordConvUpBlock",
"StandardConvUpBlock",
"SelfAttention",
"ConvConfig",
"FreqCoordConvDownConfig",
"StandardConvDownConfig",
"FreqCoordConvUpConfig",
"StandardConvUpConfig",
"LayerConfig",
"build_layer_from_config",
]
@ -154,6 +167,22 @@ class SelfAttention(nn.Module):
return op
class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock."""
block_type: Literal["ConvBlock"] = "ConvBlock"
"""Discriminator field indicating the block type."""
out_channels: int
"""Number of output channels."""
kernel_size: int = 3
"""Size of the square convolutional kernel."""
pad_size: int = 1
"""Padding size."""
class ConvBlock(nn.Module):
"""Basic Convolutional Block.
@ -171,8 +200,7 @@ class ConvBlock(nn.Module):
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Amount of padding added to preserve spatial dimensions (assuming
stride=1 and kernel_size=3).
Amount of padding added to preserve spatial dimensions.
"""
def __init__(
@ -261,6 +289,22 @@ class VerticalConv(nn.Module):
return F.relu_(self.bn(self.conv(x)))
class FreqCoordConvDownConfig(BaseConfig):
"""Configuration for a FreqCoordConvDownBlock."""
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
"""Discriminator field indicating the block type."""
out_channels: int
"""Number of output channels."""
kernel_size: int = 3
"""Size of the square convolutional kernel."""
pad_size: int = 1
"""Padding size."""
class FreqCoordConvDownBlock(nn.Module):
"""Downsampling Conv Block incorporating Frequency Coordinate features.
@ -289,9 +333,6 @@ class FreqCoordConvDownBlock(nn.Module):
Size of the square convolutional kernel.
pad_size : int, default=1
Padding added before convolution.
stride : int, default=1
Stride of the convolution. Note: Downsampling is achieved via
MaxPool2d(2, 2).
"""
def __init__(
@ -301,7 +342,6 @@ class FreqCoordConvDownBlock(nn.Module):
input_height: int,
kernel_size: int = 3,
pad_size: int = 1,
stride: int = 1,
):
super().__init__()
@ -314,7 +354,7 @@ class FreqCoordConvDownBlock(nn.Module):
out_channels,
kernel_size=kernel_size,
padding=pad_size,
stride=stride,
stride=1,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
@ -339,6 +379,22 @@ class FreqCoordConvDownBlock(nn.Module):
return x
class StandardConvDownConfig(BaseConfig):
"""Configuration for a StandardConvDownBlock."""
block_type: Literal["StandardConvDown"] = "StandardConvDown"
"""Discriminator field indicating the block type."""
out_channels: int
"""Number of output channels."""
kernel_size: int = 3
"""Size of the square convolutional kernel."""
pad_size: int = 1
"""Padding size."""
class StandardConvDownBlock(nn.Module):
"""Standard Downsampling Convolutional Block.
@ -357,8 +413,6 @@ class StandardConvDownBlock(nn.Module):
Size of the square convolutional kernel.
pad_size : int, default=1
Padding added before convolution.
stride : int, default=1
Stride of the convolution (downsampling is done by MaxPool).
"""
def __init__(
@ -367,7 +421,6 @@ class StandardConvDownBlock(nn.Module):
out_channels: int,
kernel_size: int = 3,
pad_size: int = 1,
stride: int = 1,
):
super(StandardConvDownBlock, self).__init__()
self.conv = nn.Conv2d(
@ -375,7 +428,7 @@ class StandardConvDownBlock(nn.Module):
out_channels,
kernel_size=kernel_size,
padding=pad_size,
stride=stride,
stride=1,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
@ -396,6 +449,22 @@ class StandardConvDownBlock(nn.Module):
return F.relu(self.conv_bn(x), inplace=True)
class FreqCoordConvUpConfig(BaseConfig):
"""Configuration for a FreqCoordConvUpBlock."""
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
"""Discriminator field indicating the block type."""
out_channels: int
"""Number of output channels."""
kernel_size: int = 3
"""Size of the square convolutional kernel."""
pad_size: int = 1
"""Padding size."""
class FreqCoordConvUpBlock(nn.Module):
"""Upsampling Conv Block incorporating Frequency Coordinate features.
@ -489,6 +558,22 @@ class FreqCoordConvUpBlock(nn.Module):
return op
class StandardConvUpConfig(BaseConfig):
"""Configuration for a StandardConvUpBlock."""
block_type: Literal["StandardConvUp"] = "StandardConvUp"
"""Discriminator field indicating the block type."""
out_channels: int
"""Number of output channels."""
kernel_size: int = 3
"""Size of the square convolutional kernel."""
pad_size: int = 1
"""Padding size."""
class StandardConvUpBlock(nn.Module):
"""Standard Upsampling Convolutional Block.
@ -559,3 +644,122 @@ class StandardConvUpBlock(nn.Module):
op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True)
return op
LayerConfig = Annotated[
Union[
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configuration models."""
def build_layer_from_config(
input_height: int,
in_channels: int,
config: LayerConfig,
) -> Tuple[nn.Module, int, int]:
"""Factory function to build a specific nn.Module block from its config.
Takes configuration object (one of the types included in the `LayerConfig`
union) and instantiates the corresponding nn.Module block with the correct
parameters derived from the config and the current pipeline state
(`input_height`, `in_channels`).
It uses the `block_type` field within the `config` object to determine
which block class to instantiate.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor *to this layer*.
in_channels : int
Number of channels in the input tensor *to this layer*.
config : LayerConfig
A Pydantic configuration object for the desired block (e.g., an
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
by its `block_type` field.
Returns
-------
Tuple[nn.Module, int, int]
A tuple containing:
- The instantiated `nn.Module` block.
- The number of output channels produced by the block.
- The calculated height of the output produced by the block.
Raises
------
NotImplementedError
If the `config.block_type` does not correspond to a known block type.
ValueError
If parameters derived from the config are invalid for the block.
"""
if config.block_type == "ConvBlock":
return (
ConvBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height,
)
if config.block_type == "FreqCoordConvDown":
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
if config.block_type == "StandardConvDown":
return (
StandardConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
if config.block_type == "FreqCoordConvUp":
return (
FreqCoordConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.block_type == "StandardConvUp":
return (
StandardConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
raise NotImplementedError(f"Unknown block type {config.block_type}")

View File

@ -0,0 +1,254 @@
"""Defines the Bottleneck component of an Encoder-Decoder architecture.
This module provides the configuration (`BottleneckConfig`) and
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the
bottleneck layer(s) that typically connect the Encoder (downsampling path) and
Decoder (upsampling path) in networks like U-Nets.
The bottleneck processes the lowest-resolution, highest-dimensionality feature
map produced by the Encoder. This module offers a configurable option to include
a `SelfAttention` layer within the bottleneck, allowing the model to capture
global temporal context before features are passed to the Decoder.
A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration.
"""
from typing import Optional
import torch
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import SelfAttention, VerticalConv
__all__ = [
"BottleneckConfig",
"Bottleneck",
"BottleneckAttn",
"build_bottleneck",
]
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
----------
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
"""
channels: int
self_attention: bool
class Bottleneck(nn.Module):
"""Base Bottleneck module for Encoder-Decoder architectures.
This implementation represents the simplest bottleneck structure
considered, primarily consisting of a `VerticalConv` layer. This layer
collapses the frequency dimension (height) to 1, summarizing information
across frequencies at each time step. The output is then repeated along the
height dimension to match the original bottleneck input height before being
passed to the decoder.
This base version does *not* include self-attention.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor. Must be positive.
in_channels : int
Number of channels in the input tensor from the encoder. Must be
positive.
out_channels : int
Number of output channels. Must be positive.
Attributes
----------
in_channels : int
Number of input channels accepted by the bottleneck.
input_height : int
Expected height of the input tensor.
channels : int
Number of output channels.
conv_vert : VerticalConv
The vertical convolution layer.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
"""
def __init__(
self,
input_height: int,
in_channels: int,
out_channels: int,
) -> None:
"""Initialize the base Bottleneck layer."""
super().__init__()
self.in_channels = in_channels
self.input_height = input_height
self.out_channels = out_channels
self.conv_vert = VerticalConv(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input features through the bottleneck.
Applies vertical convolution and repeats the output height.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height
dimension `H_in` is restored via repetition after the vertical
convolution.
"""
x = self.conv_vert(x)
return x.repeat([1, 1, self.input_height, 1])
class BottleneckAttn(Bottleneck):
"""Bottleneck module including a Self-Attention layer.
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
the initial `VerticalConv`. This allows the bottleneck to capture global
temporal dependencies in the summarized frequency features before passing
them to the decoder.
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor from the encoder.
in_channels : int
Number of channels in the input tensor from the encoder.
out_channels : int
Number of output channels produced by the `VerticalConv` and
subsequently processed and output by this bottleneck. Also determines
the input/output channels of the internal `SelfAttention` layer.
attention : nn.Module
An initialized `SelfAttention` module instance.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
"""
def __init__(
self,
input_height: int,
in_channels: int,
out_channels: int,
attention: nn.Module,
) -> None:
"""Initialize the Bottleneck with Self-Attention."""
super().__init__(input_height, in_channels, out_channels)
self.attention = attention
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input tensor.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
and repeating the height dimension.
"""
x = self.conv_vert(x)
x = self.attention(x)
return x.repeat([1, 1, self.input_height, 1])
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
channels=256,
self_attention=True,
)
def build_bottleneck(
input_height: int,
in_channels: int,
config: Optional[BottleneckConfig] = None,
) -> nn.Module:
"""Factory function to build the Bottleneck module from configuration.
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
on the `config.self_attention` flag.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor. Must be positive.
in_channels : int
Number of channels in the input tensor. Must be positive.
config : BottleneckConfig, optional
Configuration object specifying the bottleneck channels and whether
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None.
Returns
-------
nn.Module
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`).
Raises
------
ValueError
If `input_height` or `in_channels` are not positive.
"""
config = config or DEFAULT_BOTTLENECK_CONFIG
if config.self_attention:
attention = SelfAttention(
in_channels=config.channels,
attention_channels=config.channels,
)
return BottleneckAttn(
input_height=input_height,
in_channels=in_channels,
out_channels=config.channels,
attention=attention,
)
return Bottleneck(
input_height=input_height,
in_channels=in_channels,
out_channels=config.channels,
)

View File

@ -0,0 +1,267 @@
"""Constructs the Decoder part of an Encoder-Decoder neural network.
This module defines the configuration structure (`DecoderConfig`) for the layer
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory
function (`build_decoder`). Decoders typically form the upsampling path in
architectures like U-Nets, taking bottleneck features
(usually from an `Encoder`) and skip connections to reconstruct
higher-resolution feature maps.
The decoder is built dynamically by stacking neural network blocks based on a
list of configuration objects provided in `DecoderConfig.layers`. Each config
object specifies the type of block (e.g., standard convolution,
coordinate-feature convolution with upsampling) and its parameters. This allows
flexible definition of decoder architectures via configuration files.
The `Decoder`'s `forward` method is designed to accept skip connection tensors
(`residuals`) from the encoder, merging them with the upsampled feature maps
at each stage.
"""
from typing import Annotated, List, Optional, Union
import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
build_layer_from_config,
)
__all__ = [
"DecoderConfig",
"Decoder",
"build_decoder",
"DEFAULT_DECODER_CONFIG",
]
DecoderLayerConfig = Annotated[
Union[ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
class DecoderConfig(BaseConfig):
"""Configuration for the sequence of layers in the Decoder module.
Defines the types and parameters of the neural network blocks that
constitute the decoder's upsampling path.
Attributes
----------
layers : List[DecoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the decoder sequence. Each item must be a valid block
config including a `block_type` field and necessary parameters like
`out_channels`. Input channels for each layer are inferred sequentially.
The list must contain at least one layer.
"""
layers: List[DecoderLayerConfig] = Field(min_length=1)
class Decoder(nn.Module):
"""Sequential Decoder module composed of configurable upsampling layers.
Constructs the upsampling path of an encoder-decoder network by stacking
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`)
based on a list of layer modules provided during initialization (typically
created by the `build_decoder` factory function).
The `forward` method is designed to integrate skip connection tensors
(`residuals`) from the corresponding encoder stages, by adding them
element-wise to the input of each decoder layer before processing.
Attributes
----------
in_channels : int
Number of channels expected in the input tensor.
out_channels : int
Number of channels in the final output tensor produced by the last
layer.
input_height : int
Height (frequency bins) expected in the input tensor.
output_height : int
Height (frequency bins) expected in the output tensor.
layers : nn.ModuleList
The sequence of instantiated upscaling layer modules.
depth : int
The number of upscaling layers (depth) in the decoder.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
input_height: int,
output_height: int,
layers: List[nn.Module],
):
"""Initialize the Decoder module.
Note: This constructor is typically called internally by the
`build_decoder` factory function.
Parameters
----------
out_channels : int
Number of channels produced by the final layer.
input_height : int
Expected height of the input tensor (bottleneck).
in_channels : int
Expected number of channels in the input tensor (bottleneck).
layers : List[nn.Module]
A list of pre-instantiated upscaling layer modules (e.g.,
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired
sequence (from bottleneck towards output resolution).
"""
super().__init__()
self.input_height = input_height
self.output_height = output_height
self.in_channels = in_channels
self.out_channels = out_channels
self.layers = nn.ModuleList(layers)
self.depth = len(self.layers)
def forward(
self,
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
"""Pass input through decoder layers, incorporating skip connections.
Processes the input tensor `x` sequentially through the upscaling
layers. At each stage, the corresponding skip connection tensor from
the `residuals` list is added element-wise to the input before passing
it to the upscaling block.
Parameters
----------
x : torch.Tensor
Input tensor from the previous stage (e.g., encoder bottleneck).
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
`self.in_channels`.
residuals : List[torch.Tensor]
List containing the skip connection tensors from the corresponding
encoder stages. Should be ordered from the deepest encoder layer
output (lowest resolution) to the shallowest (highest resolution
near input). The number of tensors in this list must match the
number of decoder layers (`self.depth`). Each residual tensor's
channel count must be compatible with the input tensor `x` for
element-wise addition (or concatenation if the blocks were designed
for it).
Returns
-------
torch.Tensor
The final decoded feature map tensor produced by the last layer.
Shape `(B, C_out, H_out, W_out)`.
Raises
------
ValueError
If the number of `residuals` provided does not match the decoder
depth.
RuntimeError
If shapes mismatch during skip connection addition or layer
processing.
"""
if len(residuals) != len(self.layers):
raise ValueError(
f"Incorrect number of residuals provided. "
f"Expected {len(self.layers)} (matching the number of layers), "
f"but got {len(residuals)}."
)
for layer, res in zip(self.layers, residuals[::-1]):
x = layer(x + res)
return x
DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
layers=[
FreqCoordConvUpConfig(out_channels=64),
FreqCoordConvUpConfig(out_channels=32),
FreqCoordConvUpConfig(out_channels=32),
ConvConfig(out_channels=32),
],
)
"""A default configuration for the Decoder's *layer sequence*.
Specifies an architecture often used in BatDetect2, consisting of three
frequency coordinate-aware upsampling blocks followed by a standard
convolutional block.
"""
def build_decoder(
in_channels: int,
input_height: int,
config: Optional[DecoderConfig] = None,
) -> Decoder:
"""Factory function to build a Decoder instance from configuration.
Constructs a sequential `Decoder` module based on the layer sequence
defined in a `DecoderConfig` object and the provided input dimensions
(bottleneck channels and height). If no config is provided, uses the
default layer sequence from `DEFAULT_DECODER_CONFIG`.
It iteratively builds the layers using the unified `build_layer_from_config`
factory (from `.blocks`), tracking the changing number of channels and
feature map height required for each subsequent layer.
Parameters
----------
in_channels : int
The number of channels in the input tensor to the decoder. Must be > 0.
input_height : int
The height (frequency bins) of the input tensor to the decoder. Must be
> 0.
config : DecoderConfig, optional
The configuration object detailing the sequence of layers and their
parameters. If None, `DEFAULT_DECODER_CONFIG` is used.
Returns
-------
Decoder
An initialized `Decoder` module.
Raises
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `block_type`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `block_type`.
"""
config = config or DEFAULT_DECODER_CONFIG
current_channels = in_channels
current_height = input_height
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
in_channels=current_channels,
input_height=current_height,
config=layer_config,
)
layers.append(layer)
return Decoder(
in_channels=in_channels,
out_channels=current_channels,
input_height=input_height,
output_height=current_height,
layers=layers,
)

View File

@ -0,0 +1,173 @@
"""Assembles the complete BatDetect2 Detection Model.
This module defines the concrete `Detector` class, which implements the
`DetectionModel` interface defined in `.types`. It combines a feature
extraction backbone with specific prediction heads to create the end-to-end
neural network used for detecting bat calls, predicting their size, and
classifying them.
The primary components are:
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
- `build_detector`: A factory function to conveniently construct a standard
`Detector` instance given a backbone and the number of target classes.
This module focuses purely on the neural network architecture definition. The
logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
"""
import torch
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
class Detector(DetectionModel):
"""Concrete implementation of the BatDetect2 Detection Model.
Assembles a complete detection and classification model by combining a
feature extraction backbone network with specific prediction heads for
detection probability, bounding box size regression, and class
probabilities.
Attributes
----------
backbone : BackboneModel
The feature extraction backbone network module.
num_classes : int
The number of specific target classes the model predicts (derived from
the `classifier_head`).
classifier_head : ClassifierHead
The prediction head responsible for generating class probabilities.
detector_head : DetectorHead
The prediction head responsible for generating detection probabilities.
bbox_head : BBoxHead
The prediction head responsible for generating bounding box size
predictions.
"""
backbone: BackboneModel
def __init__(
self,
backbone: BackboneModel,
classifier_head: ClassifierHead,
detector_head: DetectorHead,
bbox_head: BBoxHead,
):
"""Initialize the Detector model.
Note: Instances are typically created using the `build_detector`
factory function.
Parameters
----------
backbone : BackboneModel
An initialized feature extraction backbone module (e.g., built by
`build_backbone` from the `.backbone` module).
classifier_head : ClassifierHead
An initialized classification head module. The number of classes
is inferred from this head.
detector_head : DetectorHead
An initialized detection head module.
bbox_head : BBoxHead
An initialized bounding box size prediction head module.
Raises
------
TypeError
If the provided modules are not of the expected types.
"""
super().__init__()
self.backbone = backbone
self.num_classes = classifier_head.num_classes
self.classifier_head = classifier_head
self.detector_head = detector_head
self.bbox_head = bbox_head
def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Perform the forward pass of the complete detection model.
Processes the input spectrogram through the backbone to extract
features, then passes these features through the separate prediction
heads to generate detection probabilities, class probabilities, and
size predictions.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, typically with shape
`(batch_size, input_channels, frequency_bins, time_bins)`. The
shape must be compatible with the `self.backbone` input
requirements.
Returns
-------
ModelOutput
A NamedTuple containing the four output tensors:
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`.
- `class_probs`: Class probabilities (excluding background)
`(B, num_classes, H, W)`.
- `features`: Output feature map from the backbone
`(B, C_out, H, W)`.
"""
features = self.backbone(spec)
detection = self.detector_head(features)
classification = self.classifier_head(features)
size_preds = self.bbox_head(features)
return ModelOutput(
detection_probs=detection,
size_preds=size_preds,
class_probs=classification,
features=features,
)
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
"""Factory function to build a standard Detector model instance.
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`,
`BBoxHead`) configured appropriately based on the output channels of the
provided `backbone` and the specified `num_classes`. It then assembles
these components into a `Detector` model.
Parameters
----------
num_classes : int
The number of specific target classes for the classification head
(excluding any implicit background class). Must be positive.
backbone : BackboneModel
An initialized feature extraction backbone module instance. The number
of output channels from this backbone (`backbone.out_channels`) is used
to configure the input channels for the prediction heads.
Returns
-------
Detector
An initialized `Detector` model instance.
Raises
------
ValueError
If `num_classes` is not positive.
AttributeError
If `backbone` does not have the required `out_channels` attribute.
"""
classifier_head = ClassifierHead(
num_classes=num_classes,
in_channels=backbone.out_channels,
)
detector_head = DetectorHead(
in_channels=backbone.out_channels,
)
bbox_head = BBoxHead(
in_channels=backbone.out_channels,
)
return Detector(
backbone=backbone,
classifier_head=classifier_head,
detector_head=detector_head,
bbox_head=bbox_head,
)

View File

@ -1,25 +1,26 @@
"""Constructs the Encoder part of an Encoder-Decoder neural network.
"""Constructs the Encoder part of a configurable neural network backbone.
This module defines the configuration structure (`EncoderConfig`) and provides
the `Encoder` class (an `nn.Module`) along with a factory function
(`build_encoder`) to create sequential encoders commonly used as the
downsampling path in architectures like U-Nets for spectrogram analysis.
(`build_encoder`) to create sequential encoders. Encoders typically form the
downsampling path in architectures like U-Nets, processing input feature maps
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
representations (bottleneck features).
The encoder is built by stacking configurable downscaling blocks. Two types
of downscaling blocks are supported, selectable via the configuration:
- `StandardConvDownBlock`: A basic Conv2d -> MaxPool2d -> BN -> ReLU block.
- `FreqCoordConvDownBlock`: A similar block that incorporates frequency
coordinate information (CoordF) before the convolution to potentially aid
spatial awareness along the frequency axis.
The encoder is built dynamically by stacking neural network blocks based on a
list of configuration objects provided in `EncoderConfig.layers`. Each
configuration object specifies the type of block (e.g., standard convolution,
coordinate-feature convolution with downsampling) and its parameters
(e.g., output channels). This allows for flexible definition of encoder
architectures via configuration files.
The `Encoder`'s `forward` method provides access to intermediate feature maps
from each stage, suitable for use as skip connections in a corresponding
Decoder. A separate `encode` method returns only the final output (bottleneck)
features.
The `Encoder`'s `forward` method returns outputs from all intermediate layers,
suitable for skip connections, while the `encode` method returns only the final
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
provided.
"""
from enum import Enum
from typing import List
from typing import Annotated, List, Optional, Union
import torch
from pydantic import Field
@ -27,142 +28,44 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
FreqCoordConvDownBlock,
StandardConvDownBlock,
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
build_layer_from_config,
)
__all__ = [
"DownscalingLayer",
"EncoderLayer",
"EncoderConfig",
"Encoder",
"build_encoder",
"DEFAULT_ENCODER_CONFIG",
]
class DownscalingLayer(str, Enum):
"""Enumeration of available downscaling layer types for the Encoder.
Used in configuration to specify which block implementation to use at each
stage of the encoder.
Attributes
----------
standard : str
Identifier for the `StandardConvDownBlock`.
coord : str
Identifier for the `FreqCoordConvDownBlock` (incorporates frequency
coords).
"""
standard = "ConvBlockDownStandard"
coord = "FreqCoordConvDownBlock"
class EncoderLayer(BaseConfig):
"""Configuration for a single layer within the Encoder sequence.
Attributes
----------
layer_type : DownscalingLayer
Specifies the type of downscaling block to use for this layer
(either 'standard' or 'coord').
channels : int
The number of output channels this layer should produce. Must be > 0.
"""
layer_type: DownscalingLayer
channels: int
EncoderLayerConfig = Annotated[
Union[ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configs usable in Encoder."""
class EncoderConfig(BaseConfig):
"""Configuration for building the entire sequential Encoder.
"""Configuration for building the sequential Encoder module.
Defines the sequence of neural network blocks that constitute the encoder
(downsampling path).
Attributes
----------
input_height : int
The expected height (number of frequency bins) of the input spectrogram
tensor fed into the first layer of the encoder. Required for
calculating intermediate heights, especially for CoordF layers. Must be
> 0.
layers : List[EncoderLayer]
An ordered list defining the sequence of downscaling layers in the
encoder. Each item specifies the layer type and its output channel
count. The number of input channels for each layer is inferred from the
previous layer's output channels (or `input_channels` for the first
layer). Must contain at least one layer definition.
input_channels : int, default=1
The number of channels in the initial input tensor to the encoder
(e.g., 1 for a standard single-channel spectrogram). Must be > 0.
layers : List[EncoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the encoder sequence. Each item must be a valid block config
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
`StandardConvDownConfig`) including a `block_type` field and necessary
parameters like `out_channels`. Input channels for each layer are
inferred sequentially. The list must contain at least one layer.
"""
input_height: int = Field(gt=0)
layers: List[EncoderLayer] = Field(min_length=1)
input_channels: int = Field(gt=0)
def build_downscaling_layer(
in_channels: int,
out_channels: int,
input_height: int,
layer_type: DownscalingLayer,
) -> tuple[nn.Module, int, int]:
"""Build a single downscaling layer based on configuration.
Internal factory function used by `build_encoder`. Instantiates the
appropriate downscaling block (`StandardConvDownBlock` or
`FreqCoordConvDownBlock`) and returns it along with its expected output
channel count and output height (assuming 2x spatial downsampling).
Parameters
----------
in_channels : int
Number of input channels to the layer.
out_channels : int
Desired number of output channels from the layer.
input_height : int
Height of the input feature map to this layer.
layer_type : DownscalingLayer
The type of layer to build ('standard' or 'coord').
Returns
-------
Tuple[nn.Module, int, int]
A tuple containing:
- The instantiated `nn.Module` layer.
- The number of output channels (`out_channels`).
- The expected output height (`input_height // 2`).
Raises
------
ValueError
If `layer_type` is invalid.
"""
if layer_type == DownscalingLayer.standard:
return (
StandardConvDownBlock(
in_channels=in_channels,
out_channels=out_channels,
),
out_channels,
input_height // 2,
)
if layer_type == DownscalingLayer.coord:
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
),
out_channels,
input_height // 2,
)
raise ValueError(
f"Invalid downscaling layer type {layer_type}. "
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
)
layers: List[EncoderLayerConfig] = Field(min_length=1)
class Encoder(nn.Module):
@ -178,12 +81,14 @@ class Encoder(nn.Module):
Attributes
----------
input_channels : int
in_channels : int
Number of channels expected in the input tensor.
input_height : int
Height (frequency bins) expected in the input tensor.
output_channels : int
Number of channels in the final output tensor (bottleneck).
output_height : int
Height (frequency bins) expected in the output tensor.
layers : nn.ModuleList
The sequence of instantiated downscaling layer modules.
depth : int
@ -193,9 +98,10 @@ class Encoder(nn.Module):
def __init__(
self,
output_channels: int,
output_height: int,
layers: List[nn.Module],
input_height: int = 128,
input_channels: int = 1,
in_channels: int = 1,
):
"""Initialize the Encoder module.
@ -206,20 +112,23 @@ class Encoder(nn.Module):
----------
output_channels : int
Number of channels produced by the final layer.
output_height : int
The expected height of the output tensor.
layers : List[nn.Module]
A list of pre-instantiated downscaling layer modules (e.g.,
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
sequence.
input_height : int, default=128
Expected height of the input tensor.
input_channels : int, default=1
in_channels : int, default=1
Expected number of channels in the input tensor.
"""
super().__init__()
self.input_channels = input_channels
self.in_channels = in_channels
self.input_height = input_height
self.output_channels = output_channels
self.out_channels = output_channels
self.output_height = output_height
self.layers = nn.ModuleList(layers)
self.depth = len(self.layers)
@ -234,7 +143,7 @@ class Encoder(nn.Module):
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
`self.input_channels` and `H_in` must match `self.input_height`.
`self.in_channels` and `H_in` must match `self.input_height`.
Returns
-------
@ -249,10 +158,10 @@ class Encoder(nn.Module):
If input tensor channel count or height does not match expected
values.
"""
if x.shape[1] != self.input_channels:
if x.shape[1] != self.in_channels:
raise ValueError(
f"Input tensor has {x.shape[1]} channels, "
f"but encoder expects {self.input_channels}."
f"but encoder expects {self.in_channels}."
)
if x.shape[2] != self.input_height:
@ -279,7 +188,7 @@ class Encoder(nn.Module):
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
`input_channels` and `input_height`.
`in_channels` and `input_height`.
Returns
-------
@ -293,10 +202,10 @@ class Encoder(nn.Module):
If input tensor channel count or height does not match expected
values.
"""
if x.shape[1] != self.input_channels:
if x.shape[1] != self.in_channels:
raise ValueError(
f"Input tensor has {x.shape[1]} channels, "
f"but encoder expects {self.input_channels}."
f"but encoder expects {self.in_channels}."
)
if x.shape[2] != self.input_height:
@ -311,19 +220,53 @@ class Encoder(nn.Module):
return x
def build_encoder(config: EncoderConfig) -> Encoder:
DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
layers=[
FreqCoordConvDownConfig(out_channels=32),
FreqCoordConvDownConfig(out_channels=64),
FreqCoordConvDownConfig(out_channels=128),
ConvConfig(out_channels=256),
],
)
"""Default configuration for the Encoder.
Specifies an architecture typically used in BatDetect2:
- Input: 1 channel, 128 frequency bins.
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
- Layer 2: FreqCoordConvDown -> 64 channels, H=32
- Layer 3: FreqCoordConvDown -> 128 channels, H=16
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck)
"""
def build_encoder(
in_channels: int,
input_height: int,
config: Optional[EncoderConfig] = None,
) -> Encoder:
"""Factory function to build an Encoder instance from configuration.
Constructs a sequential `Encoder` module based on the specifications in
an `EncoderConfig` object. It iteratively builds the specified sequence
of downscaling layers (`StandardConvDownBlock` or `FreqCoordConvDownBlock`),
tracking the changing number of channels and feature map height.
Constructs a sequential `Encoder` module based on the layer sequence
defined in an `EncoderConfig` object and the provided input dimensions.
If no config is provided, uses the default layer sequence from
`DEFAULT_ENCODER_CONFIG`.
It iteratively builds the layers using the unified
`build_layer_from_config` factory (from `.blocks`), tracking the changing
number of channels and feature map height required for each subsequent
layer, especially for coordinate- aware blocks.
Parameters
----------
config : EncoderConfig
The configuration object detailing the encoder architecture, including
input dimensions, layer types, and channel counts for each stage.
in_channels : int
The number of channels expected in the input tensor to the encoder.
Must be > 0.
input_height : int
The height (frequency bins) expected in the input tensor. Must be > 0.
Crucial for initializing coordinate-aware layers correctly.
config : EncoderConfig, optional
The configuration object detailing the sequence of layers and their
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used.
Returns
-------
@ -333,25 +276,33 @@ def build_encoder(config: EncoderConfig) -> Encoder:
Raises
------
ValueError
If the layer configuration is invalid (e.g., unknown layer type).
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `block_type`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `block_type`.
"""
current_channels = config.input_channels
current_height = config.input_height
if in_channels <= 0 or input_height <= 0:
raise ValueError("in_channels and input_height must be positive.")
config = config or DEFAULT_ENCODER_CONFIG
current_channels = in_channels
current_height = input_height
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_downscaling_layer(
layer, current_channels, current_height = build_layer_from_config(
in_channels=current_channels,
out_channels=layer_config.channels,
input_height=current_height,
layer_type=layer_config.layer_type,
config=layer_config,
)
layers.append(layer)
return Encoder(
input_height=config.input_height,
input_height=input_height,
layers=layers,
input_channels=config.input_channels,
in_channels=in_channels,
output_channels=current_channels,
output_height=current_height,
)

View File

@ -1,42 +1,199 @@
from typing import NamedTuple
"""Prediction Head modules for BatDetect2 models.
This module defines simple `torch.nn.Module` subclasses that serve as
prediction heads, typically attached to the output feature map of a backbone
network
Each head is responsible for generating one specific type of output required
by the BatDetect2 task:
- `DetectorHead`: Predicts the probability of sound event presence.
- `ClassifierHead`: Predicts the probability distribution over target classes.
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding
box.
These heads use 1x1 convolutions to map the backbone feature channels
to the desired number of output channels for each prediction task at each
spatial location, followed by an appropriate activation function (e.g., sigmoid
for detection, softmax for classification, none for size regression).
"""
import torch
from torch import nn
__all__ = ["ClassifierHead"]
class Output(NamedTuple):
detection: torch.Tensor
classification: torch.Tensor
__all__ = [
"ClassifierHead",
"DetectorHead",
"BBoxHead",
]
class ClassifierHead(nn.Module):
"""Prediction head for multi-class classification probabilities.
Takes an input feature map and produces a probability map where each
channel corresponds to a specific target class. It uses a 1x1 convolution
to map input channels to `num_classes + 1` outputs (one for each target
class plus an assumed background/generic class), applies softmax across the
channels, and returns the probabilities for the specific target classes
(excluding the last background/generic channel).
Parameters
----------
num_classes : int
The number of specific target classes the model should predict
(excluding any background or generic category). Must be positive.
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Attributes
----------
num_classes : int
Number of specific output classes.
in_channels : int
Number of input channels expected.
classifier : nn.Conv2d
The 1x1 convolutional layer used for prediction.
Output channels = num_classes + 1.
Raises
------
ValueError
If `num_classes` or `in_channels` are not positive.
"""
def __init__(self, num_classes: int, in_channels: int):
"""Initialize the ClassifierHead."""
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.classifier = nn.Conv2d(
self.in_channels,
# Add one to account for the background class
self.num_classes + 1,
kernel_size=1,
padding=0,
)
def forward(self, features: torch.Tensor) -> Output:
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute class probabilities from input features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns
-------
torch.Tensor
Class probability map tensor with shape `(B, num_classes, H, W)`.
Contains probabilities for the specific target classes after
softmax, excluding the implicit background/generic class channel.
"""
logits = self.classifier(features)
probs = torch.softmax(logits, dim=1)
detection_probs = probs[:, :-1].sum(dim=1, keepdim=True)
return Output(
detection=detection_probs,
classification=probs[:, :-1],
return probs[:, :-1]
class DetectorHead(nn.Module):
"""Prediction head for sound event detection probability.
Takes an input feature map and produces a single-channel heatmap where
each value represents the probability ([0, 1]) of a relevant sound event
(of any class) being present at that spatial location.
Uses a 1x1 convolution to map input channels to 1 output channel, followed
by a sigmoid activation function.
Parameters
----------
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Attributes
----------
in_channels : int
Number of input channels expected.
detector : nn.Conv2d
The 1x1 convolutional layer mapping to a single output channel.
Raises
------
ValueError
If `in_channels` is not positive.
"""
def __init__(self, in_channels: int):
"""Initialize the DetectorHead."""
super().__init__()
self.in_channels = in_channels
self.detector = nn.Conv2d(
in_channels=self.in_channels,
out_channels=1,
kernel_size=1,
padding=0,
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute detection probabilities from input features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns
-------
torch.Tensor
Detection probability heatmap tensor with shape `(B, 1, H, W)`.
Values are in the range [0, 1] due to the sigmoid activation.
Raises
------
RuntimeError
If input channel count does not match `self.in_channels`.
"""
return torch.sigmoid(self.detector(features))
class BBoxHead(nn.Module):
"""Prediction head for bounding box size dimensions.
Takes an input feature map and produces a two-channel map where each
channel represents a predicted size dimension (typically width/duration and
height/bandwidth) for a potential sound event at that spatial location.
Uses a 1x1 convolution to map input channels to 2 output channels. No
activation function is typically applied, as size prediction is often
treated as a direct regression task. The output values usually represent
*scaled* dimensions that need to be un-scaled during postprocessing.
Parameters
----------
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Attributes
----------
in_channels : int
Number of input channels expected.
bbox : nn.Conv2d
The 1x1 convolutional layer mapping to 2 output channels
(width, height).
Raises
------
ValueError
If `in_channels` is not positive.
"""
def __init__(self, in_channels: int):
"""Initialize the BBoxHead."""
super().__init__()
self.in_channels = in_channels
@ -48,4 +205,19 @@ class BBoxHead(nn.Module):
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute predicted bounding box dimensions from input features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns
-------
torch.Tensor
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually
represents scaled width, Channel 1 scaled height. These values
need to be un-scaled during postprocessing.
"""
return self.bbox(features)