mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Restructured models module
This commit is contained in:
parent
096d180ea3
commit
ce15afc231
@ -1,26 +1,135 @@
|
|||||||
|
"""Defines and builds the neural network models used in BatDetect2.
|
||||||
|
|
||||||
|
This package (`batdetect2.models`) contains the PyTorch implementations of the
|
||||||
|
deep neural network architectures used for detecting and classifying bat calls
|
||||||
|
from spectrograms. It provides modular components and configuration-driven
|
||||||
|
assembly, allowing for experimentation and use of different architectural
|
||||||
|
variants.
|
||||||
|
|
||||||
|
Key Submodules:
|
||||||
|
- `.types`: Defines core data structures (`ModelOutput`) and abstract base
|
||||||
|
classes (`BackboneModel`, `DetectionModel`) establishing interfaces.
|
||||||
|
- `.blocks`: Provides reusable neural network building blocks.
|
||||||
|
- `.encoder`: Defines and builds the downsampling path (encoder) of the network.
|
||||||
|
- `.bottleneck`: Defines and builds the central bottleneck component.
|
||||||
|
- `.decoder`: Defines and builds the upsampling path (decoder) of the network.
|
||||||
|
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete
|
||||||
|
feature extraction backbone (e.g., a U-Net like structure).
|
||||||
|
- `.heads`: Defines simple prediction heads (detection, classification, size)
|
||||||
|
that attach to the backbone features.
|
||||||
|
- `.detectors`: Assembles the backbone and prediction heads into the final,
|
||||||
|
end-to-end `Detector` model.
|
||||||
|
|
||||||
|
This module re-exports the most important classes, configurations, and builder
|
||||||
|
functions from these submodules for convenient access. The primary entry point
|
||||||
|
for creating a standard BatDetect2 model instance is the `build_model` function
|
||||||
|
provided here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Net2DFast,
|
Backbone,
|
||||||
Net2DFastNoAttn,
|
BackboneConfig,
|
||||||
Net2DFastNoCoordConv,
|
build_backbone,
|
||||||
Net2DPlain,
|
load_backbone_config,
|
||||||
)
|
)
|
||||||
from batdetect2.models.build import build_architecture
|
from batdetect2.models.blocks import (
|
||||||
from batdetect2.models.config import ModelConfig, ModelType, load_model_config
|
ConvConfig,
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
FreqCoordConvDownConfig,
|
||||||
from batdetect2.models.types import BackboneModel, ModelOutput
|
FreqCoordConvUpConfig,
|
||||||
|
StandardConvDownConfig,
|
||||||
|
StandardConvUpConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.models.bottleneck import (
|
||||||
|
Bottleneck,
|
||||||
|
BottleneckConfig,
|
||||||
|
build_bottleneck,
|
||||||
|
)
|
||||||
|
from batdetect2.models.decoder import (
|
||||||
|
DEFAULT_DECODER_CONFIG,
|
||||||
|
DecoderConfig,
|
||||||
|
build_decoder,
|
||||||
|
)
|
||||||
|
from batdetect2.models.detectors import (
|
||||||
|
Detector,
|
||||||
|
build_detector,
|
||||||
|
)
|
||||||
|
from batdetect2.models.encoder import (
|
||||||
|
DEFAULT_ENCODER_CONFIG,
|
||||||
|
EncoderConfig,
|
||||||
|
build_encoder,
|
||||||
|
)
|
||||||
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
|
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxHead",
|
"BBoxHead",
|
||||||
|
"Backbone",
|
||||||
|
"BackboneConfig",
|
||||||
"BackboneModel",
|
"BackboneModel",
|
||||||
|
"BackboneModel",
|
||||||
|
"Bottleneck",
|
||||||
|
"BottleneckConfig",
|
||||||
"ClassifierHead",
|
"ClassifierHead",
|
||||||
"ModelConfig",
|
"ConvConfig",
|
||||||
|
"DEFAULT_DECODER_CONFIG",
|
||||||
|
"DEFAULT_ENCODER_CONFIG",
|
||||||
|
"DecoderConfig",
|
||||||
|
"DetectionModel",
|
||||||
|
"Detector",
|
||||||
|
"DetectorHead",
|
||||||
|
"EncoderConfig",
|
||||||
|
"FreqCoordConvDownConfig",
|
||||||
|
"FreqCoordConvUpConfig",
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
"ModelType",
|
"StandardConvDownConfig",
|
||||||
"Net2DFast",
|
"StandardConvUpConfig",
|
||||||
"Net2DFastNoAttn",
|
"build_backbone",
|
||||||
"Net2DFastNoCoordConv",
|
"build_bottleneck",
|
||||||
"Net2DPlain",
|
"build_decoder",
|
||||||
"build_architecture",
|
"build_detector",
|
||||||
"build_architecture",
|
"build_encoder",
|
||||||
"load_model_config",
|
"build_model",
|
||||||
|
"load_backbone_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(
|
||||||
|
num_classes: int,
|
||||||
|
config: Optional[BackboneConfig] = None,
|
||||||
|
) -> DetectionModel:
|
||||||
|
"""Build the complete BatDetect2 detection model.
|
||||||
|
|
||||||
|
This high-level factory function constructs the standard BatDetect2 model
|
||||||
|
architecture. It first builds the feature extraction backbone (typically an
|
||||||
|
encoder-bottleneck-decoder structure) based on the provided
|
||||||
|
`BackboneConfig` (or defaults if None), and then attaches the standard
|
||||||
|
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the
|
||||||
|
`build_detector` function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num_classes : int
|
||||||
|
The number of specific target classes the model should predict
|
||||||
|
(required for the `ClassifierHead`). Must be positive.
|
||||||
|
config : BackboneConfig, optional
|
||||||
|
Configuration object specifying the architecture of the backbone
|
||||||
|
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||||
|
within the respective builder functions (`build_encoder`, etc.) will be
|
||||||
|
used to construct a default backbone architecture.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
DetectionModel
|
||||||
|
An initialized `Detector` model instance.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `num_classes` is not positive, or if errors occur during the
|
||||||
|
construction of the backbone or detector components (e.g., incompatible
|
||||||
|
configurations, invalid parameters).
|
||||||
|
"""
|
||||||
|
backbone = build_backbone(config or BackboneConfig())
|
||||||
|
return build_detector(num_classes, backbone)
|
||||||
|
@ -1,204 +1,194 @@
|
|||||||
from enum import Enum
|
"""Assembles a complete Encoder-Decoder Backbone network.
|
||||||
from typing import Optional, Sequence, Tuple
|
|
||||||
|
This module defines the configuration (`BackboneConfig`) and implementation
|
||||||
|
(`Backbone`) for a standard encoder-decoder style neural network backbone.
|
||||||
|
|
||||||
|
It orchestrates the connection between three main components, built using their
|
||||||
|
respective configurations and factory functions from sibling modules:
|
||||||
|
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
|
||||||
|
at multiple resolutions and provides skip connections.
|
||||||
|
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
|
||||||
|
lowest resolution, optionally applying self-attention.
|
||||||
|
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
|
||||||
|
resolution features using bottleneck features and skip connections.
|
||||||
|
|
||||||
|
The resulting `Backbone` module takes a spectrogram as input and outputs a
|
||||||
|
final feature map, typically used by subsequent prediction heads. It includes
|
||||||
|
automatic padding to handle input sizes not perfectly divisible by the
|
||||||
|
network's total downsampling factor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import ConvBlock
|
||||||
ConvBlock,
|
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck
|
||||||
SelfAttention,
|
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder
|
||||||
VerticalConv,
|
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder
|
||||||
)
|
|
||||||
from batdetect2.models.decoder import Decoder, UpscalingLayer
|
|
||||||
from batdetect2.models.encoder import DownscalingLayer, Encoder
|
|
||||||
from batdetect2.models.types import BackboneModel
|
from batdetect2.models.types import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Net2DFast",
|
"Backbone",
|
||||||
"Net2DFastNoAttn",
|
"BackboneConfig",
|
||||||
"Net2DFastNoCoordConv",
|
"load_backbone_config",
|
||||||
"Net2DPlain",
|
"build_backbone",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Net2DPlain(BackboneModel):
|
class Backbone(BackboneModel):
|
||||||
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
|
"""Encoder-Decoder Backbone Network Implementation.
|
||||||
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
|
|
||||||
|
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
|
||||||
|
skip connections between the Encoder and Decoder. Implements the standard
|
||||||
|
U-Net style forward pass. Includes automatic input padding to handle
|
||||||
|
various input sizes and a final convolutional block to adjust the output
|
||||||
|
channels.
|
||||||
|
|
||||||
|
This class inherits from `BackboneModel` and implements its `forward`
|
||||||
|
method. Instances are typically created using the `build_backbone` factory
|
||||||
|
function.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
Expected height of the input spectrogram.
|
||||||
|
out_channels : int
|
||||||
|
Number of channels in the final output feature map.
|
||||||
|
encoder : Encoder
|
||||||
|
The instantiated encoder module.
|
||||||
|
decoder : Decoder
|
||||||
|
The instantiated decoder module.
|
||||||
|
bottleneck : nn.Module
|
||||||
|
The instantiated bottleneck module.
|
||||||
|
final_conv : ConvBlock
|
||||||
|
Final convolutional block applied after the decoder.
|
||||||
|
divide_factor : int
|
||||||
|
The total downsampling factor (2^depth) applied by the encoder,
|
||||||
|
used for automatic input padding.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_height: int = 128,
|
input_height: int,
|
||||||
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
out_channels: int,
|
||||||
bottleneck_channels: int = 256,
|
encoder: Encoder,
|
||||||
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
decoder: Decoder,
|
||||||
out_channels: int = 32,
|
bottleneck: nn.Module,
|
||||||
):
|
):
|
||||||
super().__init__()
|
"""Initialize the Backbone network.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
Expected height of the input spectrogram.
|
||||||
|
out_channels : int
|
||||||
|
Desired number of output channels for the backbone's feature map.
|
||||||
|
encoder : Encoder
|
||||||
|
An initialized Encoder module.
|
||||||
|
decoder : Decoder
|
||||||
|
An initialized Decoder module.
|
||||||
|
bottleneck : nn.Module
|
||||||
|
An initialized Bottleneck module.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If component output/input channels or heights are incompatible.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.encoder_channels = tuple(encoder_channels)
|
|
||||||
self.decoder_channels = tuple(decoder_channels)
|
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
if len(encoder_channels) != len(decoder_channels):
|
self.encoder = encoder
|
||||||
raise ValueError(
|
self.decoder = decoder
|
||||||
f"Mismatched encoder and decoder channel lists. "
|
self.bottleneck = bottleneck
|
||||||
f"The encoder has {len(encoder_channels)} channels "
|
|
||||||
f"(implying {len(encoder_channels) - 1} layers), "
|
|
||||||
f"while the decoder has {len(decoder_channels)} channels "
|
|
||||||
f"(implying {len(decoder_channels) - 1} layers). "
|
|
||||||
f"These lengths must be equal."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.divide_factor = 2 ** (len(encoder_channels) - 1)
|
self.final_conv = ConvBlock(
|
||||||
if self.input_height % self.divide_factor != 0:
|
in_channels=decoder.out_channels,
|
||||||
raise ValueError(
|
|
||||||
f"Input height ({self.input_height}) must be divisible by "
|
|
||||||
f"the divide factor ({self.divide_factor}). "
|
|
||||||
f"This ensures proper upscaling after downscaling to recover "
|
|
||||||
f"the original input height."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.encoder = Encoder(
|
|
||||||
channels=encoder_channels,
|
|
||||||
input_height=self.input_height,
|
|
||||||
layer_type=self.downscaling_layer_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_same_1 = ConvBlock(
|
|
||||||
in_channels=encoder_channels[-1],
|
|
||||||
out_channels=bottleneck_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
# bottleneck
|
|
||||||
self.conv_vert = VerticalConv(
|
|
||||||
in_channels=bottleneck_channels,
|
|
||||||
out_channels=bottleneck_channels,
|
|
||||||
input_height=self.input_height // (2**self.encoder.depth),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.decoder = Decoder(
|
|
||||||
channels=decoder_channels,
|
|
||||||
input_height=self.input_height,
|
|
||||||
layer_type=self.upscaling_layer_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv_same_2 = ConvBlock(
|
|
||||||
in_channels=decoder_channels[-1],
|
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Down/Up scaling factor. Need to ensure inputs are divisible by
|
||||||
|
# this factor in order to be processed by the down/up scaling layers
|
||||||
|
# and recover the correct shape
|
||||||
|
self.divide_factor = input_height // self.encoder.output_height
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
"""Perform the forward pass through the encoder-decoder backbone.
|
||||||
|
|
||||||
|
Applies padding, runs encoder, bottleneck, decoder (with skip
|
||||||
|
connections), removes padding, and applies a final convolution.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match
|
||||||
|
`self.encoder.input_channels` and `self.input_height`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where
|
||||||
|
`C_out` is `self.out_channels`.
|
||||||
|
"""
|
||||||
|
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
|
||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
residuals = self.encoder(spec)
|
residuals = self.encoder(spec)
|
||||||
residuals[-1] = self.conv_same_1(residuals[-1])
|
|
||||||
|
|
||||||
# bottleneck
|
# bottleneck
|
||||||
x = self.conv_vert(residuals[-1])
|
x = self.bottleneck(residuals[-1])
|
||||||
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
x = self.decoder(x, residuals=residuals)
|
x = self.decoder(x, residuals=residuals)
|
||||||
|
|
||||||
# Restore original size
|
# Restore original size
|
||||||
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
x = _restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||||
|
|
||||||
return self.conv_same_2(x)
|
return self.final_conv(x)
|
||||||
|
|
||||||
|
|
||||||
class Net2DFast(Net2DPlain):
|
|
||||||
downscaling_layer_type = "ConvBlockDownCoordF"
|
|
||||||
upscaling_layer_type = "ConvBlockUpF"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_height: int = 128,
|
|
||||||
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
|
||||||
bottleneck_channels: int = 256,
|
|
||||||
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
|
||||||
out_channels: int = 32,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
input_height=input_height,
|
|
||||||
encoder_channels=encoder_channels,
|
|
||||||
bottleneck_channels=bottleneck_channels,
|
|
||||||
decoder_channels=decoder_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
|
||||||
|
|
||||||
# encoder
|
|
||||||
residuals = self.encoder(spec)
|
|
||||||
residuals[-1] = self.conv_same_1(residuals[-1])
|
|
||||||
|
|
||||||
# bottleneck
|
|
||||||
x = self.conv_vert(residuals[-1])
|
|
||||||
x = self.att(x)
|
|
||||||
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
x = self.decoder(x, residuals=residuals)
|
|
||||||
|
|
||||||
# Restore original size
|
|
||||||
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
|
||||||
|
|
||||||
return self.conv_same_2(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoAttn(Net2DPlain):
|
|
||||||
downscaling_layer_type = "ConvBlockDownCoordF"
|
|
||||||
upscaling_layer_type = "ConvBlockUpF"
|
|
||||||
|
|
||||||
|
|
||||||
class Net2DFastNoCoordConv(Net2DFast):
|
|
||||||
downscaling_layer_type = "ConvBlockDownStandard"
|
|
||||||
upscaling_layer_type = "ConvBlockUpStandard"
|
|
||||||
|
|
||||||
|
|
||||||
def pad_adjust(
|
|
||||||
spec: torch.Tensor,
|
|
||||||
factor: int = 32,
|
|
||||||
) -> Tuple[torch.Tensor, int, int]:
|
|
||||||
h, w = spec.shape[2:]
|
|
||||||
h_pad = -h % factor
|
|
||||||
w_pad = -w % factor
|
|
||||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
|
||||||
|
|
||||||
|
|
||||||
def restore_pad(
|
|
||||||
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Restore original size
|
|
||||||
if h_pad > 0:
|
|
||||||
x = x[:, :, :-h_pad, :]
|
|
||||||
|
|
||||||
if w_pad > 0:
|
|
||||||
x = x[:, :, :, :-w_pad]
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
|
||||||
Net2DFast = "Net2DFast"
|
|
||||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
|
||||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
|
||||||
Net2DPlain = "Net2DPlain"
|
|
||||||
|
|
||||||
|
|
||||||
class BackboneConfig(BaseConfig):
|
class BackboneConfig(BaseConfig):
|
||||||
backbone_type: ModelType = ModelType.Net2DFast
|
"""Configuration for the Encoder-Decoder Backbone network.
|
||||||
|
|
||||||
|
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||||
|
components, along with defining the input and final output dimensions
|
||||||
|
for the complete backbone.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
input_height : int, default=128
|
||||||
|
Expected height (frequency bins) of the input spectrograms to the
|
||||||
|
backbone. Must be positive.
|
||||||
|
in_channels : int, default=1
|
||||||
|
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||||
|
mono). Must be positive.
|
||||||
|
encoder : EncoderConfig, optional
|
||||||
|
Configuration for the encoder. If None or omitted,
|
||||||
|
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||||
|
encoder module) will be used.
|
||||||
|
bottleneck : BottleneckConfig, optional
|
||||||
|
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||||
|
If None or omitted, the default bottleneck configuration will be used.
|
||||||
|
decoder : DecoderConfig, optional
|
||||||
|
Configuration for the decoder. If None or omitted,
|
||||||
|
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||||
|
decoder module) will be used.
|
||||||
|
out_channels : int, default=32
|
||||||
|
Desired number of channels in the final feature map output by the
|
||||||
|
backbone. Must be positive.
|
||||||
|
"""
|
||||||
|
|
||||||
input_height: int = 128
|
input_height: int = 128
|
||||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
in_channels: int = 1
|
||||||
bottleneck_channels: int = 256
|
encoder: Optional[EncoderConfig] = None
|
||||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
bottleneck: Optional[BottleneckConfig] = None
|
||||||
|
decoder: Optional[DecoderConfig] = None
|
||||||
out_channels: int = 32
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
@ -206,48 +196,162 @@ def load_backbone_config(
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
) -> BackboneConfig:
|
) -> BackboneConfig:
|
||||||
|
"""Load the backbone configuration from a file.
|
||||||
|
|
||||||
|
Reads a configuration file (YAML) and validates it against the
|
||||||
|
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||||
|
file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BackboneConfig
|
||||||
|
The loaded and validated backbone configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
yaml.YAMLError
|
||||||
|
If the file content is not valid YAML.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the loaded config data does not conform to `BackboneConfig`.
|
||||||
|
KeyError, TypeError
|
||||||
|
If `field` specifies an invalid path.
|
||||||
|
"""
|
||||||
return load_config(path, schema=BackboneConfig, field=field)
|
return load_config(path, schema=BackboneConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
def build_model_backbone(
|
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||||
config: Optional[BackboneConfig] = None,
|
"""Factory function to build a Backbone from configuration.
|
||||||
) -> BackboneModel:
|
|
||||||
config = config or BackboneConfig()
|
|
||||||
|
|
||||||
if config.backbone_type == ModelType.Net2DFast:
|
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on
|
||||||
return Net2DFast(
|
the provided `BackboneConfig`, validates their compatibility, and assembles
|
||||||
input_height=config.input_height,
|
them into a `Backbone` instance.
|
||||||
encoder_channels=config.encoder_channels,
|
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
Parameters
|
||||||
decoder_channels=config.decoder_channels,
|
----------
|
||||||
out_channels=config.out_channels,
|
config : BackboneConfig
|
||||||
|
The configuration object detailing the backbone architecture, including
|
||||||
|
input dimensions and configurations for encoder, bottleneck, and
|
||||||
|
decoder.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BackboneModel
|
||||||
|
An initialized `Backbone` module ready for use.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If sub-component configurations are incompatible
|
||||||
|
(e.g., channel mismatches, decoder output height doesn't match backbone
|
||||||
|
input height).
|
||||||
|
NotImplementedError
|
||||||
|
If an unknown block type is specified in sub-configs.
|
||||||
|
"""
|
||||||
|
encoder = build_encoder(
|
||||||
|
in_channels=config.in_channels,
|
||||||
|
input_height=config.input_height,
|
||||||
|
config=config.encoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
bottleneck = build_bottleneck(
|
||||||
|
input_height=encoder.output_height,
|
||||||
|
in_channels=encoder.out_channels,
|
||||||
|
config=config.bottleneck,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = build_decoder(
|
||||||
|
in_channels=bottleneck.out_channels,
|
||||||
|
input_height=encoder.output_height,
|
||||||
|
config=config.decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decoder.output_height != config.input_height:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid configuration: Decoder output height "
|
||||||
|
f"({decoder.output_height}) must match the Backbone input height "
|
||||||
|
f"({config.input_height}). Check encoder/decoder layer "
|
||||||
|
"configurations and input/bottleneck heights."
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.backbone_type == ModelType.Net2DFastNoAttn:
|
return Backbone(
|
||||||
return Net2DFastNoAttn(
|
input_height=config.input_height,
|
||||||
input_height=config.input_height,
|
out_channels=config.out_channels,
|
||||||
encoder_channels=config.encoder_channels,
|
encoder=encoder,
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
decoder=decoder,
|
||||||
decoder_channels=config.decoder_channels,
|
bottleneck=bottleneck,
|
||||||
out_channels=config.out_channels,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if config.backbone_type == ModelType.Net2DFastNoCoordConv:
|
|
||||||
return Net2DFastNoCoordConv(
|
|
||||||
input_height=config.input_height,
|
|
||||||
encoder_channels=config.encoder_channels,
|
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
|
||||||
decoder_channels=config.decoder_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.backbone_type == ModelType.Net2DPlain:
|
def _pad_adjust(
|
||||||
return Net2DPlain(
|
spec: torch.Tensor,
|
||||||
input_height=config.input_height,
|
factor: int = 32,
|
||||||
encoder_channels=config.encoder_channels,
|
) -> Tuple[torch.Tensor, int, int]:
|
||||||
bottleneck_channels=config.bottleneck_channels,
|
"""Pad tensor height and width to be divisible by a factor.
|
||||||
decoder_channels=config.decoder_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown model type: {config.backbone_type}")
|
Calculates the required padding for the last two dimensions (H, W) to make
|
||||||
|
them divisible by `factor` and applies right/bottom padding using
|
||||||
|
`torch.nn.functional.pad`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input tensor, typically shape `(B, C, H, W)`.
|
||||||
|
factor : int, default=32
|
||||||
|
The factor to make height and width divisible by.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[torch.Tensor, int, int]
|
||||||
|
A tuple containing:
|
||||||
|
- The padded tensor.
|
||||||
|
- The amount of padding added to height (`h_pad`).
|
||||||
|
- The amount of padding added to width (`w_pad`).
|
||||||
|
"""
|
||||||
|
h, w = spec.shape[2:]
|
||||||
|
h_pad = -h % factor
|
||||||
|
w_pad = -w % factor
|
||||||
|
|
||||||
|
if h_pad == 0 and w_pad == 0:
|
||||||
|
return spec, 0, 0
|
||||||
|
|
||||||
|
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||||
|
|
||||||
|
|
||||||
|
def _restore_pad(
|
||||||
|
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Remove padding added by _pad_adjust.
|
||||||
|
|
||||||
|
Removes padding from the bottom and right edges of the tensor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Padded tensor, typically shape `(B, C, H_padded, W_padded)`.
|
||||||
|
h_pad : int, default=0
|
||||||
|
Amount of padding previously added to the height (bottom).
|
||||||
|
w_pad : int, default=0
|
||||||
|
Amount of padding previously added to the width (right).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
|
||||||
|
"""
|
||||||
|
if h_pad > 0:
|
||||||
|
x = x[:, :, :-h_pad, :]
|
||||||
|
|
||||||
|
if w_pad > 0:
|
||||||
|
x = x[:, :, :, :-w_pad]
|
||||||
|
|
||||||
|
return x
|
||||||
|
@ -22,14 +22,20 @@ research:
|
|||||||
|
|
||||||
These blocks can be utilized directly in custom PyTorch model definitions or
|
These blocks can be utilized directly in custom PyTorch model definitions or
|
||||||
assembled into larger architectures.
|
assembled into larger architectures.
|
||||||
|
|
||||||
|
A unified factory function `build_layer_from_config` allows creating instances
|
||||||
|
of these blocks based on configuration objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Annotated, Literal, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConvBlock",
|
"ConvBlock",
|
||||||
"VerticalConv",
|
"VerticalConv",
|
||||||
@ -38,6 +44,13 @@ __all__ = [
|
|||||||
"FreqCoordConvUpBlock",
|
"FreqCoordConvUpBlock",
|
||||||
"StandardConvUpBlock",
|
"StandardConvUpBlock",
|
||||||
"SelfAttention",
|
"SelfAttention",
|
||||||
|
"ConvConfig",
|
||||||
|
"FreqCoordConvDownConfig",
|
||||||
|
"StandardConvDownConfig",
|
||||||
|
"FreqCoordConvUpConfig",
|
||||||
|
"StandardConvUpConfig",
|
||||||
|
"LayerConfig",
|
||||||
|
"build_layer_from_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -154,6 +167,22 @@ class SelfAttention(nn.Module):
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
class ConvConfig(BaseConfig):
|
||||||
|
"""Configuration for a basic ConvBlock."""
|
||||||
|
|
||||||
|
block_type: Literal["ConvBlock"] = "ConvBlock"
|
||||||
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
|
out_channels: int
|
||||||
|
"""Number of output channels."""
|
||||||
|
|
||||||
|
kernel_size: int = 3
|
||||||
|
"""Size of the square convolutional kernel."""
|
||||||
|
|
||||||
|
pad_size: int = 1
|
||||||
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
"""Basic Convolutional Block.
|
"""Basic Convolutional Block.
|
||||||
|
|
||||||
@ -171,8 +200,7 @@ class ConvBlock(nn.Module):
|
|||||||
kernel_size : int, default=3
|
kernel_size : int, default=3
|
||||||
Size of the square convolutional kernel.
|
Size of the square convolutional kernel.
|
||||||
pad_size : int, default=1
|
pad_size : int, default=1
|
||||||
Amount of padding added to preserve spatial dimensions (assuming
|
Amount of padding added to preserve spatial dimensions.
|
||||||
stride=1 and kernel_size=3).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -261,6 +289,22 @@ class VerticalConv(nn.Module):
|
|||||||
return F.relu_(self.bn(self.conv(x)))
|
return F.relu_(self.bn(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class FreqCoordConvDownConfig(BaseConfig):
|
||||||
|
"""Configuration for a FreqCoordConvDownBlock."""
|
||||||
|
|
||||||
|
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
||||||
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
|
out_channels: int
|
||||||
|
"""Number of output channels."""
|
||||||
|
|
||||||
|
kernel_size: int = 3
|
||||||
|
"""Size of the square convolutional kernel."""
|
||||||
|
|
||||||
|
pad_size: int = 1
|
||||||
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvDownBlock(nn.Module):
|
class FreqCoordConvDownBlock(nn.Module):
|
||||||
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
||||||
|
|
||||||
@ -289,9 +333,6 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
Size of the square convolutional kernel.
|
Size of the square convolutional kernel.
|
||||||
pad_size : int, default=1
|
pad_size : int, default=1
|
||||||
Padding added before convolution.
|
Padding added before convolution.
|
||||||
stride : int, default=1
|
|
||||||
Stride of the convolution. Note: Downsampling is achieved via
|
|
||||||
MaxPool2d(2, 2).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -301,7 +342,6 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
stride: int = 1,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -314,7 +354,7 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
stride=stride,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
@ -339,6 +379,22 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StandardConvDownConfig(BaseConfig):
|
||||||
|
"""Configuration for a StandardConvDownBlock."""
|
||||||
|
|
||||||
|
block_type: Literal["StandardConvDown"] = "StandardConvDown"
|
||||||
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
|
out_channels: int
|
||||||
|
"""Number of output channels."""
|
||||||
|
|
||||||
|
kernel_size: int = 3
|
||||||
|
"""Size of the square convolutional kernel."""
|
||||||
|
|
||||||
|
pad_size: int = 1
|
||||||
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class StandardConvDownBlock(nn.Module):
|
class StandardConvDownBlock(nn.Module):
|
||||||
"""Standard Downsampling Convolutional Block.
|
"""Standard Downsampling Convolutional Block.
|
||||||
|
|
||||||
@ -357,8 +413,6 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
Size of the square convolutional kernel.
|
Size of the square convolutional kernel.
|
||||||
pad_size : int, default=1
|
pad_size : int, default=1
|
||||||
Padding added before convolution.
|
Padding added before convolution.
|
||||||
stride : int, default=1
|
|
||||||
Stride of the convolution (downsampling is done by MaxPool).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -367,7 +421,6 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
out_channels: int,
|
out_channels: int,
|
||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
stride: int = 1,
|
|
||||||
):
|
):
|
||||||
super(StandardConvDownBlock, self).__init__()
|
super(StandardConvDownBlock, self).__init__()
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
@ -375,7 +428,7 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
out_channels,
|
out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
stride=stride,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
@ -396,6 +449,22 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
return F.relu(self.conv_bn(x), inplace=True)
|
return F.relu(self.conv_bn(x), inplace=True)
|
||||||
|
|
||||||
|
|
||||||
|
class FreqCoordConvUpConfig(BaseConfig):
|
||||||
|
"""Configuration for a FreqCoordConvUpBlock."""
|
||||||
|
|
||||||
|
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
||||||
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
|
out_channels: int
|
||||||
|
"""Number of output channels."""
|
||||||
|
|
||||||
|
kernel_size: int = 3
|
||||||
|
"""Size of the square convolutional kernel."""
|
||||||
|
|
||||||
|
pad_size: int = 1
|
||||||
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvUpBlock(nn.Module):
|
class FreqCoordConvUpBlock(nn.Module):
|
||||||
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
||||||
|
|
||||||
@ -489,6 +558,22 @@ class FreqCoordConvUpBlock(nn.Module):
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
class StandardConvUpConfig(BaseConfig):
|
||||||
|
"""Configuration for a StandardConvUpBlock."""
|
||||||
|
|
||||||
|
block_type: Literal["StandardConvUp"] = "StandardConvUp"
|
||||||
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
|
out_channels: int
|
||||||
|
"""Number of output channels."""
|
||||||
|
|
||||||
|
kernel_size: int = 3
|
||||||
|
"""Size of the square convolutional kernel."""
|
||||||
|
|
||||||
|
pad_size: int = 1
|
||||||
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class StandardConvUpBlock(nn.Module):
|
class StandardConvUpBlock(nn.Module):
|
||||||
"""Standard Upsampling Convolutional Block.
|
"""Standard Upsampling Convolutional Block.
|
||||||
|
|
||||||
@ -559,3 +644,122 @@ class StandardConvUpBlock(nn.Module):
|
|||||||
op = self.conv(op)
|
op = self.conv(op)
|
||||||
op = F.relu(self.conv_bn(op), inplace=True)
|
op = F.relu(self.conv_bn(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
LayerConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
ConvConfig,
|
||||||
|
FreqCoordConvDownConfig,
|
||||||
|
StandardConvDownConfig,
|
||||||
|
FreqCoordConvUpConfig,
|
||||||
|
StandardConvUpConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="block_type"),
|
||||||
|
]
|
||||||
|
"""Type alias for the discriminated union of block configuration models."""
|
||||||
|
|
||||||
|
|
||||||
|
def build_layer_from_config(
|
||||||
|
input_height: int,
|
||||||
|
in_channels: int,
|
||||||
|
config: LayerConfig,
|
||||||
|
) -> Tuple[nn.Module, int, int]:
|
||||||
|
"""Factory function to build a specific nn.Module block from its config.
|
||||||
|
|
||||||
|
Takes configuration object (one of the types included in the `LayerConfig`
|
||||||
|
union) and instantiates the corresponding nn.Module block with the correct
|
||||||
|
parameters derived from the config and the current pipeline state
|
||||||
|
(`input_height`, `in_channels`).
|
||||||
|
|
||||||
|
It uses the `block_type` field within the `config` object to determine
|
||||||
|
which block class to instantiate.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
Height (frequency bins) of the input tensor *to this layer*.
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor *to this layer*.
|
||||||
|
config : LayerConfig
|
||||||
|
A Pydantic configuration object for the desired block (e.g., an
|
||||||
|
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
||||||
|
by its `block_type` field.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[nn.Module, int, int]
|
||||||
|
A tuple containing:
|
||||||
|
- The instantiated `nn.Module` block.
|
||||||
|
- The number of output channels produced by the block.
|
||||||
|
- The calculated height of the output produced by the block.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
NotImplementedError
|
||||||
|
If the `config.block_type` does not correspond to a known block type.
|
||||||
|
ValueError
|
||||||
|
If parameters derived from the config are invalid for the block.
|
||||||
|
"""
|
||||||
|
if config.block_type == "ConvBlock":
|
||||||
|
return (
|
||||||
|
ConvBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
),
|
||||||
|
config.out_channels,
|
||||||
|
input_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.block_type == "FreqCoordConvDown":
|
||||||
|
return (
|
||||||
|
FreqCoordConvDownBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
),
|
||||||
|
config.out_channels,
|
||||||
|
input_height // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.block_type == "StandardConvDown":
|
||||||
|
return (
|
||||||
|
StandardConvDownBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
),
|
||||||
|
config.out_channels,
|
||||||
|
input_height // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.block_type == "FreqCoordConvUp":
|
||||||
|
return (
|
||||||
|
FreqCoordConvUpBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
),
|
||||||
|
config.out_channels,
|
||||||
|
input_height * 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.block_type == "StandardConvUp":
|
||||||
|
return (
|
||||||
|
StandardConvUpBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
),
|
||||||
|
config.out_channels,
|
||||||
|
input_height * 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError(f"Unknown block type {config.block_type}")
|
||||||
|
254
batdetect2/models/bottleneck.py
Normal file
254
batdetect2/models/bottleneck.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
"""Defines the Bottleneck component of an Encoder-Decoder architecture.
|
||||||
|
|
||||||
|
This module provides the configuration (`BottleneckConfig`) and
|
||||||
|
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the
|
||||||
|
bottleneck layer(s) that typically connect the Encoder (downsampling path) and
|
||||||
|
Decoder (upsampling path) in networks like U-Nets.
|
||||||
|
|
||||||
|
The bottleneck processes the lowest-resolution, highest-dimensionality feature
|
||||||
|
map produced by the Encoder. This module offers a configurable option to include
|
||||||
|
a `SelfAttention` layer within the bottleneck, allowing the model to capture
|
||||||
|
global temporal context before features are passed to the Decoder.
|
||||||
|
|
||||||
|
A factory function `build_bottleneck` constructs the appropriate bottleneck
|
||||||
|
module based on the provided configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.models.blocks import SelfAttention, VerticalConv
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BottleneckConfig",
|
||||||
|
"Bottleneck",
|
||||||
|
"BottleneckAttn",
|
||||||
|
"build_bottleneck",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckConfig(BaseConfig):
|
||||||
|
"""Configuration for the bottleneck layer(s).
|
||||||
|
|
||||||
|
Defines the number of channels within the bottleneck and whether to include
|
||||||
|
a self-attention mechanism.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
channels : int
|
||||||
|
The number of output channels produced by the main convolutional layer
|
||||||
|
within the bottleneck. This often matches the number of channels coming
|
||||||
|
from the last encoder stage, but can be different. Must be positive.
|
||||||
|
This also defines the channel dimensions used within the optional
|
||||||
|
`SelfAttention` layer.
|
||||||
|
self_attention : bool
|
||||||
|
If True, includes a `SelfAttention` layer operating on the time
|
||||||
|
dimension after an initial `VerticalConv` layer within the bottleneck.
|
||||||
|
If False, only the initial `VerticalConv` (and height repetition) is
|
||||||
|
performed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
channels: int
|
||||||
|
self_attention: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||||
|
|
||||||
|
This implementation represents the simplest bottleneck structure
|
||||||
|
considered, primarily consisting of a `VerticalConv` layer. This layer
|
||||||
|
collapses the frequency dimension (height) to 1, summarizing information
|
||||||
|
across frequencies at each time step. The output is then repeated along the
|
||||||
|
height dimension to match the original bottleneck input height before being
|
||||||
|
passed to the decoder.
|
||||||
|
|
||||||
|
This base version does *not* include self-attention.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
Height (frequency bins) of the input tensor. Must be positive.
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor from the encoder. Must be
|
||||||
|
positive.
|
||||||
|
out_channels : int
|
||||||
|
Number of output channels. Must be positive.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of input channels accepted by the bottleneck.
|
||||||
|
input_height : int
|
||||||
|
Expected height of the input tensor.
|
||||||
|
channels : int
|
||||||
|
Number of output channels.
|
||||||
|
conv_vert : VerticalConv
|
||||||
|
The vertical convolution layer.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_height: int,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the base Bottleneck layer."""
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.input_height = input_height
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.conv_vert = VerticalConv(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Process input features through the bottleneck.
|
||||||
|
|
||||||
|
Applies vertical convolution and repeats the output height.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor from the encoder bottleneck, shape
|
||||||
|
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
||||||
|
`H_in` must match `self.input_height`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height
|
||||||
|
dimension `H_in` is restored via repetition after the vertical
|
||||||
|
convolution.
|
||||||
|
"""
|
||||||
|
x = self.conv_vert(x)
|
||||||
|
return x.repeat([1, 1, self.input_height, 1])
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckAttn(Bottleneck):
|
||||||
|
"""Bottleneck module including a Self-Attention layer.
|
||||||
|
|
||||||
|
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
|
||||||
|
the initial `VerticalConv`. This allows the bottleneck to capture global
|
||||||
|
temporal dependencies in the summarized frequency features before passing
|
||||||
|
them to the decoder.
|
||||||
|
|
||||||
|
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
Height (frequency bins) of the input tensor from the encoder.
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor from the encoder.
|
||||||
|
out_channels : int
|
||||||
|
Number of output channels produced by the `VerticalConv` and
|
||||||
|
subsequently processed and output by this bottleneck. Also determines
|
||||||
|
the input/output channels of the internal `SelfAttention` layer.
|
||||||
|
attention : nn.Module
|
||||||
|
An initialized `SelfAttention` module instance.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_height: int,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
attention: nn.Module,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the Bottleneck with Self-Attention."""
|
||||||
|
super().__init__(input_height, in_channels, out_channels)
|
||||||
|
self.attention = attention
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Process input tensor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor from the encoder bottleneck, shape
|
||||||
|
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
||||||
|
`H_in` must match `self.input_height`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
|
||||||
|
and repeating the height dimension.
|
||||||
|
"""
|
||||||
|
x = self.conv_vert(x)
|
||||||
|
x = self.attention(x)
|
||||||
|
return x.repeat([1, 1, self.input_height, 1])
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||||
|
channels=256,
|
||||||
|
self_attention=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_bottleneck(
|
||||||
|
input_height: int,
|
||||||
|
in_channels: int,
|
||||||
|
config: Optional[BottleneckConfig] = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""Factory function to build the Bottleneck module from configuration.
|
||||||
|
|
||||||
|
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
||||||
|
on the `config.self_attention` flag.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
Height (frequency bins) of the input tensor. Must be positive.
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor. Must be positive.
|
||||||
|
config : BottleneckConfig, optional
|
||||||
|
Configuration object specifying the bottleneck channels and whether
|
||||||
|
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
nn.Module
|
||||||
|
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `input_height` or `in_channels` are not positive.
|
||||||
|
"""
|
||||||
|
config = config or DEFAULT_BOTTLENECK_CONFIG
|
||||||
|
|
||||||
|
if config.self_attention:
|
||||||
|
attention = SelfAttention(
|
||||||
|
in_channels=config.channels,
|
||||||
|
attention_channels=config.channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BottleneckAttn(
|
||||||
|
input_height=input_height,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.channels,
|
||||||
|
attention=attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Bottleneck(
|
||||||
|
input_height=input_height,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.channels,
|
||||||
|
)
|
267
batdetect2/models/decoder.py
Normal file
267
batdetect2/models/decoder.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
"""Constructs the Decoder part of an Encoder-Decoder neural network.
|
||||||
|
|
||||||
|
This module defines the configuration structure (`DecoderConfig`) for the layer
|
||||||
|
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory
|
||||||
|
function (`build_decoder`). Decoders typically form the upsampling path in
|
||||||
|
architectures like U-Nets, taking bottleneck features
|
||||||
|
(usually from an `Encoder`) and skip connections to reconstruct
|
||||||
|
higher-resolution feature maps.
|
||||||
|
|
||||||
|
The decoder is built dynamically by stacking neural network blocks based on a
|
||||||
|
list of configuration objects provided in `DecoderConfig.layers`. Each config
|
||||||
|
object specifies the type of block (e.g., standard convolution,
|
||||||
|
coordinate-feature convolution with upsampling) and its parameters. This allows
|
||||||
|
flexible definition of decoder architectures via configuration files.
|
||||||
|
|
||||||
|
The `Decoder`'s `forward` method is designed to accept skip connection tensors
|
||||||
|
(`residuals`) from the encoder, merging them with the upsampled feature maps
|
||||||
|
at each stage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Annotated, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import Field
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.models.blocks import (
|
||||||
|
ConvConfig,
|
||||||
|
FreqCoordConvUpConfig,
|
||||||
|
StandardConvUpConfig,
|
||||||
|
build_layer_from_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DecoderConfig",
|
||||||
|
"Decoder",
|
||||||
|
"build_decoder",
|
||||||
|
"DEFAULT_DECODER_CONFIG",
|
||||||
|
]
|
||||||
|
|
||||||
|
DecoderLayerConfig = Annotated[
|
||||||
|
Union[ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig],
|
||||||
|
Field(discriminator="block_type"),
|
||||||
|
]
|
||||||
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderConfig(BaseConfig):
|
||||||
|
"""Configuration for the sequence of layers in the Decoder module.
|
||||||
|
|
||||||
|
Defines the types and parameters of the neural network blocks that
|
||||||
|
constitute the decoder's upsampling path.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
layers : List[DecoderLayerConfig]
|
||||||
|
An ordered list of configuration objects, each defining one layer or
|
||||||
|
block in the decoder sequence. Each item must be a valid block
|
||||||
|
config including a `block_type` field and necessary parameters like
|
||||||
|
`out_channels`. Input channels for each layer are inferred sequentially.
|
||||||
|
The list must contain at least one layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
layers: List[DecoderLayerConfig] = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""Sequential Decoder module composed of configurable upsampling layers.
|
||||||
|
|
||||||
|
Constructs the upsampling path of an encoder-decoder network by stacking
|
||||||
|
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`)
|
||||||
|
based on a list of layer modules provided during initialization (typically
|
||||||
|
created by the `build_decoder` factory function).
|
||||||
|
|
||||||
|
The `forward` method is designed to integrate skip connection tensors
|
||||||
|
(`residuals`) from the corresponding encoder stages, by adding them
|
||||||
|
element-wise to the input of each decoder layer before processing.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of channels expected in the input tensor.
|
||||||
|
out_channels : int
|
||||||
|
Number of channels in the final output tensor produced by the last
|
||||||
|
layer.
|
||||||
|
input_height : int
|
||||||
|
Height (frequency bins) expected in the input tensor.
|
||||||
|
output_height : int
|
||||||
|
Height (frequency bins) expected in the output tensor.
|
||||||
|
layers : nn.ModuleList
|
||||||
|
The sequence of instantiated upscaling layer modules.
|
||||||
|
depth : int
|
||||||
|
The number of upscaling layers (depth) in the decoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
output_height: int,
|
||||||
|
layers: List[nn.Module],
|
||||||
|
):
|
||||||
|
"""Initialize the Decoder module.
|
||||||
|
|
||||||
|
Note: This constructor is typically called internally by the
|
||||||
|
`build_decoder` factory function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
out_channels : int
|
||||||
|
Number of channels produced by the final layer.
|
||||||
|
input_height : int
|
||||||
|
Expected height of the input tensor (bottleneck).
|
||||||
|
in_channels : int
|
||||||
|
Expected number of channels in the input tensor (bottleneck).
|
||||||
|
layers : List[nn.Module]
|
||||||
|
A list of pre-instantiated upscaling layer modules (e.g.,
|
||||||
|
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired
|
||||||
|
sequence (from bottleneck towards output resolution).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_height = input_height
|
||||||
|
self.output_height = output_height
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(layers)
|
||||||
|
self.depth = len(self.layers)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residuals: List[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Pass input through decoder layers, incorporating skip connections.
|
||||||
|
|
||||||
|
Processes the input tensor `x` sequentially through the upscaling
|
||||||
|
layers. At each stage, the corresponding skip connection tensor from
|
||||||
|
the `residuals` list is added element-wise to the input before passing
|
||||||
|
it to the upscaling block.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor from the previous stage (e.g., encoder bottleneck).
|
||||||
|
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
|
||||||
|
`self.in_channels`.
|
||||||
|
residuals : List[torch.Tensor]
|
||||||
|
List containing the skip connection tensors from the corresponding
|
||||||
|
encoder stages. Should be ordered from the deepest encoder layer
|
||||||
|
output (lowest resolution) to the shallowest (highest resolution
|
||||||
|
near input). The number of tensors in this list must match the
|
||||||
|
number of decoder layers (`self.depth`). Each residual tensor's
|
||||||
|
channel count must be compatible with the input tensor `x` for
|
||||||
|
element-wise addition (or concatenation if the blocks were designed
|
||||||
|
for it).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
The final decoded feature map tensor produced by the last layer.
|
||||||
|
Shape `(B, C_out, H_out, W_out)`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the number of `residuals` provided does not match the decoder
|
||||||
|
depth.
|
||||||
|
RuntimeError
|
||||||
|
If shapes mismatch during skip connection addition or layer
|
||||||
|
processing.
|
||||||
|
"""
|
||||||
|
if len(residuals) != len(self.layers):
|
||||||
|
raise ValueError(
|
||||||
|
f"Incorrect number of residuals provided. "
|
||||||
|
f"Expected {len(self.layers)} (matching the number of layers), "
|
||||||
|
f"but got {len(residuals)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer, res in zip(self.layers, residuals[::-1]):
|
||||||
|
x = layer(x + res)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
||||||
|
layers=[
|
||||||
|
FreqCoordConvUpConfig(out_channels=64),
|
||||||
|
FreqCoordConvUpConfig(out_channels=32),
|
||||||
|
FreqCoordConvUpConfig(out_channels=32),
|
||||||
|
ConvConfig(out_channels=32),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
"""A default configuration for the Decoder's *layer sequence*.
|
||||||
|
|
||||||
|
Specifies an architecture often used in BatDetect2, consisting of three
|
||||||
|
frequency coordinate-aware upsampling blocks followed by a standard
|
||||||
|
convolutional block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def build_decoder(
|
||||||
|
in_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
config: Optional[DecoderConfig] = None,
|
||||||
|
) -> Decoder:
|
||||||
|
"""Factory function to build a Decoder instance from configuration.
|
||||||
|
|
||||||
|
Constructs a sequential `Decoder` module based on the layer sequence
|
||||||
|
defined in a `DecoderConfig` object and the provided input dimensions
|
||||||
|
(bottleneck channels and height). If no config is provided, uses the
|
||||||
|
default layer sequence from `DEFAULT_DECODER_CONFIG`.
|
||||||
|
|
||||||
|
It iteratively builds the layers using the unified `build_layer_from_config`
|
||||||
|
factory (from `.blocks`), tracking the changing number of channels and
|
||||||
|
feature map height required for each subsequent layer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
The number of channels in the input tensor to the decoder. Must be > 0.
|
||||||
|
input_height : int
|
||||||
|
The height (frequency bins) of the input tensor to the decoder. Must be
|
||||||
|
> 0.
|
||||||
|
config : DecoderConfig, optional
|
||||||
|
The configuration object detailing the sequence of layers and their
|
||||||
|
parameters. If None, `DEFAULT_DECODER_CONFIG` is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Decoder
|
||||||
|
An initialized `Decoder` module.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `in_channels` or `input_height` are not positive, or if the layer
|
||||||
|
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||||
|
NotImplementedError
|
||||||
|
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||||
|
"""
|
||||||
|
config = config or DEFAULT_DECODER_CONFIG
|
||||||
|
|
||||||
|
current_channels = in_channels
|
||||||
|
current_height = input_height
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for layer_config in config.layers:
|
||||||
|
layer, current_channels, current_height = build_layer_from_config(
|
||||||
|
in_channels=current_channels,
|
||||||
|
input_height=current_height,
|
||||||
|
config=layer_config,
|
||||||
|
)
|
||||||
|
layers.append(layer)
|
||||||
|
|
||||||
|
return Decoder(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=current_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
output_height=current_height,
|
||||||
|
layers=layers,
|
||||||
|
)
|
173
batdetect2/models/detectors.py
Normal file
173
batdetect2/models/detectors.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
"""Assembles the complete BatDetect2 Detection Model.
|
||||||
|
|
||||||
|
This module defines the concrete `Detector` class, which implements the
|
||||||
|
`DetectionModel` interface defined in `.types`. It combines a feature
|
||||||
|
extraction backbone with specific prediction heads to create the end-to-end
|
||||||
|
neural network used for detecting bat calls, predicting their size, and
|
||||||
|
classifying them.
|
||||||
|
|
||||||
|
The primary components are:
|
||||||
|
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
||||||
|
- `build_detector`: A factory function to conveniently construct a standard
|
||||||
|
`Detector` instance given a backbone and the number of target classes.
|
||||||
|
|
||||||
|
This module focuses purely on the neural network architecture definition. The
|
||||||
|
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||||
|
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
|
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
class Detector(DetectionModel):
|
||||||
|
"""Concrete implementation of the BatDetect2 Detection Model.
|
||||||
|
|
||||||
|
Assembles a complete detection and classification model by combining a
|
||||||
|
feature extraction backbone network with specific prediction heads for
|
||||||
|
detection probability, bounding box size regression, and class
|
||||||
|
probabilities.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
backbone : BackboneModel
|
||||||
|
The feature extraction backbone network module.
|
||||||
|
num_classes : int
|
||||||
|
The number of specific target classes the model predicts (derived from
|
||||||
|
the `classifier_head`).
|
||||||
|
classifier_head : ClassifierHead
|
||||||
|
The prediction head responsible for generating class probabilities.
|
||||||
|
detector_head : DetectorHead
|
||||||
|
The prediction head responsible for generating detection probabilities.
|
||||||
|
bbox_head : BBoxHead
|
||||||
|
The prediction head responsible for generating bounding box size
|
||||||
|
predictions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
backbone: BackboneModel
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone: BackboneModel,
|
||||||
|
classifier_head: ClassifierHead,
|
||||||
|
detector_head: DetectorHead,
|
||||||
|
bbox_head: BBoxHead,
|
||||||
|
):
|
||||||
|
"""Initialize the Detector model.
|
||||||
|
|
||||||
|
Note: Instances are typically created using the `build_detector`
|
||||||
|
factory function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
backbone : BackboneModel
|
||||||
|
An initialized feature extraction backbone module (e.g., built by
|
||||||
|
`build_backbone` from the `.backbone` module).
|
||||||
|
classifier_head : ClassifierHead
|
||||||
|
An initialized classification head module. The number of classes
|
||||||
|
is inferred from this head.
|
||||||
|
detector_head : DetectorHead
|
||||||
|
An initialized detection head module.
|
||||||
|
bbox_head : BBoxHead
|
||||||
|
An initialized bounding box size prediction head module.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
TypeError
|
||||||
|
If the provided modules are not of the expected types.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.backbone = backbone
|
||||||
|
self.num_classes = classifier_head.num_classes
|
||||||
|
self.classifier_head = classifier_head
|
||||||
|
self.detector_head = detector_head
|
||||||
|
self.bbox_head = bbox_head
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
|
"""Perform the forward pass of the complete detection model.
|
||||||
|
|
||||||
|
Processes the input spectrogram through the backbone to extract
|
||||||
|
features, then passes these features through the separate prediction
|
||||||
|
heads to generate detection probabilities, class probabilities, and
|
||||||
|
size predictions.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input spectrogram tensor, typically with shape
|
||||||
|
`(batch_size, input_channels, frequency_bins, time_bins)`. The
|
||||||
|
shape must be compatible with the `self.backbone` input
|
||||||
|
requirements.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ModelOutput
|
||||||
|
A NamedTuple containing the four output tensors:
|
||||||
|
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
|
||||||
|
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`.
|
||||||
|
- `class_probs`: Class probabilities (excluding background)
|
||||||
|
`(B, num_classes, H, W)`.
|
||||||
|
- `features`: Output feature map from the backbone
|
||||||
|
`(B, C_out, H, W)`.
|
||||||
|
"""
|
||||||
|
features = self.backbone(spec)
|
||||||
|
detection = self.detector_head(features)
|
||||||
|
classification = self.classifier_head(features)
|
||||||
|
size_preds = self.bbox_head(features)
|
||||||
|
return ModelOutput(
|
||||||
|
detection_probs=detection,
|
||||||
|
size_preds=size_preds,
|
||||||
|
class_probs=classification,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
|
||||||
|
"""Factory function to build a standard Detector model instance.
|
||||||
|
|
||||||
|
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`,
|
||||||
|
`BBoxHead`) configured appropriately based on the output channels of the
|
||||||
|
provided `backbone` and the specified `num_classes`. It then assembles
|
||||||
|
these components into a `Detector` model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num_classes : int
|
||||||
|
The number of specific target classes for the classification head
|
||||||
|
(excluding any implicit background class). Must be positive.
|
||||||
|
backbone : BackboneModel
|
||||||
|
An initialized feature extraction backbone module instance. The number
|
||||||
|
of output channels from this backbone (`backbone.out_channels`) is used
|
||||||
|
to configure the input channels for the prediction heads.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Detector
|
||||||
|
An initialized `Detector` model instance.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `num_classes` is not positive.
|
||||||
|
AttributeError
|
||||||
|
If `backbone` does not have the required `out_channels` attribute.
|
||||||
|
"""
|
||||||
|
classifier_head = ClassifierHead(
|
||||||
|
num_classes=num_classes,
|
||||||
|
in_channels=backbone.out_channels,
|
||||||
|
)
|
||||||
|
detector_head = DetectorHead(
|
||||||
|
in_channels=backbone.out_channels,
|
||||||
|
)
|
||||||
|
bbox_head = BBoxHead(
|
||||||
|
in_channels=backbone.out_channels,
|
||||||
|
)
|
||||||
|
return Detector(
|
||||||
|
backbone=backbone,
|
||||||
|
classifier_head=classifier_head,
|
||||||
|
detector_head=detector_head,
|
||||||
|
bbox_head=bbox_head,
|
||||||
|
)
|
@ -1,25 +1,26 @@
|
|||||||
"""Constructs the Encoder part of an Encoder-Decoder neural network.
|
"""Constructs the Encoder part of a configurable neural network backbone.
|
||||||
|
|
||||||
This module defines the configuration structure (`EncoderConfig`) and provides
|
This module defines the configuration structure (`EncoderConfig`) and provides
|
||||||
the `Encoder` class (an `nn.Module`) along with a factory function
|
the `Encoder` class (an `nn.Module`) along with a factory function
|
||||||
(`build_encoder`) to create sequential encoders commonly used as the
|
(`build_encoder`) to create sequential encoders. Encoders typically form the
|
||||||
downsampling path in architectures like U-Nets for spectrogram analysis.
|
downsampling path in architectures like U-Nets, processing input feature maps
|
||||||
|
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
|
||||||
|
representations (bottleneck features).
|
||||||
|
|
||||||
The encoder is built by stacking configurable downscaling blocks. Two types
|
The encoder is built dynamically by stacking neural network blocks based on a
|
||||||
of downscaling blocks are supported, selectable via the configuration:
|
list of configuration objects provided in `EncoderConfig.layers`. Each
|
||||||
- `StandardConvDownBlock`: A basic Conv2d -> MaxPool2d -> BN -> ReLU block.
|
configuration object specifies the type of block (e.g., standard convolution,
|
||||||
- `FreqCoordConvDownBlock`: A similar block that incorporates frequency
|
coordinate-feature convolution with downsampling) and its parameters
|
||||||
coordinate information (CoordF) before the convolution to potentially aid
|
(e.g., output channels). This allows for flexible definition of encoder
|
||||||
spatial awareness along the frequency axis.
|
architectures via configuration files.
|
||||||
|
|
||||||
The `Encoder`'s `forward` method provides access to intermediate feature maps
|
The `Encoder`'s `forward` method returns outputs from all intermediate layers,
|
||||||
from each stage, suitable for use as skip connections in a corresponding
|
suitable for skip connections, while the `encode` method returns only the final
|
||||||
Decoder. A separate `encode` method returns only the final output (bottleneck)
|
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
|
||||||
features.
|
provided.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from enum import Enum
|
from typing import Annotated, List, Optional, Union
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -27,142 +28,44 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
FreqCoordConvDownBlock,
|
ConvConfig,
|
||||||
StandardConvDownBlock,
|
FreqCoordConvDownConfig,
|
||||||
|
StandardConvDownConfig,
|
||||||
|
build_layer_from_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DownscalingLayer",
|
|
||||||
"EncoderLayer",
|
|
||||||
"EncoderConfig",
|
"EncoderConfig",
|
||||||
"Encoder",
|
"Encoder",
|
||||||
"build_encoder",
|
"build_encoder",
|
||||||
|
"DEFAULT_ENCODER_CONFIG",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
EncoderLayerConfig = Annotated[
|
||||||
class DownscalingLayer(str, Enum):
|
Union[ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig],
|
||||||
"""Enumeration of available downscaling layer types for the Encoder.
|
Field(discriminator="block_type"),
|
||||||
|
]
|
||||||
Used in configuration to specify which block implementation to use at each
|
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||||
stage of the encoder.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
standard : str
|
|
||||||
Identifier for the `StandardConvDownBlock`.
|
|
||||||
coord : str
|
|
||||||
Identifier for the `FreqCoordConvDownBlock` (incorporates frequency
|
|
||||||
coords).
|
|
||||||
"""
|
|
||||||
|
|
||||||
standard = "ConvBlockDownStandard"
|
|
||||||
coord = "FreqCoordConvDownBlock"
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderLayer(BaseConfig):
|
|
||||||
"""Configuration for a single layer within the Encoder sequence.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
layer_type : DownscalingLayer
|
|
||||||
Specifies the type of downscaling block to use for this layer
|
|
||||||
(either 'standard' or 'coord').
|
|
||||||
channels : int
|
|
||||||
The number of output channels this layer should produce. Must be > 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
layer_type: DownscalingLayer
|
|
||||||
channels: int
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderConfig(BaseConfig):
|
class EncoderConfig(BaseConfig):
|
||||||
"""Configuration for building the entire sequential Encoder.
|
"""Configuration for building the sequential Encoder module.
|
||||||
|
|
||||||
|
Defines the sequence of neural network blocks that constitute the encoder
|
||||||
|
(downsampling path).
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
input_height : int
|
layers : List[EncoderLayerConfig]
|
||||||
The expected height (number of frequency bins) of the input spectrogram
|
An ordered list of configuration objects, each defining one layer or
|
||||||
tensor fed into the first layer of the encoder. Required for
|
block in the encoder sequence. Each item must be a valid block config
|
||||||
calculating intermediate heights, especially for CoordF layers. Must be
|
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
||||||
> 0.
|
`StandardConvDownConfig`) including a `block_type` field and necessary
|
||||||
layers : List[EncoderLayer]
|
parameters like `out_channels`. Input channels for each layer are
|
||||||
An ordered list defining the sequence of downscaling layers in the
|
inferred sequentially. The list must contain at least one layer.
|
||||||
encoder. Each item specifies the layer type and its output channel
|
|
||||||
count. The number of input channels for each layer is inferred from the
|
|
||||||
previous layer's output channels (or `input_channels` for the first
|
|
||||||
layer). Must contain at least one layer definition.
|
|
||||||
input_channels : int, default=1
|
|
||||||
The number of channels in the initial input tensor to the encoder
|
|
||||||
(e.g., 1 for a standard single-channel spectrogram). Must be > 0.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_height: int = Field(gt=0)
|
layers: List[EncoderLayerConfig] = Field(min_length=1)
|
||||||
layers: List[EncoderLayer] = Field(min_length=1)
|
|
||||||
input_channels: int = Field(gt=0)
|
|
||||||
|
|
||||||
|
|
||||||
def build_downscaling_layer(
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
layer_type: DownscalingLayer,
|
|
||||||
) -> tuple[nn.Module, int, int]:
|
|
||||||
"""Build a single downscaling layer based on configuration.
|
|
||||||
|
|
||||||
Internal factory function used by `build_encoder`. Instantiates the
|
|
||||||
appropriate downscaling block (`StandardConvDownBlock` or
|
|
||||||
`FreqCoordConvDownBlock`) and returns it along with its expected output
|
|
||||||
channel count and output height (assuming 2x spatial downsampling).
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
in_channels : int
|
|
||||||
Number of input channels to the layer.
|
|
||||||
out_channels : int
|
|
||||||
Desired number of output channels from the layer.
|
|
||||||
input_height : int
|
|
||||||
Height of the input feature map to this layer.
|
|
||||||
layer_type : DownscalingLayer
|
|
||||||
The type of layer to build ('standard' or 'coord').
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Tuple[nn.Module, int, int]
|
|
||||||
A tuple containing:
|
|
||||||
- The instantiated `nn.Module` layer.
|
|
||||||
- The number of output channels (`out_channels`).
|
|
||||||
- The expected output height (`input_height // 2`).
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `layer_type` is invalid.
|
|
||||||
"""
|
|
||||||
if layer_type == DownscalingLayer.standard:
|
|
||||||
return (
|
|
||||||
StandardConvDownBlock(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
),
|
|
||||||
out_channels,
|
|
||||||
input_height // 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if layer_type == DownscalingLayer.coord:
|
|
||||||
return (
|
|
||||||
FreqCoordConvDownBlock(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
input_height=input_height,
|
|
||||||
),
|
|
||||||
out_channels,
|
|
||||||
input_height // 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid downscaling layer type {layer_type}. "
|
|
||||||
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
@ -178,12 +81,14 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
input_channels : int
|
in_channels : int
|
||||||
Number of channels expected in the input tensor.
|
Number of channels expected in the input tensor.
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (frequency bins) expected in the input tensor.
|
Height (frequency bins) expected in the input tensor.
|
||||||
output_channels : int
|
output_channels : int
|
||||||
Number of channels in the final output tensor (bottleneck).
|
Number of channels in the final output tensor (bottleneck).
|
||||||
|
output_height : int
|
||||||
|
Height (frequency bins) expected in the output tensor.
|
||||||
layers : nn.ModuleList
|
layers : nn.ModuleList
|
||||||
The sequence of instantiated downscaling layer modules.
|
The sequence of instantiated downscaling layer modules.
|
||||||
depth : int
|
depth : int
|
||||||
@ -193,9 +98,10 @@ class Encoder(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
output_channels: int,
|
output_channels: int,
|
||||||
|
output_height: int,
|
||||||
layers: List[nn.Module],
|
layers: List[nn.Module],
|
||||||
input_height: int = 128,
|
input_height: int = 128,
|
||||||
input_channels: int = 1,
|
in_channels: int = 1,
|
||||||
):
|
):
|
||||||
"""Initialize the Encoder module.
|
"""Initialize the Encoder module.
|
||||||
|
|
||||||
@ -206,20 +112,23 @@ class Encoder(nn.Module):
|
|||||||
----------
|
----------
|
||||||
output_channels : int
|
output_channels : int
|
||||||
Number of channels produced by the final layer.
|
Number of channels produced by the final layer.
|
||||||
|
output_height : int
|
||||||
|
The expected height of the output tensor.
|
||||||
layers : List[nn.Module]
|
layers : List[nn.Module]
|
||||||
A list of pre-instantiated downscaling layer modules (e.g.,
|
A list of pre-instantiated downscaling layer modules (e.g.,
|
||||||
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
|
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
|
||||||
sequence.
|
sequence.
|
||||||
input_height : int, default=128
|
input_height : int, default=128
|
||||||
Expected height of the input tensor.
|
Expected height of the input tensor.
|
||||||
input_channels : int, default=1
|
in_channels : int, default=1
|
||||||
Expected number of channels in the input tensor.
|
Expected number of channels in the input tensor.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.input_channels = input_channels
|
self.in_channels = in_channels
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.output_channels = output_channels
|
self.out_channels = output_channels
|
||||||
|
self.output_height = output_height
|
||||||
|
|
||||||
self.layers = nn.ModuleList(layers)
|
self.layers = nn.ModuleList(layers)
|
||||||
self.depth = len(self.layers)
|
self.depth = len(self.layers)
|
||||||
@ -234,7 +143,7 @@ class Encoder(nn.Module):
|
|||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
|
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
|
||||||
`self.input_channels` and `H_in` must match `self.input_height`.
|
`self.in_channels` and `H_in` must match `self.input_height`.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -249,10 +158,10 @@ class Encoder(nn.Module):
|
|||||||
If input tensor channel count or height does not match expected
|
If input tensor channel count or height does not match expected
|
||||||
values.
|
values.
|
||||||
"""
|
"""
|
||||||
if x.shape[1] != self.input_channels:
|
if x.shape[1] != self.in_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input tensor has {x.shape[1]} channels, "
|
f"Input tensor has {x.shape[1]} channels, "
|
||||||
f"but encoder expects {self.input_channels}."
|
f"but encoder expects {self.in_channels}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if x.shape[2] != self.input_height:
|
if x.shape[2] != self.input_height:
|
||||||
@ -279,7 +188,7 @@ class Encoder(nn.Module):
|
|||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
|
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
|
||||||
`input_channels` and `input_height`.
|
`in_channels` and `input_height`.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -293,10 +202,10 @@ class Encoder(nn.Module):
|
|||||||
If input tensor channel count or height does not match expected
|
If input tensor channel count or height does not match expected
|
||||||
values.
|
values.
|
||||||
"""
|
"""
|
||||||
if x.shape[1] != self.input_channels:
|
if x.shape[1] != self.in_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Input tensor has {x.shape[1]} channels, "
|
f"Input tensor has {x.shape[1]} channels, "
|
||||||
f"but encoder expects {self.input_channels}."
|
f"but encoder expects {self.in_channels}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if x.shape[2] != self.input_height:
|
if x.shape[2] != self.input_height:
|
||||||
@ -311,19 +220,53 @@ class Encoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def build_encoder(config: EncoderConfig) -> Encoder:
|
DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
||||||
|
layers=[
|
||||||
|
FreqCoordConvDownConfig(out_channels=32),
|
||||||
|
FreqCoordConvDownConfig(out_channels=64),
|
||||||
|
FreqCoordConvDownConfig(out_channels=128),
|
||||||
|
ConvConfig(out_channels=256),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
"""Default configuration for the Encoder.
|
||||||
|
|
||||||
|
Specifies an architecture typically used in BatDetect2:
|
||||||
|
- Input: 1 channel, 128 frequency bins.
|
||||||
|
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
|
||||||
|
- Layer 2: FreqCoordConvDown -> 64 channels, H=32
|
||||||
|
- Layer 3: FreqCoordConvDown -> 128 channels, H=16
|
||||||
|
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder(
|
||||||
|
in_channels: int,
|
||||||
|
input_height: int,
|
||||||
|
config: Optional[EncoderConfig] = None,
|
||||||
|
) -> Encoder:
|
||||||
"""Factory function to build an Encoder instance from configuration.
|
"""Factory function to build an Encoder instance from configuration.
|
||||||
|
|
||||||
Constructs a sequential `Encoder` module based on the specifications in
|
Constructs a sequential `Encoder` module based on the layer sequence
|
||||||
an `EncoderConfig` object. It iteratively builds the specified sequence
|
defined in an `EncoderConfig` object and the provided input dimensions.
|
||||||
of downscaling layers (`StandardConvDownBlock` or `FreqCoordConvDownBlock`),
|
If no config is provided, uses the default layer sequence from
|
||||||
tracking the changing number of channels and feature map height.
|
`DEFAULT_ENCODER_CONFIG`.
|
||||||
|
|
||||||
|
It iteratively builds the layers using the unified
|
||||||
|
`build_layer_from_config` factory (from `.blocks`), tracking the changing
|
||||||
|
number of channels and feature map height required for each subsequent
|
||||||
|
layer, especially for coordinate- aware blocks.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
config : EncoderConfig
|
in_channels : int
|
||||||
The configuration object detailing the encoder architecture, including
|
The number of channels expected in the input tensor to the encoder.
|
||||||
input dimensions, layer types, and channel counts for each stage.
|
Must be > 0.
|
||||||
|
input_height : int
|
||||||
|
The height (frequency bins) expected in the input tensor. Must be > 0.
|
||||||
|
Crucial for initializing coordinate-aware layers correctly.
|
||||||
|
config : EncoderConfig, optional
|
||||||
|
The configuration object detailing the sequence of layers and their
|
||||||
|
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -333,25 +276,33 @@ def build_encoder(config: EncoderConfig) -> Encoder:
|
|||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If the layer configuration is invalid (e.g., unknown layer type).
|
If `in_channels` or `input_height` are not positive, or if the layer
|
||||||
|
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||||
|
NotImplementedError
|
||||||
|
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||||
"""
|
"""
|
||||||
current_channels = config.input_channels
|
if in_channels <= 0 or input_height <= 0:
|
||||||
current_height = config.input_height
|
raise ValueError("in_channels and input_height must be positive.")
|
||||||
|
|
||||||
|
config = config or DEFAULT_ENCODER_CONFIG
|
||||||
|
|
||||||
|
current_channels = in_channels
|
||||||
|
current_height = input_height
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for layer_config in config.layers:
|
for layer_config in config.layers:
|
||||||
layer, current_channels, current_height = build_downscaling_layer(
|
layer, current_channels, current_height = build_layer_from_config(
|
||||||
in_channels=current_channels,
|
in_channels=current_channels,
|
||||||
out_channels=layer_config.channels,
|
|
||||||
input_height=current_height,
|
input_height=current_height,
|
||||||
layer_type=layer_config.layer_type,
|
config=layer_config,
|
||||||
)
|
)
|
||||||
layers.append(layer)
|
layers.append(layer)
|
||||||
|
|
||||||
return Encoder(
|
return Encoder(
|
||||||
input_height=config.input_height,
|
input_height=input_height,
|
||||||
layers=layers,
|
layers=layers,
|
||||||
input_channels=config.input_channels,
|
in_channels=in_channels,
|
||||||
output_channels=current_channels,
|
output_channels=current_channels,
|
||||||
|
output_height=current_height,
|
||||||
)
|
)
|
||||||
|
@ -1,42 +1,199 @@
|
|||||||
from typing import NamedTuple
|
"""Prediction Head modules for BatDetect2 models.
|
||||||
|
|
||||||
|
This module defines simple `torch.nn.Module` subclasses that serve as
|
||||||
|
prediction heads, typically attached to the output feature map of a backbone
|
||||||
|
network
|
||||||
|
|
||||||
|
Each head is responsible for generating one specific type of output required
|
||||||
|
by the BatDetect2 task:
|
||||||
|
- `DetectorHead`: Predicts the probability of sound event presence.
|
||||||
|
- `ClassifierHead`: Predicts the probability distribution over target classes.
|
||||||
|
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding
|
||||||
|
box.
|
||||||
|
|
||||||
|
These heads use 1x1 convolutions to map the backbone feature channels
|
||||||
|
to the desired number of output channels for each prediction task at each
|
||||||
|
spatial location, followed by an appropriate activation function (e.g., sigmoid
|
||||||
|
for detection, softmax for classification, none for size regression).
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
__all__ = ["ClassifierHead"]
|
__all__ = [
|
||||||
|
"ClassifierHead",
|
||||||
|
"DetectorHead",
|
||||||
class Output(NamedTuple):
|
"BBoxHead",
|
||||||
detection: torch.Tensor
|
]
|
||||||
classification: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifierHead(nn.Module):
|
class ClassifierHead(nn.Module):
|
||||||
|
"""Prediction head for multi-class classification probabilities.
|
||||||
|
|
||||||
|
Takes an input feature map and produces a probability map where each
|
||||||
|
channel corresponds to a specific target class. It uses a 1x1 convolution
|
||||||
|
to map input channels to `num_classes + 1` outputs (one for each target
|
||||||
|
class plus an assumed background/generic class), applies softmax across the
|
||||||
|
channels, and returns the probabilities for the specific target classes
|
||||||
|
(excluding the last background/generic channel).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num_classes : int
|
||||||
|
The number of specific target classes the model should predict
|
||||||
|
(excluding any background or generic category). Must be positive.
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input feature map tensor from the backbone.
|
||||||
|
Must be positive.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
num_classes : int
|
||||||
|
Number of specific output classes.
|
||||||
|
in_channels : int
|
||||||
|
Number of input channels expected.
|
||||||
|
classifier : nn.Conv2d
|
||||||
|
The 1x1 convolutional layer used for prediction.
|
||||||
|
Output channels = num_classes + 1.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `num_classes` or `in_channels` are not positive.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, num_classes: int, in_channels: int):
|
def __init__(self, num_classes: int, in_channels: int):
|
||||||
|
"""Initialize the ClassifierHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.classifier = nn.Conv2d(
|
self.classifier = nn.Conv2d(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
# Add one to account for the background class
|
|
||||||
self.num_classes + 1,
|
self.num_classes + 1,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, features: torch.Tensor) -> Output:
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute class probabilities from input features.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
features : torch.Tensor
|
||||||
|
Input feature map tensor from the backbone, typically with shape
|
||||||
|
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Class probability map tensor with shape `(B, num_classes, H, W)`.
|
||||||
|
Contains probabilities for the specific target classes after
|
||||||
|
softmax, excluding the implicit background/generic class channel.
|
||||||
|
"""
|
||||||
logits = self.classifier(features)
|
logits = self.classifier(features)
|
||||||
probs = torch.softmax(logits, dim=1)
|
probs = torch.softmax(logits, dim=1)
|
||||||
detection_probs = probs[:, :-1].sum(dim=1, keepdim=True)
|
return probs[:, :-1]
|
||||||
return Output(
|
|
||||||
detection=detection_probs,
|
|
||||||
classification=probs[:, :-1],
|
class DetectorHead(nn.Module):
|
||||||
|
"""Prediction head for sound event detection probability.
|
||||||
|
|
||||||
|
Takes an input feature map and produces a single-channel heatmap where
|
||||||
|
each value represents the probability ([0, 1]) of a relevant sound event
|
||||||
|
(of any class) being present at that spatial location.
|
||||||
|
|
||||||
|
Uses a 1x1 convolution to map input channels to 1 output channel, followed
|
||||||
|
by a sigmoid activation function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input feature map tensor from the backbone.
|
||||||
|
Must be positive.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of input channels expected.
|
||||||
|
detector : nn.Conv2d
|
||||||
|
The 1x1 convolutional layer mapping to a single output channel.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `in_channels` is not positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
"""Initialize the DetectorHead."""
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.detector = nn.Conv2d(
|
||||||
|
in_channels=self.in_channels,
|
||||||
|
out_channels=1,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute detection probabilities from input features.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
features : torch.Tensor
|
||||||
|
Input feature map tensor from the backbone, typically with shape
|
||||||
|
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Detection probability heatmap tensor with shape `(B, 1, H, W)`.
|
||||||
|
Values are in the range [0, 1] due to the sigmoid activation.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
RuntimeError
|
||||||
|
If input channel count does not match `self.in_channels`.
|
||||||
|
"""
|
||||||
|
return torch.sigmoid(self.detector(features))
|
||||||
|
|
||||||
|
|
||||||
class BBoxHead(nn.Module):
|
class BBoxHead(nn.Module):
|
||||||
|
"""Prediction head for bounding box size dimensions.
|
||||||
|
|
||||||
|
Takes an input feature map and produces a two-channel map where each
|
||||||
|
channel represents a predicted size dimension (typically width/duration and
|
||||||
|
height/bandwidth) for a potential sound event at that spatial location.
|
||||||
|
|
||||||
|
Uses a 1x1 convolution to map input channels to 2 output channels. No
|
||||||
|
activation function is typically applied, as size prediction is often
|
||||||
|
treated as a direct regression task. The output values usually represent
|
||||||
|
*scaled* dimensions that need to be un-scaled during postprocessing.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input feature map tensor from the backbone.
|
||||||
|
Must be positive.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of input channels expected.
|
||||||
|
bbox : nn.Conv2d
|
||||||
|
The 1x1 convolutional layer mapping to 2 output channels
|
||||||
|
(width, height).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `in_channels` is not positive.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int):
|
def __init__(self, in_channels: int):
|
||||||
|
"""Initialize the BBoxHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
@ -48,4 +205,19 @@ class BBoxHead(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute predicted bounding box dimensions from input features.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
features : torch.Tensor
|
||||||
|
Input feature map tensor from the backbone, typically with shape
|
||||||
|
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually
|
||||||
|
represents scaled width, Channel 1 scaled height. These values
|
||||||
|
need to be un-scaled during postprocessing.
|
||||||
|
"""
|
||||||
return self.bbox(features)
|
return self.bbox(features)
|
||||||
|
Loading…
Reference in New Issue
Block a user