Update model docstrings

This commit is contained in:
mbsantiago 2026-03-08 16:34:17 +00:00
parent 4207661da4
commit ef3348d651
9 changed files with 873 additions and 613 deletions

View File

@ -1,29 +1,29 @@
"""Defines and builds the neural network models used in BatDetect2. """Neural network model definitions and builders for BatDetect2.
This package (`batdetect2.models`) contains the PyTorch implementations of the This package contains the PyTorch implementations of the deep neural network
deep neural network architectures used for detecting and classifying bat calls architectures used to detect and classify bat echolocation calls in
from spectrograms. It provides modular components and configuration-driven spectrograms. Components are designed to be combined through configuration
assembly, allowing for experimentation and use of different architectural objects, making it easy to experiment with different architectures.
variants.
Key Submodules: Key submodules
- `.types`: Defines core data structures (`ModelOutput`) and abstract base --------------
classes (`BackboneModel`, `DetectionModel`) establishing interfaces. - ``blocks``: Reusable convolutional building blocks (downsampling,
- `.blocks`: Provides reusable neural network building blocks. upsampling, attention, coord-conv variants).
- `.encoder`: Defines and builds the downsampling path (encoder) of the network. - ``encoder``: The downsampling path; reduces spatial resolution whilst
- `.bottleneck`: Defines and builds the central bottleneck component. extracting increasingly abstract features.
- `.decoder`: Defines and builds the upsampling path (decoder) of the network. - ``bottleneck``: The central component connecting encoder to decoder;
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete optionally applies self-attention along the time axis.
feature extraction backbone (e.g., a U-Net like structure). - ``decoder``: The upsampling path; reconstructs high-resolution feature
- `.heads`: Defines simple prediction heads (detection, classification, size) maps using bottleneck output and skip connections from the encoder.
that attach to the backbone features. - ``backbones``: Assembles encoder, bottleneck, and decoder into a complete
- `.detectors`: Assembles the backbone and prediction heads into the final, U-Net-style feature extraction backbone.
end-to-end `Detector` model. - ``heads``: Lightweight 1×1 convolutional heads that produce detection,
classification, and bounding-box size predictions from backbone features.
- ``detectors``: Combines a backbone with prediction heads into the final
end-to-end ``Detector`` model.
This module re-exports the most important classes, configurations, and builder The primary entry point for building a full, ready-to-use BatDetect2 model
functions from these submodules for convenient access. The primary entry point is the ``build_model`` factory function exported from this module.
for creating a standard BatDetect2 model instance is the `build_model` function
provided here.
""" """
from typing import List from typing import List
@ -99,6 +99,31 @@ __all__ = [
class Model(torch.nn.Module): class Model(torch.nn.Module):
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
Combines a preprocessor, a detection model, and a postprocessor into a
single PyTorch module. Calling ``forward`` on a raw waveform tensor
returns a list of detection tensors ready for downstream use.
This class is the top-level object produced by ``build_model``. Most
users will not need to construct it directly.
Attributes
----------
detector : DetectionModel
The neural network that processes spectrograms and produces raw
detection, classification, and bounding-box outputs.
preprocessor : PreprocessorProtocol
Converts a raw waveform tensor into a spectrogram tensor accepted by
``detector``.
postprocessor : PostprocessorProtocol
Converts the raw ``ModelOutput`` from ``detector`` into a list of
per-clip detection tensors.
targets : TargetProtocol
Describes the set of target classes; used when building heads and
during training target construction.
"""
detector: DetectionModel detector: DetectionModel
preprocessor: PreprocessorProtocol preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol postprocessor: PostprocessorProtocol
@ -118,6 +143,25 @@ class Model(torch.nn.Module):
self.targets = targets self.targets = targets
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]: def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor.
Converts the waveform to a spectrogram, passes it through the
detector, and postprocesses the raw outputs into detection tensors.
Parameters
----------
wav : torch.Tensor
Raw audio waveform tensor. The exact expected shape depends on
the preprocessor, but is typically ``(batch, samples)`` or
``(batch, channels, samples)``.
Returns
-------
List[ClipDetectionsTensor]
One detection tensor per clip in the batch. Each tensor encodes
the detected events (locations, class scores, sizes) for that
clip.
"""
spec = self.preprocessor(wav) spec = self.preprocessor(wav)
outputs = self.detector(spec) outputs = self.detector(spec)
return self.postprocessor(outputs) return self.postprocessor(outputs)
@ -128,7 +172,38 @@ def build_model(
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None,
): ) -> "Model":
"""Build a complete, ready-to-use BatDetect2 model.
Assembles a ``Model`` instance from optional configuration and component
overrides. Any argument left as ``None`` will be replaced by a sensible
default built with the project's own builder functions.
Parameters
----------
config : BackboneConfig, optional
Configuration describing the backbone architecture (encoder,
bottleneck, decoder). Defaults to ``UNetBackboneConfig()`` if not
provided.
targets : TargetProtocol, optional
Describes the target bat species or call types to detect. Determines
the number of output classes. Defaults to the standard BatDetect2
target set.
preprocessor : PreprocessorProtocol, optional
Converts raw audio waveforms to spectrograms. Defaults to the
standard BatDetect2 preprocessor.
postprocessor : PostprocessorProtocol, optional
Converts raw model outputs to detection tensors. Defaults to the
standard BatDetect2 postprocessor. If a custom ``preprocessor`` is
given without a matching ``postprocessor``, the default postprocessor
will be built using the provided preprocessor so that frequency and
time scaling remain consistent.
Returns
-------
Model
A fully assembled ``Model`` instance ready for inference or training.
"""
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets

View File

@ -1,21 +1,26 @@
"""Assembles a complete Encoder-Decoder Backbone network. """Assembles a complete encoder-decoder backbone network.
This module defines the configuration (`BackboneConfig`) and implementation This module defines ``UNetBackboneConfig`` and the ``UNetBackbone``
(`Backbone`) for a standard encoder-decoder style neural network backbone. ``nn.Module``, together with the ``build_backbone`` and
``load_backbone_config`` helpers.
It orchestrates the connection between three main components, built using their A backbone combines three components built from the sibling modules:
respective configurations and factory functions from sibling modules:
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
at multiple resolutions and provides skip connections.
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
lowest resolution, optionally applying self-attention.
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
resolution features using bottleneck features and skip connections.
The resulting `Backbone` module takes a spectrogram as input and outputs a 1. **Encoder** (``batdetect2.models.encoder``) reduces spatial resolution
final feature map, typically used by subsequent prediction heads. It includes while extracting hierarchical features and storing skip-connection tensors.
automatic padding to handle input sizes not perfectly divisible by the 2. **Bottleneck** (``batdetect2.models.bottleneck``) processes the
network's total downsampling factor. lowest-resolution features, optionally applying self-attention.
3. **Decoder** (``batdetect2.models.decoder``) restores spatial resolution
using bottleneck features and skip connections from the encoder.
The resulting ``UNetBackbone`` takes a spectrogram tensor as input and returns
a high-resolution feature map consumed by the prediction heads in
``batdetect2.models.detectors``.
Input padding is handled automatically: the backbone pads the input to be
divisible by the total downsampling factor and strips the padding from the
output so that the output spatial dimensions always match the input spatial
dimensions.
""" """
from typing import Annotated, Literal, Tuple, Union from typing import Annotated, Literal, Tuple, Union
@ -51,6 +56,34 @@ from batdetect2.typing.models import (
class UNetBackboneConfig(BaseConfig): class UNetBackboneConfig(BaseConfig):
"""Configuration for a U-Net-style encoder-decoder backbone.
All fields have sensible defaults that reproduce the standard BatDetect2
architecture, so you can start with ``UNetBackboneConfig()`` and override
only the fields you want to change.
Attributes
----------
name : str
Discriminator field used by the backbone registry; always
``"UNetBackbone"``.
input_height : int
Number of frequency bins in the input spectrogram. Defaults to
``128``.
in_channels : int
Number of channels in the input spectrogram (e.g. ``1`` for a
standard mel-spectrogram). Defaults to ``1``.
encoder : EncoderConfig
Configuration for the downsampling path. Defaults to
``DEFAULT_ENCODER_CONFIG``.
bottleneck : BottleneckConfig
Configuration for the bottleneck. Defaults to
``DEFAULT_BOTTLENECK_CONFIG``.
decoder : DecoderConfig
Configuration for the upsampling path. Defaults to
``DEFAULT_DECODER_CONFIG``.
"""
name: Literal["UNetBackbone"] = "UNetBackbone" name: Literal["UNetBackbone"] = "UNetBackbone"
input_height: int = 128 input_height: int = 128
in_channels: int = 1 in_channels: int = 1
@ -70,35 +103,36 @@ __all__ = [
class UNetBackbone(BackboneModel): class UNetBackbone(BackboneModel):
"""Encoder-Decoder Backbone Network Implementation. """U-Net-style encoder-decoder backbone network.
Combines an Encoder, Bottleneck, and Decoder module sequentially, using Combines an encoder, a bottleneck, and a decoder into a single module
skip connections between the Encoder and Decoder. Implements the standard that produces a high-resolution feature map from an input spectrogram.
U-Net style forward pass. Includes automatic input padding to handle Skip connections from each encoder stage are added element-wise to the
various input sizes and a final convolutional block to adjust the output corresponding decoder stage input.
channels.
This class inherits from `BackboneModel` and implements its `forward` Input spectrograms of arbitrary width are handled automatically: the
method. Instances are typically created using the `build_backbone` factory backbone pads the input so that its dimensions are divisible by
function. ``divide_factor`` and removes the padding from the output.
Instances are typically created via ``build_backbone``.
Attributes Attributes
---------- ----------
input_height : int input_height : int
Expected height of the input spectrogram. Expected height (frequency bins) of the input spectrogram.
out_channels : int out_channels : int
Number of channels in the final output feature map. Number of channels in the output feature map (taken from the
decoder's output channel count).
encoder : EncoderProtocol encoder : EncoderProtocol
The instantiated encoder module. The instantiated encoder module.
decoder : DecoderProtocol decoder : DecoderProtocol
The instantiated decoder module. The instantiated decoder module.
bottleneck : BottleneckProtocol bottleneck : BottleneckProtocol
The instantiated bottleneck module. The instantiated bottleneck module.
final_conv : ConvBlock
Final convolutional block applied after the decoder.
divide_factor : int divide_factor : int
The total downsampling factor (2^depth) applied by the encoder, The total spatial downsampling factor applied by the encoder
used for automatic input padding. (``input_height // encoder.output_height``). The input width is
padded to be a multiple of this value before processing.
""" """
def __init__( def __init__(
@ -108,25 +142,19 @@ class UNetBackbone(BackboneModel):
decoder: DecoderProtocol, decoder: DecoderProtocol,
bottleneck: BottleneckProtocol, bottleneck: BottleneckProtocol,
): ):
"""Initialize the Backbone network. """Initialise the backbone network.
Parameters Parameters
---------- ----------
input_height : int input_height : int
Expected height of the input spectrogram. Expected height (frequency bins) of the input spectrogram.
out_channels : int
Desired number of output channels for the backbone's feature map.
encoder : EncoderProtocol encoder : EncoderProtocol
An initialized Encoder module. An initialised encoder module.
decoder : DecoderProtocol decoder : DecoderProtocol
An initialized Decoder module. An initialised decoder module. Its ``output_height`` must equal
``input_height``; a ``ValueError`` is raised otherwise.
bottleneck : BottleneckProtocol bottleneck : BottleneckProtocol
An initialized Bottleneck module. An initialised bottleneck module.
Raises
------
ValueError
If component output/input channels or heights are incompatible.
""" """
super().__init__() super().__init__()
self.input_height = input_height self.input_height = input_height
@ -143,22 +171,25 @@ class UNetBackbone(BackboneModel):
self.divide_factor = input_height // self.encoder.output_height self.divide_factor = input_height // self.encoder.output_height
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Perform the forward pass through the encoder-decoder backbone. """Produce a feature map from an input spectrogram.
Applies padding, runs encoder, bottleneck, decoder (with skip Pads the input if necessary, runs it through the encoder, then
connections), removes padding, and applies a final convolution. the bottleneck, then the decoder (incorporating encoder skip
connections), and finally removes any padding added earlier.
Parameters Parameters
---------- ----------
spec : torch.Tensor spec : torch.Tensor
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match Input spectrogram tensor, shape
`self.encoder.input_channels` and `self.input_height`. ``(B, C_in, H_in, W_in)``. ``H_in`` must equal
``self.input_height``; ``W_in`` can be any positive integer.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where Feature map tensor, shape ``(B, C_out, H_in, W_in)``, where
`C_out` is `self.out_channels`. ``C_out`` is ``self.out_channels``. The spatial dimensions
always match those of the input.
""" """
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor) spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
@ -219,6 +250,24 @@ BackboneConfig = Annotated[
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel: def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
"""Build a backbone network from configuration.
Looks up the backbone class corresponding to ``config.name`` in the
backbone registry and calls its ``from_config`` method. If no
configuration is provided, a default ``UNetBackbone`` is returned.
Parameters
----------
config : BackboneConfig, optional
A configuration object describing the desired backbone. Currently
``UNetBackboneConfig`` is the only supported type. Defaults to
``UNetBackboneConfig()`` if not provided.
Returns
-------
BackboneModel
An initialised backbone module.
"""
config = config or UNetBackboneConfig() config = config or UNetBackboneConfig()
return backbone_registry.build(config) return backbone_registry.build(config)
@ -227,26 +276,25 @@ def _pad_adjust(
spec: torch.Tensor, spec: torch.Tensor,
factor: int = 32, factor: int = 32,
) -> Tuple[torch.Tensor, int, int]: ) -> Tuple[torch.Tensor, int, int]:
"""Pad tensor height and width to be divisible by a factor. """Pad a tensor's height and width to be divisible by ``factor``.
Calculates the required padding for the last two dimensions (H, W) to make Adds zero-padding to the bottom and right edges of the tensor so that
them divisible by `factor` and applies right/bottom padding using both dimensions are exact multiples of ``factor``. If both dimensions
`torch.nn.functional.pad`. are already divisible, the tensor is returned unchanged.
Parameters Parameters
---------- ----------
spec : torch.Tensor spec : torch.Tensor
Input tensor, typically shape `(B, C, H, W)`. Input tensor, typically shape ``(B, C, H, W)``.
factor : int, default=32 factor : int, default=32
The factor to make height and width divisible by. The factor that both H and W should be divisible by after padding.
Returns Returns
------- -------
Tuple[torch.Tensor, int, int] Tuple[torch.Tensor, int, int]
A tuple containing: - Padded tensor.
- The padded tensor. - Number of rows added to the height (``h_pad``).
- The amount of padding added to height (`h_pad`). - Number of columns added to the width (``w_pad``).
- The amount of padding added to width (`w_pad`).
""" """
h, w = spec.shape[-2:] h, w = spec.shape[-2:]
h_pad = -h % factor h_pad = -h % factor
@ -261,23 +309,25 @@ def _pad_adjust(
def _restore_pad( def _restore_pad(
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0 x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
) -> torch.Tensor: ) -> torch.Tensor:
"""Remove padding added by _pad_adjust. """Remove padding previously added by ``_pad_adjust``.
Removes padding from the bottom and right edges of the tensor. Trims ``h_pad`` rows from the bottom and ``w_pad`` columns from the
right of the tensor, restoring its original spatial dimensions.
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Padded tensor, typically shape `(B, C, H_padded, W_padded)`. Padded tensor, typically shape ``(B, C, H_padded, W_padded)``.
h_pad : int, default=0 h_pad : int, default=0
Amount of padding previously added to the height (bottom). Number of rows to remove from the bottom.
w_pad : int, default=0 w_pad : int, default=0
Amount of padding previously added to the width (right). Number of columns to remove from the right.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Tensor with padding removed, shape `(B, C, H_original, W_original)`. Tensor with padding removed, shape
``(B, C, H_padded - h_pad, W_padded - w_pad)``.
""" """
if h_pad > 0: if h_pad > 0:
x = x[..., :-h_pad, :] x = x[..., :-h_pad, :]
@ -292,6 +342,36 @@ def load_backbone_config(
path: data.PathLike, path: data.PathLike,
field: str | None = None, field: str | None = None,
) -> BackboneConfig: ) -> BackboneConfig:
"""Load a backbone configuration from a YAML or JSON file.
Reads the file at ``path``, optionally descends into a named sub-field,
and validates the result against the ``BackboneConfig`` discriminated
union.
Parameters
----------
path : PathLike
Path to the configuration file. Both YAML and JSON formats are
supported.
field : str, optional
Dot-separated key path to the sub-field that contains the backbone
configuration (e.g. ``"model"``). If ``None``, the root of the
file is used.
Returns
-------
BackboneConfig
A validated backbone configuration object (currently always a
``UNetBackboneConfig`` instance).
Raises
------
FileNotFoundError
If ``path`` does not exist.
ValidationError
If the loaded data does not conform to a known ``BackboneConfig``
schema.
"""
return load_config( return load_config(
path, path,
schema=TypeAdapter(BackboneConfig), schema=TypeAdapter(BackboneConfig),

View File

@ -1,30 +1,49 @@
"""Commonly used neural network building blocks for BatDetect2 models. """Reusable convolutional building blocks for BatDetect2 models.
This module provides various reusable `torch.nn.Module` subclasses that form This module provides a collection of ``torch.nn.Module`` subclasses that form
the fundamental building blocks for constructing convolutional neural network the fundamental building blocks for the encoder-decoder backbone used in
architectures, particularly encoder-decoder backbones used in BatDetect2. BatDetect2. All blocks follow a consistent interface: they store
``in_channels`` and ``out_channels`` as attributes and implement a
``get_output_height`` method that reports how a given input height maps to an
output height (e.g., halved by downsampling blocks, doubled by upsampling
blocks).
It includes standard components like basic convolutional blocks (`ConvBlock`), Available block families
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with ------------------------
upsampling (`StandardConvUpBlock`). Standard blocks
``ConvBlock`` convolution + batch normalisation + ReLU, no change in
spatial resolution.
Additionally, it features specialized layers investigated in BatDetect2 Downsampling blocks
research: ``StandardConvDownBlock`` convolution then 2×2 max-pooling, halves H
and W.
``FreqCoordConvDownBlock`` like ``StandardConvDownBlock`` but prepends
a normalised frequency-coordinate channel before the convolution
(CoordConv concept), helping filters learn frequency-position-dependent
patterns.
- `SelfAttention`: Applies self-attention along the time dimension, enabling Upsampling blocks
the model to weigh information across the entire temporal context, often ``StandardConvUpBlock`` bilinear interpolation then convolution,
used in the bottleneck of an encoder-decoder. doubles H and W.
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv" ``FreqCoordConvUpBlock`` like ``StandardConvUpBlock`` but prepends a
concept by concatenating normalized frequency coordinate information as an frequency-coordinate channel after upsampling.
extra channel to the input of convolutional layers. This explicitly provides
spatial frequency information to filters, potentially enabling them to learn
frequency-dependent patterns more effectively.
These blocks can be used directly in custom PyTorch model definitions or Bottleneck blocks
assembled into larger architectures. ``VerticalConv`` 1-D convolution whose kernel spans the entire
frequency axis, collapsing H to 1 whilst preserving W.
``SelfAttention`` scaled dot-product self-attention along the time
axis; typically follows a ``VerticalConv``.
A unified factory function `build_layer` allows creating instances Group block
of these blocks based on configuration objects. ``LayerGroup`` chains several blocks sequentially into one unit,
useful when a single encoder or decoder "stage" requires more than one
operation.
Factory function
----------------
``build_layer`` creates any of the above blocks from the matching
configuration object (one of the ``*Config`` classes exported here), using
a discriminated-union ``name`` field to dispatch to the correct class.
""" """
from typing import Annotated, List, Literal, Tuple, Union from typing import Annotated, List, Literal, Tuple, Union
@ -57,10 +76,43 @@ __all__ = [
class Block(nn.Module): class Block(nn.Module):
"""Abstract base class for all BatDetect2 building blocks.
Subclasses must set ``in_channels`` and ``out_channels`` as integer
attributes so that factory functions can wire blocks together without
inspecting configuration objects at runtime. They may also override
``get_output_height`` when the block changes the height dimension (e.g.
downsampling or upsampling blocks).
Attributes
----------
in_channels : int
Number of channels expected in the input tensor.
out_channels : int
Number of channels produced in the output tensor.
"""
in_channels: int in_channels: int
out_channels: int out_channels: int
def get_output_height(self, input_height: int) -> int: def get_output_height(self, input_height: int) -> int:
"""Return the output height for a given input height.
The default implementation returns ``input_height`` unchanged,
which is correct for blocks that do not alter spatial resolution.
Override this in downsampling (returns ``input_height // 2``) or
upsampling (returns ``input_height * 2``) subclasses.
Parameters
----------
input_height : int
Height (number of frequency bins) of the input feature map.
Returns
-------
int
Height of the output feature map.
"""
return input_height return input_height
@ -68,58 +120,65 @@ block_registry: Registry[Block, [int, int]] = Registry("block")
class SelfAttentionConfig(BaseConfig): class SelfAttentionConfig(BaseConfig):
"""Configuration for a ``SelfAttention`` block.
Attributes
----------
name : str
Discriminator field; always ``"SelfAttention"``.
attention_channels : int
Dimensionality of the query, key, and value projections.
temperature : float
Scaling factor applied to the weighted values before the final
linear projection. Defaults to ``1``.
"""
name: Literal["SelfAttention"] = "SelfAttention" name: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int attention_channels: int
temperature: float = 1 temperature: float = 1
class SelfAttention(Block): class SelfAttention(Block):
"""Self-Attention mechanism operating along the time dimension. """Self-attention block operating along the time axis.
This module implements a scaled dot-product self-attention mechanism, Applies a scaled dot-product self-attention mechanism across the time
specifically designed here to operate across the time steps of an input steps of an input feature map. Before attention is computed the height
feature map, typically after spatial dimensions (like frequency) have been dimension (frequency axis) is expected to have been reduced to 1, e.g.
condensed or squeezed. by a preceding ``VerticalConv`` layer.
By calculating attention weights between all pairs of time steps, it allows For each time step the block computes query, key, and value projections
the model to capture long-range temporal dependencies and focus on relevant with learned linear weights, then calculates attention weights from the
parts of the sequence. It's often employed in the bottleneck or querykey dot products divided by ``temperature × attention_channels``.
intermediate layers of an encoder-decoder architecture to integrate global The weighted sum of values is projected back to ``in_channels`` via a
temporal context. final linear layer, and the height dimension is restored so that the
output shape matches the input shape.
The implementation uses linear projections to create query, key, and value
representations, computes scaled dot-product attention scores, applies
softmax, and produces an output by weighting the values according to the
attention scores, followed by a final linear projection. Positional encoding
is not explicitly included in this block.
Parameters Parameters
---------- ----------
in_channels : int in_channels : int
Number of input channels (features per time step after spatial squeeze). Number of input channels (features per time step). The output will
also have ``in_channels`` channels.
attention_channels : int attention_channels : int
Number of channels for the query, key, and value projections. Also the Dimensionality of the query, key, and value projections.
dimension of the output projection's input.
temperature : float, default=1.0 temperature : float, default=1.0
Scaling factor applied *before* the final projection layer. Can be used Divisor applied together with ``attention_channels`` when scaling
to adjust the sharpness or focus of the attention mechanism, although the dot-product scores before softmax. Larger values produce softer
scaling within the softmax (dividing by sqrt(dim)) is more common for (more uniform) attention distributions.
standard transformers. Here it scales the weighted values.
Attributes Attributes
---------- ----------
key_fun : nn.Linear key_fun : nn.Linear
Linear layer for key projection. Linear projection for keys.
value_fun : nn.Linear value_fun : nn.Linear
Linear layer for value projection. Linear projection for values.
query_fun : nn.Linear query_fun : nn.Linear
Linear layer for query projection. Linear projection for queries.
pro_fun : nn.Linear pro_fun : nn.Linear
Final linear projection layer applied after attention weighting. Final linear projection applied to the attended values.
temperature : float temperature : float
Scaling factor applied before final projection. Scaling divisor used when computing attention scores.
att_dim : int att_dim : int
Dimensionality of the attention space (`attention_channels`). Dimensionality of the attention space (``attention_channels``).
""" """
def __init__( def __init__(
@ -148,20 +207,16 @@ class SelfAttention(Block):
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor, expected shape `(B, C, H, W)`, where H is typically Input tensor with shape ``(B, C, 1, W)``. The height dimension
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before must be 1 (i.e. the frequency axis should already have been
applying attention along the W (time) dimension. collapsed by a preceding ``VerticalConv`` layer).
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor of the same shape as the input `(B, C, H, W)`, where Output tensor with the same shape ``(B, C, 1, W)`` as the
attention has been applied across the W dimension. input, with each time step updated by attended context from all
other time steps.
Raises
------
RuntimeError
If input tensor dimensions are incompatible with operations.
""" """
x = x.squeeze(2).permute(0, 2, 1) x = x.squeeze(2).permute(0, 2, 1)
@ -190,6 +245,22 @@ class SelfAttention(Block):
return op return op
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor: def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
"""Return the softmax attention weight matrix.
Useful for visualising which time steps attend to which others.
Parameters
----------
x : torch.Tensor
Input tensor with shape ``(B, C, 1, W)``.
Returns
-------
torch.Tensor
Attention weight matrix with shape ``(B, W, W)``. Entry
``[b, i, j]`` is the attention weight that time step ``i``
assigns to time step ``j`` in batch item ``b``.
"""
x = x.squeeze(2).permute(0, 2, 1) x = x.squeeze(2).permute(0, 2, 1)
key = torch.matmul( key = torch.matmul(
@ -304,6 +375,16 @@ class ConvBlock(Block):
class VerticalConvConfig(BaseConfig): class VerticalConvConfig(BaseConfig):
"""Configuration for a ``VerticalConv`` block.
Attributes
----------
name : str
Discriminator field; always ``"VerticalConv"``.
channels : int
Number of output channels produced by the vertical convolution.
"""
name: Literal["VerticalConv"] = "VerticalConv" name: Literal["VerticalConv"] = "VerticalConv"
channels: int channels: int
@ -844,12 +925,53 @@ LayerConfig = Annotated[
class LayerGroupConfig(BaseConfig): class LayerGroupConfig(BaseConfig):
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
Use this when a single encoder or decoder stage needs more than one
block. The blocks are executed in the order they appear in ``layers``,
with channel counts and heights propagated automatically.
Attributes
----------
name : str
Discriminator field; always ``"LayerGroup"``.
layers : List[LayerConfig]
Ordered list of block configurations to chain together.
"""
name: Literal["LayerGroup"] = "LayerGroup" name: Literal["LayerGroup"] = "LayerGroup"
layers: List[LayerConfig] layers: List[LayerConfig]
class LayerGroup(nn.Module): class LayerGroup(nn.Module):
"""Standard implementation of the `LayerGroup` architecture.""" """Sequential chain of blocks that acts as a single composite block.
Wraps multiple ``Block`` instances in an ``nn.Sequential`` container,
exposing the same ``in_channels``, ``out_channels``, and
``get_output_height`` interface as a regular ``Block`` so it can be
used transparently wherever a single block is expected.
Instances are typically constructed by ``build_layer`` when given a
``LayerGroupConfig``; you rarely need to create them directly.
Parameters
----------
layers : list[Block]
Pre-built block instances to chain, in execution order.
input_height : int
Height of the tensor entering the first block.
input_channels : int
Number of channels in the tensor entering the first block.
Attributes
----------
in_channels : int
Number of input channels (taken from the first block).
out_channels : int
Number of output channels (taken from the last block).
layers : nn.Sequential
The wrapped sequence of block modules.
"""
def __init__( def __init__(
self, self,
@ -865,9 +987,33 @@ class LayerGroup(nn.Module):
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Pass input through all blocks in sequence.
Parameters
----------
x : torch.Tensor
Input feature map, shape ``(B, C_in, H, W)``.
Returns
-------
torch.Tensor
Output feature map after all blocks have been applied.
"""
return self.layers(x) return self.layers(x)
def get_output_height(self, input_height: int) -> int: def get_output_height(self, input_height: int) -> int:
"""Compute the output height by propagating through all blocks.
Parameters
----------
input_height : int
Height of the input feature map.
Returns
-------
int
Height after all blocks in the group have been applied.
"""
for block in self.layers: for block in self.layers:
input_height = block.get_output_height(input_height) # type: ignore input_height = block.get_output_height(input_height) # type: ignore
return input_height return input_height
@ -903,40 +1049,37 @@ def build_layer(
in_channels: int, in_channels: int,
config: LayerConfig, config: LayerConfig,
) -> Block: ) -> Block:
"""Factory function to build a specific nn.Module block from its config. """Build a block from its configuration object.
Takes configuration object (one of the types included in the `LayerConfig` Looks up the block class corresponding to ``config.name`` in the
union) and instantiates the corresponding nn.Module block with the correct internal block registry and instantiates it with the given input
parameters derived from the config and the current pipeline state dimensions. This is the standard way to construct blocks when
(`input_height`, `in_channels`). assembling an encoder or decoder from a configuration file.
It uses the `name` field within the `config` object to determine
which block class to instantiate.
Parameters Parameters
---------- ----------
input_height : int input_height : int
Height (frequency bins) of the input tensor *to this layer*. Height (number of frequency bins) of the input tensor to this
block. Required for blocks whose kernel size depends on the input
height (e.g. ``VerticalConv``) and for coordinate-aware blocks.
in_channels : int in_channels : int
Number of channels in the input tensor *to this layer*. Number of channels in the input tensor to this block.
config : LayerConfig config : LayerConfig
A Pydantic configuration object for the desired block (e.g., an A configuration object for the desired block type. The ``name``
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified field selects the block class; remaining fields supply its
by its `name` field. parameters.
Returns Returns
------- -------
Tuple[nn.Module, int, int] Block
A tuple containing: An initialised block module ready to be added to an
- The instantiated `nn.Module` block. ``nn.Sequential`` or ``nn.ModuleList``.
- The number of output channels produced by the block.
- The calculated height of the output produced by the block.
Raises Raises
------ ------
NotImplementedError KeyError
If the `config.name` does not correspond to a known block type. If ``config.name`` does not correspond to a registered block type.
ValueError ValueError
If parameters derived from the config are invalid for the block. If the configuration parameters are invalid for the chosen block.
""" """
return block_registry.build(config, in_channels, input_height) return block_registry.build(config, in_channels, input_height)

View File

@ -1,17 +1,21 @@
"""Defines the Bottleneck component of an Encoder-Decoder architecture. """Bottleneck component for encoder-decoder network architectures.
This module provides the configuration (`BottleneckConfig`) and The bottleneck sits between the encoder (downsampling path) and the decoder
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the (upsampling path) and processes the lowest-resolution, highest-channel feature
bottleneck layer(s) that typically connect the Encoder (downsampling path) and map produced by the encoder.
Decoder (upsampling path) in networks like U-Nets.
The bottleneck processes the lowest-resolution, highest-dimensionality feature This module provides:
map produced by the Encoder. This module offers a configurable option to include
a `SelfAttention` layer within the bottleneck, allowing the model to capture
global temporal context before features are passed to the Decoder.
A factory function `build_bottleneck` constructs the appropriate bottleneck - ``BottleneckConfig`` configuration dataclass describing the number of
module based on the provided configuration. internal channels and an optional sequence of additional layers (currently
only ``SelfAttention`` is supported).
- ``Bottleneck`` the ``torch.nn.Module`` implementation. It first applies a
``VerticalConv`` to collapse the frequency axis to a single bin, optionally
runs one or more additional layers (e.g. self-attention along the time axis),
then repeats the output along the height dimension to restore the original
frequency resolution before passing features to the decoder.
- ``build_bottleneck`` factory function that constructs a ``Bottleneck``
instance from a ``BottleneckConfig`` and the encoder's output dimensions.
""" """
from typing import Annotated, List from typing import Annotated, List
@ -37,42 +41,51 @@ __all__ = [
class Bottleneck(Block): class Bottleneck(Block):
"""Base Bottleneck module for Encoder-Decoder architectures. """Bottleneck module for encoder-decoder architectures.
This implementation represents the simplest bottleneck structure Processes the lowest-resolution feature map that links the encoder and
considered, primarily consisting of a `VerticalConv` layer. This layer decoder. The sequence of operations is:
collapses the frequency dimension (height) to 1, summarizing information
across frequencies at each time step. The output is then repeated along the
height dimension to match the original bottleneck input height before being
passed to the decoder.
This base version does *not* include self-attention. 1. ``VerticalConv`` collapses the frequency axis (height) to a single
bin by applying a convolution whose kernel spans the full height.
2. Optional additional layers (e.g. ``SelfAttention``) applied while
the feature map has height 1, so they operate purely along the time
axis.
3. Height restoration the single-bin output is repeated along the
height axis to restore the original frequency resolution, producing
a tensor that the decoder can accept.
Parameters Parameters
---------- ----------
input_height : int input_height : int
Height (frequency bins) of the input tensor. Must be positive. Height (number of frequency bins) of the input tensor. Must be
positive.
in_channels : int in_channels : int
Number of channels in the input tensor from the encoder. Must be Number of channels in the input tensor from the encoder. Must be
positive. positive.
out_channels : int out_channels : int
Number of output channels. Must be positive. Number of output channels after the bottleneck. Must be positive.
bottleneck_channels : int, optional
Number of internal channels used by the ``VerticalConv`` layer.
Defaults to ``out_channels`` if not provided.
layers : List[torch.nn.Module], optional
Additional modules (e.g. ``SelfAttention``) to apply after the
``VerticalConv`` and before height restoration.
Attributes Attributes
---------- ----------
in_channels : int in_channels : int
Number of input channels accepted by the bottleneck. Number of input channels accepted by the bottleneck.
out_channels : int
Number of output channels produced by the bottleneck.
input_height : int input_height : int
Expected height of the input tensor. Expected height of the input tensor.
channels : int bottleneck_channels : int
Number of output channels. Number of channels used internally by the vertical convolution.
conv_vert : VerticalConv conv_vert : VerticalConv
The vertical convolution layer. The vertical convolution layer.
layers : nn.ModuleList
Raises Additional layers applied after the vertical convolution.
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
""" """
def __init__( def __init__(
@ -83,7 +96,23 @@ class Bottleneck(Block):
bottleneck_channels: int | None = None, bottleneck_channels: int | None = None,
layers: List[torch.nn.Module] | None = None, layers: List[torch.nn.Module] | None = None,
) -> None: ) -> None:
"""Initialize the base Bottleneck layer.""" """Initialise the Bottleneck layer.
Parameters
----------
input_height : int
Height (number of frequency bins) of the input tensor.
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of channels in the output tensor.
bottleneck_channels : int, optional
Number of internal channels for the ``VerticalConv``. Defaults
to ``out_channels``.
layers : List[torch.nn.Module], optional
Additional modules applied after the ``VerticalConv``, such as
a ``SelfAttention`` block.
"""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.input_height = input_height self.input_height = input_height
@ -103,23 +132,24 @@ class Bottleneck(Block):
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input features through the bottleneck. """Process the encoder's bottleneck features.
Applies vertical convolution and repeats the output height. Applies vertical convolution, optional additional layers, then
restores the height dimension by repetition.
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor from the encoder bottleneck, shape Input tensor from the encoder, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`, ``(B, C_in, H_in, W)``. ``C_in`` must match
`H_in` must match `self.input_height`. ``self.in_channels`` and ``H_in`` must match
``self.input_height``.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height Output tensor with shape ``(B, C_out, H_in, W)``. The height
dimension `H_in` is restored via repetition after the vertical ``H_in`` is restored by repeating the single-bin result.
convolution.
""" """
x = self.conv_vert(x) x = self.conv_vert(x)
@ -133,28 +163,22 @@ BottleneckLayerConfig = Annotated[
SelfAttentionConfig, SelfAttentionConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configs usable in Decoder.""" """Type alias for the discriminated union of block configs usable in the Bottleneck."""
class BottleneckConfig(BaseConfig): class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s). """Configuration for the bottleneck component.
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes Attributes
---------- ----------
channels : int channels : int
The number of output channels produced by the main convolutional layer Number of output channels produced by the bottleneck. This value
within the bottleneck. This often matches the number of channels coming is also used as the dimensionality of any optional layers (e.g.
from the last encoder stage, but can be different. Must be positive. self-attention). Must be positive.
This also defines the channel dimensions used within the optional layers : List[BottleneckLayerConfig]
`SelfAttention` layer. Ordered list of additional block configurations to apply after the
self_attention : bool initial ``VerticalConv``. Currently only ``SelfAttentionConfig`` is
If True, includes a `SelfAttention` layer operating on the time supported. Defaults to an empty list (no extra layers).
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
""" """
channels: int channels: int
@ -174,30 +198,37 @@ def build_bottleneck(
in_channels: int, in_channels: int,
config: BottleneckConfig | None = None, config: BottleneckConfig | None = None,
) -> BottleneckProtocol: ) -> BottleneckProtocol:
"""Factory function to build the Bottleneck module from configuration. """Build a ``Bottleneck`` module from configuration.
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based Constructs a ``Bottleneck`` instance whose internal channel count and
on the `config.self_attention` flag. optional extra layers (e.g. self-attention) are controlled by
``config``. If no configuration is provided, the default
``DEFAULT_BOTTLENECK_CONFIG`` is used, which includes a
``SelfAttention`` layer.
Parameters Parameters
---------- ----------
input_height : int input_height : int
Height (frequency bins) of the input tensor. Must be positive. Height (number of frequency bins) of the input tensor from the
encoder. Must be positive.
in_channels : int in_channels : int
Number of channels in the input tensor. Must be positive. Number of channels in the input tensor from the encoder. Must be
positive.
config : BottleneckConfig, optional config : BottleneckConfig, optional
Configuration object specifying the bottleneck channels and whether Configuration specifying the output channel count and any
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None. additional layers. Uses ``DEFAULT_BOTTLENECK_CONFIG`` if ``None``.
Returns Returns
------- -------
nn.Module BottleneckProtocol
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`). An initialised ``Bottleneck`` module.
Raises Raises
------ ------
ValueError AssertionError
If `input_height` or `in_channels` are not positive. If any configured layer changes the height of the feature map
(bottleneck layers must preserve height so that it can be restored
by repetition).
""" """
config = config or DEFAULT_BOTTLENECK_CONFIG config = config or DEFAULT_BOTTLENECK_CONFIG

View File

@ -1,21 +1,21 @@
"""Constructs the Decoder part of an Encoder-Decoder neural network. """Decoder (upsampling path) for the BatDetect2 backbone.
This module defines the configuration structure (`DecoderConfig`) for the layer This module defines ``DecoderConfig`` and the ``Decoder`` ``nn.Module``,
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory together with the ``build_decoder`` factory function.
function (`build_decoder`). Decoders typically form the upsampling path in
architectures like U-Nets, taking bottleneck features
(usually from an `Encoder`) and skip connections to reconstruct
higher-resolution feature maps.
The decoder is built dynamically by stacking neural network blocks based on a In a U-Net-style network the decoder progressively restores the spatial
list of configuration objects provided in `DecoderConfig.layers`. Each config resolution of the feature map back towards the input resolution. At each
object specifies the type of block (e.g., standard convolution, stage it combines the upsampled features with the corresponding skip-connection
coordinate-feature convolution with upsampling) and its parameters. This allows tensor from the encoder (the residual) by element-wise addition before passing
flexible definition of decoder architectures via configuration files. the result to the upsampling block.
The `Decoder`'s `forward` method is designed to accept skip connection tensors The decoder is fully configurable: the type, number, and parameters of the
(`residuals`) from the encoder, merging them with the upsampled feature maps upsampling blocks are described by a ``DecoderConfig`` object containing an
at each stage. ordered list of block configuration objects (see ``batdetect2.models.blocks``
for available block types).
A default configuration ``DEFAULT_DECODER_CONFIG`` is provided and used by
``build_decoder`` when no explicit configuration is supplied.
""" """
from typing import Annotated, List from typing import Annotated, List
@ -51,51 +51,47 @@ DecoderLayerConfig = Annotated[
class DecoderConfig(BaseConfig): class DecoderConfig(BaseConfig):
"""Configuration for the sequence of layers in the Decoder module. """Configuration for the sequential ``Decoder`` module.
Defines the types and parameters of the neural network blocks that
constitute the decoder's upsampling path.
Attributes Attributes
---------- ----------
layers : List[DecoderLayerConfig] layers : List[DecoderLayerConfig]
An ordered list of configuration objects, each defining one layer or Ordered list of block configuration objects defining the decoder's
block in the decoder sequence. Each item must be a valid block upsampling stages (from deepest to shallowest). Each entry
config including a `name` field and necessary parameters like specifies the block type (via its ``name`` field) and any
`out_channels`. Input channels for each layer are inferred sequentially. block-specific parameters such as ``out_channels``. Input channels
The list must contain at least one layer. for each block are inferred automatically from the output of the
previous block. Must contain at least one entry.
""" """
layers: List[DecoderLayerConfig] = Field(min_length=1) layers: List[DecoderLayerConfig] = Field(min_length=1)
class Decoder(nn.Module): class Decoder(nn.Module):
"""Sequential Decoder module composed of configurable upsampling layers. """Sequential decoder module composed of configurable upsampling layers.
Constructs the upsampling path of an encoder-decoder network by stacking Executes a series of upsampling blocks in order, adding the
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`) corresponding encoder skip-connection tensor (residual) to the feature
based on a list of layer modules provided during initialization (typically map before each block. The residuals are consumed in reverse order (from
created by the `build_decoder` factory function). deepest encoder layer to shallowest) to match the spatial resolutions at
each decoder stage.
The `forward` method is designed to integrate skip connection tensors Instances are typically created by ``build_decoder``.
(`residuals`) from the corresponding encoder stages, by adding them
element-wise to the input of each decoder layer before processing.
Attributes Attributes
---------- ----------
in_channels : int in_channels : int
Number of channels expected in the input tensor. Number of channels expected in the input tensor (bottleneck output).
out_channels : int out_channels : int
Number of channels in the final output tensor produced by the last Number of channels in the final output feature map.
layer.
input_height : int input_height : int
Height (frequency bins) expected in the input tensor. Height (frequency bins) of the input tensor.
output_height : int output_height : int
Height (frequency bins) expected in the output tensor. Height (frequency bins) of the output tensor.
layers : nn.ModuleList layers : nn.ModuleList
The sequence of instantiated upscaling layer modules. Sequence of instantiated upsampling block modules.
depth : int depth : int
The number of upscaling layers (depth) in the decoder. Number of upsampling layers.
""" """
def __init__( def __init__(
@ -106,23 +102,24 @@ class Decoder(nn.Module):
output_height: int, output_height: int,
layers: List[nn.Module], layers: List[nn.Module],
): ):
"""Initialize the Decoder module. """Initialise the Decoder module.
Note: This constructor is typically called internally by the This constructor is typically called by the ``build_decoder``
`build_decoder` factory function. factory function.
Parameters Parameters
---------- ----------
in_channels : int
Number of channels in the input tensor (bottleneck output).
out_channels : int out_channels : int
Number of channels produced by the final layer. Number of channels produced by the final layer.
input_height : int input_height : int
Expected height of the input tensor (bottleneck). Height of the input tensor (bottleneck output height).
in_channels : int output_height : int
Expected number of channels in the input tensor (bottleneck). Height of the output tensor after all layers have been applied.
layers : List[nn.Module] layers : List[nn.Module]
A list of pre-instantiated upscaling layer modules (e.g., Pre-built upsampling block modules in execution order (deepest
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired stage first).
sequence (from bottleneck towards output resolution).
""" """
super().__init__() super().__init__()
@ -140,43 +137,35 @@ class Decoder(nn.Module):
x: torch.Tensor, x: torch.Tensor,
residuals: List[torch.Tensor], residuals: List[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
"""Pass input through decoder layers, incorporating skip connections. """Pass input through all decoder layers, incorporating skip connections.
Processes the input tensor `x` sequentially through the upscaling At each stage the corresponding residual tensor is added
layers. At each stage, the corresponding skip connection tensor from element-wise to ``x`` before it is passed to the upsampling block.
the `residuals` list is added element-wise to the input before passing Residuals are consumed in reverse order the last element of
it to the upscaling block. ``residuals`` (the output of the shallowest encoder layer) is added
at the first decoder stage, and the first element (output of the
deepest encoder layer) is added at the last decoder stage.
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor from the previous stage (e.g., encoder bottleneck). Bottleneck feature map, shape ``(B, C_in, H_in, W)``.
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
`self.in_channels`.
residuals : List[torch.Tensor] residuals : List[torch.Tensor]
List containing the skip connection tensors from the corresponding Skip-connection tensors from the encoder, ordered from shallowest
encoder stages. Should be ordered from the deepest encoder layer (index 0) to deepest (index -1). Must contain exactly
output (lowest resolution) to the shallowest (highest resolution ``self.depth`` tensors. Each tensor must have the same spatial
near input). The number of tensors in this list must match the dimensions and channel count as ``x`` at the corresponding
number of decoder layers (`self.depth`). Each residual tensor's decoder stage.
channel count must be compatible with the input tensor `x` for
element-wise addition (or concatenation if the blocks were designed
for it).
Returns Returns
------- -------
torch.Tensor torch.Tensor
The final decoded feature map tensor produced by the last layer. Decoded feature map, shape ``(B, C_out, H_out, W)``.
Shape `(B, C_out, H_out, W_out)`.
Raises Raises
------ ------
ValueError ValueError
If the number of `residuals` provided does not match the decoder If the number of ``residuals`` does not equal ``self.depth``.
depth.
RuntimeError
If shapes mismatch during skip connection addition or layer
processing.
""" """
if len(residuals) != len(self.layers): if len(residuals) != len(self.layers):
raise ValueError( raise ValueError(
@ -203,11 +192,17 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
), ),
], ],
) )
"""A default configuration for the Decoder's *layer sequence*. """Default decoder configuration used in standard BatDetect2 models.
Specifies an architecture often used in BatDetect2, consisting of three Mirrors ``DEFAULT_ENCODER_CONFIG`` in reverse. Assumes the bottleneck
frequency coordinate-aware upsampling blocks followed by a standard output has 256 channels and height 16, and produces:
convolutional block.
- Stage 1 (``FreqCoordConvUp``): 64 channels, height 32.
- Stage 2 (``FreqCoordConvUp``): 32 channels, height 64.
- Stage 3 (``LayerGroup``):
- ``FreqCoordConvUp``: 32 channels, height 128.
- ``ConvBlock``: 32 channels, height 128 (final feature map).
""" """
@ -216,40 +211,36 @@ def build_decoder(
input_height: int, input_height: int,
config: DecoderConfig | None = None, config: DecoderConfig | None = None,
) -> Decoder: ) -> Decoder:
"""Factory function to build a Decoder instance from configuration. """Build a ``Decoder`` from configuration.
Constructs a sequential `Decoder` module based on the layer sequence Constructs a sequential ``Decoder`` by iterating over the block
defined in a `DecoderConfig` object and the provided input dimensions configurations in ``config.layers``, building each block with
(bottleneck channels and height). If no config is provided, uses the ``build_layer``, and tracking the channel count and feature-map height
default layer sequence from `DEFAULT_DECODER_CONFIG`. as they change through the sequence.
It iteratively builds the layers using the unified `build_layer_from_config`
factory (from `.blocks`), tracking the changing number of channels and
feature map height required for each subsequent layer.
Parameters Parameters
---------- ----------
in_channels : int in_channels : int
The number of channels in the input tensor to the decoder. Must be > 0. Number of channels in the input tensor (bottleneck output). Must
be positive.
input_height : int input_height : int
The height (frequency bins) of the input tensor to the decoder. Must be Height (number of frequency bins) of the input tensor. Must be
> 0. positive.
config : DecoderConfig, optional config : DecoderConfig, optional
The configuration object detailing the sequence of layers and their Configuration specifying the layer sequence. Defaults to
parameters. If None, `DEFAULT_DECODER_CONFIG` is used. ``DEFAULT_DECODER_CONFIG`` if not provided.
Returns Returns
------- -------
Decoder Decoder
An initialized `Decoder` module. An initialised ``Decoder`` module.
Raises Raises
------ ------
ValueError ValueError
If `in_channels` or `input_height` are not positive, or if the layer If ``in_channels`` or ``input_height`` are not positive.
configuration is invalid (e.g., empty list, unknown `name`). KeyError
NotImplementedError If a layer configuration specifies an unknown block type.
If `build_layer_from_config` encounters an unknown `name`.
""" """
config = config or DEFAULT_DECODER_CONFIG config = config or DEFAULT_DECODER_CONFIG

View File

@ -1,17 +1,20 @@
"""Assembles the complete BatDetect2 Detection Model. """Assembles the complete BatDetect2 detection model.
This module defines the concrete `Detector` class, which implements the This module defines the ``Detector`` class, which combines a backbone
`DetectionModel` interface defined in `.types`. It combines a feature feature extractor with prediction heads for detection, classification, and
extraction backbone with specific prediction heads to create the end-to-end bounding-box size regression.
neural network used for detecting bat calls, predicting their size, and
classifying them.
The primary components are: Components
- `Detector`: The `torch.nn.Module` subclass representing the complete model. ----------
- ``Detector`` the ``torch.nn.Module`` that wires together a backbone
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
produce a ``ModelOutput`` tuple from an input spectrogram.
- ``build_detector`` factory function that builds a ready-to-use
``Detector`` from a backbone configuration and a target class count.
This module focuses purely on the neural network architecture definition. The Note that ``Detector`` operates purely on spectrogram tensors; raw audio
logic for preprocessing inputs and postprocessing/decoding outputs resides in preprocessing and output postprocessing are handled by
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. ``batdetect2.preprocess`` and ``batdetect2.postprocess`` respectively.
""" """
import torch import torch
@ -32,25 +35,30 @@ __all__ = [
class Detector(DetectionModel): class Detector(DetectionModel):
"""Concrete implementation of the BatDetect2 Detection Model. """Complete BatDetect2 detection and classification model.
Assembles a complete detection and classification model by combining a Combines a backbone feature extractor with two prediction heads:
feature extraction backbone network with specific prediction heads for
detection probability, bounding box size regression, and class - ``ClassifierHead``: predicts per-class probabilities at each
probabilities. timefrequency location.
- ``BBoxHead``: predicts call duration and bandwidth at each location.
The detection probability map is derived from the class probabilities by
summing across the class dimension (i.e. the probability that *any* class
is present), rather than from a separate detection head.
Instances are typically created via ``build_detector``.
Attributes Attributes
---------- ----------
backbone : BackboneModel backbone : BackboneModel
The feature extraction backbone network module. The feature extraction backbone.
num_classes : int num_classes : int
The number of specific target classes the model predicts (derived from Number of target classes (inferred from the classifier head).
the `classifier_head`).
classifier_head : ClassifierHead classifier_head : ClassifierHead
The prediction head responsible for generating class probabilities. Produces per-class probability maps from backbone features.
bbox_head : BBoxHead bbox_head : BBoxHead
The prediction head responsible for generating bounding box size Produces duration and bandwidth predictions from backbone features.
predictions.
""" """
backbone: BackboneModel backbone: BackboneModel
@ -61,26 +69,21 @@ class Detector(DetectionModel):
classifier_head: ClassifierHead, classifier_head: ClassifierHead,
bbox_head: BBoxHead, bbox_head: BBoxHead,
): ):
"""Initialize the Detector model. """Initialise the Detector model.
Note: Instances are typically created using the `build_detector` This constructor is typically called by the ``build_detector``
factory function. factory function.
Parameters Parameters
---------- ----------
backbone : BackboneModel backbone : BackboneModel
An initialized feature extraction backbone module (e.g., built by An initialised backbone module (e.g. built by
`build_backbone` from the `.backbone` module). ``build_backbone``).
classifier_head : ClassifierHead classifier_head : ClassifierHead
An initialized classification head module. The number of classes An initialised classification head. The ``num_classes``
is inferred from this head. attribute is read from this head.
bbox_head : BBoxHead bbox_head : BBoxHead
An initialized bounding box size prediction head module. An initialised bounding-box size prediction head.
Raises
------
TypeError
If the provided modules are not of the expected types.
""" """
super().__init__() super().__init__()
@ -90,31 +93,34 @@ class Detector(DetectionModel):
self.bbox_head = bbox_head self.bbox_head = bbox_head
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Perform the forward pass of the complete detection model. """Run the complete detection model on an input spectrogram.
Processes the input spectrogram through the backbone to extract Passes the spectrogram through the backbone to produce a feature
features, then passes these features through the separate prediction map, then applies the classifier and bounding-box heads. The
heads to generate detection probabilities, class probabilities, and detection probability map is derived by summing the per-class
size predictions. probability maps across the class dimension; no separate detection
head is used.
Parameters Parameters
---------- ----------
spec : torch.Tensor spec : torch.Tensor
Input spectrogram tensor, typically with shape Input spectrogram tensor, shape
`(batch_size, input_channels, frequency_bins, time_bins)`. The ``(batch_size, channels, frequency_bins, time_bins)``.
shape must be compatible with the `self.backbone` input
requirements.
Returns Returns
------- -------
ModelOutput ModelOutput
A NamedTuple containing the four output tensors: A named tuple with four fields:
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`. - ``detection_probs`` ``(B, 1, H, W)`` probability that a
- `class_probs`: Class probabilities (excluding background) call of any class is present at each location. Derived by
`(B, num_classes, H, W)`. summing ``class_probs`` over the class dimension.
- `features`: Output feature map from the backbone - ``size_preds`` ``(B, 2, H, W)`` scaled duration (channel
`(B, C_out, H, W)`. 0) and bandwidth (channel 1) predictions at each location.
- ``class_probs`` ``(B, num_classes, H, W)`` per-class
probabilities at each location.
- ``features`` ``(B, C_out, H, W)`` raw backbone feature
map.
""" """
features = self.backbone(spec) features = self.backbone(spec)
classification = self.classifier_head(features) classification = self.classifier_head(features)
@ -131,30 +137,33 @@ class Detector(DetectionModel):
def build_detector( def build_detector(
num_classes: int, config: BackboneConfig | None = None num_classes: int, config: BackboneConfig | None = None
) -> DetectionModel: ) -> DetectionModel:
"""Build the complete BatDetect2 detection model. """Build a complete BatDetect2 detection model.
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
and a ``BBoxHead`` sized to the backbone's output channel count, and
returns them wrapped in a ``Detector``.
Parameters Parameters
---------- ----------
num_classes : int num_classes : int
The number of specific target classes the model should predict Number of target bat species or call types to predict. Must be
(required for the `ClassifierHead`). Must be positive. positive.
config : BackboneConfig, optional config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone Backbone architecture configuration. Defaults to
(encoder, bottleneck, decoder). If None, default configurations defined ``UNetBackboneConfig()`` (the standard BatDetect2 architecture) if
within the respective builder functions (`build_encoder`, etc.) will be not provided.
used to construct a default backbone architecture.
Returns Returns
------- -------
DetectionModel DetectionModel
An initialized `Detector` model instance. An initialised ``Detector`` instance ready for training or
inference.
Raises Raises
------ ------
ValueError ValueError
If `num_classes` is not positive, or if errors occur during the If ``num_classes`` is not positive, or if the backbone
construction of the backbone or detector components (e.g., incompatible configuration is invalid.
configurations, invalid parameters).
""" """
config = config or UNetBackboneConfig() config = config or UNetBackboneConfig()

View File

@ -1,23 +1,24 @@
"""Constructs the Encoder part of a configurable neural network backbone. """Encoder (downsampling path) for the BatDetect2 backbone.
This module defines the configuration structure (`EncoderConfig`) and provides This module defines ``EncoderConfig`` and the ``Encoder`` ``nn.Module``,
the `Encoder` class (an `nn.Module`) along with a factory function together with the ``build_encoder`` factory function.
(`build_encoder`) to create sequential encoders. Encoders typically form the
downsampling path in architectures like U-Nets, processing input feature maps
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
representations (bottleneck features).
The encoder is built dynamically by stacking neural network blocks based on a In a U-Net-style network the encoder progressively reduces the spatial
list of configuration objects provided in `EncoderConfig.layers`. Each resolution of the spectrogram whilst increasing the number of feature
configuration object specifies the type of block (e.g., standard convolution, channels. Each layer in the encoder produces a feature map that is stored
coordinate-feature convolution with downsampling) and its parameters for use as a skip connection in the corresponding decoder layer.
(e.g., output channels). This allows for flexible definition of encoder
architectures via configuration files.
The `Encoder`'s `forward` method returns outputs from all intermediate layers, The encoder is fully configurable: the type, number, and parameters of the
suitable for skip connections, while the `encode` method returns only the final downsampling blocks are described by an ``EncoderConfig`` object containing
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also an ordered list of block configuration objects (see ``batdetect2.models.blocks``
provided. for available block types).
``Encoder.forward`` returns the outputs of *all* encoder layers as a list,
so that skip connections are available to the decoder.
``Encoder.encode`` returns only the final output (the input to the bottleneck).
A default configuration ``DEFAULT_ENCODER_CONFIG`` is provided and used by
``build_encoder`` when no explicit configuration is supplied.
""" """
from typing import Annotated, List from typing import Annotated, List
@ -53,35 +54,32 @@ EncoderLayerConfig = Annotated[
class EncoderConfig(BaseConfig): class EncoderConfig(BaseConfig):
"""Configuration for building the sequential Encoder module. """Configuration for the sequential ``Encoder`` module.
Defines the sequence of neural network blocks that constitute the encoder
(downsampling path).
Attributes Attributes
---------- ----------
layers : List[EncoderLayerConfig] layers : List[EncoderLayerConfig]
An ordered list of configuration objects, each defining one layer or Ordered list of block configuration objects defining the encoder's
block in the encoder sequence. Each item must be a valid block config downsampling stages. Each entry specifies the block type (via its
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`, ``name`` field) and any block-specific parameters such as
`StandardConvDownConfig`) including a `name` field and necessary ``out_channels``. Input channels for each block are inferred
parameters like `out_channels`. Input channels for each layer are automatically from the output of the previous block. Must contain
inferred sequentially. The list must contain at least one layer. at least one entry.
""" """
layers: List[EncoderLayerConfig] = Field(min_length=1) layers: List[EncoderLayerConfig] = Field(min_length=1)
class Encoder(nn.Module): class Encoder(nn.Module):
"""Sequential Encoder module composed of configurable downscaling layers. """Sequential encoder module composed of configurable downsampling layers.
Constructs the downsampling path of an encoder-decoder network by stacking Executes a series of downsampling blocks in order, storing the output of
multiple downscaling blocks. each block so that it can be passed as a skip connection to the
corresponding decoder layer.
The `forward` method executes the sequence and returns the output feature ``forward`` returns the outputs of *all* layers (useful when skip
map from *each* downscaling stage, facilitating the implementation of skip connections are needed). ``encode`` returns only the final output
connections in U-Net-like architectures. The `encode` method returns only (the input to the bottleneck).
the final output tensor (bottleneck features).
Attributes Attributes
---------- ----------
@ -89,14 +87,14 @@ class Encoder(nn.Module):
Number of channels expected in the input tensor. Number of channels expected in the input tensor.
input_height : int input_height : int
Height (frequency bins) expected in the input tensor. Height (frequency bins) expected in the input tensor.
output_channels : int out_channels : int
Number of channels in the final output tensor (bottleneck). Number of channels in the final output tensor (bottleneck input).
output_height : int output_height : int
Height (frequency bins) expected in the output tensor. Height (frequency bins) of the final output tensor.
layers : nn.ModuleList layers : nn.ModuleList
The sequence of instantiated downscaling layer modules. Sequence of instantiated downsampling block modules.
depth : int depth : int
The number of downscaling layers in the encoder. Number of downsampling layers.
""" """
def __init__( def __init__(
@ -107,23 +105,22 @@ class Encoder(nn.Module):
input_height: int = 128, input_height: int = 128,
in_channels: int = 1, in_channels: int = 1,
): ):
"""Initialize the Encoder module. """Initialise the Encoder module.
Note: This constructor is typically called internally by the This constructor is typically called by the ``build_encoder`` factory
`build_encoder` factory function, which prepares the `layers` list. function, which takes care of building the ``layers`` list from a
configuration object.
Parameters Parameters
---------- ----------
output_channels : int output_channels : int
Number of channels produced by the final layer. Number of channels produced by the final layer.
output_height : int output_height : int
The expected height of the output tensor. Height of the output tensor after all layers have been applied.
layers : List[nn.Module] layers : List[nn.Module]
A list of pre-instantiated downscaling layer modules (e.g., Pre-built downsampling block modules in execution order.
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
sequence.
input_height : int, default=128 input_height : int, default=128
Expected height of the input tensor. Expected height of the input tensor (frequency bins).
in_channels : int, default=1 in_channels : int, default=1
Expected number of channels in the input tensor. Expected number of channels in the input tensor.
""" """
@ -138,29 +135,30 @@ class Encoder(nn.Module):
self.depth = len(self.layers) self.depth = len(self.layers)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Pass input through encoder layers, returns all intermediate outputs. """Pass input through all encoder layers and return every output.
This method is typically used when the Encoder is part of a U-Net or Used when skip connections are needed (e.g. in a U-Net decoder).
similar architecture requiring skip connections.
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
`self.in_channels` and `H_in` must match `self.input_height`. ``C_in`` must match ``self.in_channels`` and ``H_in`` must
match ``self.input_height``.
Returns Returns
------- -------
List[torch.Tensor] List[torch.Tensor]
A list containing the output tensors from *each* downscaling layer Output tensors from every layer in order.
in the sequence. `outputs[0]` is the output of the first layer, ``outputs[0]`` is the output of the first (shallowest) layer;
`outputs[-1]` is the final output (bottleneck) of the encoder. ``outputs[-1]`` is the output of the last (deepest) layer,
which serves as the input to the bottleneck.
Raises Raises
------ ------
ValueError ValueError
If input tensor channel count or height does not match expected If the input channel count or height does not match the
values. expected values.
""" """
if x.shape[1] != self.in_channels: if x.shape[1] != self.in_channels:
raise ValueError( raise ValueError(
@ -183,28 +181,29 @@ class Encoder(nn.Module):
return outputs return outputs
def encode(self, x: torch.Tensor) -> torch.Tensor: def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Pass input through encoder layers, returning only the final output. """Pass input through all encoder layers and return only the final output.
This method provides access to the bottleneck features produced after Use this when skip connections are not needed and you only require
the last downscaling layer. the bottleneck feature map.
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
`in_channels` and `input_height`. Must satisfy the same shape requirements as ``forward``.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The final output tensor (bottleneck features) from the last layer Output of the last encoder layer, shape
of the encoder. Shape `(B, C_out, H_out, W_out)`. ``(B, C_out, H_out, W)``, where ``C_out`` is
``self.out_channels`` and ``H_out`` is ``self.output_height``.
Raises Raises
------ ------
ValueError ValueError
If input tensor channel count or height does not match expected If the input channel count or height does not match the
values. expected values.
""" """
if x.shape[1] != self.in_channels: if x.shape[1] != self.in_channels:
raise ValueError( raise ValueError(
@ -236,14 +235,17 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
), ),
], ],
) )
"""Default configuration for the Encoder. """Default encoder configuration used in standard BatDetect2 models.
Specifies an architecture typically used in BatDetect2: Assumes a 1-channel input with 128 frequency bins and produces the
- Input: 1 channel, 128 frequency bins. following feature maps:
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
- Layer 2: FreqCoordConvDown -> 64 channels, H=32 - Stage 1 (``FreqCoordConvDown``): 32 channels, height 64.
- Layer 3: FreqCoordConvDown -> 128 channels, H=16 - Stage 2 (``FreqCoordConvDown``): 64 channels, height 32.
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck) - Stage 3 (``LayerGroup``):
- ``FreqCoordConvDown``: 128 channels, height 16.
- ``ConvBlock``: 256 channels, height 16 (bottleneck input).
""" """
@ -252,42 +254,38 @@ def build_encoder(
input_height: int, input_height: int,
config: EncoderConfig | None = None, config: EncoderConfig | None = None,
) -> Encoder: ) -> Encoder:
"""Factory function to build an Encoder instance from configuration. """Build an ``Encoder`` from configuration.
Constructs a sequential `Encoder` module based on the layer sequence Constructs a sequential ``Encoder`` by iterating over the block
defined in an `EncoderConfig` object and the provided input dimensions. configurations in ``config.layers``, building each block with
If no config is provided, uses the default layer sequence from ``build_layer``, and tracking the channel count and feature-map height
`DEFAULT_ENCODER_CONFIG`. as they change through the sequence.
It iteratively builds the layers using the unified
`build_layer_from_config` factory (from `.blocks`), tracking the changing
number of channels and feature map height required for each subsequent
layer, especially for coordinate- aware blocks.
Parameters Parameters
---------- ----------
in_channels : int in_channels : int
The number of channels expected in the input tensor to the encoder. Number of channels in the input spectrogram tensor. Must be
Must be > 0. positive.
input_height : int input_height : int
The height (frequency bins) expected in the input tensor. Must be > 0. Height (number of frequency bins) of the input spectrogram.
Crucial for initializing coordinate-aware layers correctly. Must be positive and should be divisible by
``2 ** (number of downsampling stages)`` to avoid size mismatches
later in the network.
config : EncoderConfig, optional config : EncoderConfig, optional
The configuration object detailing the sequence of layers and their Configuration specifying the layer sequence. Defaults to
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used. ``DEFAULT_ENCODER_CONFIG`` if not provided.
Returns Returns
------- -------
Encoder Encoder
An initialized `Encoder` module. An initialised ``Encoder`` module.
Raises Raises
------ ------
ValueError ValueError
If `in_channels` or `input_height` are not positive, or if the layer If ``in_channels`` or ``input_height`` are not positive.
configuration is invalid (e.g., empty list, unknown `name`). KeyError
NotImplementedError If a layer configuration specifies an unknown block type.
If `build_layer_from_config` encounters an unknown `name`.
""" """
if in_channels <= 0 or input_height <= 0: if in_channels <= 0 or input_height <= 0:
raise ValueError("in_channels and input_height must be positive.") raise ValueError("in_channels and input_height must be positive.")

View File

@ -1,20 +1,19 @@
"""Prediction Head modules for BatDetect2 models. """Prediction heads attached to the backbone feature map.
This module defines simple `torch.nn.Module` subclasses that serve as Each head is a lightweight ``torch.nn.Module`` that applies a 1×1
prediction heads, typically attached to the output feature map of a backbone convolution to map backbone feature channels to one specific type of
network output required by BatDetect2:
Each head is responsible for generating one specific type of output required - ``DetectorHead``: single-channel detection probability heatmap (sigmoid
by the BatDetect2 task: activation).
- `DetectorHead`: Predicts the probability of sound event presence. - ``ClassifierHead``: multi-class probability map over the target bat
- `ClassifierHead`: Predicts the probability distribution over target classes. species / call types (softmax activation).
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding - ``BBoxHead``: two-channel map of predicted call duration (time axis) and
box. bandwidth (frequency axis) at each location (no activation; raw
regression output).
These heads use 1x1 convolutions to map the backbone feature channels All three heads share the same input feature map produced by the backbone,
to the desired number of output channels for each prediction task at each so they can be evaluated in parallel in a single forward pass.
spatial location, followed by an appropriate activation function (e.g., sigmoid
for detection, softmax for classification, none for size regression).
""" """
import torch import torch
@ -28,42 +27,35 @@ __all__ = [
class ClassifierHead(nn.Module): class ClassifierHead(nn.Module):
"""Prediction head for multi-class classification probabilities. """Prediction head for species / call-type classification probabilities.
Takes an input feature map and produces a probability map where each Takes a backbone feature map and produces a probability map where each
channel corresponds to a specific target class. It uses a 1x1 convolution channel corresponds to a target class. Internally the 1×1 convolution
to map input channels to `num_classes + 1` outputs (one for each target maps ``in_channels`` to ``num_classes + 1`` logits (the extra channel
class plus an assumed background/generic class), applies softmax across the represents a generic background / unknown category); a softmax is then
channels, and returns the probabilities for the specific target classes applied across the channel dimension and the background channel is
(excluding the last background/generic channel). discarded before returning.
Parameters Parameters
---------- ----------
num_classes : int num_classes : int
The number of specific target classes the model should predict Number of target classes (bat species or call types) to predict,
(excluding any background or generic category). Must be positive. excluding the background category. Must be positive.
in_channels : int in_channels : int
Number of channels in the input feature map tensor from the backbone. Number of channels in the backbone feature map. Must be positive.
Must be positive.
Attributes Attributes
---------- ----------
num_classes : int num_classes : int
Number of specific output classes. Number of specific output classes (background excluded).
in_channels : int in_channels : int
Number of input channels expected. Number of input channels expected.
classifier : nn.Conv2d classifier : nn.Conv2d
The 1x1 convolutional layer used for prediction. 1×1 convolution with ``num_classes + 1`` output channels.
Output channels = num_classes + 1.
Raises
------
ValueError
If `num_classes` or `in_channels` are not positive.
""" """
def __init__(self, num_classes: int, in_channels: int): def __init__(self, num_classes: int, in_channels: int):
"""Initialize the ClassifierHead.""" """Initialise the ClassifierHead."""
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
@ -76,20 +68,20 @@ class ClassifierHead(nn.Module):
) )
def forward(self, features: torch.Tensor) -> torch.Tensor: def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute class probabilities from input features. """Compute per-class probabilities from backbone features.
Parameters Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
Input feature map tensor from the backbone, typically with shape Backbone feature map, shape ``(B, C_in, H, W)``.
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Class probability map tensor with shape `(B, num_classes, H, W)`. Class probability map, shape ``(B, num_classes, H, W)``.
Contains probabilities for the specific target classes after Values are softmax probabilities in the range [0, 1] and
softmax, excluding the implicit background/generic class channel. sum to less than 1 per location (the background probability
is discarded).
""" """
logits = self.classifier(features) logits = self.classifier(features)
probs = torch.softmax(logits, dim=1) probs = torch.softmax(logits, dim=1)
@ -97,36 +89,30 @@ class ClassifierHead(nn.Module):
class DetectorHead(nn.Module): class DetectorHead(nn.Module):
"""Prediction head for sound event detection probability. """Prediction head for detection probability (is a call present here?).
Takes an input feature map and produces a single-channel heatmap where Produces a single-channel heatmap where each value indicates the
each value represents the probability ([0, 1]) of a relevant sound event probability ([0, 1]) that a bat call of *any* species is present at
(of any class) being present at that spatial location. that timefrequency location in the spectrogram.
Uses a 1x1 convolution to map input channels to 1 output channel, followed Applies a 1×1 convolution mapping ``in_channels`` 1, followed by
by a sigmoid activation function. sigmoid activation.
Parameters Parameters
---------- ----------
in_channels : int in_channels : int
Number of channels in the input feature map tensor from the backbone. Number of channels in the backbone feature map. Must be positive.
Must be positive.
Attributes Attributes
---------- ----------
in_channels : int in_channels : int
Number of input channels expected. Number of input channels expected.
detector : nn.Conv2d detector : nn.Conv2d
The 1x1 convolutional layer mapping to a single output channel. 1×1 convolution with a single output channel.
Raises
------
ValueError
If `in_channels` is not positive.
""" """
def __init__(self, in_channels: int): def __init__(self, in_channels: int):
"""Initialize the DetectorHead.""" """Initialise the DetectorHead."""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -138,62 +124,49 @@ class DetectorHead(nn.Module):
) )
def forward(self, features: torch.Tensor) -> torch.Tensor: def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute detection probabilities from input features. """Compute detection probabilities from backbone features.
Parameters Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
Input feature map tensor from the backbone, typically with shape Backbone feature map, shape ``(B, C_in, H, W)``.
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Detection probability heatmap tensor with shape `(B, 1, H, W)`. Detection probability heatmap, shape ``(B, 1, H, W)``.
Values are in the range [0, 1] due to the sigmoid activation. Values are in the range [0, 1].
Raises
------
RuntimeError
If input channel count does not match `self.in_channels`.
""" """
return torch.sigmoid(self.detector(features)) return torch.sigmoid(self.detector(features))
class BBoxHead(nn.Module): class BBoxHead(nn.Module):
"""Prediction head for bounding box size dimensions. """Prediction head for bounding box size (duration and bandwidth).
Takes an input feature map and produces a two-channel map where each Produces a two-channel map where channel 0 predicts the scaled duration
channel represents a predicted size dimension (typically width/duration and (time-axis extent) and channel 1 predicts the scaled bandwidth
height/bandwidth) for a potential sound event at that spatial location. (frequency-axis extent) of the call at each spectrogram location.
Uses a 1x1 convolution to map input channels to 2 output channels. No Applies a 1×1 convolution mapping ``in_channels`` 2 with no
activation function is typically applied, as size prediction is often activation function (raw regression output). The predicted values are
treated as a direct regression task. The output values usually represent in a scaled space and must be converted to real units (seconds and Hz)
*scaled* dimensions that need to be un-scaled during postprocessing. during postprocessing.
Parameters Parameters
---------- ----------
in_channels : int in_channels : int
Number of channels in the input feature map tensor from the backbone. Number of channels in the backbone feature map. Must be positive.
Must be positive.
Attributes Attributes
---------- ----------
in_channels : int in_channels : int
Number of input channels expected. Number of input channels expected.
bbox : nn.Conv2d bbox : nn.Conv2d
The 1x1 convolutional layer mapping to 2 output channels 1×1 convolution with 2 output channels (duration, bandwidth).
(width, height).
Raises
------
ValueError
If `in_channels` is not positive.
""" """
def __init__(self, in_channels: int): def __init__(self, in_channels: int):
"""Initialize the BBoxHead.""" """Initialise the BBoxHead."""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -205,19 +178,19 @@ class BBoxHead(nn.Module):
) )
def forward(self, features: torch.Tensor) -> torch.Tensor: def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute predicted bounding box dimensions from input features. """Predict call duration and bandwidth from backbone features.
Parameters Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
Input feature map tensor from the backbone, typically with shape Backbone feature map, shape ``(B, C_in, H, W)``.
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually Size prediction tensor, shape ``(B, 2, H, W)``. Channel 0 is
represents scaled width, Channel 1 scaled height. These values the predicted scaled duration; channel 1 is the predicted
need to be un-scaled during postprocessing. scaled bandwidth. Values must be rescaled to real units during
postprocessing.
""" """
return self.bbox(features) return self.bbox(features)

View File

@ -1,14 +1,4 @@
"""Tests for backbone configuration loading and the backbone registry. """Tests for backbone configuration loading and the backbone registry."""
Covers:
- UNetBackboneConfig default construction and field values.
- build_backbone with default and explicit configs.
- load_backbone_config loading from a YAML file.
- load_backbone_config with a nested field path.
- load_backbone_config round-trip: YAML config build_backbone.
- Registry registration and dispatch for UNetBackbone.
- BackboneConfig discriminated union validation.
"""
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
@ -25,10 +15,6 @@ from batdetect2.models.backbones import (
) )
from batdetect2.typing.models import BackboneModel from batdetect2.typing.models import BackboneModel
# ---------------------------------------------------------------------------
# UNetBackboneConfig
# ---------------------------------------------------------------------------
def test_unet_backbone_config_defaults(): def test_unet_backbone_config_defaults():
"""Default config has expected field values.""" """Default config has expected field values."""
@ -57,11 +43,6 @@ def test_unet_backbone_config_extra_fields_ignored():
assert not hasattr(config, "unknown_field") assert not hasattr(config, "unknown_field")
# ---------------------------------------------------------------------------
# build_backbone
# ---------------------------------------------------------------------------
def test_build_backbone_default(): def test_build_backbone_default():
"""Building with no config uses UNetBackbone defaults.""" """Building with no config uses UNetBackbone defaults."""
backbone = build_backbone() backbone = build_backbone()
@ -83,15 +64,9 @@ def test_build_backbone_custom_config():
def test_build_backbone_returns_backbone_model(): def test_build_backbone_returns_backbone_model():
"""build_backbone always returns a BackboneModel instance.""" """build_backbone always returns a BackboneModel instance."""
backbone = build_backbone() backbone = build_backbone()
assert isinstance(backbone, BackboneModel) assert isinstance(backbone, BackboneModel)
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
def test_registry_has_unet_backbone(): def test_registry_has_unet_backbone():
"""The backbone registry has UNetBackbone registered.""" """The backbone registry has UNetBackbone registered."""
config_types = backbone_registry.get_config_types() config_types = backbone_registry.get_config_types()
@ -124,11 +99,6 @@ def test_registry_build_unknown_name_raises():
backbone_registry.build(FakeConfig()) # type: ignore[arg-type] backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# BackboneConfig discriminated union
# ---------------------------------------------------------------------------
def test_backbone_config_validates_unet_from_dict(): def test_backbone_config_validates_unet_from_dict():
"""BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name.""" """BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name."""
from pydantic import TypeAdapter from pydantic import TypeAdapter
@ -151,11 +121,6 @@ def test_backbone_config_invalid_name_raises():
adapter.validate_python({"name": "NonExistentBackbone"}) adapter.validate_python({"name": "NonExistentBackbone"})
# ---------------------------------------------------------------------------
# load_backbone_config
# ---------------------------------------------------------------------------
def test_load_backbone_config_from_yaml( def test_load_backbone_config_from_yaml(
create_temp_yaml: Callable[[str], Path], create_temp_yaml: Callable[[str], Path],
): ):
@ -218,11 +183,6 @@ deprecated_field: 99
assert config.input_height == 128 assert config.input_height == 128
# ---------------------------------------------------------------------------
# Round-trip: YAML → config → build_backbone
# ---------------------------------------------------------------------------
def test_round_trip_yaml_to_build_backbone( def test_round_trip_yaml_to_build_backbone(
create_temp_yaml: Callable[[str], Path], create_temp_yaml: Callable[[str], Path],
): ):