From ce15afc2313fa57b2d4b66f9ecd44ae65926b794 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 21 Apr 2025 21:25:50 +0100 Subject: [PATCH] Restructured models module --- batdetect2/models/__init__.py | 143 +++++++-- batdetect2/models/backbones.py | 498 +++++++++++++++++++------------- batdetect2/models/blocks.py | 228 ++++++++++++++- batdetect2/models/bottleneck.py | 254 ++++++++++++++++ batdetect2/models/decoder.py | 267 +++++++++++++++++ batdetect2/models/detectors.py | 173 +++++++++++ batdetect2/models/encoder.py | 275 ++++++++---------- batdetect2/models/heads.py | 198 ++++++++++++- 8 files changed, 1635 insertions(+), 401 deletions(-) create mode 100644 batdetect2/models/bottleneck.py create mode 100644 batdetect2/models/decoder.py create mode 100644 batdetect2/models/detectors.py diff --git a/batdetect2/models/__init__.py b/batdetect2/models/__init__.py index 6d909dd..daa5ad6 100644 --- a/batdetect2/models/__init__.py +++ b/batdetect2/models/__init__.py @@ -1,26 +1,135 @@ +"""Defines and builds the neural network models used in BatDetect2. + +This package (`batdetect2.models`) contains the PyTorch implementations of the +deep neural network architectures used for detecting and classifying bat calls +from spectrograms. It provides modular components and configuration-driven +assembly, allowing for experimentation and use of different architectural +variants. + +Key Submodules: +- `.types`: Defines core data structures (`ModelOutput`) and abstract base + classes (`BackboneModel`, `DetectionModel`) establishing interfaces. +- `.blocks`: Provides reusable neural network building blocks. +- `.encoder`: Defines and builds the downsampling path (encoder) of the network. +- `.bottleneck`: Defines and builds the central bottleneck component. +- `.decoder`: Defines and builds the upsampling path (decoder) of the network. +- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete + feature extraction backbone (e.g., a U-Net like structure). +- `.heads`: Defines simple prediction heads (detection, classification, size) + that attach to the backbone features. +- `.detectors`: Assembles the backbone and prediction heads into the final, + end-to-end `Detector` model. + +This module re-exports the most important classes, configurations, and builder +functions from these submodules for convenient access. The primary entry point +for creating a standard BatDetect2 model instance is the `build_model` function +provided here. +""" + +from typing import Optional + from batdetect2.models.backbones import ( - Net2DFast, - Net2DFastNoAttn, - Net2DFastNoCoordConv, - Net2DPlain, + Backbone, + BackboneConfig, + build_backbone, + load_backbone_config, ) -from batdetect2.models.build import build_architecture -from batdetect2.models.config import ModelConfig, ModelType, load_model_config -from batdetect2.models.heads import BBoxHead, ClassifierHead -from batdetect2.models.types import BackboneModel, ModelOutput +from batdetect2.models.blocks import ( + ConvConfig, + FreqCoordConvDownConfig, + FreqCoordConvUpConfig, + StandardConvDownConfig, + StandardConvUpConfig, +) +from batdetect2.models.bottleneck import ( + Bottleneck, + BottleneckConfig, + build_bottleneck, +) +from batdetect2.models.decoder import ( + DEFAULT_DECODER_CONFIG, + DecoderConfig, + build_decoder, +) +from batdetect2.models.detectors import ( + Detector, + build_detector, +) +from batdetect2.models.encoder import ( + DEFAULT_ENCODER_CONFIG, + EncoderConfig, + build_encoder, +) +from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead +from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput __all__ = [ "BBoxHead", + "Backbone", + "BackboneConfig", "BackboneModel", + "BackboneModel", + "Bottleneck", + "BottleneckConfig", "ClassifierHead", - "ModelConfig", + "ConvConfig", + "DEFAULT_DECODER_CONFIG", + "DEFAULT_ENCODER_CONFIG", + "DecoderConfig", + "DetectionModel", + "Detector", + "DetectorHead", + "EncoderConfig", + "FreqCoordConvDownConfig", + "FreqCoordConvUpConfig", "ModelOutput", - "ModelType", - "Net2DFast", - "Net2DFastNoAttn", - "Net2DFastNoCoordConv", - "Net2DPlain", - "build_architecture", - "build_architecture", - "load_model_config", + "StandardConvDownConfig", + "StandardConvUpConfig", + "build_backbone", + "build_bottleneck", + "build_decoder", + "build_detector", + "build_encoder", + "build_model", + "load_backbone_config", ] + + +def build_model( + num_classes: int, + config: Optional[BackboneConfig] = None, +) -> DetectionModel: + """Build the complete BatDetect2 detection model. + + This high-level factory function constructs the standard BatDetect2 model + architecture. It first builds the feature extraction backbone (typically an + encoder-bottleneck-decoder structure) based on the provided + `BackboneConfig` (or defaults if None), and then attaches the standard + prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the + `build_detector` function. + + Parameters + ---------- + num_classes : int + The number of specific target classes the model should predict + (required for the `ClassifierHead`). Must be positive. + config : BackboneConfig, optional + Configuration object specifying the architecture of the backbone + (encoder, bottleneck, decoder). If None, default configurations defined + within the respective builder functions (`build_encoder`, etc.) will be + used to construct a default backbone architecture. + + Returns + ------- + DetectionModel + An initialized `Detector` model instance. + + Raises + ------ + ValueError + If `num_classes` is not positive, or if errors occur during the + construction of the backbone or detector components (e.g., incompatible + configurations, invalid parameters). + """ + backbone = build_backbone(config or BackboneConfig()) + return build_detector(num_classes, backbone) diff --git a/batdetect2/models/backbones.py b/batdetect2/models/backbones.py index d5dadee..1728b58 100644 --- a/batdetect2/models/backbones.py +++ b/batdetect2/models/backbones.py @@ -1,204 +1,194 @@ -from enum import Enum -from typing import Optional, Sequence, Tuple +"""Assembles a complete Encoder-Decoder Backbone network. + +This module defines the configuration (`BackboneConfig`) and implementation +(`Backbone`) for a standard encoder-decoder style neural network backbone. + +It orchestrates the connection between three main components, built using their +respective configurations and factory functions from sibling modules: +1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features + at multiple resolutions and provides skip connections. +2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the + lowest resolution, optionally applying self-attention. +3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high- + resolution features using bottleneck features and skip connections. + +The resulting `Backbone` module takes a spectrogram as input and outputs a +final feature map, typically used by subsequent prediction heads. It includes +automatic padding to handle input sizes not perfectly divisible by the +network's total downsampling factor. +""" + +from typing import Optional, Tuple import torch import torch.nn.functional as F from soundevent import data +from torch import nn from batdetect2.configs import BaseConfig, load_config -from batdetect2.models.blocks import ( - ConvBlock, - SelfAttention, - VerticalConv, -) -from batdetect2.models.decoder import Decoder, UpscalingLayer -from batdetect2.models.encoder import DownscalingLayer, Encoder +from batdetect2.models.blocks import ConvBlock +from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck +from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder +from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder from batdetect2.models.types import BackboneModel __all__ = [ - "Net2DFast", - "Net2DFastNoAttn", - "Net2DFastNoCoordConv", - "Net2DPlain", + "Backbone", + "BackboneConfig", + "load_backbone_config", + "build_backbone", ] -class Net2DPlain(BackboneModel): - downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard" - upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard" +class Backbone(BackboneModel): + """Encoder-Decoder Backbone Network Implementation. + + Combines an Encoder, Bottleneck, and Decoder module sequentially, using + skip connections between the Encoder and Decoder. Implements the standard + U-Net style forward pass. Includes automatic input padding to handle + various input sizes and a final convolutional block to adjust the output + channels. + + This class inherits from `BackboneModel` and implements its `forward` + method. Instances are typically created using the `build_backbone` factory + function. + + Attributes + ---------- + input_height : int + Expected height of the input spectrogram. + out_channels : int + Number of channels in the final output feature map. + encoder : Encoder + The instantiated encoder module. + decoder : Decoder + The instantiated decoder module. + bottleneck : nn.Module + The instantiated bottleneck module. + final_conv : ConvBlock + Final convolutional block applied after the decoder. + divide_factor : int + The total downsampling factor (2^depth) applied by the encoder, + used for automatic input padding. + """ def __init__( self, - input_height: int = 128, - encoder_channels: Sequence[int] = (1, 32, 64, 128), - bottleneck_channels: int = 256, - decoder_channels: Sequence[int] = (256, 64, 32, 32), - out_channels: int = 32, + input_height: int, + out_channels: int, + encoder: Encoder, + decoder: Decoder, + bottleneck: nn.Module, ): - super().__init__() + """Initialize the Backbone network. + Parameters + ---------- + input_height : int + Expected height of the input spectrogram. + out_channels : int + Desired number of output channels for the backbone's feature map. + encoder : Encoder + An initialized Encoder module. + decoder : Decoder + An initialized Decoder module. + bottleneck : nn.Module + An initialized Bottleneck module. + + Raises + ------ + ValueError + If component output/input channels or heights are incompatible. + """ + super().__init__() self.input_height = input_height - self.encoder_channels = tuple(encoder_channels) - self.decoder_channels = tuple(decoder_channels) self.out_channels = out_channels - if len(encoder_channels) != len(decoder_channels): - raise ValueError( - f"Mismatched encoder and decoder channel lists. " - f"The encoder has {len(encoder_channels)} channels " - f"(implying {len(encoder_channels) - 1} layers), " - f"while the decoder has {len(decoder_channels)} channels " - f"(implying {len(decoder_channels) - 1} layers). " - f"These lengths must be equal." - ) + self.encoder = encoder + self.decoder = decoder + self.bottleneck = bottleneck - self.divide_factor = 2 ** (len(encoder_channels) - 1) - if self.input_height % self.divide_factor != 0: - raise ValueError( - f"Input height ({self.input_height}) must be divisible by " - f"the divide factor ({self.divide_factor}). " - f"This ensures proper upscaling after downscaling to recover " - f"the original input height." - ) - - self.encoder = Encoder( - channels=encoder_channels, - input_height=self.input_height, - layer_type=self.downscaling_layer_type, - ) - - self.conv_same_1 = ConvBlock( - in_channels=encoder_channels[-1], - out_channels=bottleneck_channels, - ) - - # bottleneck - self.conv_vert = VerticalConv( - in_channels=bottleneck_channels, - out_channels=bottleneck_channels, - input_height=self.input_height // (2**self.encoder.depth), - ) - - self.decoder = Decoder( - channels=decoder_channels, - input_height=self.input_height, - layer_type=self.upscaling_layer_type, - ) - - self.conv_same_2 = ConvBlock( - in_channels=decoder_channels[-1], + self.final_conv = ConvBlock( + in_channels=decoder.out_channels, out_channels=out_channels, ) + # Down/Up scaling factor. Need to ensure inputs are divisible by + # this factor in order to be processed by the down/up scaling layers + # and recover the correct shape + self.divide_factor = input_height // self.encoder.output_height + def forward(self, spec: torch.Tensor) -> torch.Tensor: - spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor) + """Perform the forward pass through the encoder-decoder backbone. + + Applies padding, runs encoder, bottleneck, decoder (with skip + connections), removes padding, and applies a final convolution. + + Parameters + ---------- + spec : torch.Tensor + Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match + `self.encoder.input_channels` and `self.input_height`. + + Returns + ------- + torch.Tensor + Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where + `C_out` is `self.out_channels`. + """ + spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor) # encoder residuals = self.encoder(spec) - residuals[-1] = self.conv_same_1(residuals[-1]) # bottleneck - x = self.conv_vert(residuals[-1]) - x = x.repeat([1, 1, residuals[-1].shape[-2], 1]) + x = self.bottleneck(residuals[-1]) # decoder x = self.decoder(x, residuals=residuals) # Restore original size - x = restore_pad(x, h_pad=h_pad, w_pad=w_pad) + x = _restore_pad(x, h_pad=h_pad, w_pad=w_pad) - return self.conv_same_2(x) - - -class Net2DFast(Net2DPlain): - downscaling_layer_type = "ConvBlockDownCoordF" - upscaling_layer_type = "ConvBlockUpF" - - def __init__( - self, - input_height: int = 128, - encoder_channels: Sequence[int] = (1, 32, 64, 128), - bottleneck_channels: int = 256, - decoder_channels: Sequence[int] = (256, 64, 32, 32), - out_channels: int = 32, - ): - super().__init__( - input_height=input_height, - encoder_channels=encoder_channels, - bottleneck_channels=bottleneck_channels, - decoder_channels=decoder_channels, - out_channels=out_channels, - ) - - self.att = SelfAttention(bottleneck_channels, bottleneck_channels) - - def forward(self, spec: torch.Tensor) -> torch.Tensor: - spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor) - - # encoder - residuals = self.encoder(spec) - residuals[-1] = self.conv_same_1(residuals[-1]) - - # bottleneck - x = self.conv_vert(residuals[-1]) - x = self.att(x) - x = x.repeat([1, 1, residuals[-1].shape[-2], 1]) - - # decoder - x = self.decoder(x, residuals=residuals) - - # Restore original size - x = restore_pad(x, h_pad=h_pad, w_pad=w_pad) - - return self.conv_same_2(x) - - -class Net2DFastNoAttn(Net2DPlain): - downscaling_layer_type = "ConvBlockDownCoordF" - upscaling_layer_type = "ConvBlockUpF" - - -class Net2DFastNoCoordConv(Net2DFast): - downscaling_layer_type = "ConvBlockDownStandard" - upscaling_layer_type = "ConvBlockUpStandard" - - -def pad_adjust( - spec: torch.Tensor, - factor: int = 32, -) -> Tuple[torch.Tensor, int, int]: - h, w = spec.shape[2:] - h_pad = -h % factor - w_pad = -w % factor - return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad - - -def restore_pad( - x: torch.Tensor, h_pad: int = 0, w_pad: int = 0 -) -> torch.Tensor: - # Restore original size - if h_pad > 0: - x = x[:, :, :-h_pad, :] - - if w_pad > 0: - x = x[:, :, :, :-w_pad] - - return x - - -class ModelType(str, Enum): - Net2DFast = "Net2DFast" - Net2DFastNoAttn = "Net2DFastNoAttn" - Net2DFastNoCoordConv = "Net2DFastNoCoordConv" - Net2DPlain = "Net2DPlain" + return self.final_conv(x) class BackboneConfig(BaseConfig): - backbone_type: ModelType = ModelType.Net2DFast + """Configuration for the Encoder-Decoder Backbone network. + + Aggregates configurations for the encoder, bottleneck, and decoder + components, along with defining the input and final output dimensions + for the complete backbone. + + Attributes + ---------- + input_height : int, default=128 + Expected height (frequency bins) of the input spectrograms to the + backbone. Must be positive. + in_channels : int, default=1 + Expected number of channels in the input spectrograms (e.g., 1 for + mono). Must be positive. + encoder : EncoderConfig, optional + Configuration for the encoder. If None or omitted, + the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the + encoder module) will be used. + bottleneck : BottleneckConfig, optional + Configuration for the bottleneck layer connecting encoder and decoder. + If None or omitted, the default bottleneck configuration will be used. + decoder : DecoderConfig, optional + Configuration for the decoder. If None or omitted, + the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the + decoder module) will be used. + out_channels : int, default=32 + Desired number of channels in the final feature map output by the + backbone. Must be positive. + """ + input_height: int = 128 - encoder_channels: Tuple[int, ...] = (1, 32, 64, 128) - bottleneck_channels: int = 256 - decoder_channels: Tuple[int, ...] = (256, 64, 32, 32) + in_channels: int = 1 + encoder: Optional[EncoderConfig] = None + bottleneck: Optional[BottleneckConfig] = None + decoder: Optional[DecoderConfig] = None out_channels: int = 32 @@ -206,48 +196,162 @@ def load_backbone_config( path: data.PathLike, field: Optional[str] = None, ) -> BackboneConfig: + """Load the backbone configuration from a file. + + Reads a configuration file (YAML) and validates it against the + `BackboneConfig` schema, potentially extracting data from a nested field. + + Parameters + ---------- + path : PathLike + Path to the configuration file. + field : str, optional + Dot-separated path to a nested section within the file containing the + backbone configuration (e.g., "model.backbone"). If None, the entire + file content is used. + + Returns + ------- + BackboneConfig + The loaded and validated backbone configuration object. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + yaml.YAMLError + If the file content is not valid YAML. + pydantic.ValidationError + If the loaded config data does not conform to `BackboneConfig`. + KeyError, TypeError + If `field` specifies an invalid path. + """ return load_config(path, schema=BackboneConfig, field=field) -def build_model_backbone( - config: Optional[BackboneConfig] = None, -) -> BackboneModel: - config = config or BackboneConfig() +def build_backbone(config: BackboneConfig) -> BackboneModel: + """Factory function to build a Backbone from configuration. - if config.backbone_type == ModelType.Net2DFast: - return Net2DFast( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, + Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on + the provided `BackboneConfig`, validates their compatibility, and assembles + them into a `Backbone` instance. + + Parameters + ---------- + config : BackboneConfig + The configuration object detailing the backbone architecture, including + input dimensions and configurations for encoder, bottleneck, and + decoder. + + Returns + ------- + BackboneModel + An initialized `Backbone` module ready for use. + + Raises + ------ + ValueError + If sub-component configurations are incompatible + (e.g., channel mismatches, decoder output height doesn't match backbone + input height). + NotImplementedError + If an unknown block type is specified in sub-configs. + """ + encoder = build_encoder( + in_channels=config.in_channels, + input_height=config.input_height, + config=config.encoder, + ) + + bottleneck = build_bottleneck( + input_height=encoder.output_height, + in_channels=encoder.out_channels, + config=config.bottleneck, + ) + + decoder = build_decoder( + in_channels=bottleneck.out_channels, + input_height=encoder.output_height, + config=config.decoder, + ) + + if decoder.output_height != config.input_height: + raise ValueError( + "Invalid configuration: Decoder output height " + f"({decoder.output_height}) must match the Backbone input height " + f"({config.input_height}). Check encoder/decoder layer " + "configurations and input/bottleneck heights." ) - if config.backbone_type == ModelType.Net2DFastNoAttn: - return Net2DFastNoAttn( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) + return Backbone( + input_height=config.input_height, + out_channels=config.out_channels, + encoder=encoder, + decoder=decoder, + bottleneck=bottleneck, + ) - if config.backbone_type == ModelType.Net2DFastNoCoordConv: - return Net2DFastNoCoordConv( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) - if config.backbone_type == ModelType.Net2DPlain: - return Net2DPlain( - input_height=config.input_height, - encoder_channels=config.encoder_channels, - bottleneck_channels=config.bottleneck_channels, - decoder_channels=config.decoder_channels, - out_channels=config.out_channels, - ) +def _pad_adjust( + spec: torch.Tensor, + factor: int = 32, +) -> Tuple[torch.Tensor, int, int]: + """Pad tensor height and width to be divisible by a factor. - raise ValueError(f"Unknown model type: {config.backbone_type}") + Calculates the required padding for the last two dimensions (H, W) to make + them divisible by `factor` and applies right/bottom padding using + `torch.nn.functional.pad`. + + Parameters + ---------- + spec : torch.Tensor + Input tensor, typically shape `(B, C, H, W)`. + factor : int, default=32 + The factor to make height and width divisible by. + + Returns + ------- + Tuple[torch.Tensor, int, int] + A tuple containing: + - The padded tensor. + - The amount of padding added to height (`h_pad`). + - The amount of padding added to width (`w_pad`). + """ + h, w = spec.shape[2:] + h_pad = -h % factor + w_pad = -w % factor + + if h_pad == 0 and w_pad == 0: + return spec, 0, 0 + + return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad + + +def _restore_pad( + x: torch.Tensor, h_pad: int = 0, w_pad: int = 0 +) -> torch.Tensor: + """Remove padding added by _pad_adjust. + + Removes padding from the bottom and right edges of the tensor. + + Parameters + ---------- + x : torch.Tensor + Padded tensor, typically shape `(B, C, H_padded, W_padded)`. + h_pad : int, default=0 + Amount of padding previously added to the height (bottom). + w_pad : int, default=0 + Amount of padding previously added to the width (right). + + Returns + ------- + torch.Tensor + Tensor with padding removed, shape `(B, C, H_original, W_original)`. + """ + if h_pad > 0: + x = x[:, :, :-h_pad, :] + + if w_pad > 0: + x = x[:, :, :, :-w_pad] + + return x diff --git a/batdetect2/models/blocks.py b/batdetect2/models/blocks.py index ac68888..59fa7c0 100644 --- a/batdetect2/models/blocks.py +++ b/batdetect2/models/blocks.py @@ -22,14 +22,20 @@ research: These blocks can be utilized directly in custom PyTorch model definitions or assembled into larger architectures. + +A unified factory function `build_layer_from_config` allows creating instances +of these blocks based on configuration objects. """ -from typing import Tuple +from typing import Annotated, Literal, Tuple, Union import torch import torch.nn.functional as F +from pydantic import Field from torch import nn +from batdetect2.configs import BaseConfig + __all__ = [ "ConvBlock", "VerticalConv", @@ -38,6 +44,13 @@ __all__ = [ "FreqCoordConvUpBlock", "StandardConvUpBlock", "SelfAttention", + "ConvConfig", + "FreqCoordConvDownConfig", + "StandardConvDownConfig", + "FreqCoordConvUpConfig", + "StandardConvUpConfig", + "LayerConfig", + "build_layer_from_config", ] @@ -154,6 +167,22 @@ class SelfAttention(nn.Module): return op +class ConvConfig(BaseConfig): + """Configuration for a basic ConvBlock.""" + + block_type: Literal["ConvBlock"] = "ConvBlock" + """Discriminator field indicating the block type.""" + + out_channels: int + """Number of output channels.""" + + kernel_size: int = 3 + """Size of the square convolutional kernel.""" + + pad_size: int = 1 + """Padding size.""" + + class ConvBlock(nn.Module): """Basic Convolutional Block. @@ -171,8 +200,7 @@ class ConvBlock(nn.Module): kernel_size : int, default=3 Size of the square convolutional kernel. pad_size : int, default=1 - Amount of padding added to preserve spatial dimensions (assuming - stride=1 and kernel_size=3). + Amount of padding added to preserve spatial dimensions. """ def __init__( @@ -261,6 +289,22 @@ class VerticalConv(nn.Module): return F.relu_(self.bn(self.conv(x))) +class FreqCoordConvDownConfig(BaseConfig): + """Configuration for a FreqCoordConvDownBlock.""" + + block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown" + """Discriminator field indicating the block type.""" + + out_channels: int + """Number of output channels.""" + + kernel_size: int = 3 + """Size of the square convolutional kernel.""" + + pad_size: int = 1 + """Padding size.""" + + class FreqCoordConvDownBlock(nn.Module): """Downsampling Conv Block incorporating Frequency Coordinate features. @@ -289,9 +333,6 @@ class FreqCoordConvDownBlock(nn.Module): Size of the square convolutional kernel. pad_size : int, default=1 Padding added before convolution. - stride : int, default=1 - Stride of the convolution. Note: Downsampling is achieved via - MaxPool2d(2, 2). """ def __init__( @@ -301,7 +342,6 @@ class FreqCoordConvDownBlock(nn.Module): input_height: int, kernel_size: int = 3, pad_size: int = 1, - stride: int = 1, ): super().__init__() @@ -314,7 +354,7 @@ class FreqCoordConvDownBlock(nn.Module): out_channels, kernel_size=kernel_size, padding=pad_size, - stride=stride, + stride=1, ) self.conv_bn = nn.BatchNorm2d(out_channels) @@ -339,6 +379,22 @@ class FreqCoordConvDownBlock(nn.Module): return x +class StandardConvDownConfig(BaseConfig): + """Configuration for a StandardConvDownBlock.""" + + block_type: Literal["StandardConvDown"] = "StandardConvDown" + """Discriminator field indicating the block type.""" + + out_channels: int + """Number of output channels.""" + + kernel_size: int = 3 + """Size of the square convolutional kernel.""" + + pad_size: int = 1 + """Padding size.""" + + class StandardConvDownBlock(nn.Module): """Standard Downsampling Convolutional Block. @@ -357,8 +413,6 @@ class StandardConvDownBlock(nn.Module): Size of the square convolutional kernel. pad_size : int, default=1 Padding added before convolution. - stride : int, default=1 - Stride of the convolution (downsampling is done by MaxPool). """ def __init__( @@ -367,7 +421,6 @@ class StandardConvDownBlock(nn.Module): out_channels: int, kernel_size: int = 3, pad_size: int = 1, - stride: int = 1, ): super(StandardConvDownBlock, self).__init__() self.conv = nn.Conv2d( @@ -375,7 +428,7 @@ class StandardConvDownBlock(nn.Module): out_channels, kernel_size=kernel_size, padding=pad_size, - stride=stride, + stride=1, ) self.conv_bn = nn.BatchNorm2d(out_channels) @@ -396,6 +449,22 @@ class StandardConvDownBlock(nn.Module): return F.relu(self.conv_bn(x), inplace=True) +class FreqCoordConvUpConfig(BaseConfig): + """Configuration for a FreqCoordConvUpBlock.""" + + block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp" + """Discriminator field indicating the block type.""" + + out_channels: int + """Number of output channels.""" + + kernel_size: int = 3 + """Size of the square convolutional kernel.""" + + pad_size: int = 1 + """Padding size.""" + + class FreqCoordConvUpBlock(nn.Module): """Upsampling Conv Block incorporating Frequency Coordinate features. @@ -489,6 +558,22 @@ class FreqCoordConvUpBlock(nn.Module): return op +class StandardConvUpConfig(BaseConfig): + """Configuration for a StandardConvUpBlock.""" + + block_type: Literal["StandardConvUp"] = "StandardConvUp" + """Discriminator field indicating the block type.""" + + out_channels: int + """Number of output channels.""" + + kernel_size: int = 3 + """Size of the square convolutional kernel.""" + + pad_size: int = 1 + """Padding size.""" + + class StandardConvUpBlock(nn.Module): """Standard Upsampling Convolutional Block. @@ -559,3 +644,122 @@ class StandardConvUpBlock(nn.Module): op = self.conv(op) op = F.relu(self.conv_bn(op), inplace=True) return op + + +LayerConfig = Annotated[ + Union[ + ConvConfig, + FreqCoordConvDownConfig, + StandardConvDownConfig, + FreqCoordConvUpConfig, + StandardConvUpConfig, + ], + Field(discriminator="block_type"), +] +"""Type alias for the discriminated union of block configuration models.""" + + +def build_layer_from_config( + input_height: int, + in_channels: int, + config: LayerConfig, +) -> Tuple[nn.Module, int, int]: + """Factory function to build a specific nn.Module block from its config. + + Takes configuration object (one of the types included in the `LayerConfig` + union) and instantiates the corresponding nn.Module block with the correct + parameters derived from the config and the current pipeline state + (`input_height`, `in_channels`). + + It uses the `block_type` field within the `config` object to determine + which block class to instantiate. + + Parameters + ---------- + input_height : int + Height (frequency bins) of the input tensor *to this layer*. + in_channels : int + Number of channels in the input tensor *to this layer*. + config : LayerConfig + A Pydantic configuration object for the desired block (e.g., an + instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified + by its `block_type` field. + + Returns + ------- + Tuple[nn.Module, int, int] + A tuple containing: + - The instantiated `nn.Module` block. + - The number of output channels produced by the block. + - The calculated height of the output produced by the block. + + Raises + ------ + NotImplementedError + If the `config.block_type` does not correspond to a known block type. + ValueError + If parameters derived from the config are invalid for the block. + """ + if config.block_type == "ConvBlock": + return ( + ConvBlock( + in_channels=in_channels, + out_channels=config.out_channels, + kernel_size=config.kernel_size, + pad_size=config.pad_size, + ), + config.out_channels, + input_height, + ) + + if config.block_type == "FreqCoordConvDown": + return ( + FreqCoordConvDownBlock( + in_channels=in_channels, + out_channels=config.out_channels, + input_height=input_height, + kernel_size=config.kernel_size, + pad_size=config.pad_size, + ), + config.out_channels, + input_height // 2, + ) + + if config.block_type == "StandardConvDown": + return ( + StandardConvDownBlock( + in_channels=in_channels, + out_channels=config.out_channels, + kernel_size=config.kernel_size, + pad_size=config.pad_size, + ), + config.out_channels, + input_height // 2, + ) + + if config.block_type == "FreqCoordConvUp": + return ( + FreqCoordConvUpBlock( + in_channels=in_channels, + out_channels=config.out_channels, + input_height=input_height, + kernel_size=config.kernel_size, + pad_size=config.pad_size, + ), + config.out_channels, + input_height * 2, + ) + + if config.block_type == "StandardConvUp": + return ( + StandardConvUpBlock( + in_channels=in_channels, + out_channels=config.out_channels, + kernel_size=config.kernel_size, + pad_size=config.pad_size, + ), + config.out_channels, + input_height * 2, + ) + + raise NotImplementedError(f"Unknown block type {config.block_type}") diff --git a/batdetect2/models/bottleneck.py b/batdetect2/models/bottleneck.py new file mode 100644 index 0000000..d93ea55 --- /dev/null +++ b/batdetect2/models/bottleneck.py @@ -0,0 +1,254 @@ +"""Defines the Bottleneck component of an Encoder-Decoder architecture. + +This module provides the configuration (`BottleneckConfig`) and +`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the +bottleneck layer(s) that typically connect the Encoder (downsampling path) and +Decoder (upsampling path) in networks like U-Nets. + +The bottleneck processes the lowest-resolution, highest-dimensionality feature +map produced by the Encoder. This module offers a configurable option to include +a `SelfAttention` layer within the bottleneck, allowing the model to capture +global temporal context before features are passed to the Decoder. + +A factory function `build_bottleneck` constructs the appropriate bottleneck +module based on the provided configuration. +""" + +from typing import Optional + +import torch +from torch import nn + +from batdetect2.configs import BaseConfig +from batdetect2.models.blocks import SelfAttention, VerticalConv + +__all__ = [ + "BottleneckConfig", + "Bottleneck", + "BottleneckAttn", + "build_bottleneck", +] + + +class BottleneckConfig(BaseConfig): + """Configuration for the bottleneck layer(s). + + Defines the number of channels within the bottleneck and whether to include + a self-attention mechanism. + + Attributes + ---------- + channels : int + The number of output channels produced by the main convolutional layer + within the bottleneck. This often matches the number of channels coming + from the last encoder stage, but can be different. Must be positive. + This also defines the channel dimensions used within the optional + `SelfAttention` layer. + self_attention : bool + If True, includes a `SelfAttention` layer operating on the time + dimension after an initial `VerticalConv` layer within the bottleneck. + If False, only the initial `VerticalConv` (and height repetition) is + performed. + """ + + channels: int + self_attention: bool + + +class Bottleneck(nn.Module): + """Base Bottleneck module for Encoder-Decoder architectures. + + This implementation represents the simplest bottleneck structure + considered, primarily consisting of a `VerticalConv` layer. This layer + collapses the frequency dimension (height) to 1, summarizing information + across frequencies at each time step. The output is then repeated along the + height dimension to match the original bottleneck input height before being + passed to the decoder. + + This base version does *not* include self-attention. + + Parameters + ---------- + input_height : int + Height (frequency bins) of the input tensor. Must be positive. + in_channels : int + Number of channels in the input tensor from the encoder. Must be + positive. + out_channels : int + Number of output channels. Must be positive. + + Attributes + ---------- + in_channels : int + Number of input channels accepted by the bottleneck. + input_height : int + Expected height of the input tensor. + channels : int + Number of output channels. + conv_vert : VerticalConv + The vertical convolution layer. + + Raises + ------ + ValueError + If `input_height`, `in_channels`, or `out_channels` are not positive. + """ + + def __init__( + self, + input_height: int, + in_channels: int, + out_channels: int, + ) -> None: + """Initialize the base Bottleneck layer.""" + super().__init__() + self.in_channels = in_channels + self.input_height = input_height + self.out_channels = out_channels + + self.conv_vert = VerticalConv( + in_channels=in_channels, + out_channels=out_channels, + input_height=input_height, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Process input features through the bottleneck. + + Applies vertical convolution and repeats the output height. + + Parameters + ---------- + x : torch.Tensor + Input tensor from the encoder bottleneck, shape + `(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`, + `H_in` must match `self.input_height`. + + Returns + ------- + torch.Tensor + Output tensor, shape `(B, C_out, H_in, W)`. Note that the height + dimension `H_in` is restored via repetition after the vertical + convolution. + """ + x = self.conv_vert(x) + return x.repeat([1, 1, self.input_height, 1]) + + +class BottleneckAttn(Bottleneck): + """Bottleneck module including a Self-Attention layer. + + Extends the base `Bottleneck` by inserting a `SelfAttention` layer after + the initial `VerticalConv`. This allows the bottleneck to capture global + temporal dependencies in the summarized frequency features before passing + them to the decoder. + + Sequence: VerticalConv -> SelfAttention -> Repeat Height. + + Parameters + ---------- + input_height : int + Height (frequency bins) of the input tensor from the encoder. + in_channels : int + Number of channels in the input tensor from the encoder. + out_channels : int + Number of output channels produced by the `VerticalConv` and + subsequently processed and output by this bottleneck. Also determines + the input/output channels of the internal `SelfAttention` layer. + attention : nn.Module + An initialized `SelfAttention` module instance. + + Raises + ------ + ValueError + If `input_height`, `in_channels`, or `out_channels` are not positive. + """ + + def __init__( + self, + input_height: int, + in_channels: int, + out_channels: int, + attention: nn.Module, + ) -> None: + """Initialize the Bottleneck with Self-Attention.""" + super().__init__(input_height, in_channels, out_channels) + self.attention = attention + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Process input tensor. + + Parameters + ---------- + x : torch.Tensor + Input tensor from the encoder bottleneck, shape + `(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`, + `H_in` must match `self.input_height`. + + Returns + ------- + torch.Tensor + Output tensor, shape `(B, C_out, H_in, W)`, after applying attention + and repeating the height dimension. + """ + x = self.conv_vert(x) + x = self.attention(x) + return x.repeat([1, 1, self.input_height, 1]) + + +DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig( + channels=256, + self_attention=True, +) + + +def build_bottleneck( + input_height: int, + in_channels: int, + config: Optional[BottleneckConfig] = None, +) -> nn.Module: + """Factory function to build the Bottleneck module from configuration. + + Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based + on the `config.self_attention` flag. + + Parameters + ---------- + input_height : int + Height (frequency bins) of the input tensor. Must be positive. + in_channels : int + Number of channels in the input tensor. Must be positive. + config : BottleneckConfig, optional + Configuration object specifying the bottleneck channels and whether + to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None. + + Returns + ------- + nn.Module + An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`). + + Raises + ------ + ValueError + If `input_height` or `in_channels` are not positive. + """ + config = config or DEFAULT_BOTTLENECK_CONFIG + + if config.self_attention: + attention = SelfAttention( + in_channels=config.channels, + attention_channels=config.channels, + ) + + return BottleneckAttn( + input_height=input_height, + in_channels=in_channels, + out_channels=config.channels, + attention=attention, + ) + + return Bottleneck( + input_height=input_height, + in_channels=in_channels, + out_channels=config.channels, + ) diff --git a/batdetect2/models/decoder.py b/batdetect2/models/decoder.py new file mode 100644 index 0000000..9760ede --- /dev/null +++ b/batdetect2/models/decoder.py @@ -0,0 +1,267 @@ +"""Constructs the Decoder part of an Encoder-Decoder neural network. + +This module defines the configuration structure (`DecoderConfig`) for the layer +sequence and provides the `Decoder` class (an `nn.Module`) along with a factory +function (`build_decoder`). Decoders typically form the upsampling path in +architectures like U-Nets, taking bottleneck features +(usually from an `Encoder`) and skip connections to reconstruct +higher-resolution feature maps. + +The decoder is built dynamically by stacking neural network blocks based on a +list of configuration objects provided in `DecoderConfig.layers`. Each config +object specifies the type of block (e.g., standard convolution, +coordinate-feature convolution with upsampling) and its parameters. This allows +flexible definition of decoder architectures via configuration files. + +The `Decoder`'s `forward` method is designed to accept skip connection tensors +(`residuals`) from the encoder, merging them with the upsampled feature maps +at each stage. +""" + +from typing import Annotated, List, Optional, Union + +import torch +from pydantic import Field +from torch import nn + +from batdetect2.configs import BaseConfig +from batdetect2.models.blocks import ( + ConvConfig, + FreqCoordConvUpConfig, + StandardConvUpConfig, + build_layer_from_config, +) + +__all__ = [ + "DecoderConfig", + "Decoder", + "build_decoder", + "DEFAULT_DECODER_CONFIG", +] + +DecoderLayerConfig = Annotated[ + Union[ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig], + Field(discriminator="block_type"), +] +"""Type alias for the discriminated union of block configs usable in Decoder.""" + + +class DecoderConfig(BaseConfig): + """Configuration for the sequence of layers in the Decoder module. + + Defines the types and parameters of the neural network blocks that + constitute the decoder's upsampling path. + + Attributes + ---------- + layers : List[DecoderLayerConfig] + An ordered list of configuration objects, each defining one layer or + block in the decoder sequence. Each item must be a valid block + config including a `block_type` field and necessary parameters like + `out_channels`. Input channels for each layer are inferred sequentially. + The list must contain at least one layer. + """ + + layers: List[DecoderLayerConfig] = Field(min_length=1) + + +class Decoder(nn.Module): + """Sequential Decoder module composed of configurable upsampling layers. + + Constructs the upsampling path of an encoder-decoder network by stacking + multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`) + based on a list of layer modules provided during initialization (typically + created by the `build_decoder` factory function). + + The `forward` method is designed to integrate skip connection tensors + (`residuals`) from the corresponding encoder stages, by adding them + element-wise to the input of each decoder layer before processing. + + Attributes + ---------- + in_channels : int + Number of channels expected in the input tensor. + out_channels : int + Number of channels in the final output tensor produced by the last + layer. + input_height : int + Height (frequency bins) expected in the input tensor. + output_height : int + Height (frequency bins) expected in the output tensor. + layers : nn.ModuleList + The sequence of instantiated upscaling layer modules. + depth : int + The number of upscaling layers (depth) in the decoder. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + input_height: int, + output_height: int, + layers: List[nn.Module], + ): + """Initialize the Decoder module. + + Note: This constructor is typically called internally by the + `build_decoder` factory function. + + Parameters + ---------- + out_channels : int + Number of channels produced by the final layer. + input_height : int + Expected height of the input tensor (bottleneck). + in_channels : int + Expected number of channels in the input tensor (bottleneck). + layers : List[nn.Module] + A list of pre-instantiated upscaling layer modules (e.g., + `StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired + sequence (from bottleneck towards output resolution). + """ + super().__init__() + + self.input_height = input_height + self.output_height = output_height + + self.in_channels = in_channels + self.out_channels = out_channels + + self.layers = nn.ModuleList(layers) + self.depth = len(self.layers) + + def forward( + self, + x: torch.Tensor, + residuals: List[torch.Tensor], + ) -> torch.Tensor: + """Pass input through decoder layers, incorporating skip connections. + + Processes the input tensor `x` sequentially through the upscaling + layers. At each stage, the corresponding skip connection tensor from + the `residuals` list is added element-wise to the input before passing + it to the upscaling block. + + Parameters + ---------- + x : torch.Tensor + Input tensor from the previous stage (e.g., encoder bottleneck). + Shape `(B, C_in, H_in, W_in)`, where `C_in` matches + `self.in_channels`. + residuals : List[torch.Tensor] + List containing the skip connection tensors from the corresponding + encoder stages. Should be ordered from the deepest encoder layer + output (lowest resolution) to the shallowest (highest resolution + near input). The number of tensors in this list must match the + number of decoder layers (`self.depth`). Each residual tensor's + channel count must be compatible with the input tensor `x` for + element-wise addition (or concatenation if the blocks were designed + for it). + + Returns + ------- + torch.Tensor + The final decoded feature map tensor produced by the last layer. + Shape `(B, C_out, H_out, W_out)`. + + Raises + ------ + ValueError + If the number of `residuals` provided does not match the decoder + depth. + RuntimeError + If shapes mismatch during skip connection addition or layer + processing. + """ + if len(residuals) != len(self.layers): + raise ValueError( + f"Incorrect number of residuals provided. " + f"Expected {len(self.layers)} (matching the number of layers), " + f"but got {len(residuals)}." + ) + + for layer, res in zip(self.layers, residuals[::-1]): + x = layer(x + res) + + return x + + +DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig( + layers=[ + FreqCoordConvUpConfig(out_channels=64), + FreqCoordConvUpConfig(out_channels=32), + FreqCoordConvUpConfig(out_channels=32), + ConvConfig(out_channels=32), + ], +) +"""A default configuration for the Decoder's *layer sequence*. + +Specifies an architecture often used in BatDetect2, consisting of three +frequency coordinate-aware upsampling blocks followed by a standard +convolutional block. +""" + + +def build_decoder( + in_channels: int, + input_height: int, + config: Optional[DecoderConfig] = None, +) -> Decoder: + """Factory function to build a Decoder instance from configuration. + + Constructs a sequential `Decoder` module based on the layer sequence + defined in a `DecoderConfig` object and the provided input dimensions + (bottleneck channels and height). If no config is provided, uses the + default layer sequence from `DEFAULT_DECODER_CONFIG`. + + It iteratively builds the layers using the unified `build_layer_from_config` + factory (from `.blocks`), tracking the changing number of channels and + feature map height required for each subsequent layer. + + Parameters + ---------- + in_channels : int + The number of channels in the input tensor to the decoder. Must be > 0. + input_height : int + The height (frequency bins) of the input tensor to the decoder. Must be + > 0. + config : DecoderConfig, optional + The configuration object detailing the sequence of layers and their + parameters. If None, `DEFAULT_DECODER_CONFIG` is used. + + Returns + ------- + Decoder + An initialized `Decoder` module. + + Raises + ------ + ValueError + If `in_channels` or `input_height` are not positive, or if the layer + configuration is invalid (e.g., empty list, unknown `block_type`). + NotImplementedError + If `build_layer_from_config` encounters an unknown `block_type`. + """ + config = config or DEFAULT_DECODER_CONFIG + + current_channels = in_channels + current_height = input_height + + layers = [] + + for layer_config in config.layers: + layer, current_channels, current_height = build_layer_from_config( + in_channels=current_channels, + input_height=current_height, + config=layer_config, + ) + layers.append(layer) + + return Decoder( + in_channels=in_channels, + out_channels=current_channels, + input_height=input_height, + output_height=current_height, + layers=layers, + ) diff --git a/batdetect2/models/detectors.py b/batdetect2/models/detectors.py new file mode 100644 index 0000000..c5ab691 --- /dev/null +++ b/batdetect2/models/detectors.py @@ -0,0 +1,173 @@ +"""Assembles the complete BatDetect2 Detection Model. + +This module defines the concrete `Detector` class, which implements the +`DetectionModel` interface defined in `.types`. It combines a feature +extraction backbone with specific prediction heads to create the end-to-end +neural network used for detecting bat calls, predicting their size, and +classifying them. + +The primary components are: +- `Detector`: The `torch.nn.Module` subclass representing the complete model. +- `build_detector`: A factory function to conveniently construct a standard + `Detector` instance given a backbone and the number of target classes. + +This module focuses purely on the neural network architecture definition. The +logic for preprocessing inputs and postprocessing/decoding outputs resides in +the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. +""" + +import torch + +from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead +from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput + + +class Detector(DetectionModel): + """Concrete implementation of the BatDetect2 Detection Model. + + Assembles a complete detection and classification model by combining a + feature extraction backbone network with specific prediction heads for + detection probability, bounding box size regression, and class + probabilities. + + Attributes + ---------- + backbone : BackboneModel + The feature extraction backbone network module. + num_classes : int + The number of specific target classes the model predicts (derived from + the `classifier_head`). + classifier_head : ClassifierHead + The prediction head responsible for generating class probabilities. + detector_head : DetectorHead + The prediction head responsible for generating detection probabilities. + bbox_head : BBoxHead + The prediction head responsible for generating bounding box size + predictions. + """ + + backbone: BackboneModel + + def __init__( + self, + backbone: BackboneModel, + classifier_head: ClassifierHead, + detector_head: DetectorHead, + bbox_head: BBoxHead, + ): + """Initialize the Detector model. + + Note: Instances are typically created using the `build_detector` + factory function. + + Parameters + ---------- + backbone : BackboneModel + An initialized feature extraction backbone module (e.g., built by + `build_backbone` from the `.backbone` module). + classifier_head : ClassifierHead + An initialized classification head module. The number of classes + is inferred from this head. + detector_head : DetectorHead + An initialized detection head module. + bbox_head : BBoxHead + An initialized bounding box size prediction head module. + + Raises + ------ + TypeError + If the provided modules are not of the expected types. + """ + super().__init__() + + self.backbone = backbone + self.num_classes = classifier_head.num_classes + self.classifier_head = classifier_head + self.detector_head = detector_head + self.bbox_head = bbox_head + + def forward(self, spec: torch.Tensor) -> ModelOutput: + """Perform the forward pass of the complete detection model. + + Processes the input spectrogram through the backbone to extract + features, then passes these features through the separate prediction + heads to generate detection probabilities, class probabilities, and + size predictions. + + Parameters + ---------- + spec : torch.Tensor + Input spectrogram tensor, typically with shape + `(batch_size, input_channels, frequency_bins, time_bins)`. The + shape must be compatible with the `self.backbone` input + requirements. + + Returns + ------- + ModelOutput + A NamedTuple containing the four output tensors: + - `detection_probs`: Detection probability heatmap `(B, 1, H, W)`. + - `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`. + - `class_probs`: Class probabilities (excluding background) + `(B, num_classes, H, W)`. + - `features`: Output feature map from the backbone + `(B, C_out, H, W)`. + """ + features = self.backbone(spec) + detection = self.detector_head(features) + classification = self.classifier_head(features) + size_preds = self.bbox_head(features) + return ModelOutput( + detection_probs=detection, + size_preds=size_preds, + class_probs=classification, + features=features, + ) + + +def build_detector(num_classes: int, backbone: BackboneModel) -> Detector: + """Factory function to build a standard Detector model instance. + + Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`, + `BBoxHead`) configured appropriately based on the output channels of the + provided `backbone` and the specified `num_classes`. It then assembles + these components into a `Detector` model. + + Parameters + ---------- + num_classes : int + The number of specific target classes for the classification head + (excluding any implicit background class). Must be positive. + backbone : BackboneModel + An initialized feature extraction backbone module instance. The number + of output channels from this backbone (`backbone.out_channels`) is used + to configure the input channels for the prediction heads. + + Returns + ------- + Detector + An initialized `Detector` model instance. + + Raises + ------ + ValueError + If `num_classes` is not positive. + AttributeError + If `backbone` does not have the required `out_channels` attribute. + """ + classifier_head = ClassifierHead( + num_classes=num_classes, + in_channels=backbone.out_channels, + ) + detector_head = DetectorHead( + in_channels=backbone.out_channels, + ) + bbox_head = BBoxHead( + in_channels=backbone.out_channels, + ) + return Detector( + backbone=backbone, + classifier_head=classifier_head, + detector_head=detector_head, + bbox_head=bbox_head, + ) diff --git a/batdetect2/models/encoder.py b/batdetect2/models/encoder.py index f522565..7f2600e 100644 --- a/batdetect2/models/encoder.py +++ b/batdetect2/models/encoder.py @@ -1,25 +1,26 @@ -"""Constructs the Encoder part of an Encoder-Decoder neural network. +"""Constructs the Encoder part of a configurable neural network backbone. This module defines the configuration structure (`EncoderConfig`) and provides the `Encoder` class (an `nn.Module`) along with a factory function -(`build_encoder`) to create sequential encoders commonly used as the -downsampling path in architectures like U-Nets for spectrogram analysis. +(`build_encoder`) to create sequential encoders. Encoders typically form the +downsampling path in architectures like U-Nets, processing input feature maps +(like spectrograms) to produce lower-resolution, higher-dimensionality feature +representations (bottleneck features). -The encoder is built by stacking configurable downscaling blocks. Two types -of downscaling blocks are supported, selectable via the configuration: -- `StandardConvDownBlock`: A basic Conv2d -> MaxPool2d -> BN -> ReLU block. -- `FreqCoordConvDownBlock`: A similar block that incorporates frequency - coordinate information (CoordF) before the convolution to potentially aid - spatial awareness along the frequency axis. +The encoder is built dynamically by stacking neural network blocks based on a +list of configuration objects provided in `EncoderConfig.layers`. Each +configuration object specifies the type of block (e.g., standard convolution, +coordinate-feature convolution with downsampling) and its parameters +(e.g., output channels). This allows for flexible definition of encoder +architectures via configuration files. -The `Encoder`'s `forward` method provides access to intermediate feature maps -from each stage, suitable for use as skip connections in a corresponding -Decoder. A separate `encode` method returns only the final output (bottleneck) -features. +The `Encoder`'s `forward` method returns outputs from all intermediate layers, +suitable for skip connections, while the `encode` method returns only the final +bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also +provided. """ -from enum import Enum -from typing import List +from typing import Annotated, List, Optional, Union import torch from pydantic import Field @@ -27,142 +28,44 @@ from torch import nn from batdetect2.configs import BaseConfig from batdetect2.models.blocks import ( - FreqCoordConvDownBlock, - StandardConvDownBlock, + ConvConfig, + FreqCoordConvDownConfig, + StandardConvDownConfig, + build_layer_from_config, ) __all__ = [ - "DownscalingLayer", - "EncoderLayer", "EncoderConfig", "Encoder", "build_encoder", + "DEFAULT_ENCODER_CONFIG", ] - -class DownscalingLayer(str, Enum): - """Enumeration of available downscaling layer types for the Encoder. - - Used in configuration to specify which block implementation to use at each - stage of the encoder. - - Attributes - ---------- - standard : str - Identifier for the `StandardConvDownBlock`. - coord : str - Identifier for the `FreqCoordConvDownBlock` (incorporates frequency - coords). - """ - - standard = "ConvBlockDownStandard" - coord = "FreqCoordConvDownBlock" - - -class EncoderLayer(BaseConfig): - """Configuration for a single layer within the Encoder sequence. - - Attributes - ---------- - layer_type : DownscalingLayer - Specifies the type of downscaling block to use for this layer - (either 'standard' or 'coord'). - channels : int - The number of output channels this layer should produce. Must be > 0. - """ - - layer_type: DownscalingLayer - channels: int +EncoderLayerConfig = Annotated[ + Union[ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig], + Field(discriminator="block_type"), +] +"""Type alias for the discriminated union of block configs usable in Encoder.""" class EncoderConfig(BaseConfig): - """Configuration for building the entire sequential Encoder. + """Configuration for building the sequential Encoder module. + + Defines the sequence of neural network blocks that constitute the encoder + (downsampling path). Attributes ---------- - input_height : int - The expected height (number of frequency bins) of the input spectrogram - tensor fed into the first layer of the encoder. Required for - calculating intermediate heights, especially for CoordF layers. Must be - > 0. - layers : List[EncoderLayer] - An ordered list defining the sequence of downscaling layers in the - encoder. Each item specifies the layer type and its output channel - count. The number of input channels for each layer is inferred from the - previous layer's output channels (or `input_channels` for the first - layer). Must contain at least one layer definition. - input_channels : int, default=1 - The number of channels in the initial input tensor to the encoder - (e.g., 1 for a standard single-channel spectrogram). Must be > 0. + layers : List[EncoderLayerConfig] + An ordered list of configuration objects, each defining one layer or + block in the encoder sequence. Each item must be a valid block config + (e.g., `ConvConfig`, `FreqCoordConvDownConfig`, + `StandardConvDownConfig`) including a `block_type` field and necessary + parameters like `out_channels`. Input channels for each layer are + inferred sequentially. The list must contain at least one layer. """ - input_height: int = Field(gt=0) - layers: List[EncoderLayer] = Field(min_length=1) - input_channels: int = Field(gt=0) - - -def build_downscaling_layer( - in_channels: int, - out_channels: int, - input_height: int, - layer_type: DownscalingLayer, -) -> tuple[nn.Module, int, int]: - """Build a single downscaling layer based on configuration. - - Internal factory function used by `build_encoder`. Instantiates the - appropriate downscaling block (`StandardConvDownBlock` or - `FreqCoordConvDownBlock`) and returns it along with its expected output - channel count and output height (assuming 2x spatial downsampling). - - Parameters - ---------- - in_channels : int - Number of input channels to the layer. - out_channels : int - Desired number of output channels from the layer. - input_height : int - Height of the input feature map to this layer. - layer_type : DownscalingLayer - The type of layer to build ('standard' or 'coord'). - - Returns - ------- - Tuple[nn.Module, int, int] - A tuple containing: - - The instantiated `nn.Module` layer. - - The number of output channels (`out_channels`). - - The expected output height (`input_height // 2`). - - Raises - ------ - ValueError - If `layer_type` is invalid. - """ - if layer_type == DownscalingLayer.standard: - return ( - StandardConvDownBlock( - in_channels=in_channels, - out_channels=out_channels, - ), - out_channels, - input_height // 2, - ) - - if layer_type == DownscalingLayer.coord: - return ( - FreqCoordConvDownBlock( - in_channels=in_channels, - out_channels=out_channels, - input_height=input_height, - ), - out_channels, - input_height // 2, - ) - - raise ValueError( - f"Invalid downscaling layer type {layer_type}. " - f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard" - ) + layers: List[EncoderLayerConfig] = Field(min_length=1) class Encoder(nn.Module): @@ -178,12 +81,14 @@ class Encoder(nn.Module): Attributes ---------- - input_channels : int + in_channels : int Number of channels expected in the input tensor. input_height : int Height (frequency bins) expected in the input tensor. output_channels : int Number of channels in the final output tensor (bottleneck). + output_height : int + Height (frequency bins) expected in the output tensor. layers : nn.ModuleList The sequence of instantiated downscaling layer modules. depth : int @@ -193,9 +98,10 @@ class Encoder(nn.Module): def __init__( self, output_channels: int, + output_height: int, layers: List[nn.Module], input_height: int = 128, - input_channels: int = 1, + in_channels: int = 1, ): """Initialize the Encoder module. @@ -206,20 +112,23 @@ class Encoder(nn.Module): ---------- output_channels : int Number of channels produced by the final layer. + output_height : int + The expected height of the output tensor. layers : List[nn.Module] A list of pre-instantiated downscaling layer modules (e.g., `StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired sequence. input_height : int, default=128 Expected height of the input tensor. - input_channels : int, default=1 + in_channels : int, default=1 Expected number of channels in the input tensor. """ super().__init__() - self.input_channels = input_channels + self.in_channels = in_channels self.input_height = input_height - self.output_channels = output_channels + self.out_channels = output_channels + self.output_height = output_height self.layers = nn.ModuleList(layers) self.depth = len(self.layers) @@ -234,7 +143,7 @@ class Encoder(nn.Module): ---------- x : torch.Tensor Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match - `self.input_channels` and `H_in` must match `self.input_height`. + `self.in_channels` and `H_in` must match `self.input_height`. Returns ------- @@ -249,10 +158,10 @@ class Encoder(nn.Module): If input tensor channel count or height does not match expected values. """ - if x.shape[1] != self.input_channels: + if x.shape[1] != self.in_channels: raise ValueError( f"Input tensor has {x.shape[1]} channels, " - f"but encoder expects {self.input_channels}." + f"but encoder expects {self.in_channels}." ) if x.shape[2] != self.input_height: @@ -279,7 +188,7 @@ class Encoder(nn.Module): ---------- x : torch.Tensor Input tensor, shape `(B, C_in, H_in, W)`. Must match expected - `input_channels` and `input_height`. + `in_channels` and `input_height`. Returns ------- @@ -293,10 +202,10 @@ class Encoder(nn.Module): If input tensor channel count or height does not match expected values. """ - if x.shape[1] != self.input_channels: + if x.shape[1] != self.in_channels: raise ValueError( f"Input tensor has {x.shape[1]} channels, " - f"but encoder expects {self.input_channels}." + f"but encoder expects {self.in_channels}." ) if x.shape[2] != self.input_height: @@ -311,19 +220,53 @@ class Encoder(nn.Module): return x -def build_encoder(config: EncoderConfig) -> Encoder: +DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig( + layers=[ + FreqCoordConvDownConfig(out_channels=32), + FreqCoordConvDownConfig(out_channels=64), + FreqCoordConvDownConfig(out_channels=128), + ConvConfig(out_channels=256), + ], +) +"""Default configuration for the Encoder. + +Specifies an architecture typically used in BatDetect2: +- Input: 1 channel, 128 frequency bins. +- Layer 1: FreqCoordConvDown -> 32 channels, H=64 +- Layer 2: FreqCoordConvDown -> 64 channels, H=32 +- Layer 3: FreqCoordConvDown -> 128 channels, H=16 +- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck) +""" + + +def build_encoder( + in_channels: int, + input_height: int, + config: Optional[EncoderConfig] = None, +) -> Encoder: """Factory function to build an Encoder instance from configuration. - Constructs a sequential `Encoder` module based on the specifications in - an `EncoderConfig` object. It iteratively builds the specified sequence - of downscaling layers (`StandardConvDownBlock` or `FreqCoordConvDownBlock`), - tracking the changing number of channels and feature map height. + Constructs a sequential `Encoder` module based on the layer sequence + defined in an `EncoderConfig` object and the provided input dimensions. + If no config is provided, uses the default layer sequence from + `DEFAULT_ENCODER_CONFIG`. + + It iteratively builds the layers using the unified + `build_layer_from_config` factory (from `.blocks`), tracking the changing + number of channels and feature map height required for each subsequent + layer, especially for coordinate- aware blocks. Parameters ---------- - config : EncoderConfig - The configuration object detailing the encoder architecture, including - input dimensions, layer types, and channel counts for each stage. + in_channels : int + The number of channels expected in the input tensor to the encoder. + Must be > 0. + input_height : int + The height (frequency bins) expected in the input tensor. Must be > 0. + Crucial for initializing coordinate-aware layers correctly. + config : EncoderConfig, optional + The configuration object detailing the sequence of layers and their + parameters. If None, `DEFAULT_ENCODER_CONFIG` is used. Returns ------- @@ -333,25 +276,33 @@ def build_encoder(config: EncoderConfig) -> Encoder: Raises ------ ValueError - If the layer configuration is invalid (e.g., unknown layer type). + If `in_channels` or `input_height` are not positive, or if the layer + configuration is invalid (e.g., empty list, unknown `block_type`). + NotImplementedError + If `build_layer_from_config` encounters an unknown `block_type`. """ - current_channels = config.input_channels - current_height = config.input_height + if in_channels <= 0 or input_height <= 0: + raise ValueError("in_channels and input_height must be positive.") + + config = config or DEFAULT_ENCODER_CONFIG + + current_channels = in_channels + current_height = input_height layers = [] for layer_config in config.layers: - layer, current_channels, current_height = build_downscaling_layer( + layer, current_channels, current_height = build_layer_from_config( in_channels=current_channels, - out_channels=layer_config.channels, input_height=current_height, - layer_type=layer_config.layer_type, + config=layer_config, ) layers.append(layer) return Encoder( - input_height=config.input_height, + input_height=input_height, layers=layers, - input_channels=config.input_channels, + in_channels=in_channels, output_channels=current_channels, + output_height=current_height, ) diff --git a/batdetect2/models/heads.py b/batdetect2/models/heads.py index 5d7ce3f..1fd9f8e 100644 --- a/batdetect2/models/heads.py +++ b/batdetect2/models/heads.py @@ -1,42 +1,199 @@ -from typing import NamedTuple +"""Prediction Head modules for BatDetect2 models. + +This module defines simple `torch.nn.Module` subclasses that serve as +prediction heads, typically attached to the output feature map of a backbone +network + +Each head is responsible for generating one specific type of output required +by the BatDetect2 task: +- `DetectorHead`: Predicts the probability of sound event presence. +- `ClassifierHead`: Predicts the probability distribution over target classes. +- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding +box. + +These heads use 1x1 convolutions to map the backbone feature channels +to the desired number of output channels for each prediction task at each +spatial location, followed by an appropriate activation function (e.g., sigmoid +for detection, softmax for classification, none for size regression). +""" import torch from torch import nn -__all__ = ["ClassifierHead"] - - -class Output(NamedTuple): - detection: torch.Tensor - classification: torch.Tensor +__all__ = [ + "ClassifierHead", + "DetectorHead", + "BBoxHead", +] class ClassifierHead(nn.Module): + """Prediction head for multi-class classification probabilities. + + Takes an input feature map and produces a probability map where each + channel corresponds to a specific target class. It uses a 1x1 convolution + to map input channels to `num_classes + 1` outputs (one for each target + class plus an assumed background/generic class), applies softmax across the + channels, and returns the probabilities for the specific target classes + (excluding the last background/generic channel). + + Parameters + ---------- + num_classes : int + The number of specific target classes the model should predict + (excluding any background or generic category). Must be positive. + in_channels : int + Number of channels in the input feature map tensor from the backbone. + Must be positive. + + Attributes + ---------- + num_classes : int + Number of specific output classes. + in_channels : int + Number of input channels expected. + classifier : nn.Conv2d + The 1x1 convolutional layer used for prediction. + Output channels = num_classes + 1. + + Raises + ------ + ValueError + If `num_classes` or `in_channels` are not positive. + """ + def __init__(self, num_classes: int, in_channels: int): + """Initialize the ClassifierHead.""" super().__init__() self.num_classes = num_classes self.in_channels = in_channels self.classifier = nn.Conv2d( self.in_channels, - # Add one to account for the background class self.num_classes + 1, kernel_size=1, padding=0, ) - def forward(self, features: torch.Tensor) -> Output: + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Compute class probabilities from input features. + + Parameters + ---------- + features : torch.Tensor + Input feature map tensor from the backbone, typically with shape + `(B, C_in, H, W)`. `C_in` must match `self.in_channels`. + + Returns + ------- + torch.Tensor + Class probability map tensor with shape `(B, num_classes, H, W)`. + Contains probabilities for the specific target classes after + softmax, excluding the implicit background/generic class channel. + """ logits = self.classifier(features) probs = torch.softmax(logits, dim=1) - detection_probs = probs[:, :-1].sum(dim=1, keepdim=True) - return Output( - detection=detection_probs, - classification=probs[:, :-1], + return probs[:, :-1] + + +class DetectorHead(nn.Module): + """Prediction head for sound event detection probability. + + Takes an input feature map and produces a single-channel heatmap where + each value represents the probability ([0, 1]) of a relevant sound event + (of any class) being present at that spatial location. + + Uses a 1x1 convolution to map input channels to 1 output channel, followed + by a sigmoid activation function. + + Parameters + ---------- + in_channels : int + Number of channels in the input feature map tensor from the backbone. + Must be positive. + + Attributes + ---------- + in_channels : int + Number of input channels expected. + detector : nn.Conv2d + The 1x1 convolutional layer mapping to a single output channel. + + Raises + ------ + ValueError + If `in_channels` is not positive. + """ + + def __init__(self, in_channels: int): + """Initialize the DetectorHead.""" + super().__init__() + self.in_channels = in_channels + + self.detector = nn.Conv2d( + in_channels=self.in_channels, + out_channels=1, + kernel_size=1, + padding=0, ) + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Compute detection probabilities from input features. + + Parameters + ---------- + features : torch.Tensor + Input feature map tensor from the backbone, typically with shape + `(B, C_in, H, W)`. `C_in` must match `self.in_channels`. + + Returns + ------- + torch.Tensor + Detection probability heatmap tensor with shape `(B, 1, H, W)`. + Values are in the range [0, 1] due to the sigmoid activation. + + Raises + ------ + RuntimeError + If input channel count does not match `self.in_channels`. + """ + return torch.sigmoid(self.detector(features)) + class BBoxHead(nn.Module): + """Prediction head for bounding box size dimensions. + + Takes an input feature map and produces a two-channel map where each + channel represents a predicted size dimension (typically width/duration and + height/bandwidth) for a potential sound event at that spatial location. + + Uses a 1x1 convolution to map input channels to 2 output channels. No + activation function is typically applied, as size prediction is often + treated as a direct regression task. The output values usually represent + *scaled* dimensions that need to be un-scaled during postprocessing. + + Parameters + ---------- + in_channels : int + Number of channels in the input feature map tensor from the backbone. + Must be positive. + + Attributes + ---------- + in_channels : int + Number of input channels expected. + bbox : nn.Conv2d + The 1x1 convolutional layer mapping to 2 output channels + (width, height). + + Raises + ------ + ValueError + If `in_channels` is not positive. + """ + def __init__(self, in_channels: int): + """Initialize the BBoxHead.""" super().__init__() self.in_channels = in_channels @@ -48,4 +205,19 @@ class BBoxHead(nn.Module): ) def forward(self, features: torch.Tensor) -> torch.Tensor: + """Compute predicted bounding box dimensions from input features. + + Parameters + ---------- + features : torch.Tensor + Input feature map tensor from the backbone, typically with shape + `(B, C_in, H, W)`. `C_in` must match `self.in_channels`. + + Returns + ------- + torch.Tensor + Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually + represents scaled width, Channel 1 scaled height. These values + need to be un-scaled during postprocessing. + """ return self.bbox(features)