mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Update model docstrings
This commit is contained in:
parent
4207661da4
commit
ef3348d651
@ -1,29 +1,29 @@
|
||||
"""Defines and builds the neural network models used in BatDetect2.
|
||||
"""Neural network model definitions and builders for BatDetect2.
|
||||
|
||||
This package (`batdetect2.models`) contains the PyTorch implementations of the
|
||||
deep neural network architectures used for detecting and classifying bat calls
|
||||
from spectrograms. It provides modular components and configuration-driven
|
||||
assembly, allowing for experimentation and use of different architectural
|
||||
variants.
|
||||
This package contains the PyTorch implementations of the deep neural network
|
||||
architectures used to detect and classify bat echolocation calls in
|
||||
spectrograms. Components are designed to be combined through configuration
|
||||
objects, making it easy to experiment with different architectures.
|
||||
|
||||
Key Submodules:
|
||||
- `.types`: Defines core data structures (`ModelOutput`) and abstract base
|
||||
classes (`BackboneModel`, `DetectionModel`) establishing interfaces.
|
||||
- `.blocks`: Provides reusable neural network building blocks.
|
||||
- `.encoder`: Defines and builds the downsampling path (encoder) of the network.
|
||||
- `.bottleneck`: Defines and builds the central bottleneck component.
|
||||
- `.decoder`: Defines and builds the upsampling path (decoder) of the network.
|
||||
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete
|
||||
feature extraction backbone (e.g., a U-Net like structure).
|
||||
- `.heads`: Defines simple prediction heads (detection, classification, size)
|
||||
that attach to the backbone features.
|
||||
- `.detectors`: Assembles the backbone and prediction heads into the final,
|
||||
end-to-end `Detector` model.
|
||||
Key submodules
|
||||
--------------
|
||||
- ``blocks``: Reusable convolutional building blocks (downsampling,
|
||||
upsampling, attention, coord-conv variants).
|
||||
- ``encoder``: The downsampling path; reduces spatial resolution whilst
|
||||
extracting increasingly abstract features.
|
||||
- ``bottleneck``: The central component connecting encoder to decoder;
|
||||
optionally applies self-attention along the time axis.
|
||||
- ``decoder``: The upsampling path; reconstructs high-resolution feature
|
||||
maps using bottleneck output and skip connections from the encoder.
|
||||
- ``backbones``: Assembles encoder, bottleneck, and decoder into a complete
|
||||
U-Net-style feature extraction backbone.
|
||||
- ``heads``: Lightweight 1×1 convolutional heads that produce detection,
|
||||
classification, and bounding-box size predictions from backbone features.
|
||||
- ``detectors``: Combines a backbone with prediction heads into the final
|
||||
end-to-end ``Detector`` model.
|
||||
|
||||
This module re-exports the most important classes, configurations, and builder
|
||||
functions from these submodules for convenient access. The primary entry point
|
||||
for creating a standard BatDetect2 model instance is the `build_model` function
|
||||
provided here.
|
||||
The primary entry point for building a full, ready-to-use BatDetect2 model
|
||||
is the ``build_model`` factory function exported from this module.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
@ -99,6 +99,31 @@ __all__ = [
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
|
||||
|
||||
Combines a preprocessor, a detection model, and a postprocessor into a
|
||||
single PyTorch module. Calling ``forward`` on a raw waveform tensor
|
||||
returns a list of detection tensors ready for downstream use.
|
||||
|
||||
This class is the top-level object produced by ``build_model``. Most
|
||||
users will not need to construct it directly.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detector : DetectionModel
|
||||
The neural network that processes spectrograms and produces raw
|
||||
detection, classification, and bounding-box outputs.
|
||||
preprocessor : PreprocessorProtocol
|
||||
Converts a raw waveform tensor into a spectrogram tensor accepted by
|
||||
``detector``.
|
||||
postprocessor : PostprocessorProtocol
|
||||
Converts the raw ``ModelOutput`` from ``detector`` into a list of
|
||||
per-clip detection tensors.
|
||||
targets : TargetProtocol
|
||||
Describes the set of target classes; used when building heads and
|
||||
during training target construction.
|
||||
"""
|
||||
|
||||
detector: DetectionModel
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
@ -118,6 +143,25 @@ class Model(torch.nn.Module):
|
||||
self.targets = targets
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
||||
"""Run the full detection pipeline on a waveform tensor.
|
||||
|
||||
Converts the waveform to a spectrogram, passes it through the
|
||||
detector, and postprocesses the raw outputs into detection tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : torch.Tensor
|
||||
Raw audio waveform tensor. The exact expected shape depends on
|
||||
the preprocessor, but is typically ``(batch, samples)`` or
|
||||
``(batch, channels, samples)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[ClipDetectionsTensor]
|
||||
One detection tensor per clip in the batch. Each tensor encodes
|
||||
the detected events (locations, class scores, sizes) for that
|
||||
clip.
|
||||
"""
|
||||
spec = self.preprocessor(wav)
|
||||
outputs = self.detector(spec)
|
||||
return self.postprocessor(outputs)
|
||||
@ -128,7 +172,38 @@ def build_model(
|
||||
targets: TargetProtocol | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
postprocessor: PostprocessorProtocol | None = None,
|
||||
):
|
||||
) -> "Model":
|
||||
"""Build a complete, ready-to-use BatDetect2 model.
|
||||
|
||||
Assembles a ``Model`` instance from optional configuration and component
|
||||
overrides. Any argument left as ``None`` will be replaced by a sensible
|
||||
default built with the project's own builder functions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : BackboneConfig, optional
|
||||
Configuration describing the backbone architecture (encoder,
|
||||
bottleneck, decoder). Defaults to ``UNetBackboneConfig()`` if not
|
||||
provided.
|
||||
targets : TargetProtocol, optional
|
||||
Describes the target bat species or call types to detect. Determines
|
||||
the number of output classes. Defaults to the standard BatDetect2
|
||||
target set.
|
||||
preprocessor : PreprocessorProtocol, optional
|
||||
Converts raw audio waveforms to spectrograms. Defaults to the
|
||||
standard BatDetect2 preprocessor.
|
||||
postprocessor : PostprocessorProtocol, optional
|
||||
Converts raw model outputs to detection tensors. Defaults to the
|
||||
standard BatDetect2 postprocessor. If a custom ``preprocessor`` is
|
||||
given without a matching ``postprocessor``, the default postprocessor
|
||||
will be built using the provided preprocessor so that frequency and
|
||||
time scaling remain consistent.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Model
|
||||
A fully assembled ``Model`` instance ready for inference or training.
|
||||
"""
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
|
||||
@ -1,21 +1,26 @@
|
||||
"""Assembles a complete Encoder-Decoder Backbone network.
|
||||
"""Assembles a complete encoder-decoder backbone network.
|
||||
|
||||
This module defines the configuration (`BackboneConfig`) and implementation
|
||||
(`Backbone`) for a standard encoder-decoder style neural network backbone.
|
||||
This module defines ``UNetBackboneConfig`` and the ``UNetBackbone``
|
||||
``nn.Module``, together with the ``build_backbone`` and
|
||||
``load_backbone_config`` helpers.
|
||||
|
||||
It orchestrates the connection between three main components, built using their
|
||||
respective configurations and factory functions from sibling modules:
|
||||
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
|
||||
at multiple resolutions and provides skip connections.
|
||||
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
|
||||
lowest resolution, optionally applying self-attention.
|
||||
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
|
||||
resolution features using bottleneck features and skip connections.
|
||||
A backbone combines three components built from the sibling modules:
|
||||
|
||||
The resulting `Backbone` module takes a spectrogram as input and outputs a
|
||||
final feature map, typically used by subsequent prediction heads. It includes
|
||||
automatic padding to handle input sizes not perfectly divisible by the
|
||||
network's total downsampling factor.
|
||||
1. **Encoder** (``batdetect2.models.encoder``) – reduces spatial resolution
|
||||
while extracting hierarchical features and storing skip-connection tensors.
|
||||
2. **Bottleneck** (``batdetect2.models.bottleneck``) – processes the
|
||||
lowest-resolution features, optionally applying self-attention.
|
||||
3. **Decoder** (``batdetect2.models.decoder``) – restores spatial resolution
|
||||
using bottleneck features and skip connections from the encoder.
|
||||
|
||||
The resulting ``UNetBackbone`` takes a spectrogram tensor as input and returns
|
||||
a high-resolution feature map consumed by the prediction heads in
|
||||
``batdetect2.models.detectors``.
|
||||
|
||||
Input padding is handled automatically: the backbone pads the input to be
|
||||
divisible by the total downsampling factor and strips the padding from the
|
||||
output so that the output spatial dimensions always match the input spatial
|
||||
dimensions.
|
||||
"""
|
||||
|
||||
from typing import Annotated, Literal, Tuple, Union
|
||||
@ -51,6 +56,34 @@ from batdetect2.typing.models import (
|
||||
|
||||
|
||||
class UNetBackboneConfig(BaseConfig):
|
||||
"""Configuration for a U-Net-style encoder-decoder backbone.
|
||||
|
||||
All fields have sensible defaults that reproduce the standard BatDetect2
|
||||
architecture, so you can start with ``UNetBackboneConfig()`` and override
|
||||
only the fields you want to change.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Discriminator field used by the backbone registry; always
|
||||
``"UNetBackbone"``.
|
||||
input_height : int
|
||||
Number of frequency bins in the input spectrogram. Defaults to
|
||||
``128``.
|
||||
in_channels : int
|
||||
Number of channels in the input spectrogram (e.g. ``1`` for a
|
||||
standard mel-spectrogram). Defaults to ``1``.
|
||||
encoder : EncoderConfig
|
||||
Configuration for the downsampling path. Defaults to
|
||||
``DEFAULT_ENCODER_CONFIG``.
|
||||
bottleneck : BottleneckConfig
|
||||
Configuration for the bottleneck. Defaults to
|
||||
``DEFAULT_BOTTLENECK_CONFIG``.
|
||||
decoder : DecoderConfig
|
||||
Configuration for the upsampling path. Defaults to
|
||||
``DEFAULT_DECODER_CONFIG``.
|
||||
"""
|
||||
|
||||
name: Literal["UNetBackbone"] = "UNetBackbone"
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
@ -70,35 +103,36 @@ __all__ = [
|
||||
|
||||
|
||||
class UNetBackbone(BackboneModel):
|
||||
"""Encoder-Decoder Backbone Network Implementation.
|
||||
"""U-Net-style encoder-decoder backbone network.
|
||||
|
||||
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
|
||||
skip connections between the Encoder and Decoder. Implements the standard
|
||||
U-Net style forward pass. Includes automatic input padding to handle
|
||||
various input sizes and a final convolutional block to adjust the output
|
||||
channels.
|
||||
Combines an encoder, a bottleneck, and a decoder into a single module
|
||||
that produces a high-resolution feature map from an input spectrogram.
|
||||
Skip connections from each encoder stage are added element-wise to the
|
||||
corresponding decoder stage input.
|
||||
|
||||
This class inherits from `BackboneModel` and implements its `forward`
|
||||
method. Instances are typically created using the `build_backbone` factory
|
||||
function.
|
||||
Input spectrograms of arbitrary width are handled automatically: the
|
||||
backbone pads the input so that its dimensions are divisible by
|
||||
``divide_factor`` and removes the padding from the output.
|
||||
|
||||
Instances are typically created via ``build_backbone``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
Expected height of the input spectrogram.
|
||||
Expected height (frequency bins) of the input spectrogram.
|
||||
out_channels : int
|
||||
Number of channels in the final output feature map.
|
||||
Number of channels in the output feature map (taken from the
|
||||
decoder's output channel count).
|
||||
encoder : EncoderProtocol
|
||||
The instantiated encoder module.
|
||||
decoder : DecoderProtocol
|
||||
The instantiated decoder module.
|
||||
bottleneck : BottleneckProtocol
|
||||
The instantiated bottleneck module.
|
||||
final_conv : ConvBlock
|
||||
Final convolutional block applied after the decoder.
|
||||
divide_factor : int
|
||||
The total downsampling factor (2^depth) applied by the encoder,
|
||||
used for automatic input padding.
|
||||
The total spatial downsampling factor applied by the encoder
|
||||
(``input_height // encoder.output_height``). The input width is
|
||||
padded to be a multiple of this value before processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -108,25 +142,19 @@ class UNetBackbone(BackboneModel):
|
||||
decoder: DecoderProtocol,
|
||||
bottleneck: BottleneckProtocol,
|
||||
):
|
||||
"""Initialize the Backbone network.
|
||||
"""Initialise the backbone network.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Expected height of the input spectrogram.
|
||||
out_channels : int
|
||||
Desired number of output channels for the backbone's feature map.
|
||||
Expected height (frequency bins) of the input spectrogram.
|
||||
encoder : EncoderProtocol
|
||||
An initialized Encoder module.
|
||||
An initialised encoder module.
|
||||
decoder : DecoderProtocol
|
||||
An initialized Decoder module.
|
||||
An initialised decoder module. Its ``output_height`` must equal
|
||||
``input_height``; a ``ValueError`` is raised otherwise.
|
||||
bottleneck : BottleneckProtocol
|
||||
An initialized Bottleneck module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If component output/input channels or heights are incompatible.
|
||||
An initialised bottleneck module.
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_height = input_height
|
||||
@ -143,22 +171,25 @@ class UNetBackbone(BackboneModel):
|
||||
self.divide_factor = input_height // self.encoder.output_height
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform the forward pass through the encoder-decoder backbone.
|
||||
"""Produce a feature map from an input spectrogram.
|
||||
|
||||
Applies padding, runs encoder, bottleneck, decoder (with skip
|
||||
connections), removes padding, and applies a final convolution.
|
||||
Pads the input if necessary, runs it through the encoder, then
|
||||
the bottleneck, then the decoder (incorporating encoder skip
|
||||
connections), and finally removes any padding added earlier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match
|
||||
`self.encoder.input_channels` and `self.input_height`.
|
||||
Input spectrogram tensor, shape
|
||||
``(B, C_in, H_in, W_in)``. ``H_in`` must equal
|
||||
``self.input_height``; ``W_in`` can be any positive integer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where
|
||||
`C_out` is `self.out_channels`.
|
||||
Feature map tensor, shape ``(B, C_out, H_in, W_in)``, where
|
||||
``C_out`` is ``self.out_channels``. The spatial dimensions
|
||||
always match those of the input.
|
||||
"""
|
||||
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
|
||||
|
||||
@ -219,6 +250,24 @@ BackboneConfig = Annotated[
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||
"""Build a backbone network from configuration.
|
||||
|
||||
Looks up the backbone class corresponding to ``config.name`` in the
|
||||
backbone registry and calls its ``from_config`` method. If no
|
||||
configuration is provided, a default ``UNetBackbone`` is returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : BackboneConfig, optional
|
||||
A configuration object describing the desired backbone. Currently
|
||||
``UNetBackboneConfig`` is the only supported type. Defaults to
|
||||
``UNetBackboneConfig()`` if not provided.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneModel
|
||||
An initialised backbone module.
|
||||
"""
|
||||
config = config or UNetBackboneConfig()
|
||||
return backbone_registry.build(config)
|
||||
|
||||
@ -227,26 +276,25 @@ def _pad_adjust(
|
||||
spec: torch.Tensor,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
"""Pad tensor height and width to be divisible by a factor.
|
||||
"""Pad a tensor's height and width to be divisible by ``factor``.
|
||||
|
||||
Calculates the required padding for the last two dimensions (H, W) to make
|
||||
them divisible by `factor` and applies right/bottom padding using
|
||||
`torch.nn.functional.pad`.
|
||||
Adds zero-padding to the bottom and right edges of the tensor so that
|
||||
both dimensions are exact multiples of ``factor``. If both dimensions
|
||||
are already divisible, the tensor is returned unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input tensor, typically shape `(B, C, H, W)`.
|
||||
Input tensor, typically shape ``(B, C, H, W)``.
|
||||
factor : int, default=32
|
||||
The factor to make height and width divisible by.
|
||||
The factor that both H and W should be divisible by after padding.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[torch.Tensor, int, int]
|
||||
A tuple containing:
|
||||
- The padded tensor.
|
||||
- The amount of padding added to height (`h_pad`).
|
||||
- The amount of padding added to width (`w_pad`).
|
||||
- Padded tensor.
|
||||
- Number of rows added to the height (``h_pad``).
|
||||
- Number of columns added to the width (``w_pad``).
|
||||
"""
|
||||
h, w = spec.shape[-2:]
|
||||
h_pad = -h % factor
|
||||
@ -261,23 +309,25 @@ def _pad_adjust(
|
||||
def _restore_pad(
|
||||
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""Remove padding added by _pad_adjust.
|
||||
"""Remove padding previously added by ``_pad_adjust``.
|
||||
|
||||
Removes padding from the bottom and right edges of the tensor.
|
||||
Trims ``h_pad`` rows from the bottom and ``w_pad`` columns from the
|
||||
right of the tensor, restoring its original spatial dimensions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Padded tensor, typically shape `(B, C, H_padded, W_padded)`.
|
||||
Padded tensor, typically shape ``(B, C, H_padded, W_padded)``.
|
||||
h_pad : int, default=0
|
||||
Amount of padding previously added to the height (bottom).
|
||||
Number of rows to remove from the bottom.
|
||||
w_pad : int, default=0
|
||||
Amount of padding previously added to the width (right).
|
||||
Number of columns to remove from the right.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
|
||||
Tensor with padding removed, shape
|
||||
``(B, C, H_padded - h_pad, W_padded - w_pad)``.
|
||||
"""
|
||||
if h_pad > 0:
|
||||
x = x[..., :-h_pad, :]
|
||||
@ -292,6 +342,36 @@ def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: str | None = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load a backbone configuration from a YAML or JSON file.
|
||||
|
||||
Reads the file at ``path``, optionally descends into a named sub-field,
|
||||
and validates the result against the ``BackboneConfig`` discriminated
|
||||
union.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file. Both YAML and JSON formats are
|
||||
supported.
|
||||
field : str, optional
|
||||
Dot-separated key path to the sub-field that contains the backbone
|
||||
configuration (e.g. ``"model"``). If ``None``, the root of the
|
||||
file is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
A validated backbone configuration object (currently always a
|
||||
``UNetBackboneConfig`` instance).
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If ``path`` does not exist.
|
||||
ValidationError
|
||||
If the loaded data does not conform to a known ``BackboneConfig``
|
||||
schema.
|
||||
"""
|
||||
return load_config(
|
||||
path,
|
||||
schema=TypeAdapter(BackboneConfig),
|
||||
|
||||
@ -1,30 +1,49 @@
|
||||
"""Commonly used neural network building blocks for BatDetect2 models.
|
||||
"""Reusable convolutional building blocks for BatDetect2 models.
|
||||
|
||||
This module provides various reusable `torch.nn.Module` subclasses that form
|
||||
the fundamental building blocks for constructing convolutional neural network
|
||||
architectures, particularly encoder-decoder backbones used in BatDetect2.
|
||||
This module provides a collection of ``torch.nn.Module`` subclasses that form
|
||||
the fundamental building blocks for the encoder-decoder backbone used in
|
||||
BatDetect2. All blocks follow a consistent interface: they store
|
||||
``in_channels`` and ``out_channels`` as attributes and implement a
|
||||
``get_output_height`` method that reports how a given input height maps to an
|
||||
output height (e.g., halved by downsampling blocks, doubled by upsampling
|
||||
blocks).
|
||||
|
||||
It includes standard components like basic convolutional blocks (`ConvBlock`),
|
||||
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
|
||||
upsampling (`StandardConvUpBlock`).
|
||||
Available block families
|
||||
------------------------
|
||||
Standard blocks
|
||||
``ConvBlock`` – convolution + batch normalisation + ReLU, no change in
|
||||
spatial resolution.
|
||||
|
||||
Additionally, it features specialized layers investigated in BatDetect2
|
||||
research:
|
||||
Downsampling blocks
|
||||
``StandardConvDownBlock`` – convolution then 2×2 max-pooling, halves H
|
||||
and W.
|
||||
``FreqCoordConvDownBlock`` – like ``StandardConvDownBlock`` but prepends
|
||||
a normalised frequency-coordinate channel before the convolution
|
||||
(CoordConv concept), helping filters learn frequency-position-dependent
|
||||
patterns.
|
||||
|
||||
- `SelfAttention`: Applies self-attention along the time dimension, enabling
|
||||
the model to weigh information across the entire temporal context, often
|
||||
used in the bottleneck of an encoder-decoder.
|
||||
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
|
||||
concept by concatenating normalized frequency coordinate information as an
|
||||
extra channel to the input of convolutional layers. This explicitly provides
|
||||
spatial frequency information to filters, potentially enabling them to learn
|
||||
frequency-dependent patterns more effectively.
|
||||
Upsampling blocks
|
||||
``StandardConvUpBlock`` – bilinear interpolation then convolution,
|
||||
doubles H and W.
|
||||
``FreqCoordConvUpBlock`` – like ``StandardConvUpBlock`` but prepends a
|
||||
frequency-coordinate channel after upsampling.
|
||||
|
||||
These blocks can be used directly in custom PyTorch model definitions or
|
||||
assembled into larger architectures.
|
||||
Bottleneck blocks
|
||||
``VerticalConv`` – 1-D convolution whose kernel spans the entire
|
||||
frequency axis, collapsing H to 1 whilst preserving W.
|
||||
``SelfAttention`` – scaled dot-product self-attention along the time
|
||||
axis; typically follows a ``VerticalConv``.
|
||||
|
||||
A unified factory function `build_layer` allows creating instances
|
||||
of these blocks based on configuration objects.
|
||||
Group block
|
||||
``LayerGroup`` – chains several blocks sequentially into one unit,
|
||||
useful when a single encoder or decoder "stage" requires more than one
|
||||
operation.
|
||||
|
||||
Factory function
|
||||
----------------
|
||||
``build_layer`` creates any of the above blocks from the matching
|
||||
configuration object (one of the ``*Config`` classes exported here), using
|
||||
a discriminated-union ``name`` field to dispatch to the correct class.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Literal, Tuple, Union
|
||||
@ -57,10 +76,43 @@ __all__ = [
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Abstract base class for all BatDetect2 building blocks.
|
||||
|
||||
Subclasses must set ``in_channels`` and ``out_channels`` as integer
|
||||
attributes so that factory functions can wire blocks together without
|
||||
inspecting configuration objects at runtime. They may also override
|
||||
``get_output_height`` when the block changes the height dimension (e.g.
|
||||
downsampling or upsampling blocks).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels expected in the input tensor.
|
||||
out_channels : int
|
||||
Number of channels produced in the output tensor.
|
||||
"""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
"""Return the output height for a given input height.
|
||||
|
||||
The default implementation returns ``input_height`` unchanged,
|
||||
which is correct for blocks that do not alter spatial resolution.
|
||||
Override this in downsampling (returns ``input_height // 2``) or
|
||||
upsampling (returns ``input_height * 2``) subclasses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (number of frequency bins) of the input feature map.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Height of the output feature map.
|
||||
"""
|
||||
return input_height
|
||||
|
||||
|
||||
@ -68,58 +120,65 @@ block_registry: Registry[Block, [int, int]] = Registry("block")
|
||||
|
||||
|
||||
class SelfAttentionConfig(BaseConfig):
|
||||
"""Configuration for a ``SelfAttention`` block.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Discriminator field; always ``"SelfAttention"``.
|
||||
attention_channels : int
|
||||
Dimensionality of the query, key, and value projections.
|
||||
temperature : float
|
||||
Scaling factor applied to the weighted values before the final
|
||||
linear projection. Defaults to ``1``.
|
||||
"""
|
||||
|
||||
name: Literal["SelfAttention"] = "SelfAttention"
|
||||
attention_channels: int
|
||||
temperature: float = 1
|
||||
|
||||
|
||||
class SelfAttention(Block):
|
||||
"""Self-Attention mechanism operating along the time dimension.
|
||||
"""Self-attention block operating along the time axis.
|
||||
|
||||
This module implements a scaled dot-product self-attention mechanism,
|
||||
specifically designed here to operate across the time steps of an input
|
||||
feature map, typically after spatial dimensions (like frequency) have been
|
||||
condensed or squeezed.
|
||||
Applies a scaled dot-product self-attention mechanism across the time
|
||||
steps of an input feature map. Before attention is computed the height
|
||||
dimension (frequency axis) is expected to have been reduced to 1, e.g.
|
||||
by a preceding ``VerticalConv`` layer.
|
||||
|
||||
By calculating attention weights between all pairs of time steps, it allows
|
||||
the model to capture long-range temporal dependencies and focus on relevant
|
||||
parts of the sequence. It's often employed in the bottleneck or
|
||||
intermediate layers of an encoder-decoder architecture to integrate global
|
||||
temporal context.
|
||||
|
||||
The implementation uses linear projections to create query, key, and value
|
||||
representations, computes scaled dot-product attention scores, applies
|
||||
softmax, and produces an output by weighting the values according to the
|
||||
attention scores, followed by a final linear projection. Positional encoding
|
||||
is not explicitly included in this block.
|
||||
For each time step the block computes query, key, and value projections
|
||||
with learned linear weights, then calculates attention weights from the
|
||||
query–key dot products divided by ``temperature × attention_channels``.
|
||||
The weighted sum of values is projected back to ``in_channels`` via a
|
||||
final linear layer, and the height dimension is restored so that the
|
||||
output shape matches the input shape.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels (features per time step after spatial squeeze).
|
||||
Number of input channels (features per time step). The output will
|
||||
also have ``in_channels`` channels.
|
||||
attention_channels : int
|
||||
Number of channels for the query, key, and value projections. Also the
|
||||
dimension of the output projection's input.
|
||||
Dimensionality of the query, key, and value projections.
|
||||
temperature : float, default=1.0
|
||||
Scaling factor applied *before* the final projection layer. Can be used
|
||||
to adjust the sharpness or focus of the attention mechanism, although
|
||||
scaling within the softmax (dividing by sqrt(dim)) is more common for
|
||||
standard transformers. Here it scales the weighted values.
|
||||
Divisor applied together with ``attention_channels`` when scaling
|
||||
the dot-product scores before softmax. Larger values produce softer
|
||||
(more uniform) attention distributions.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
key_fun : nn.Linear
|
||||
Linear layer for key projection.
|
||||
Linear projection for keys.
|
||||
value_fun : nn.Linear
|
||||
Linear layer for value projection.
|
||||
Linear projection for values.
|
||||
query_fun : nn.Linear
|
||||
Linear layer for query projection.
|
||||
Linear projection for queries.
|
||||
pro_fun : nn.Linear
|
||||
Final linear projection layer applied after attention weighting.
|
||||
Final linear projection applied to the attended values.
|
||||
temperature : float
|
||||
Scaling factor applied before final projection.
|
||||
Scaling divisor used when computing attention scores.
|
||||
att_dim : int
|
||||
Dimensionality of the attention space (`attention_channels`).
|
||||
Dimensionality of the attention space (``attention_channels``).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -148,20 +207,16 @@ class SelfAttention(Block):
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, expected shape `(B, C, H, W)`, where H is typically
|
||||
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
|
||||
applying attention along the W (time) dimension.
|
||||
Input tensor with shape ``(B, C, 1, W)``. The height dimension
|
||||
must be 1 (i.e. the frequency axis should already have been
|
||||
collapsed by a preceding ``VerticalConv`` layer).
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor of the same shape as the input `(B, C, H, W)`, where
|
||||
attention has been applied across the W dimension.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If input tensor dimensions are incompatible with operations.
|
||||
Output tensor with the same shape ``(B, C, 1, W)`` as the
|
||||
input, with each time step updated by attended context from all
|
||||
other time steps.
|
||||
"""
|
||||
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
@ -190,6 +245,22 @@ class SelfAttention(Block):
|
||||
return op
|
||||
|
||||
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Return the softmax attention weight matrix.
|
||||
|
||||
Useful for visualising which time steps attend to which others.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor with shape ``(B, C, 1, W)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Attention weight matrix with shape ``(B, W, W)``. Entry
|
||||
``[b, i, j]`` is the attention weight that time step ``i``
|
||||
assigns to time step ``j`` in batch item ``b``.
|
||||
"""
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
|
||||
key = torch.matmul(
|
||||
@ -304,6 +375,16 @@ class ConvBlock(Block):
|
||||
|
||||
|
||||
class VerticalConvConfig(BaseConfig):
|
||||
"""Configuration for a ``VerticalConv`` block.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Discriminator field; always ``"VerticalConv"``.
|
||||
channels : int
|
||||
Number of output channels produced by the vertical convolution.
|
||||
"""
|
||||
|
||||
name: Literal["VerticalConv"] = "VerticalConv"
|
||||
channels: int
|
||||
|
||||
@ -844,12 +925,53 @@ LayerConfig = Annotated[
|
||||
|
||||
|
||||
class LayerGroupConfig(BaseConfig):
|
||||
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
|
||||
|
||||
Use this when a single encoder or decoder stage needs more than one
|
||||
block. The blocks are executed in the order they appear in ``layers``,
|
||||
with channel counts and heights propagated automatically.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Discriminator field; always ``"LayerGroup"``.
|
||||
layers : List[LayerConfig]
|
||||
Ordered list of block configurations to chain together.
|
||||
"""
|
||||
|
||||
name: Literal["LayerGroup"] = "LayerGroup"
|
||||
layers: List[LayerConfig]
|
||||
|
||||
|
||||
class LayerGroup(nn.Module):
|
||||
"""Standard implementation of the `LayerGroup` architecture."""
|
||||
"""Sequential chain of blocks that acts as a single composite block.
|
||||
|
||||
Wraps multiple ``Block`` instances in an ``nn.Sequential`` container,
|
||||
exposing the same ``in_channels``, ``out_channels``, and
|
||||
``get_output_height`` interface as a regular ``Block`` so it can be
|
||||
used transparently wherever a single block is expected.
|
||||
|
||||
Instances are typically constructed by ``build_layer`` when given a
|
||||
``LayerGroupConfig``; you rarely need to create them directly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
layers : list[Block]
|
||||
Pre-built block instances to chain, in execution order.
|
||||
input_height : int
|
||||
Height of the tensor entering the first block.
|
||||
input_channels : int
|
||||
Number of channels in the tensor entering the first block.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels (taken from the first block).
|
||||
out_channels : int
|
||||
Number of output channels (taken from the last block).
|
||||
layers : nn.Sequential
|
||||
The wrapped sequence of block modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -865,9 +987,33 @@ class LayerGroup(nn.Module):
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pass input through all blocks in sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input feature map, shape ``(B, C_in, H, W)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output feature map after all blocks have been applied.
|
||||
"""
|
||||
return self.layers(x)
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
"""Compute the output height by propagating through all blocks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height of the input feature map.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Height after all blocks in the group have been applied.
|
||||
"""
|
||||
for block in self.layers:
|
||||
input_height = block.get_output_height(input_height) # type: ignore
|
||||
return input_height
|
||||
@ -903,40 +1049,37 @@ def build_layer(
|
||||
in_channels: int,
|
||||
config: LayerConfig,
|
||||
) -> Block:
|
||||
"""Factory function to build a specific nn.Module block from its config.
|
||||
"""Build a block from its configuration object.
|
||||
|
||||
Takes configuration object (one of the types included in the `LayerConfig`
|
||||
union) and instantiates the corresponding nn.Module block with the correct
|
||||
parameters derived from the config and the current pipeline state
|
||||
(`input_height`, `in_channels`).
|
||||
|
||||
It uses the `name` field within the `config` object to determine
|
||||
which block class to instantiate.
|
||||
Looks up the block class corresponding to ``config.name`` in the
|
||||
internal block registry and instantiates it with the given input
|
||||
dimensions. This is the standard way to construct blocks when
|
||||
assembling an encoder or decoder from a configuration file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor *to this layer*.
|
||||
Height (number of frequency bins) of the input tensor to this
|
||||
block. Required for blocks whose kernel size depends on the input
|
||||
height (e.g. ``VerticalConv``) and for coordinate-aware blocks.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor *to this layer*.
|
||||
Number of channels in the input tensor to this block.
|
||||
config : LayerConfig
|
||||
A Pydantic configuration object for the desired block (e.g., an
|
||||
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
||||
by its `name` field.
|
||||
A configuration object for the desired block type. The ``name``
|
||||
field selects the block class; remaining fields supply its
|
||||
parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[nn.Module, int, int]
|
||||
A tuple containing:
|
||||
- The instantiated `nn.Module` block.
|
||||
- The number of output channels produced by the block.
|
||||
- The calculated height of the output produced by the block.
|
||||
Block
|
||||
An initialised block module ready to be added to an
|
||||
``nn.Sequential`` or ``nn.ModuleList``.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the `config.name` does not correspond to a known block type.
|
||||
KeyError
|
||||
If ``config.name`` does not correspond to a registered block type.
|
||||
ValueError
|
||||
If parameters derived from the config are invalid for the block.
|
||||
If the configuration parameters are invalid for the chosen block.
|
||||
"""
|
||||
return block_registry.build(config, in_channels, input_height)
|
||||
|
||||
@ -1,17 +1,21 @@
|
||||
"""Defines the Bottleneck component of an Encoder-Decoder architecture.
|
||||
"""Bottleneck component for encoder-decoder network architectures.
|
||||
|
||||
This module provides the configuration (`BottleneckConfig`) and
|
||||
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the
|
||||
bottleneck layer(s) that typically connect the Encoder (downsampling path) and
|
||||
Decoder (upsampling path) in networks like U-Nets.
|
||||
The bottleneck sits between the encoder (downsampling path) and the decoder
|
||||
(upsampling path) and processes the lowest-resolution, highest-channel feature
|
||||
map produced by the encoder.
|
||||
|
||||
The bottleneck processes the lowest-resolution, highest-dimensionality feature
|
||||
map produced by the Encoder. This module offers a configurable option to include
|
||||
a `SelfAttention` layer within the bottleneck, allowing the model to capture
|
||||
global temporal context before features are passed to the Decoder.
|
||||
This module provides:
|
||||
|
||||
A factory function `build_bottleneck` constructs the appropriate bottleneck
|
||||
module based on the provided configuration.
|
||||
- ``BottleneckConfig`` – configuration dataclass describing the number of
|
||||
internal channels and an optional sequence of additional layers (currently
|
||||
only ``SelfAttention`` is supported).
|
||||
- ``Bottleneck`` – the ``torch.nn.Module`` implementation. It first applies a
|
||||
``VerticalConv`` to collapse the frequency axis to a single bin, optionally
|
||||
runs one or more additional layers (e.g. self-attention along the time axis),
|
||||
then repeats the output along the height dimension to restore the original
|
||||
frequency resolution before passing features to the decoder.
|
||||
- ``build_bottleneck`` – factory function that constructs a ``Bottleneck``
|
||||
instance from a ``BottleneckConfig`` and the encoder's output dimensions.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List
|
||||
@ -37,42 +41,51 @@ __all__ = [
|
||||
|
||||
|
||||
class Bottleneck(Block):
|
||||
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||
"""Bottleneck module for encoder-decoder architectures.
|
||||
|
||||
This implementation represents the simplest bottleneck structure
|
||||
considered, primarily consisting of a `VerticalConv` layer. This layer
|
||||
collapses the frequency dimension (height) to 1, summarizing information
|
||||
across frequencies at each time step. The output is then repeated along the
|
||||
height dimension to match the original bottleneck input height before being
|
||||
passed to the decoder.
|
||||
Processes the lowest-resolution feature map that links the encoder and
|
||||
decoder. The sequence of operations is:
|
||||
|
||||
This base version does *not* include self-attention.
|
||||
1. ``VerticalConv`` – collapses the frequency axis (height) to a single
|
||||
bin by applying a convolution whose kernel spans the full height.
|
||||
2. Optional additional layers (e.g. ``SelfAttention``) – applied while
|
||||
the feature map has height 1, so they operate purely along the time
|
||||
axis.
|
||||
3. Height restoration – the single-bin output is repeated along the
|
||||
height axis to restore the original frequency resolution, producing
|
||||
a tensor that the decoder can accept.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor. Must be positive.
|
||||
Height (number of frequency bins) of the input tensor. Must be
|
||||
positive.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor from the encoder. Must be
|
||||
positive.
|
||||
out_channels : int
|
||||
Number of output channels. Must be positive.
|
||||
Number of output channels after the bottleneck. Must be positive.
|
||||
bottleneck_channels : int, optional
|
||||
Number of internal channels used by the ``VerticalConv`` layer.
|
||||
Defaults to ``out_channels`` if not provided.
|
||||
layers : List[torch.nn.Module], optional
|
||||
Additional modules (e.g. ``SelfAttention``) to apply after the
|
||||
``VerticalConv`` and before height restoration.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels accepted by the bottleneck.
|
||||
out_channels : int
|
||||
Number of output channels produced by the bottleneck.
|
||||
input_height : int
|
||||
Expected height of the input tensor.
|
||||
channels : int
|
||||
Number of output channels.
|
||||
bottleneck_channels : int
|
||||
Number of channels used internally by the vertical convolution.
|
||||
conv_vert : VerticalConv
|
||||
The vertical convolution layer.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
||||
layers : nn.ModuleList
|
||||
Additional layers applied after the vertical convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -83,7 +96,23 @@ class Bottleneck(Block):
|
||||
bottleneck_channels: int | None = None,
|
||||
layers: List[torch.nn.Module] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the base Bottleneck layer."""
|
||||
"""Initialise the Bottleneck layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (number of frequency bins) of the input tensor.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor.
|
||||
out_channels : int
|
||||
Number of channels in the output tensor.
|
||||
bottleneck_channels : int, optional
|
||||
Number of internal channels for the ``VerticalConv``. Defaults
|
||||
to ``out_channels``.
|
||||
layers : List[torch.nn.Module], optional
|
||||
Additional modules applied after the ``VerticalConv``, such as
|
||||
a ``SelfAttention`` block.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.input_height = input_height
|
||||
@ -103,23 +132,24 @@ class Bottleneck(Block):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process input features through the bottleneck.
|
||||
"""Process the encoder's bottleneck features.
|
||||
|
||||
Applies vertical convolution and repeats the output height.
|
||||
Applies vertical convolution, optional additional layers, then
|
||||
restores the height dimension by repetition.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor from the encoder bottleneck, shape
|
||||
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
||||
`H_in` must match `self.input_height`.
|
||||
Input tensor from the encoder, shape
|
||||
``(B, C_in, H_in, W)``. ``C_in`` must match
|
||||
``self.in_channels`` and ``H_in`` must match
|
||||
``self.input_height``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height
|
||||
dimension `H_in` is restored via repetition after the vertical
|
||||
convolution.
|
||||
Output tensor with shape ``(B, C_out, H_in, W)``. The height
|
||||
``H_in`` is restored by repeating the single-bin result.
|
||||
"""
|
||||
x = self.conv_vert(x)
|
||||
|
||||
@ -133,28 +163,22 @@ BottleneckLayerConfig = Annotated[
|
||||
SelfAttentionConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
"""Type alias for the discriminated union of block configs usable in the Bottleneck."""
|
||||
|
||||
|
||||
class BottleneckConfig(BaseConfig):
|
||||
"""Configuration for the bottleneck layer(s).
|
||||
|
||||
Defines the number of channels within the bottleneck and whether to include
|
||||
a self-attention mechanism.
|
||||
"""Configuration for the bottleneck component.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
channels : int
|
||||
The number of output channels produced by the main convolutional layer
|
||||
within the bottleneck. This often matches the number of channels coming
|
||||
from the last encoder stage, but can be different. Must be positive.
|
||||
This also defines the channel dimensions used within the optional
|
||||
`SelfAttention` layer.
|
||||
self_attention : bool
|
||||
If True, includes a `SelfAttention` layer operating on the time
|
||||
dimension after an initial `VerticalConv` layer within the bottleneck.
|
||||
If False, only the initial `VerticalConv` (and height repetition) is
|
||||
performed.
|
||||
Number of output channels produced by the bottleneck. This value
|
||||
is also used as the dimensionality of any optional layers (e.g.
|
||||
self-attention). Must be positive.
|
||||
layers : List[BottleneckLayerConfig]
|
||||
Ordered list of additional block configurations to apply after the
|
||||
initial ``VerticalConv``. Currently only ``SelfAttentionConfig`` is
|
||||
supported. Defaults to an empty list (no extra layers).
|
||||
"""
|
||||
|
||||
channels: int
|
||||
@ -174,30 +198,37 @@ def build_bottleneck(
|
||||
in_channels: int,
|
||||
config: BottleneckConfig | None = None,
|
||||
) -> BottleneckProtocol:
|
||||
"""Factory function to build the Bottleneck module from configuration.
|
||||
"""Build a ``Bottleneck`` module from configuration.
|
||||
|
||||
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
||||
on the `config.self_attention` flag.
|
||||
Constructs a ``Bottleneck`` instance whose internal channel count and
|
||||
optional extra layers (e.g. self-attention) are controlled by
|
||||
``config``. If no configuration is provided, the default
|
||||
``DEFAULT_BOTTLENECK_CONFIG`` is used, which includes a
|
||||
``SelfAttention`` layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor. Must be positive.
|
||||
Height (number of frequency bins) of the input tensor from the
|
||||
encoder. Must be positive.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor. Must be positive.
|
||||
Number of channels in the input tensor from the encoder. Must be
|
||||
positive.
|
||||
config : BottleneckConfig, optional
|
||||
Configuration object specifying the bottleneck channels and whether
|
||||
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None.
|
||||
Configuration specifying the output channel count and any
|
||||
additional layers. Uses ``DEFAULT_BOTTLENECK_CONFIG`` if ``None``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
nn.Module
|
||||
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`).
|
||||
BottleneckProtocol
|
||||
An initialised ``Bottleneck`` module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `input_height` or `in_channels` are not positive.
|
||||
AssertionError
|
||||
If any configured layer changes the height of the feature map
|
||||
(bottleneck layers must preserve height so that it can be restored
|
||||
by repetition).
|
||||
"""
|
||||
config = config or DEFAULT_BOTTLENECK_CONFIG
|
||||
|
||||
|
||||
@ -1,21 +1,21 @@
|
||||
"""Constructs the Decoder part of an Encoder-Decoder neural network.
|
||||
"""Decoder (upsampling path) for the BatDetect2 backbone.
|
||||
|
||||
This module defines the configuration structure (`DecoderConfig`) for the layer
|
||||
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory
|
||||
function (`build_decoder`). Decoders typically form the upsampling path in
|
||||
architectures like U-Nets, taking bottleneck features
|
||||
(usually from an `Encoder`) and skip connections to reconstruct
|
||||
higher-resolution feature maps.
|
||||
This module defines ``DecoderConfig`` and the ``Decoder`` ``nn.Module``,
|
||||
together with the ``build_decoder`` factory function.
|
||||
|
||||
The decoder is built dynamically by stacking neural network blocks based on a
|
||||
list of configuration objects provided in `DecoderConfig.layers`. Each config
|
||||
object specifies the type of block (e.g., standard convolution,
|
||||
coordinate-feature convolution with upsampling) and its parameters. This allows
|
||||
flexible definition of decoder architectures via configuration files.
|
||||
In a U-Net-style network the decoder progressively restores the spatial
|
||||
resolution of the feature map back towards the input resolution. At each
|
||||
stage it combines the upsampled features with the corresponding skip-connection
|
||||
tensor from the encoder (the residual) by element-wise addition before passing
|
||||
the result to the upsampling block.
|
||||
|
||||
The `Decoder`'s `forward` method is designed to accept skip connection tensors
|
||||
(`residuals`) from the encoder, merging them with the upsampled feature maps
|
||||
at each stage.
|
||||
The decoder is fully configurable: the type, number, and parameters of the
|
||||
upsampling blocks are described by a ``DecoderConfig`` object containing an
|
||||
ordered list of block configuration objects (see ``batdetect2.models.blocks``
|
||||
for available block types).
|
||||
|
||||
A default configuration ``DEFAULT_DECODER_CONFIG`` is provided and used by
|
||||
``build_decoder`` when no explicit configuration is supplied.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List
|
||||
@ -51,51 +51,47 @@ DecoderLayerConfig = Annotated[
|
||||
|
||||
|
||||
class DecoderConfig(BaseConfig):
|
||||
"""Configuration for the sequence of layers in the Decoder module.
|
||||
|
||||
Defines the types and parameters of the neural network blocks that
|
||||
constitute the decoder's upsampling path.
|
||||
"""Configuration for the sequential ``Decoder`` module.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
layers : List[DecoderLayerConfig]
|
||||
An ordered list of configuration objects, each defining one layer or
|
||||
block in the decoder sequence. Each item must be a valid block
|
||||
config including a `name` field and necessary parameters like
|
||||
`out_channels`. Input channels for each layer are inferred sequentially.
|
||||
The list must contain at least one layer.
|
||||
Ordered list of block configuration objects defining the decoder's
|
||||
upsampling stages (from deepest to shallowest). Each entry
|
||||
specifies the block type (via its ``name`` field) and any
|
||||
block-specific parameters such as ``out_channels``. Input channels
|
||||
for each block are inferred automatically from the output of the
|
||||
previous block. Must contain at least one entry.
|
||||
"""
|
||||
|
||||
layers: List[DecoderLayerConfig] = Field(min_length=1)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Sequential Decoder module composed of configurable upsampling layers.
|
||||
"""Sequential decoder module composed of configurable upsampling layers.
|
||||
|
||||
Constructs the upsampling path of an encoder-decoder network by stacking
|
||||
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`)
|
||||
based on a list of layer modules provided during initialization (typically
|
||||
created by the `build_decoder` factory function).
|
||||
Executes a series of upsampling blocks in order, adding the
|
||||
corresponding encoder skip-connection tensor (residual) to the feature
|
||||
map before each block. The residuals are consumed in reverse order (from
|
||||
deepest encoder layer to shallowest) to match the spatial resolutions at
|
||||
each decoder stage.
|
||||
|
||||
The `forward` method is designed to integrate skip connection tensors
|
||||
(`residuals`) from the corresponding encoder stages, by adding them
|
||||
element-wise to the input of each decoder layer before processing.
|
||||
Instances are typically created by ``build_decoder``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels expected in the input tensor.
|
||||
Number of channels expected in the input tensor (bottleneck output).
|
||||
out_channels : int
|
||||
Number of channels in the final output tensor produced by the last
|
||||
layer.
|
||||
Number of channels in the final output feature map.
|
||||
input_height : int
|
||||
Height (frequency bins) expected in the input tensor.
|
||||
Height (frequency bins) of the input tensor.
|
||||
output_height : int
|
||||
Height (frequency bins) expected in the output tensor.
|
||||
Height (frequency bins) of the output tensor.
|
||||
layers : nn.ModuleList
|
||||
The sequence of instantiated upscaling layer modules.
|
||||
Sequence of instantiated upsampling block modules.
|
||||
depth : int
|
||||
The number of upscaling layers (depth) in the decoder.
|
||||
Number of upsampling layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -106,23 +102,24 @@ class Decoder(nn.Module):
|
||||
output_height: int,
|
||||
layers: List[nn.Module],
|
||||
):
|
||||
"""Initialize the Decoder module.
|
||||
"""Initialise the Decoder module.
|
||||
|
||||
Note: This constructor is typically called internally by the
|
||||
`build_decoder` factory function.
|
||||
This constructor is typically called by the ``build_decoder``
|
||||
factory function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor (bottleneck output).
|
||||
out_channels : int
|
||||
Number of channels produced by the final layer.
|
||||
input_height : int
|
||||
Expected height of the input tensor (bottleneck).
|
||||
in_channels : int
|
||||
Expected number of channels in the input tensor (bottleneck).
|
||||
Height of the input tensor (bottleneck output height).
|
||||
output_height : int
|
||||
Height of the output tensor after all layers have been applied.
|
||||
layers : List[nn.Module]
|
||||
A list of pre-instantiated upscaling layer modules (e.g.,
|
||||
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired
|
||||
sequence (from bottleneck towards output resolution).
|
||||
Pre-built upsampling block modules in execution order (deepest
|
||||
stage first).
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -140,43 +137,35 @@ class Decoder(nn.Module):
|
||||
x: torch.Tensor,
|
||||
residuals: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Pass input through decoder layers, incorporating skip connections.
|
||||
"""Pass input through all decoder layers, incorporating skip connections.
|
||||
|
||||
Processes the input tensor `x` sequentially through the upscaling
|
||||
layers. At each stage, the corresponding skip connection tensor from
|
||||
the `residuals` list is added element-wise to the input before passing
|
||||
it to the upscaling block.
|
||||
At each stage the corresponding residual tensor is added
|
||||
element-wise to ``x`` before it is passed to the upsampling block.
|
||||
Residuals are consumed in reverse order — the last element of
|
||||
``residuals`` (the output of the shallowest encoder layer) is added
|
||||
at the first decoder stage, and the first element (output of the
|
||||
deepest encoder layer) is added at the last decoder stage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor from the previous stage (e.g., encoder bottleneck).
|
||||
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
|
||||
`self.in_channels`.
|
||||
Bottleneck feature map, shape ``(B, C_in, H_in, W)``.
|
||||
residuals : List[torch.Tensor]
|
||||
List containing the skip connection tensors from the corresponding
|
||||
encoder stages. Should be ordered from the deepest encoder layer
|
||||
output (lowest resolution) to the shallowest (highest resolution
|
||||
near input). The number of tensors in this list must match the
|
||||
number of decoder layers (`self.depth`). Each residual tensor's
|
||||
channel count must be compatible with the input tensor `x` for
|
||||
element-wise addition (or concatenation if the blocks were designed
|
||||
for it).
|
||||
Skip-connection tensors from the encoder, ordered from shallowest
|
||||
(index 0) to deepest (index -1). Must contain exactly
|
||||
``self.depth`` tensors. Each tensor must have the same spatial
|
||||
dimensions and channel count as ``x`` at the corresponding
|
||||
decoder stage.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The final decoded feature map tensor produced by the last layer.
|
||||
Shape `(B, C_out, H_out, W_out)`.
|
||||
Decoded feature map, shape ``(B, C_out, H_out, W)``.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of `residuals` provided does not match the decoder
|
||||
depth.
|
||||
RuntimeError
|
||||
If shapes mismatch during skip connection addition or layer
|
||||
processing.
|
||||
If the number of ``residuals`` does not equal ``self.depth``.
|
||||
"""
|
||||
if len(residuals) != len(self.layers):
|
||||
raise ValueError(
|
||||
@ -203,11 +192,17 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
||||
),
|
||||
],
|
||||
)
|
||||
"""A default configuration for the Decoder's *layer sequence*.
|
||||
"""Default decoder configuration used in standard BatDetect2 models.
|
||||
|
||||
Specifies an architecture often used in BatDetect2, consisting of three
|
||||
frequency coordinate-aware upsampling blocks followed by a standard
|
||||
convolutional block.
|
||||
Mirrors ``DEFAULT_ENCODER_CONFIG`` in reverse. Assumes the bottleneck
|
||||
output has 256 channels and height 16, and produces:
|
||||
|
||||
- Stage 1 (``FreqCoordConvUp``): 64 channels, height 32.
|
||||
- Stage 2 (``FreqCoordConvUp``): 32 channels, height 64.
|
||||
- Stage 3 (``LayerGroup``):
|
||||
|
||||
- ``FreqCoordConvUp``: 32 channels, height 128.
|
||||
- ``ConvBlock``: 32 channels, height 128 (final feature map).
|
||||
"""
|
||||
|
||||
|
||||
@ -216,40 +211,36 @@ def build_decoder(
|
||||
input_height: int,
|
||||
config: DecoderConfig | None = None,
|
||||
) -> Decoder:
|
||||
"""Factory function to build a Decoder instance from configuration.
|
||||
"""Build a ``Decoder`` from configuration.
|
||||
|
||||
Constructs a sequential `Decoder` module based on the layer sequence
|
||||
defined in a `DecoderConfig` object and the provided input dimensions
|
||||
(bottleneck channels and height). If no config is provided, uses the
|
||||
default layer sequence from `DEFAULT_DECODER_CONFIG`.
|
||||
|
||||
It iteratively builds the layers using the unified `build_layer_from_config`
|
||||
factory (from `.blocks`), tracking the changing number of channels and
|
||||
feature map height required for each subsequent layer.
|
||||
Constructs a sequential ``Decoder`` by iterating over the block
|
||||
configurations in ``config.layers``, building each block with
|
||||
``build_layer``, and tracking the channel count and feature-map height
|
||||
as they change through the sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
The number of channels in the input tensor to the decoder. Must be > 0.
|
||||
Number of channels in the input tensor (bottleneck output). Must
|
||||
be positive.
|
||||
input_height : int
|
||||
The height (frequency bins) of the input tensor to the decoder. Must be
|
||||
> 0.
|
||||
Height (number of frequency bins) of the input tensor. Must be
|
||||
positive.
|
||||
config : DecoderConfig, optional
|
||||
The configuration object detailing the sequence of layers and their
|
||||
parameters. If None, `DEFAULT_DECODER_CONFIG` is used.
|
||||
Configuration specifying the layer sequence. Defaults to
|
||||
``DEFAULT_DECODER_CONFIG`` if not provided.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Decoder
|
||||
An initialized `Decoder` module.
|
||||
An initialised ``Decoder`` module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` or `input_height` are not positive, or if the layer
|
||||
configuration is invalid (e.g., empty list, unknown `name`).
|
||||
NotImplementedError
|
||||
If `build_layer_from_config` encounters an unknown `name`.
|
||||
If ``in_channels`` or ``input_height`` are not positive.
|
||||
KeyError
|
||||
If a layer configuration specifies an unknown block type.
|
||||
"""
|
||||
config = config or DEFAULT_DECODER_CONFIG
|
||||
|
||||
|
||||
@ -1,17 +1,20 @@
|
||||
"""Assembles the complete BatDetect2 Detection Model.
|
||||
"""Assembles the complete BatDetect2 detection model.
|
||||
|
||||
This module defines the concrete `Detector` class, which implements the
|
||||
`DetectionModel` interface defined in `.types`. It combines a feature
|
||||
extraction backbone with specific prediction heads to create the end-to-end
|
||||
neural network used for detecting bat calls, predicting their size, and
|
||||
classifying them.
|
||||
This module defines the ``Detector`` class, which combines a backbone
|
||||
feature extractor with prediction heads for detection, classification, and
|
||||
bounding-box size regression.
|
||||
|
||||
The primary components are:
|
||||
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
||||
Components
|
||||
----------
|
||||
- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone
|
||||
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
||||
produce a ``ModelOutput`` tuple from an input spectrogram.
|
||||
- ``build_detector`` – factory function that builds a ready-to-use
|
||||
``Detector`` from a backbone configuration and a target class count.
|
||||
|
||||
This module focuses purely on the neural network architecture definition. The
|
||||
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||
Note that ``Detector`` operates purely on spectrogram tensors; raw audio
|
||||
preprocessing and output postprocessing are handled by
|
||||
``batdetect2.preprocess`` and ``batdetect2.postprocess`` respectively.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -32,25 +35,30 @@ __all__ = [
|
||||
|
||||
|
||||
class Detector(DetectionModel):
|
||||
"""Concrete implementation of the BatDetect2 Detection Model.
|
||||
"""Complete BatDetect2 detection and classification model.
|
||||
|
||||
Assembles a complete detection and classification model by combining a
|
||||
feature extraction backbone network with specific prediction heads for
|
||||
detection probability, bounding box size regression, and class
|
||||
probabilities.
|
||||
Combines a backbone feature extractor with two prediction heads:
|
||||
|
||||
- ``ClassifierHead``: predicts per-class probabilities at each
|
||||
time–frequency location.
|
||||
- ``BBoxHead``: predicts call duration and bandwidth at each location.
|
||||
|
||||
The detection probability map is derived from the class probabilities by
|
||||
summing across the class dimension (i.e. the probability that *any* class
|
||||
is present), rather than from a separate detection head.
|
||||
|
||||
Instances are typically created via ``build_detector``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
backbone : BackboneModel
|
||||
The feature extraction backbone network module.
|
||||
The feature extraction backbone.
|
||||
num_classes : int
|
||||
The number of specific target classes the model predicts (derived from
|
||||
the `classifier_head`).
|
||||
Number of target classes (inferred from the classifier head).
|
||||
classifier_head : ClassifierHead
|
||||
The prediction head responsible for generating class probabilities.
|
||||
Produces per-class probability maps from backbone features.
|
||||
bbox_head : BBoxHead
|
||||
The prediction head responsible for generating bounding box size
|
||||
predictions.
|
||||
Produces duration and bandwidth predictions from backbone features.
|
||||
"""
|
||||
|
||||
backbone: BackboneModel
|
||||
@ -61,26 +69,21 @@ class Detector(DetectionModel):
|
||||
classifier_head: ClassifierHead,
|
||||
bbox_head: BBoxHead,
|
||||
):
|
||||
"""Initialize the Detector model.
|
||||
"""Initialise the Detector model.
|
||||
|
||||
Note: Instances are typically created using the `build_detector`
|
||||
This constructor is typically called by the ``build_detector``
|
||||
factory function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backbone : BackboneModel
|
||||
An initialized feature extraction backbone module (e.g., built by
|
||||
`build_backbone` from the `.backbone` module).
|
||||
An initialised backbone module (e.g. built by
|
||||
``build_backbone``).
|
||||
classifier_head : ClassifierHead
|
||||
An initialized classification head module. The number of classes
|
||||
is inferred from this head.
|
||||
An initialised classification head. The ``num_classes``
|
||||
attribute is read from this head.
|
||||
bbox_head : BBoxHead
|
||||
An initialized bounding box size prediction head module.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
If the provided modules are not of the expected types.
|
||||
An initialised bounding-box size prediction head.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -90,31 +93,34 @@ class Detector(DetectionModel):
|
||||
self.bbox_head = bbox_head
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
"""Perform the forward pass of the complete detection model.
|
||||
"""Run the complete detection model on an input spectrogram.
|
||||
|
||||
Processes the input spectrogram through the backbone to extract
|
||||
features, then passes these features through the separate prediction
|
||||
heads to generate detection probabilities, class probabilities, and
|
||||
size predictions.
|
||||
Passes the spectrogram through the backbone to produce a feature
|
||||
map, then applies the classifier and bounding-box heads. The
|
||||
detection probability map is derived by summing the per-class
|
||||
probability maps across the class dimension; no separate detection
|
||||
head is used.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, input_channels, frequency_bins, time_bins)`. The
|
||||
shape must be compatible with the `self.backbone` input
|
||||
requirements.
|
||||
Input spectrogram tensor, shape
|
||||
``(batch_size, channels, frequency_bins, time_bins)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ModelOutput
|
||||
A NamedTuple containing the four output tensors:
|
||||
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
|
||||
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`.
|
||||
- `class_probs`: Class probabilities (excluding background)
|
||||
`(B, num_classes, H, W)`.
|
||||
- `features`: Output feature map from the backbone
|
||||
`(B, C_out, H, W)`.
|
||||
A named tuple with four fields:
|
||||
|
||||
- ``detection_probs`` – ``(B, 1, H, W)`` – probability that a
|
||||
call of any class is present at each location. Derived by
|
||||
summing ``class_probs`` over the class dimension.
|
||||
- ``size_preds`` – ``(B, 2, H, W)`` – scaled duration (channel
|
||||
0) and bandwidth (channel 1) predictions at each location.
|
||||
- ``class_probs`` – ``(B, num_classes, H, W)`` – per-class
|
||||
probabilities at each location.
|
||||
- ``features`` – ``(B, C_out, H, W)`` – raw backbone feature
|
||||
map.
|
||||
"""
|
||||
features = self.backbone(spec)
|
||||
classification = self.classifier_head(features)
|
||||
@ -131,30 +137,33 @@ class Detector(DetectionModel):
|
||||
def build_detector(
|
||||
num_classes: int, config: BackboneConfig | None = None
|
||||
) -> DetectionModel:
|
||||
"""Build the complete BatDetect2 detection model.
|
||||
"""Build a complete BatDetect2 detection model.
|
||||
|
||||
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
|
||||
and a ``BBoxHead`` sized to the backbone's output channel count, and
|
||||
returns them wrapped in a ``Detector``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes the model should predict
|
||||
(required for the `ClassifierHead`). Must be positive.
|
||||
Number of target bat species or call types to predict. Must be
|
||||
positive.
|
||||
config : BackboneConfig, optional
|
||||
Configuration object specifying the architecture of the backbone
|
||||
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||
within the respective builder functions (`build_encoder`, etc.) will be
|
||||
used to construct a default backbone architecture.
|
||||
Backbone architecture configuration. Defaults to
|
||||
``UNetBackboneConfig()`` (the standard BatDetect2 architecture) if
|
||||
not provided.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DetectionModel
|
||||
An initialized `Detector` model instance.
|
||||
An initialised ``Detector`` instance ready for training or
|
||||
inference.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` is not positive, or if errors occur during the
|
||||
construction of the backbone or detector components (e.g., incompatible
|
||||
configurations, invalid parameters).
|
||||
If ``num_classes`` is not positive, or if the backbone
|
||||
configuration is invalid.
|
||||
"""
|
||||
config = config or UNetBackboneConfig()
|
||||
|
||||
|
||||
@ -1,23 +1,24 @@
|
||||
"""Constructs the Encoder part of a configurable neural network backbone.
|
||||
"""Encoder (downsampling path) for the BatDetect2 backbone.
|
||||
|
||||
This module defines the configuration structure (`EncoderConfig`) and provides
|
||||
the `Encoder` class (an `nn.Module`) along with a factory function
|
||||
(`build_encoder`) to create sequential encoders. Encoders typically form the
|
||||
downsampling path in architectures like U-Nets, processing input feature maps
|
||||
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
|
||||
representations (bottleneck features).
|
||||
This module defines ``EncoderConfig`` and the ``Encoder`` ``nn.Module``,
|
||||
together with the ``build_encoder`` factory function.
|
||||
|
||||
The encoder is built dynamically by stacking neural network blocks based on a
|
||||
list of configuration objects provided in `EncoderConfig.layers`. Each
|
||||
configuration object specifies the type of block (e.g., standard convolution,
|
||||
coordinate-feature convolution with downsampling) and its parameters
|
||||
(e.g., output channels). This allows for flexible definition of encoder
|
||||
architectures via configuration files.
|
||||
In a U-Net-style network the encoder progressively reduces the spatial
|
||||
resolution of the spectrogram whilst increasing the number of feature
|
||||
channels. Each layer in the encoder produces a feature map that is stored
|
||||
for use as a skip connection in the corresponding decoder layer.
|
||||
|
||||
The `Encoder`'s `forward` method returns outputs from all intermediate layers,
|
||||
suitable for skip connections, while the `encode` method returns only the final
|
||||
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
|
||||
provided.
|
||||
The encoder is fully configurable: the type, number, and parameters of the
|
||||
downsampling blocks are described by an ``EncoderConfig`` object containing
|
||||
an ordered list of block configuration objects (see ``batdetect2.models.blocks``
|
||||
for available block types).
|
||||
|
||||
``Encoder.forward`` returns the outputs of *all* encoder layers as a list,
|
||||
so that skip connections are available to the decoder.
|
||||
``Encoder.encode`` returns only the final output (the input to the bottleneck).
|
||||
|
||||
A default configuration ``DEFAULT_ENCODER_CONFIG`` is provided and used by
|
||||
``build_encoder`` when no explicit configuration is supplied.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List
|
||||
@ -53,35 +54,32 @@ EncoderLayerConfig = Annotated[
|
||||
|
||||
|
||||
class EncoderConfig(BaseConfig):
|
||||
"""Configuration for building the sequential Encoder module.
|
||||
|
||||
Defines the sequence of neural network blocks that constitute the encoder
|
||||
(downsampling path).
|
||||
"""Configuration for the sequential ``Encoder`` module.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
layers : List[EncoderLayerConfig]
|
||||
An ordered list of configuration objects, each defining one layer or
|
||||
block in the encoder sequence. Each item must be a valid block config
|
||||
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
||||
`StandardConvDownConfig`) including a `name` field and necessary
|
||||
parameters like `out_channels`. Input channels for each layer are
|
||||
inferred sequentially. The list must contain at least one layer.
|
||||
Ordered list of block configuration objects defining the encoder's
|
||||
downsampling stages. Each entry specifies the block type (via its
|
||||
``name`` field) and any block-specific parameters such as
|
||||
``out_channels``. Input channels for each block are inferred
|
||||
automatically from the output of the previous block. Must contain
|
||||
at least one entry.
|
||||
"""
|
||||
|
||||
layers: List[EncoderLayerConfig] = Field(min_length=1)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Sequential Encoder module composed of configurable downscaling layers.
|
||||
"""Sequential encoder module composed of configurable downsampling layers.
|
||||
|
||||
Constructs the downsampling path of an encoder-decoder network by stacking
|
||||
multiple downscaling blocks.
|
||||
Executes a series of downsampling blocks in order, storing the output of
|
||||
each block so that it can be passed as a skip connection to the
|
||||
corresponding decoder layer.
|
||||
|
||||
The `forward` method executes the sequence and returns the output feature
|
||||
map from *each* downscaling stage, facilitating the implementation of skip
|
||||
connections in U-Net-like architectures. The `encode` method returns only
|
||||
the final output tensor (bottleneck features).
|
||||
``forward`` returns the outputs of *all* layers (useful when skip
|
||||
connections are needed). ``encode`` returns only the final output
|
||||
(the input to the bottleneck).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@ -89,14 +87,14 @@ class Encoder(nn.Module):
|
||||
Number of channels expected in the input tensor.
|
||||
input_height : int
|
||||
Height (frequency bins) expected in the input tensor.
|
||||
output_channels : int
|
||||
Number of channels in the final output tensor (bottleneck).
|
||||
out_channels : int
|
||||
Number of channels in the final output tensor (bottleneck input).
|
||||
output_height : int
|
||||
Height (frequency bins) expected in the output tensor.
|
||||
Height (frequency bins) of the final output tensor.
|
||||
layers : nn.ModuleList
|
||||
The sequence of instantiated downscaling layer modules.
|
||||
Sequence of instantiated downsampling block modules.
|
||||
depth : int
|
||||
The number of downscaling layers in the encoder.
|
||||
Number of downsampling layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -107,23 +105,22 @@ class Encoder(nn.Module):
|
||||
input_height: int = 128,
|
||||
in_channels: int = 1,
|
||||
):
|
||||
"""Initialize the Encoder module.
|
||||
"""Initialise the Encoder module.
|
||||
|
||||
Note: This constructor is typically called internally by the
|
||||
`build_encoder` factory function, which prepares the `layers` list.
|
||||
This constructor is typically called by the ``build_encoder`` factory
|
||||
function, which takes care of building the ``layers`` list from a
|
||||
configuration object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output_channels : int
|
||||
Number of channels produced by the final layer.
|
||||
output_height : int
|
||||
The expected height of the output tensor.
|
||||
Height of the output tensor after all layers have been applied.
|
||||
layers : List[nn.Module]
|
||||
A list of pre-instantiated downscaling layer modules (e.g.,
|
||||
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
|
||||
sequence.
|
||||
Pre-built downsampling block modules in execution order.
|
||||
input_height : int, default=128
|
||||
Expected height of the input tensor.
|
||||
Expected height of the input tensor (frequency bins).
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input tensor.
|
||||
"""
|
||||
@ -138,29 +135,30 @@ class Encoder(nn.Module):
|
||||
self.depth = len(self.layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Pass input through encoder layers, returns all intermediate outputs.
|
||||
"""Pass input through all encoder layers and return every output.
|
||||
|
||||
This method is typically used when the Encoder is part of a U-Net or
|
||||
similar architecture requiring skip connections.
|
||||
Used when skip connections are needed (e.g. in a U-Net decoder).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
|
||||
`self.in_channels` and `H_in` must match `self.input_height`.
|
||||
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
|
||||
``C_in`` must match ``self.in_channels`` and ``H_in`` must
|
||||
match ``self.input_height``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[torch.Tensor]
|
||||
A list containing the output tensors from *each* downscaling layer
|
||||
in the sequence. `outputs[0]` is the output of the first layer,
|
||||
`outputs[-1]` is the final output (bottleneck) of the encoder.
|
||||
Output tensors from every layer in order.
|
||||
``outputs[0]`` is the output of the first (shallowest) layer;
|
||||
``outputs[-1]`` is the output of the last (deepest) layer,
|
||||
which serves as the input to the bottleneck.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If input tensor channel count or height does not match expected
|
||||
values.
|
||||
If the input channel count or height does not match the
|
||||
expected values.
|
||||
"""
|
||||
if x.shape[1] != self.in_channels:
|
||||
raise ValueError(
|
||||
@ -183,28 +181,29 @@ class Encoder(nn.Module):
|
||||
return outputs
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pass input through encoder layers, returning only the final output.
|
||||
"""Pass input through all encoder layers and return only the final output.
|
||||
|
||||
This method provides access to the bottleneck features produced after
|
||||
the last downscaling layer.
|
||||
Use this when skip connections are not needed and you only require
|
||||
the bottleneck feature map.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
|
||||
`in_channels` and `input_height`.
|
||||
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
|
||||
Must satisfy the same shape requirements as ``forward``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The final output tensor (bottleneck features) from the last layer
|
||||
of the encoder. Shape `(B, C_out, H_out, W_out)`.
|
||||
Output of the last encoder layer, shape
|
||||
``(B, C_out, H_out, W)``, where ``C_out`` is
|
||||
``self.out_channels`` and ``H_out`` is ``self.output_height``.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If input tensor channel count or height does not match expected
|
||||
values.
|
||||
If the input channel count or height does not match the
|
||||
expected values.
|
||||
"""
|
||||
if x.shape[1] != self.in_channels:
|
||||
raise ValueError(
|
||||
@ -236,14 +235,17 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
||||
),
|
||||
],
|
||||
)
|
||||
"""Default configuration for the Encoder.
|
||||
"""Default encoder configuration used in standard BatDetect2 models.
|
||||
|
||||
Specifies an architecture typically used in BatDetect2:
|
||||
- Input: 1 channel, 128 frequency bins.
|
||||
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
|
||||
- Layer 2: FreqCoordConvDown -> 64 channels, H=32
|
||||
- Layer 3: FreqCoordConvDown -> 128 channels, H=16
|
||||
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck)
|
||||
Assumes a 1-channel input with 128 frequency bins and produces the
|
||||
following feature maps:
|
||||
|
||||
- Stage 1 (``FreqCoordConvDown``): 32 channels, height 64.
|
||||
- Stage 2 (``FreqCoordConvDown``): 64 channels, height 32.
|
||||
- Stage 3 (``LayerGroup``):
|
||||
|
||||
- ``FreqCoordConvDown``: 128 channels, height 16.
|
||||
- ``ConvBlock``: 256 channels, height 16 (bottleneck input).
|
||||
"""
|
||||
|
||||
|
||||
@ -252,42 +254,38 @@ def build_encoder(
|
||||
input_height: int,
|
||||
config: EncoderConfig | None = None,
|
||||
) -> Encoder:
|
||||
"""Factory function to build an Encoder instance from configuration.
|
||||
"""Build an ``Encoder`` from configuration.
|
||||
|
||||
Constructs a sequential `Encoder` module based on the layer sequence
|
||||
defined in an `EncoderConfig` object and the provided input dimensions.
|
||||
If no config is provided, uses the default layer sequence from
|
||||
`DEFAULT_ENCODER_CONFIG`.
|
||||
|
||||
It iteratively builds the layers using the unified
|
||||
`build_layer_from_config` factory (from `.blocks`), tracking the changing
|
||||
number of channels and feature map height required for each subsequent
|
||||
layer, especially for coordinate- aware blocks.
|
||||
Constructs a sequential ``Encoder`` by iterating over the block
|
||||
configurations in ``config.layers``, building each block with
|
||||
``build_layer``, and tracking the channel count and feature-map height
|
||||
as they change through the sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
The number of channels expected in the input tensor to the encoder.
|
||||
Must be > 0.
|
||||
Number of channels in the input spectrogram tensor. Must be
|
||||
positive.
|
||||
input_height : int
|
||||
The height (frequency bins) expected in the input tensor. Must be > 0.
|
||||
Crucial for initializing coordinate-aware layers correctly.
|
||||
Height (number of frequency bins) of the input spectrogram.
|
||||
Must be positive and should be divisible by
|
||||
``2 ** (number of downsampling stages)`` to avoid size mismatches
|
||||
later in the network.
|
||||
config : EncoderConfig, optional
|
||||
The configuration object detailing the sequence of layers and their
|
||||
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used.
|
||||
Configuration specifying the layer sequence. Defaults to
|
||||
``DEFAULT_ENCODER_CONFIG`` if not provided.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Encoder
|
||||
An initialized `Encoder` module.
|
||||
An initialised ``Encoder`` module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` or `input_height` are not positive, or if the layer
|
||||
configuration is invalid (e.g., empty list, unknown `name`).
|
||||
NotImplementedError
|
||||
If `build_layer_from_config` encounters an unknown `name`.
|
||||
If ``in_channels`` or ``input_height`` are not positive.
|
||||
KeyError
|
||||
If a layer configuration specifies an unknown block type.
|
||||
"""
|
||||
if in_channels <= 0 or input_height <= 0:
|
||||
raise ValueError("in_channels and input_height must be positive.")
|
||||
|
||||
@ -1,20 +1,19 @@
|
||||
"""Prediction Head modules for BatDetect2 models.
|
||||
"""Prediction heads attached to the backbone feature map.
|
||||
|
||||
This module defines simple `torch.nn.Module` subclasses that serve as
|
||||
prediction heads, typically attached to the output feature map of a backbone
|
||||
network
|
||||
Each head is a lightweight ``torch.nn.Module`` that applies a 1×1
|
||||
convolution to map backbone feature channels to one specific type of
|
||||
output required by BatDetect2:
|
||||
|
||||
Each head is responsible for generating one specific type of output required
|
||||
by the BatDetect2 task:
|
||||
- `DetectorHead`: Predicts the probability of sound event presence.
|
||||
- `ClassifierHead`: Predicts the probability distribution over target classes.
|
||||
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding
|
||||
box.
|
||||
- ``DetectorHead``: single-channel detection probability heatmap (sigmoid
|
||||
activation).
|
||||
- ``ClassifierHead``: multi-class probability map over the target bat
|
||||
species / call types (softmax activation).
|
||||
- ``BBoxHead``: two-channel map of predicted call duration (time axis) and
|
||||
bandwidth (frequency axis) at each location (no activation; raw
|
||||
regression output).
|
||||
|
||||
These heads use 1x1 convolutions to map the backbone feature channels
|
||||
to the desired number of output channels for each prediction task at each
|
||||
spatial location, followed by an appropriate activation function (e.g., sigmoid
|
||||
for detection, softmax for classification, none for size regression).
|
||||
All three heads share the same input feature map produced by the backbone,
|
||||
so they can be evaluated in parallel in a single forward pass.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -28,42 +27,35 @@ __all__ = [
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
"""Prediction head for multi-class classification probabilities.
|
||||
"""Prediction head for species / call-type classification probabilities.
|
||||
|
||||
Takes an input feature map and produces a probability map where each
|
||||
channel corresponds to a specific target class. It uses a 1x1 convolution
|
||||
to map input channels to `num_classes + 1` outputs (one for each target
|
||||
class plus an assumed background/generic class), applies softmax across the
|
||||
channels, and returns the probabilities for the specific target classes
|
||||
(excluding the last background/generic channel).
|
||||
Takes a backbone feature map and produces a probability map where each
|
||||
channel corresponds to a target class. Internally the 1×1 convolution
|
||||
maps ``in_channels`` to ``num_classes + 1`` logits (the extra channel
|
||||
represents a generic background / unknown category); a softmax is then
|
||||
applied across the channel dimension and the background channel is
|
||||
discarded before returning.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes the model should predict
|
||||
(excluding any background or generic category). Must be positive.
|
||||
Number of target classes (bat species or call types) to predict,
|
||||
excluding the background category. Must be positive.
|
||||
in_channels : int
|
||||
Number of channels in the input feature map tensor from the backbone.
|
||||
Must be positive.
|
||||
Number of channels in the backbone feature map. Must be positive.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
num_classes : int
|
||||
Number of specific output classes.
|
||||
Number of specific output classes (background excluded).
|
||||
in_channels : int
|
||||
Number of input channels expected.
|
||||
classifier : nn.Conv2d
|
||||
The 1x1 convolutional layer used for prediction.
|
||||
Output channels = num_classes + 1.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` or `in_channels` are not positive.
|
||||
1×1 convolution with ``num_classes + 1`` output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, in_channels: int):
|
||||
"""Initialize the ClassifierHead."""
|
||||
"""Initialise the ClassifierHead."""
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
@ -76,20 +68,20 @@ class ClassifierHead(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute class probabilities from input features.
|
||||
"""Compute per-class probabilities from backbone features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
Input feature map tensor from the backbone, typically with shape
|
||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||
Backbone feature map, shape ``(B, C_in, H, W)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Class probability map tensor with shape `(B, num_classes, H, W)`.
|
||||
Contains probabilities for the specific target classes after
|
||||
softmax, excluding the implicit background/generic class channel.
|
||||
Class probability map, shape ``(B, num_classes, H, W)``.
|
||||
Values are softmax probabilities in the range [0, 1] and
|
||||
sum to less than 1 per location (the background probability
|
||||
is discarded).
|
||||
"""
|
||||
logits = self.classifier(features)
|
||||
probs = torch.softmax(logits, dim=1)
|
||||
@ -97,36 +89,30 @@ class ClassifierHead(nn.Module):
|
||||
|
||||
|
||||
class DetectorHead(nn.Module):
|
||||
"""Prediction head for sound event detection probability.
|
||||
"""Prediction head for detection probability (is a call present here?).
|
||||
|
||||
Takes an input feature map and produces a single-channel heatmap where
|
||||
each value represents the probability ([0, 1]) of a relevant sound event
|
||||
(of any class) being present at that spatial location.
|
||||
Produces a single-channel heatmap where each value indicates the
|
||||
probability ([0, 1]) that a bat call of *any* species is present at
|
||||
that time–frequency location in the spectrogram.
|
||||
|
||||
Uses a 1x1 convolution to map input channels to 1 output channel, followed
|
||||
by a sigmoid activation function.
|
||||
Applies a 1×1 convolution mapping ``in_channels`` → 1, followed by
|
||||
sigmoid activation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input feature map tensor from the backbone.
|
||||
Must be positive.
|
||||
Number of channels in the backbone feature map. Must be positive.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels expected.
|
||||
detector : nn.Conv2d
|
||||
The 1x1 convolutional layer mapping to a single output channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` is not positive.
|
||||
1×1 convolution with a single output channel.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
"""Initialize the DetectorHead."""
|
||||
"""Initialise the DetectorHead."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
@ -138,62 +124,49 @@ class DetectorHead(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute detection probabilities from input features.
|
||||
"""Compute detection probabilities from backbone features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
Input feature map tensor from the backbone, typically with shape
|
||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||
Backbone feature map, shape ``(B, C_in, H, W)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Detection probability heatmap tensor with shape `(B, 1, H, W)`.
|
||||
Values are in the range [0, 1] due to the sigmoid activation.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If input channel count does not match `self.in_channels`.
|
||||
Detection probability heatmap, shape ``(B, 1, H, W)``.
|
||||
Values are in the range [0, 1].
|
||||
"""
|
||||
return torch.sigmoid(self.detector(features))
|
||||
|
||||
|
||||
class BBoxHead(nn.Module):
|
||||
"""Prediction head for bounding box size dimensions.
|
||||
"""Prediction head for bounding box size (duration and bandwidth).
|
||||
|
||||
Takes an input feature map and produces a two-channel map where each
|
||||
channel represents a predicted size dimension (typically width/duration and
|
||||
height/bandwidth) for a potential sound event at that spatial location.
|
||||
Produces a two-channel map where channel 0 predicts the scaled duration
|
||||
(time-axis extent) and channel 1 predicts the scaled bandwidth
|
||||
(frequency-axis extent) of the call at each spectrogram location.
|
||||
|
||||
Uses a 1x1 convolution to map input channels to 2 output channels. No
|
||||
activation function is typically applied, as size prediction is often
|
||||
treated as a direct regression task. The output values usually represent
|
||||
*scaled* dimensions that need to be un-scaled during postprocessing.
|
||||
Applies a 1×1 convolution mapping ``in_channels`` → 2 with no
|
||||
activation function (raw regression output). The predicted values are
|
||||
in a scaled space and must be converted to real units (seconds and Hz)
|
||||
during postprocessing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input feature map tensor from the backbone.
|
||||
Must be positive.
|
||||
Number of channels in the backbone feature map. Must be positive.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels expected.
|
||||
bbox : nn.Conv2d
|
||||
The 1x1 convolutional layer mapping to 2 output channels
|
||||
(width, height).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` is not positive.
|
||||
1×1 convolution with 2 output channels (duration, bandwidth).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
"""Initialize the BBoxHead."""
|
||||
"""Initialise the BBoxHead."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
@ -205,19 +178,19 @@ class BBoxHead(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute predicted bounding box dimensions from input features.
|
||||
"""Predict call duration and bandwidth from backbone features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
Input feature map tensor from the backbone, typically with shape
|
||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||
Backbone feature map, shape ``(B, C_in, H, W)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually
|
||||
represents scaled width, Channel 1 scaled height. These values
|
||||
need to be un-scaled during postprocessing.
|
||||
Size prediction tensor, shape ``(B, 2, H, W)``. Channel 0 is
|
||||
the predicted scaled duration; channel 1 is the predicted
|
||||
scaled bandwidth. Values must be rescaled to real units during
|
||||
postprocessing.
|
||||
"""
|
||||
return self.bbox(features)
|
||||
|
||||
@ -1,14 +1,4 @@
|
||||
"""Tests for backbone configuration loading and the backbone registry.
|
||||
|
||||
Covers:
|
||||
- UNetBackboneConfig default construction and field values.
|
||||
- build_backbone with default and explicit configs.
|
||||
- load_backbone_config loading from a YAML file.
|
||||
- load_backbone_config with a nested field path.
|
||||
- load_backbone_config round-trip: YAML → config → build_backbone.
|
||||
- Registry registration and dispatch for UNetBackbone.
|
||||
- BackboneConfig discriminated union validation.
|
||||
"""
|
||||
"""Tests for backbone configuration loading and the backbone registry."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
@ -25,10 +15,6 @@ from batdetect2.models.backbones import (
|
||||
)
|
||||
from batdetect2.typing.models import BackboneModel
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UNetBackboneConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_unet_backbone_config_defaults():
|
||||
"""Default config has expected field values."""
|
||||
@ -57,11 +43,6 @@ def test_unet_backbone_config_extra_fields_ignored():
|
||||
assert not hasattr(config, "unknown_field")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_backbone
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_backbone_default():
|
||||
"""Building with no config uses UNetBackbone defaults."""
|
||||
backbone = build_backbone()
|
||||
@ -83,15 +64,9 @@ def test_build_backbone_custom_config():
|
||||
def test_build_backbone_returns_backbone_model():
|
||||
"""build_backbone always returns a BackboneModel instance."""
|
||||
backbone = build_backbone()
|
||||
|
||||
assert isinstance(backbone, BackboneModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registry_has_unet_backbone():
|
||||
"""The backbone registry has UNetBackbone registered."""
|
||||
config_types = backbone_registry.get_config_types()
|
||||
@ -124,11 +99,6 @@ def test_registry_build_unknown_name_raises():
|
||||
backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BackboneConfig discriminated union
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_backbone_config_validates_unet_from_dict():
|
||||
"""BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name."""
|
||||
from pydantic import TypeAdapter
|
||||
@ -151,11 +121,6 @@ def test_backbone_config_invalid_name_raises():
|
||||
adapter.validate_python({"name": "NonExistentBackbone"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_backbone_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_backbone_config_from_yaml(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
@ -218,11 +183,6 @@ deprecated_field: 99
|
||||
assert config.input_height == 128
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Round-trip: YAML → config → build_backbone
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_round_trip_yaml_to_build_backbone(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user