Update model docstrings

This commit is contained in:
mbsantiago 2026-03-08 16:34:17 +00:00
parent 4207661da4
commit ef3348d651
9 changed files with 873 additions and 613 deletions

View File

@ -1,29 +1,29 @@
"""Defines and builds the neural network models used in BatDetect2.
"""Neural network model definitions and builders for BatDetect2.
This package (`batdetect2.models`) contains the PyTorch implementations of the
deep neural network architectures used for detecting and classifying bat calls
from spectrograms. It provides modular components and configuration-driven
assembly, allowing for experimentation and use of different architectural
variants.
This package contains the PyTorch implementations of the deep neural network
architectures used to detect and classify bat echolocation calls in
spectrograms. Components are designed to be combined through configuration
objects, making it easy to experiment with different architectures.
Key Submodules:
- `.types`: Defines core data structures (`ModelOutput`) and abstract base
classes (`BackboneModel`, `DetectionModel`) establishing interfaces.
- `.blocks`: Provides reusable neural network building blocks.
- `.encoder`: Defines and builds the downsampling path (encoder) of the network.
- `.bottleneck`: Defines and builds the central bottleneck component.
- `.decoder`: Defines and builds the upsampling path (decoder) of the network.
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete
feature extraction backbone (e.g., a U-Net like structure).
- `.heads`: Defines simple prediction heads (detection, classification, size)
that attach to the backbone features.
- `.detectors`: Assembles the backbone and prediction heads into the final,
end-to-end `Detector` model.
Key submodules
--------------
- ``blocks``: Reusable convolutional building blocks (downsampling,
upsampling, attention, coord-conv variants).
- ``encoder``: The downsampling path; reduces spatial resolution whilst
extracting increasingly abstract features.
- ``bottleneck``: The central component connecting encoder to decoder;
optionally applies self-attention along the time axis.
- ``decoder``: The upsampling path; reconstructs high-resolution feature
maps using bottleneck output and skip connections from the encoder.
- ``backbones``: Assembles encoder, bottleneck, and decoder into a complete
U-Net-style feature extraction backbone.
- ``heads``: Lightweight 1×1 convolutional heads that produce detection,
classification, and bounding-box size predictions from backbone features.
- ``detectors``: Combines a backbone with prediction heads into the final
end-to-end ``Detector`` model.
This module re-exports the most important classes, configurations, and builder
functions from these submodules for convenient access. The primary entry point
for creating a standard BatDetect2 model instance is the `build_model` function
provided here.
The primary entry point for building a full, ready-to-use BatDetect2 model
is the ``build_model`` factory function exported from this module.
"""
from typing import List
@ -99,6 +99,31 @@ __all__ = [
class Model(torch.nn.Module):
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
Combines a preprocessor, a detection model, and a postprocessor into a
single PyTorch module. Calling ``forward`` on a raw waveform tensor
returns a list of detection tensors ready for downstream use.
This class is the top-level object produced by ``build_model``. Most
users will not need to construct it directly.
Attributes
----------
detector : DetectionModel
The neural network that processes spectrograms and produces raw
detection, classification, and bounding-box outputs.
preprocessor : PreprocessorProtocol
Converts a raw waveform tensor into a spectrogram tensor accepted by
``detector``.
postprocessor : PostprocessorProtocol
Converts the raw ``ModelOutput`` from ``detector`` into a list of
per-clip detection tensors.
targets : TargetProtocol
Describes the set of target classes; used when building heads and
during training target construction.
"""
detector: DetectionModel
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
@ -118,6 +143,25 @@ class Model(torch.nn.Module):
self.targets = targets
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor.
Converts the waveform to a spectrogram, passes it through the
detector, and postprocesses the raw outputs into detection tensors.
Parameters
----------
wav : torch.Tensor
Raw audio waveform tensor. The exact expected shape depends on
the preprocessor, but is typically ``(batch, samples)`` or
``(batch, channels, samples)``.
Returns
-------
List[ClipDetectionsTensor]
One detection tensor per clip in the batch. Each tensor encodes
the detected events (locations, class scores, sizes) for that
clip.
"""
spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)
@ -128,7 +172,38 @@ def build_model(
targets: TargetProtocol | None = None,
preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None,
):
) -> "Model":
"""Build a complete, ready-to-use BatDetect2 model.
Assembles a ``Model`` instance from optional configuration and component
overrides. Any argument left as ``None`` will be replaced by a sensible
default built with the project's own builder functions.
Parameters
----------
config : BackboneConfig, optional
Configuration describing the backbone architecture (encoder,
bottleneck, decoder). Defaults to ``UNetBackboneConfig()`` if not
provided.
targets : TargetProtocol, optional
Describes the target bat species or call types to detect. Determines
the number of output classes. Defaults to the standard BatDetect2
target set.
preprocessor : PreprocessorProtocol, optional
Converts raw audio waveforms to spectrograms. Defaults to the
standard BatDetect2 preprocessor.
postprocessor : PostprocessorProtocol, optional
Converts raw model outputs to detection tensors. Defaults to the
standard BatDetect2 postprocessor. If a custom ``preprocessor`` is
given without a matching ``postprocessor``, the default postprocessor
will be built using the provided preprocessor so that frequency and
time scaling remain consistent.
Returns
-------
Model
A fully assembled ``Model`` instance ready for inference or training.
"""
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets

View File

@ -1,21 +1,26 @@
"""Assembles a complete Encoder-Decoder Backbone network.
"""Assembles a complete encoder-decoder backbone network.
This module defines the configuration (`BackboneConfig`) and implementation
(`Backbone`) for a standard encoder-decoder style neural network backbone.
This module defines ``UNetBackboneConfig`` and the ``UNetBackbone``
``nn.Module``, together with the ``build_backbone`` and
``load_backbone_config`` helpers.
It orchestrates the connection between three main components, built using their
respective configurations and factory functions from sibling modules:
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
at multiple resolutions and provides skip connections.
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
lowest resolution, optionally applying self-attention.
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
resolution features using bottleneck features and skip connections.
A backbone combines three components built from the sibling modules:
The resulting `Backbone` module takes a spectrogram as input and outputs a
final feature map, typically used by subsequent prediction heads. It includes
automatic padding to handle input sizes not perfectly divisible by the
network's total downsampling factor.
1. **Encoder** (``batdetect2.models.encoder``) reduces spatial resolution
while extracting hierarchical features and storing skip-connection tensors.
2. **Bottleneck** (``batdetect2.models.bottleneck``) processes the
lowest-resolution features, optionally applying self-attention.
3. **Decoder** (``batdetect2.models.decoder``) restores spatial resolution
using bottleneck features and skip connections from the encoder.
The resulting ``UNetBackbone`` takes a spectrogram tensor as input and returns
a high-resolution feature map consumed by the prediction heads in
``batdetect2.models.detectors``.
Input padding is handled automatically: the backbone pads the input to be
divisible by the total downsampling factor and strips the padding from the
output so that the output spatial dimensions always match the input spatial
dimensions.
"""
from typing import Annotated, Literal, Tuple, Union
@ -51,6 +56,34 @@ from batdetect2.typing.models import (
class UNetBackboneConfig(BaseConfig):
"""Configuration for a U-Net-style encoder-decoder backbone.
All fields have sensible defaults that reproduce the standard BatDetect2
architecture, so you can start with ``UNetBackboneConfig()`` and override
only the fields you want to change.
Attributes
----------
name : str
Discriminator field used by the backbone registry; always
``"UNetBackbone"``.
input_height : int
Number of frequency bins in the input spectrogram. Defaults to
``128``.
in_channels : int
Number of channels in the input spectrogram (e.g. ``1`` for a
standard mel-spectrogram). Defaults to ``1``.
encoder : EncoderConfig
Configuration for the downsampling path. Defaults to
``DEFAULT_ENCODER_CONFIG``.
bottleneck : BottleneckConfig
Configuration for the bottleneck. Defaults to
``DEFAULT_BOTTLENECK_CONFIG``.
decoder : DecoderConfig
Configuration for the upsampling path. Defaults to
``DEFAULT_DECODER_CONFIG``.
"""
name: Literal["UNetBackbone"] = "UNetBackbone"
input_height: int = 128
in_channels: int = 1
@ -70,35 +103,36 @@ __all__ = [
class UNetBackbone(BackboneModel):
"""Encoder-Decoder Backbone Network Implementation.
"""U-Net-style encoder-decoder backbone network.
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
skip connections between the Encoder and Decoder. Implements the standard
U-Net style forward pass. Includes automatic input padding to handle
various input sizes and a final convolutional block to adjust the output
channels.
Combines an encoder, a bottleneck, and a decoder into a single module
that produces a high-resolution feature map from an input spectrogram.
Skip connections from each encoder stage are added element-wise to the
corresponding decoder stage input.
This class inherits from `BackboneModel` and implements its `forward`
method. Instances are typically created using the `build_backbone` factory
function.
Input spectrograms of arbitrary width are handled automatically: the
backbone pads the input so that its dimensions are divisible by
``divide_factor`` and removes the padding from the output.
Instances are typically created via ``build_backbone``.
Attributes
----------
input_height : int
Expected height of the input spectrogram.
Expected height (frequency bins) of the input spectrogram.
out_channels : int
Number of channels in the final output feature map.
Number of channels in the output feature map (taken from the
decoder's output channel count).
encoder : EncoderProtocol
The instantiated encoder module.
decoder : DecoderProtocol
The instantiated decoder module.
bottleneck : BottleneckProtocol
The instantiated bottleneck module.
final_conv : ConvBlock
Final convolutional block applied after the decoder.
divide_factor : int
The total downsampling factor (2^depth) applied by the encoder,
used for automatic input padding.
The total spatial downsampling factor applied by the encoder
(``input_height // encoder.output_height``). The input width is
padded to be a multiple of this value before processing.
"""
def __init__(
@ -108,25 +142,19 @@ class UNetBackbone(BackboneModel):
decoder: DecoderProtocol,
bottleneck: BottleneckProtocol,
):
"""Initialize the Backbone network.
"""Initialise the backbone network.
Parameters
----------
input_height : int
Expected height of the input spectrogram.
out_channels : int
Desired number of output channels for the backbone's feature map.
Expected height (frequency bins) of the input spectrogram.
encoder : EncoderProtocol
An initialized Encoder module.
An initialised encoder module.
decoder : DecoderProtocol
An initialized Decoder module.
An initialised decoder module. Its ``output_height`` must equal
``input_height``; a ``ValueError`` is raised otherwise.
bottleneck : BottleneckProtocol
An initialized Bottleneck module.
Raises
------
ValueError
If component output/input channels or heights are incompatible.
An initialised bottleneck module.
"""
super().__init__()
self.input_height = input_height
@ -143,22 +171,25 @@ class UNetBackbone(BackboneModel):
self.divide_factor = input_height // self.encoder.output_height
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Perform the forward pass through the encoder-decoder backbone.
"""Produce a feature map from an input spectrogram.
Applies padding, runs encoder, bottleneck, decoder (with skip
connections), removes padding, and applies a final convolution.
Pads the input if necessary, runs it through the encoder, then
the bottleneck, then the decoder (incorporating encoder skip
connections), and finally removes any padding added earlier.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match
`self.encoder.input_channels` and `self.input_height`.
Input spectrogram tensor, shape
``(B, C_in, H_in, W_in)``. ``H_in`` must equal
``self.input_height``; ``W_in`` can be any positive integer.
Returns
-------
torch.Tensor
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where
`C_out` is `self.out_channels`.
Feature map tensor, shape ``(B, C_out, H_in, W_in)``, where
``C_out`` is ``self.out_channels``. The spatial dimensions
always match those of the input.
"""
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
@ -219,6 +250,24 @@ BackboneConfig = Annotated[
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
"""Build a backbone network from configuration.
Looks up the backbone class corresponding to ``config.name`` in the
backbone registry and calls its ``from_config`` method. If no
configuration is provided, a default ``UNetBackbone`` is returned.
Parameters
----------
config : BackboneConfig, optional
A configuration object describing the desired backbone. Currently
``UNetBackboneConfig`` is the only supported type. Defaults to
``UNetBackboneConfig()`` if not provided.
Returns
-------
BackboneModel
An initialised backbone module.
"""
config = config or UNetBackboneConfig()
return backbone_registry.build(config)
@ -227,26 +276,25 @@ def _pad_adjust(
spec: torch.Tensor,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
"""Pad tensor height and width to be divisible by a factor.
"""Pad a tensor's height and width to be divisible by ``factor``.
Calculates the required padding for the last two dimensions (H, W) to make
them divisible by `factor` and applies right/bottom padding using
`torch.nn.functional.pad`.
Adds zero-padding to the bottom and right edges of the tensor so that
both dimensions are exact multiples of ``factor``. If both dimensions
are already divisible, the tensor is returned unchanged.
Parameters
----------
spec : torch.Tensor
Input tensor, typically shape `(B, C, H, W)`.
Input tensor, typically shape ``(B, C, H, W)``.
factor : int, default=32
The factor to make height and width divisible by.
The factor that both H and W should be divisible by after padding.
Returns
-------
Tuple[torch.Tensor, int, int]
A tuple containing:
- The padded tensor.
- The amount of padding added to height (`h_pad`).
- The amount of padding added to width (`w_pad`).
- Padded tensor.
- Number of rows added to the height (``h_pad``).
- Number of columns added to the width (``w_pad``).
"""
h, w = spec.shape[-2:]
h_pad = -h % factor
@ -261,23 +309,25 @@ def _pad_adjust(
def _restore_pad(
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
) -> torch.Tensor:
"""Remove padding added by _pad_adjust.
"""Remove padding previously added by ``_pad_adjust``.
Removes padding from the bottom and right edges of the tensor.
Trims ``h_pad`` rows from the bottom and ``w_pad`` columns from the
right of the tensor, restoring its original spatial dimensions.
Parameters
----------
x : torch.Tensor
Padded tensor, typically shape `(B, C, H_padded, W_padded)`.
Padded tensor, typically shape ``(B, C, H_padded, W_padded)``.
h_pad : int, default=0
Amount of padding previously added to the height (bottom).
Number of rows to remove from the bottom.
w_pad : int, default=0
Amount of padding previously added to the width (right).
Number of columns to remove from the right.
Returns
-------
torch.Tensor
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
Tensor with padding removed, shape
``(B, C, H_padded - h_pad, W_padded - w_pad)``.
"""
if h_pad > 0:
x = x[..., :-h_pad, :]
@ -292,6 +342,36 @@ def load_backbone_config(
path: data.PathLike,
field: str | None = None,
) -> BackboneConfig:
"""Load a backbone configuration from a YAML or JSON file.
Reads the file at ``path``, optionally descends into a named sub-field,
and validates the result against the ``BackboneConfig`` discriminated
union.
Parameters
----------
path : PathLike
Path to the configuration file. Both YAML and JSON formats are
supported.
field : str, optional
Dot-separated key path to the sub-field that contains the backbone
configuration (e.g. ``"model"``). If ``None``, the root of the
file is used.
Returns
-------
BackboneConfig
A validated backbone configuration object (currently always a
``UNetBackboneConfig`` instance).
Raises
------
FileNotFoundError
If ``path`` does not exist.
ValidationError
If the loaded data does not conform to a known ``BackboneConfig``
schema.
"""
return load_config(
path,
schema=TypeAdapter(BackboneConfig),

View File

@ -1,30 +1,49 @@
"""Commonly used neural network building blocks for BatDetect2 models.
"""Reusable convolutional building blocks for BatDetect2 models.
This module provides various reusable `torch.nn.Module` subclasses that form
the fundamental building blocks for constructing convolutional neural network
architectures, particularly encoder-decoder backbones used in BatDetect2.
This module provides a collection of ``torch.nn.Module`` subclasses that form
the fundamental building blocks for the encoder-decoder backbone used in
BatDetect2. All blocks follow a consistent interface: they store
``in_channels`` and ``out_channels`` as attributes and implement a
``get_output_height`` method that reports how a given input height maps to an
output height (e.g., halved by downsampling blocks, doubled by upsampling
blocks).
It includes standard components like basic convolutional blocks (`ConvBlock`),
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
upsampling (`StandardConvUpBlock`).
Available block families
------------------------
Standard blocks
``ConvBlock`` convolution + batch normalisation + ReLU, no change in
spatial resolution.
Additionally, it features specialized layers investigated in BatDetect2
research:
Downsampling blocks
``StandardConvDownBlock`` convolution then 2×2 max-pooling, halves H
and W.
``FreqCoordConvDownBlock`` like ``StandardConvDownBlock`` but prepends
a normalised frequency-coordinate channel before the convolution
(CoordConv concept), helping filters learn frequency-position-dependent
patterns.
- `SelfAttention`: Applies self-attention along the time dimension, enabling
the model to weigh information across the entire temporal context, often
used in the bottleneck of an encoder-decoder.
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
concept by concatenating normalized frequency coordinate information as an
extra channel to the input of convolutional layers. This explicitly provides
spatial frequency information to filters, potentially enabling them to learn
frequency-dependent patterns more effectively.
Upsampling blocks
``StandardConvUpBlock`` bilinear interpolation then convolution,
doubles H and W.
``FreqCoordConvUpBlock`` like ``StandardConvUpBlock`` but prepends a
frequency-coordinate channel after upsampling.
These blocks can be used directly in custom PyTorch model definitions or
assembled into larger architectures.
Bottleneck blocks
``VerticalConv`` 1-D convolution whose kernel spans the entire
frequency axis, collapsing H to 1 whilst preserving W.
``SelfAttention`` scaled dot-product self-attention along the time
axis; typically follows a ``VerticalConv``.
A unified factory function `build_layer` allows creating instances
of these blocks based on configuration objects.
Group block
``LayerGroup`` chains several blocks sequentially into one unit,
useful when a single encoder or decoder "stage" requires more than one
operation.
Factory function
----------------
``build_layer`` creates any of the above blocks from the matching
configuration object (one of the ``*Config`` classes exported here), using
a discriminated-union ``name`` field to dispatch to the correct class.
"""
from typing import Annotated, List, Literal, Tuple, Union
@ -57,10 +76,43 @@ __all__ = [
class Block(nn.Module):
"""Abstract base class for all BatDetect2 building blocks.
Subclasses must set ``in_channels`` and ``out_channels`` as integer
attributes so that factory functions can wire blocks together without
inspecting configuration objects at runtime. They may also override
``get_output_height`` when the block changes the height dimension (e.g.
downsampling or upsampling blocks).
Attributes
----------
in_channels : int
Number of channels expected in the input tensor.
out_channels : int
Number of channels produced in the output tensor.
"""
in_channels: int
out_channels: int
def get_output_height(self, input_height: int) -> int:
"""Return the output height for a given input height.
The default implementation returns ``input_height`` unchanged,
which is correct for blocks that do not alter spatial resolution.
Override this in downsampling (returns ``input_height // 2``) or
upsampling (returns ``input_height * 2``) subclasses.
Parameters
----------
input_height : int
Height (number of frequency bins) of the input feature map.
Returns
-------
int
Height of the output feature map.
"""
return input_height
@ -68,58 +120,65 @@ block_registry: Registry[Block, [int, int]] = Registry("block")
class SelfAttentionConfig(BaseConfig):
"""Configuration for a ``SelfAttention`` block.
Attributes
----------
name : str
Discriminator field; always ``"SelfAttention"``.
attention_channels : int
Dimensionality of the query, key, and value projections.
temperature : float
Scaling factor applied to the weighted values before the final
linear projection. Defaults to ``1``.
"""
name: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int
temperature: float = 1
class SelfAttention(Block):
"""Self-Attention mechanism operating along the time dimension.
"""Self-attention block operating along the time axis.
This module implements a scaled dot-product self-attention mechanism,
specifically designed here to operate across the time steps of an input
feature map, typically after spatial dimensions (like frequency) have been
condensed or squeezed.
Applies a scaled dot-product self-attention mechanism across the time
steps of an input feature map. Before attention is computed the height
dimension (frequency axis) is expected to have been reduced to 1, e.g.
by a preceding ``VerticalConv`` layer.
By calculating attention weights between all pairs of time steps, it allows
the model to capture long-range temporal dependencies and focus on relevant
parts of the sequence. It's often employed in the bottleneck or
intermediate layers of an encoder-decoder architecture to integrate global
temporal context.
The implementation uses linear projections to create query, key, and value
representations, computes scaled dot-product attention scores, applies
softmax, and produces an output by weighting the values according to the
attention scores, followed by a final linear projection. Positional encoding
is not explicitly included in this block.
For each time step the block computes query, key, and value projections
with learned linear weights, then calculates attention weights from the
querykey dot products divided by ``temperature × attention_channels``.
The weighted sum of values is projected back to ``in_channels`` via a
final linear layer, and the height dimension is restored so that the
output shape matches the input shape.
Parameters
----------
in_channels : int
Number of input channels (features per time step after spatial squeeze).
Number of input channels (features per time step). The output will
also have ``in_channels`` channels.
attention_channels : int
Number of channels for the query, key, and value projections. Also the
dimension of the output projection's input.
Dimensionality of the query, key, and value projections.
temperature : float, default=1.0
Scaling factor applied *before* the final projection layer. Can be used
to adjust the sharpness or focus of the attention mechanism, although
scaling within the softmax (dividing by sqrt(dim)) is more common for
standard transformers. Here it scales the weighted values.
Divisor applied together with ``attention_channels`` when scaling
the dot-product scores before softmax. Larger values produce softer
(more uniform) attention distributions.
Attributes
----------
key_fun : nn.Linear
Linear layer for key projection.
Linear projection for keys.
value_fun : nn.Linear
Linear layer for value projection.
Linear projection for values.
query_fun : nn.Linear
Linear layer for query projection.
Linear projection for queries.
pro_fun : nn.Linear
Final linear projection layer applied after attention weighting.
Final linear projection applied to the attended values.
temperature : float
Scaling factor applied before final projection.
Scaling divisor used when computing attention scores.
att_dim : int
Dimensionality of the attention space (`attention_channels`).
Dimensionality of the attention space (``attention_channels``).
"""
def __init__(
@ -148,20 +207,16 @@ class SelfAttention(Block):
Parameters
----------
x : torch.Tensor
Input tensor, expected shape `(B, C, H, W)`, where H is typically
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
applying attention along the W (time) dimension.
Input tensor with shape ``(B, C, 1, W)``. The height dimension
must be 1 (i.e. the frequency axis should already have been
collapsed by a preceding ``VerticalConv`` layer).
Returns
-------
torch.Tensor
Output tensor of the same shape as the input `(B, C, H, W)`, where
attention has been applied across the W dimension.
Raises
------
RuntimeError
If input tensor dimensions are incompatible with operations.
Output tensor with the same shape ``(B, C, 1, W)`` as the
input, with each time step updated by attended context from all
other time steps.
"""
x = x.squeeze(2).permute(0, 2, 1)
@ -190,6 +245,22 @@ class SelfAttention(Block):
return op
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
"""Return the softmax attention weight matrix.
Useful for visualising which time steps attend to which others.
Parameters
----------
x : torch.Tensor
Input tensor with shape ``(B, C, 1, W)``.
Returns
-------
torch.Tensor
Attention weight matrix with shape ``(B, W, W)``. Entry
``[b, i, j]`` is the attention weight that time step ``i``
assigns to time step ``j`` in batch item ``b``.
"""
x = x.squeeze(2).permute(0, 2, 1)
key = torch.matmul(
@ -304,6 +375,16 @@ class ConvBlock(Block):
class VerticalConvConfig(BaseConfig):
"""Configuration for a ``VerticalConv`` block.
Attributes
----------
name : str
Discriminator field; always ``"VerticalConv"``.
channels : int
Number of output channels produced by the vertical convolution.
"""
name: Literal["VerticalConv"] = "VerticalConv"
channels: int
@ -844,12 +925,53 @@ LayerConfig = Annotated[
class LayerGroupConfig(BaseConfig):
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
Use this when a single encoder or decoder stage needs more than one
block. The blocks are executed in the order they appear in ``layers``,
with channel counts and heights propagated automatically.
Attributes
----------
name : str
Discriminator field; always ``"LayerGroup"``.
layers : List[LayerConfig]
Ordered list of block configurations to chain together.
"""
name: Literal["LayerGroup"] = "LayerGroup"
layers: List[LayerConfig]
class LayerGroup(nn.Module):
"""Standard implementation of the `LayerGroup` architecture."""
"""Sequential chain of blocks that acts as a single composite block.
Wraps multiple ``Block`` instances in an ``nn.Sequential`` container,
exposing the same ``in_channels``, ``out_channels``, and
``get_output_height`` interface as a regular ``Block`` so it can be
used transparently wherever a single block is expected.
Instances are typically constructed by ``build_layer`` when given a
``LayerGroupConfig``; you rarely need to create them directly.
Parameters
----------
layers : list[Block]
Pre-built block instances to chain, in execution order.
input_height : int
Height of the tensor entering the first block.
input_channels : int
Number of channels in the tensor entering the first block.
Attributes
----------
in_channels : int
Number of input channels (taken from the first block).
out_channels : int
Number of output channels (taken from the last block).
layers : nn.Sequential
The wrapped sequence of block modules.
"""
def __init__(
self,
@ -865,9 +987,33 @@ class LayerGroup(nn.Module):
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Pass input through all blocks in sequence.
Parameters
----------
x : torch.Tensor
Input feature map, shape ``(B, C_in, H, W)``.
Returns
-------
torch.Tensor
Output feature map after all blocks have been applied.
"""
return self.layers(x)
def get_output_height(self, input_height: int) -> int:
"""Compute the output height by propagating through all blocks.
Parameters
----------
input_height : int
Height of the input feature map.
Returns
-------
int
Height after all blocks in the group have been applied.
"""
for block in self.layers:
input_height = block.get_output_height(input_height) # type: ignore
return input_height
@ -903,40 +1049,37 @@ def build_layer(
in_channels: int,
config: LayerConfig,
) -> Block:
"""Factory function to build a specific nn.Module block from its config.
"""Build a block from its configuration object.
Takes configuration object (one of the types included in the `LayerConfig`
union) and instantiates the corresponding nn.Module block with the correct
parameters derived from the config and the current pipeline state
(`input_height`, `in_channels`).
It uses the `name` field within the `config` object to determine
which block class to instantiate.
Looks up the block class corresponding to ``config.name`` in the
internal block registry and instantiates it with the given input
dimensions. This is the standard way to construct blocks when
assembling an encoder or decoder from a configuration file.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor *to this layer*.
Height (number of frequency bins) of the input tensor to this
block. Required for blocks whose kernel size depends on the input
height (e.g. ``VerticalConv``) and for coordinate-aware blocks.
in_channels : int
Number of channels in the input tensor *to this layer*.
Number of channels in the input tensor to this block.
config : LayerConfig
A Pydantic configuration object for the desired block (e.g., an
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
by its `name` field.
A configuration object for the desired block type. The ``name``
field selects the block class; remaining fields supply its
parameters.
Returns
-------
Tuple[nn.Module, int, int]
A tuple containing:
- The instantiated `nn.Module` block.
- The number of output channels produced by the block.
- The calculated height of the output produced by the block.
Block
An initialised block module ready to be added to an
``nn.Sequential`` or ``nn.ModuleList``.
Raises
------
NotImplementedError
If the `config.name` does not correspond to a known block type.
KeyError
If ``config.name`` does not correspond to a registered block type.
ValueError
If parameters derived from the config are invalid for the block.
If the configuration parameters are invalid for the chosen block.
"""
return block_registry.build(config, in_channels, input_height)

View File

@ -1,17 +1,21 @@
"""Defines the Bottleneck component of an Encoder-Decoder architecture.
"""Bottleneck component for encoder-decoder network architectures.
This module provides the configuration (`BottleneckConfig`) and
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the
bottleneck layer(s) that typically connect the Encoder (downsampling path) and
Decoder (upsampling path) in networks like U-Nets.
The bottleneck sits between the encoder (downsampling path) and the decoder
(upsampling path) and processes the lowest-resolution, highest-channel feature
map produced by the encoder.
The bottleneck processes the lowest-resolution, highest-dimensionality feature
map produced by the Encoder. This module offers a configurable option to include
a `SelfAttention` layer within the bottleneck, allowing the model to capture
global temporal context before features are passed to the Decoder.
This module provides:
A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration.
- ``BottleneckConfig`` configuration dataclass describing the number of
internal channels and an optional sequence of additional layers (currently
only ``SelfAttention`` is supported).
- ``Bottleneck`` the ``torch.nn.Module`` implementation. It first applies a
``VerticalConv`` to collapse the frequency axis to a single bin, optionally
runs one or more additional layers (e.g. self-attention along the time axis),
then repeats the output along the height dimension to restore the original
frequency resolution before passing features to the decoder.
- ``build_bottleneck`` factory function that constructs a ``Bottleneck``
instance from a ``BottleneckConfig`` and the encoder's output dimensions.
"""
from typing import Annotated, List
@ -37,42 +41,51 @@ __all__ = [
class Bottleneck(Block):
"""Base Bottleneck module for Encoder-Decoder architectures.
"""Bottleneck module for encoder-decoder architectures.
This implementation represents the simplest bottleneck structure
considered, primarily consisting of a `VerticalConv` layer. This layer
collapses the frequency dimension (height) to 1, summarizing information
across frequencies at each time step. The output is then repeated along the
height dimension to match the original bottleneck input height before being
passed to the decoder.
Processes the lowest-resolution feature map that links the encoder and
decoder. The sequence of operations is:
This base version does *not* include self-attention.
1. ``VerticalConv`` collapses the frequency axis (height) to a single
bin by applying a convolution whose kernel spans the full height.
2. Optional additional layers (e.g. ``SelfAttention``) applied while
the feature map has height 1, so they operate purely along the time
axis.
3. Height restoration the single-bin output is repeated along the
height axis to restore the original frequency resolution, producing
a tensor that the decoder can accept.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor. Must be positive.
Height (number of frequency bins) of the input tensor. Must be
positive.
in_channels : int
Number of channels in the input tensor from the encoder. Must be
positive.
out_channels : int
Number of output channels. Must be positive.
Number of output channels after the bottleneck. Must be positive.
bottleneck_channels : int, optional
Number of internal channels used by the ``VerticalConv`` layer.
Defaults to ``out_channels`` if not provided.
layers : List[torch.nn.Module], optional
Additional modules (e.g. ``SelfAttention``) to apply after the
``VerticalConv`` and before height restoration.
Attributes
----------
in_channels : int
Number of input channels accepted by the bottleneck.
out_channels : int
Number of output channels produced by the bottleneck.
input_height : int
Expected height of the input tensor.
channels : int
Number of output channels.
bottleneck_channels : int
Number of channels used internally by the vertical convolution.
conv_vert : VerticalConv
The vertical convolution layer.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
layers : nn.ModuleList
Additional layers applied after the vertical convolution.
"""
def __init__(
@ -83,7 +96,23 @@ class Bottleneck(Block):
bottleneck_channels: int | None = None,
layers: List[torch.nn.Module] | None = None,
) -> None:
"""Initialize the base Bottleneck layer."""
"""Initialise the Bottleneck layer.
Parameters
----------
input_height : int
Height (number of frequency bins) of the input tensor.
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of channels in the output tensor.
bottleneck_channels : int, optional
Number of internal channels for the ``VerticalConv``. Defaults
to ``out_channels``.
layers : List[torch.nn.Module], optional
Additional modules applied after the ``VerticalConv``, such as
a ``SelfAttention`` block.
"""
super().__init__()
self.in_channels = in_channels
self.input_height = input_height
@ -103,23 +132,24 @@ class Bottleneck(Block):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input features through the bottleneck.
"""Process the encoder's bottleneck features.
Applies vertical convolution and repeats the output height.
Applies vertical convolution, optional additional layers, then
restores the height dimension by repetition.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Input tensor from the encoder, shape
``(B, C_in, H_in, W)``. ``C_in`` must match
``self.in_channels`` and ``H_in`` must match
``self.input_height``.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height
dimension `H_in` is restored via repetition after the vertical
convolution.
Output tensor with shape ``(B, C_out, H_in, W)``. The height
``H_in`` is restored by repeating the single-bin result.
"""
x = self.conv_vert(x)
@ -133,28 +163,22 @@ BottleneckLayerConfig = Annotated[
SelfAttentionConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
"""Type alias for the discriminated union of block configs usable in the Bottleneck."""
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
"""Configuration for the bottleneck component.
Attributes
----------
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
Number of output channels produced by the bottleneck. This value
is also used as the dimensionality of any optional layers (e.g.
self-attention). Must be positive.
layers : List[BottleneckLayerConfig]
Ordered list of additional block configurations to apply after the
initial ``VerticalConv``. Currently only ``SelfAttentionConfig`` is
supported. Defaults to an empty list (no extra layers).
"""
channels: int
@ -174,30 +198,37 @@ def build_bottleneck(
in_channels: int,
config: BottleneckConfig | None = None,
) -> BottleneckProtocol:
"""Factory function to build the Bottleneck module from configuration.
"""Build a ``Bottleneck`` module from configuration.
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
on the `config.self_attention` flag.
Constructs a ``Bottleneck`` instance whose internal channel count and
optional extra layers (e.g. self-attention) are controlled by
``config``. If no configuration is provided, the default
``DEFAULT_BOTTLENECK_CONFIG`` is used, which includes a
``SelfAttention`` layer.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor. Must be positive.
Height (number of frequency bins) of the input tensor from the
encoder. Must be positive.
in_channels : int
Number of channels in the input tensor. Must be positive.
Number of channels in the input tensor from the encoder. Must be
positive.
config : BottleneckConfig, optional
Configuration object specifying the bottleneck channels and whether
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None.
Configuration specifying the output channel count and any
additional layers. Uses ``DEFAULT_BOTTLENECK_CONFIG`` if ``None``.
Returns
-------
nn.Module
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`).
BottleneckProtocol
An initialised ``Bottleneck`` module.
Raises
------
ValueError
If `input_height` or `in_channels` are not positive.
AssertionError
If any configured layer changes the height of the feature map
(bottleneck layers must preserve height so that it can be restored
by repetition).
"""
config = config or DEFAULT_BOTTLENECK_CONFIG

View File

@ -1,21 +1,21 @@
"""Constructs the Decoder part of an Encoder-Decoder neural network.
"""Decoder (upsampling path) for the BatDetect2 backbone.
This module defines the configuration structure (`DecoderConfig`) for the layer
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory
function (`build_decoder`). Decoders typically form the upsampling path in
architectures like U-Nets, taking bottleneck features
(usually from an `Encoder`) and skip connections to reconstruct
higher-resolution feature maps.
This module defines ``DecoderConfig`` and the ``Decoder`` ``nn.Module``,
together with the ``build_decoder`` factory function.
The decoder is built dynamically by stacking neural network blocks based on a
list of configuration objects provided in `DecoderConfig.layers`. Each config
object specifies the type of block (e.g., standard convolution,
coordinate-feature convolution with upsampling) and its parameters. This allows
flexible definition of decoder architectures via configuration files.
In a U-Net-style network the decoder progressively restores the spatial
resolution of the feature map back towards the input resolution. At each
stage it combines the upsampled features with the corresponding skip-connection
tensor from the encoder (the residual) by element-wise addition before passing
the result to the upsampling block.
The `Decoder`'s `forward` method is designed to accept skip connection tensors
(`residuals`) from the encoder, merging them with the upsampled feature maps
at each stage.
The decoder is fully configurable: the type, number, and parameters of the
upsampling blocks are described by a ``DecoderConfig`` object containing an
ordered list of block configuration objects (see ``batdetect2.models.blocks``
for available block types).
A default configuration ``DEFAULT_DECODER_CONFIG`` is provided and used by
``build_decoder`` when no explicit configuration is supplied.
"""
from typing import Annotated, List
@ -51,51 +51,47 @@ DecoderLayerConfig = Annotated[
class DecoderConfig(BaseConfig):
"""Configuration for the sequence of layers in the Decoder module.
Defines the types and parameters of the neural network blocks that
constitute the decoder's upsampling path.
"""Configuration for the sequential ``Decoder`` module.
Attributes
----------
layers : List[DecoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the decoder sequence. Each item must be a valid block
config including a `name` field and necessary parameters like
`out_channels`. Input channels for each layer are inferred sequentially.
The list must contain at least one layer.
Ordered list of block configuration objects defining the decoder's
upsampling stages (from deepest to shallowest). Each entry
specifies the block type (via its ``name`` field) and any
block-specific parameters such as ``out_channels``. Input channels
for each block are inferred automatically from the output of the
previous block. Must contain at least one entry.
"""
layers: List[DecoderLayerConfig] = Field(min_length=1)
class Decoder(nn.Module):
"""Sequential Decoder module composed of configurable upsampling layers.
"""Sequential decoder module composed of configurable upsampling layers.
Constructs the upsampling path of an encoder-decoder network by stacking
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`)
based on a list of layer modules provided during initialization (typically
created by the `build_decoder` factory function).
Executes a series of upsampling blocks in order, adding the
corresponding encoder skip-connection tensor (residual) to the feature
map before each block. The residuals are consumed in reverse order (from
deepest encoder layer to shallowest) to match the spatial resolutions at
each decoder stage.
The `forward` method is designed to integrate skip connection tensors
(`residuals`) from the corresponding encoder stages, by adding them
element-wise to the input of each decoder layer before processing.
Instances are typically created by ``build_decoder``.
Attributes
----------
in_channels : int
Number of channels expected in the input tensor.
Number of channels expected in the input tensor (bottleneck output).
out_channels : int
Number of channels in the final output tensor produced by the last
layer.
Number of channels in the final output feature map.
input_height : int
Height (frequency bins) expected in the input tensor.
Height (frequency bins) of the input tensor.
output_height : int
Height (frequency bins) expected in the output tensor.
Height (frequency bins) of the output tensor.
layers : nn.ModuleList
The sequence of instantiated upscaling layer modules.
Sequence of instantiated upsampling block modules.
depth : int
The number of upscaling layers (depth) in the decoder.
Number of upsampling layers.
"""
def __init__(
@ -106,23 +102,24 @@ class Decoder(nn.Module):
output_height: int,
layers: List[nn.Module],
):
"""Initialize the Decoder module.
"""Initialise the Decoder module.
Note: This constructor is typically called internally by the
`build_decoder` factory function.
This constructor is typically called by the ``build_decoder``
factory function.
Parameters
----------
in_channels : int
Number of channels in the input tensor (bottleneck output).
out_channels : int
Number of channels produced by the final layer.
input_height : int
Expected height of the input tensor (bottleneck).
in_channels : int
Expected number of channels in the input tensor (bottleneck).
Height of the input tensor (bottleneck output height).
output_height : int
Height of the output tensor after all layers have been applied.
layers : List[nn.Module]
A list of pre-instantiated upscaling layer modules (e.g.,
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired
sequence (from bottleneck towards output resolution).
Pre-built upsampling block modules in execution order (deepest
stage first).
"""
super().__init__()
@ -140,43 +137,35 @@ class Decoder(nn.Module):
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
"""Pass input through decoder layers, incorporating skip connections.
"""Pass input through all decoder layers, incorporating skip connections.
Processes the input tensor `x` sequentially through the upscaling
layers. At each stage, the corresponding skip connection tensor from
the `residuals` list is added element-wise to the input before passing
it to the upscaling block.
At each stage the corresponding residual tensor is added
element-wise to ``x`` before it is passed to the upsampling block.
Residuals are consumed in reverse order the last element of
``residuals`` (the output of the shallowest encoder layer) is added
at the first decoder stage, and the first element (output of the
deepest encoder layer) is added at the last decoder stage.
Parameters
----------
x : torch.Tensor
Input tensor from the previous stage (e.g., encoder bottleneck).
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
`self.in_channels`.
Bottleneck feature map, shape ``(B, C_in, H_in, W)``.
residuals : List[torch.Tensor]
List containing the skip connection tensors from the corresponding
encoder stages. Should be ordered from the deepest encoder layer
output (lowest resolution) to the shallowest (highest resolution
near input). The number of tensors in this list must match the
number of decoder layers (`self.depth`). Each residual tensor's
channel count must be compatible with the input tensor `x` for
element-wise addition (or concatenation if the blocks were designed
for it).
Skip-connection tensors from the encoder, ordered from shallowest
(index 0) to deepest (index -1). Must contain exactly
``self.depth`` tensors. Each tensor must have the same spatial
dimensions and channel count as ``x`` at the corresponding
decoder stage.
Returns
-------
torch.Tensor
The final decoded feature map tensor produced by the last layer.
Shape `(B, C_out, H_out, W_out)`.
Decoded feature map, shape ``(B, C_out, H_out, W)``.
Raises
------
ValueError
If the number of `residuals` provided does not match the decoder
depth.
RuntimeError
If shapes mismatch during skip connection addition or layer
processing.
If the number of ``residuals`` does not equal ``self.depth``.
"""
if len(residuals) != len(self.layers):
raise ValueError(
@ -203,11 +192,17 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
),
],
)
"""A default configuration for the Decoder's *layer sequence*.
"""Default decoder configuration used in standard BatDetect2 models.
Specifies an architecture often used in BatDetect2, consisting of three
frequency coordinate-aware upsampling blocks followed by a standard
convolutional block.
Mirrors ``DEFAULT_ENCODER_CONFIG`` in reverse. Assumes the bottleneck
output has 256 channels and height 16, and produces:
- Stage 1 (``FreqCoordConvUp``): 64 channels, height 32.
- Stage 2 (``FreqCoordConvUp``): 32 channels, height 64.
- Stage 3 (``LayerGroup``):
- ``FreqCoordConvUp``: 32 channels, height 128.
- ``ConvBlock``: 32 channels, height 128 (final feature map).
"""
@ -216,40 +211,36 @@ def build_decoder(
input_height: int,
config: DecoderConfig | None = None,
) -> Decoder:
"""Factory function to build a Decoder instance from configuration.
"""Build a ``Decoder`` from configuration.
Constructs a sequential `Decoder` module based on the layer sequence
defined in a `DecoderConfig` object and the provided input dimensions
(bottleneck channels and height). If no config is provided, uses the
default layer sequence from `DEFAULT_DECODER_CONFIG`.
It iteratively builds the layers using the unified `build_layer_from_config`
factory (from `.blocks`), tracking the changing number of channels and
feature map height required for each subsequent layer.
Constructs a sequential ``Decoder`` by iterating over the block
configurations in ``config.layers``, building each block with
``build_layer``, and tracking the channel count and feature-map height
as they change through the sequence.
Parameters
----------
in_channels : int
The number of channels in the input tensor to the decoder. Must be > 0.
Number of channels in the input tensor (bottleneck output). Must
be positive.
input_height : int
The height (frequency bins) of the input tensor to the decoder. Must be
> 0.
Height (number of frequency bins) of the input tensor. Must be
positive.
config : DecoderConfig, optional
The configuration object detailing the sequence of layers and their
parameters. If None, `DEFAULT_DECODER_CONFIG` is used.
Configuration specifying the layer sequence. Defaults to
``DEFAULT_DECODER_CONFIG`` if not provided.
Returns
-------
Decoder
An initialized `Decoder` module.
An initialised ``Decoder`` module.
Raises
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `name`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `name`.
If ``in_channels`` or ``input_height`` are not positive.
KeyError
If a layer configuration specifies an unknown block type.
"""
config = config or DEFAULT_DECODER_CONFIG

View File

@ -1,17 +1,20 @@
"""Assembles the complete BatDetect2 Detection Model.
"""Assembles the complete BatDetect2 detection model.
This module defines the concrete `Detector` class, which implements the
`DetectionModel` interface defined in `.types`. It combines a feature
extraction backbone with specific prediction heads to create the end-to-end
neural network used for detecting bat calls, predicting their size, and
classifying them.
This module defines the ``Detector`` class, which combines a backbone
feature extractor with prediction heads for detection, classification, and
bounding-box size regression.
The primary components are:
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
Components
----------
- ``Detector`` the ``torch.nn.Module`` that wires together a backbone
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
produce a ``ModelOutput`` tuple from an input spectrogram.
- ``build_detector`` factory function that builds a ready-to-use
``Detector`` from a backbone configuration and a target class count.
This module focuses purely on the neural network architecture definition. The
logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
Note that ``Detector`` operates purely on spectrogram tensors; raw audio
preprocessing and output postprocessing are handled by
``batdetect2.preprocess`` and ``batdetect2.postprocess`` respectively.
"""
import torch
@ -32,25 +35,30 @@ __all__ = [
class Detector(DetectionModel):
"""Concrete implementation of the BatDetect2 Detection Model.
"""Complete BatDetect2 detection and classification model.
Assembles a complete detection and classification model by combining a
feature extraction backbone network with specific prediction heads for
detection probability, bounding box size regression, and class
probabilities.
Combines a backbone feature extractor with two prediction heads:
- ``ClassifierHead``: predicts per-class probabilities at each
timefrequency location.
- ``BBoxHead``: predicts call duration and bandwidth at each location.
The detection probability map is derived from the class probabilities by
summing across the class dimension (i.e. the probability that *any* class
is present), rather than from a separate detection head.
Instances are typically created via ``build_detector``.
Attributes
----------
backbone : BackboneModel
The feature extraction backbone network module.
The feature extraction backbone.
num_classes : int
The number of specific target classes the model predicts (derived from
the `classifier_head`).
Number of target classes (inferred from the classifier head).
classifier_head : ClassifierHead
The prediction head responsible for generating class probabilities.
Produces per-class probability maps from backbone features.
bbox_head : BBoxHead
The prediction head responsible for generating bounding box size
predictions.
Produces duration and bandwidth predictions from backbone features.
"""
backbone: BackboneModel
@ -61,26 +69,21 @@ class Detector(DetectionModel):
classifier_head: ClassifierHead,
bbox_head: BBoxHead,
):
"""Initialize the Detector model.
"""Initialise the Detector model.
Note: Instances are typically created using the `build_detector`
This constructor is typically called by the ``build_detector``
factory function.
Parameters
----------
backbone : BackboneModel
An initialized feature extraction backbone module (e.g., built by
`build_backbone` from the `.backbone` module).
An initialised backbone module (e.g. built by
``build_backbone``).
classifier_head : ClassifierHead
An initialized classification head module. The number of classes
is inferred from this head.
An initialised classification head. The ``num_classes``
attribute is read from this head.
bbox_head : BBoxHead
An initialized bounding box size prediction head module.
Raises
------
TypeError
If the provided modules are not of the expected types.
An initialised bounding-box size prediction head.
"""
super().__init__()
@ -90,31 +93,34 @@ class Detector(DetectionModel):
self.bbox_head = bbox_head
def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Perform the forward pass of the complete detection model.
"""Run the complete detection model on an input spectrogram.
Processes the input spectrogram through the backbone to extract
features, then passes these features through the separate prediction
heads to generate detection probabilities, class probabilities, and
size predictions.
Passes the spectrogram through the backbone to produce a feature
map, then applies the classifier and bounding-box heads. The
detection probability map is derived by summing the per-class
probability maps across the class dimension; no separate detection
head is used.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, typically with shape
`(batch_size, input_channels, frequency_bins, time_bins)`. The
shape must be compatible with the `self.backbone` input
requirements.
Input spectrogram tensor, shape
``(batch_size, channels, frequency_bins, time_bins)``.
Returns
-------
ModelOutput
A NamedTuple containing the four output tensors:
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`.
- `class_probs`: Class probabilities (excluding background)
`(B, num_classes, H, W)`.
- `features`: Output feature map from the backbone
`(B, C_out, H, W)`.
A named tuple with four fields:
- ``detection_probs`` ``(B, 1, H, W)`` probability that a
call of any class is present at each location. Derived by
summing ``class_probs`` over the class dimension.
- ``size_preds`` ``(B, 2, H, W)`` scaled duration (channel
0) and bandwidth (channel 1) predictions at each location.
- ``class_probs`` ``(B, num_classes, H, W)`` per-class
probabilities at each location.
- ``features`` ``(B, C_out, H, W)`` raw backbone feature
map.
"""
features = self.backbone(spec)
classification = self.classifier_head(features)
@ -131,30 +137,33 @@ class Detector(DetectionModel):
def build_detector(
num_classes: int, config: BackboneConfig | None = None
) -> DetectionModel:
"""Build the complete BatDetect2 detection model.
"""Build a complete BatDetect2 detection model.
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
and a ``BBoxHead`` sized to the backbone's output channel count, and
returns them wrapped in a ``Detector``.
Parameters
----------
num_classes : int
The number of specific target classes the model should predict
(required for the `ClassifierHead`). Must be positive.
Number of target bat species or call types to predict. Must be
positive.
config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone
(encoder, bottleneck, decoder). If None, default configurations defined
within the respective builder functions (`build_encoder`, etc.) will be
used to construct a default backbone architecture.
Backbone architecture configuration. Defaults to
``UNetBackboneConfig()`` (the standard BatDetect2 architecture) if
not provided.
Returns
-------
DetectionModel
An initialized `Detector` model instance.
An initialised ``Detector`` instance ready for training or
inference.
Raises
------
ValueError
If `num_classes` is not positive, or if errors occur during the
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
If ``num_classes`` is not positive, or if the backbone
configuration is invalid.
"""
config = config or UNetBackboneConfig()

View File

@ -1,23 +1,24 @@
"""Constructs the Encoder part of a configurable neural network backbone.
"""Encoder (downsampling path) for the BatDetect2 backbone.
This module defines the configuration structure (`EncoderConfig`) and provides
the `Encoder` class (an `nn.Module`) along with a factory function
(`build_encoder`) to create sequential encoders. Encoders typically form the
downsampling path in architectures like U-Nets, processing input feature maps
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
representations (bottleneck features).
This module defines ``EncoderConfig`` and the ``Encoder`` ``nn.Module``,
together with the ``build_encoder`` factory function.
The encoder is built dynamically by stacking neural network blocks based on a
list of configuration objects provided in `EncoderConfig.layers`. Each
configuration object specifies the type of block (e.g., standard convolution,
coordinate-feature convolution with downsampling) and its parameters
(e.g., output channels). This allows for flexible definition of encoder
architectures via configuration files.
In a U-Net-style network the encoder progressively reduces the spatial
resolution of the spectrogram whilst increasing the number of feature
channels. Each layer in the encoder produces a feature map that is stored
for use as a skip connection in the corresponding decoder layer.
The `Encoder`'s `forward` method returns outputs from all intermediate layers,
suitable for skip connections, while the `encode` method returns only the final
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
provided.
The encoder is fully configurable: the type, number, and parameters of the
downsampling blocks are described by an ``EncoderConfig`` object containing
an ordered list of block configuration objects (see ``batdetect2.models.blocks``
for available block types).
``Encoder.forward`` returns the outputs of *all* encoder layers as a list,
so that skip connections are available to the decoder.
``Encoder.encode`` returns only the final output (the input to the bottleneck).
A default configuration ``DEFAULT_ENCODER_CONFIG`` is provided and used by
``build_encoder`` when no explicit configuration is supplied.
"""
from typing import Annotated, List
@ -53,35 +54,32 @@ EncoderLayerConfig = Annotated[
class EncoderConfig(BaseConfig):
"""Configuration for building the sequential Encoder module.
Defines the sequence of neural network blocks that constitute the encoder
(downsampling path).
"""Configuration for the sequential ``Encoder`` module.
Attributes
----------
layers : List[EncoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the encoder sequence. Each item must be a valid block config
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
`StandardConvDownConfig`) including a `name` field and necessary
parameters like `out_channels`. Input channels for each layer are
inferred sequentially. The list must contain at least one layer.
Ordered list of block configuration objects defining the encoder's
downsampling stages. Each entry specifies the block type (via its
``name`` field) and any block-specific parameters such as
``out_channels``. Input channels for each block are inferred
automatically from the output of the previous block. Must contain
at least one entry.
"""
layers: List[EncoderLayerConfig] = Field(min_length=1)
class Encoder(nn.Module):
"""Sequential Encoder module composed of configurable downscaling layers.
"""Sequential encoder module composed of configurable downsampling layers.
Constructs the downsampling path of an encoder-decoder network by stacking
multiple downscaling blocks.
Executes a series of downsampling blocks in order, storing the output of
each block so that it can be passed as a skip connection to the
corresponding decoder layer.
The `forward` method executes the sequence and returns the output feature
map from *each* downscaling stage, facilitating the implementation of skip
connections in U-Net-like architectures. The `encode` method returns only
the final output tensor (bottleneck features).
``forward`` returns the outputs of *all* layers (useful when skip
connections are needed). ``encode`` returns only the final output
(the input to the bottleneck).
Attributes
----------
@ -89,14 +87,14 @@ class Encoder(nn.Module):
Number of channels expected in the input tensor.
input_height : int
Height (frequency bins) expected in the input tensor.
output_channels : int
Number of channels in the final output tensor (bottleneck).
out_channels : int
Number of channels in the final output tensor (bottleneck input).
output_height : int
Height (frequency bins) expected in the output tensor.
Height (frequency bins) of the final output tensor.
layers : nn.ModuleList
The sequence of instantiated downscaling layer modules.
Sequence of instantiated downsampling block modules.
depth : int
The number of downscaling layers in the encoder.
Number of downsampling layers.
"""
def __init__(
@ -107,23 +105,22 @@ class Encoder(nn.Module):
input_height: int = 128,
in_channels: int = 1,
):
"""Initialize the Encoder module.
"""Initialise the Encoder module.
Note: This constructor is typically called internally by the
`build_encoder` factory function, which prepares the `layers` list.
This constructor is typically called by the ``build_encoder`` factory
function, which takes care of building the ``layers`` list from a
configuration object.
Parameters
----------
output_channels : int
Number of channels produced by the final layer.
output_height : int
The expected height of the output tensor.
Height of the output tensor after all layers have been applied.
layers : List[nn.Module]
A list of pre-instantiated downscaling layer modules (e.g.,
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
sequence.
Pre-built downsampling block modules in execution order.
input_height : int, default=128
Expected height of the input tensor.
Expected height of the input tensor (frequency bins).
in_channels : int, default=1
Expected number of channels in the input tensor.
"""
@ -138,29 +135,30 @@ class Encoder(nn.Module):
self.depth = len(self.layers)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Pass input through encoder layers, returns all intermediate outputs.
"""Pass input through all encoder layers and return every output.
This method is typically used when the Encoder is part of a U-Net or
similar architecture requiring skip connections.
Used when skip connections are needed (e.g. in a U-Net decoder).
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
`self.in_channels` and `H_in` must match `self.input_height`.
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
``C_in`` must match ``self.in_channels`` and ``H_in`` must
match ``self.input_height``.
Returns
-------
List[torch.Tensor]
A list containing the output tensors from *each* downscaling layer
in the sequence. `outputs[0]` is the output of the first layer,
`outputs[-1]` is the final output (bottleneck) of the encoder.
Output tensors from every layer in order.
``outputs[0]`` is the output of the first (shallowest) layer;
``outputs[-1]`` is the output of the last (deepest) layer,
which serves as the input to the bottleneck.
Raises
------
ValueError
If input tensor channel count or height does not match expected
values.
If the input channel count or height does not match the
expected values.
"""
if x.shape[1] != self.in_channels:
raise ValueError(
@ -183,28 +181,29 @@ class Encoder(nn.Module):
return outputs
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Pass input through encoder layers, returning only the final output.
"""Pass input through all encoder layers and return only the final output.
This method provides access to the bottleneck features produced after
the last downscaling layer.
Use this when skip connections are not needed and you only require
the bottleneck feature map.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
`in_channels` and `input_height`.
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
Must satisfy the same shape requirements as ``forward``.
Returns
-------
torch.Tensor
The final output tensor (bottleneck features) from the last layer
of the encoder. Shape `(B, C_out, H_out, W_out)`.
Output of the last encoder layer, shape
``(B, C_out, H_out, W)``, where ``C_out`` is
``self.out_channels`` and ``H_out`` is ``self.output_height``.
Raises
------
ValueError
If input tensor channel count or height does not match expected
values.
If the input channel count or height does not match the
expected values.
"""
if x.shape[1] != self.in_channels:
raise ValueError(
@ -236,14 +235,17 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
),
],
)
"""Default configuration for the Encoder.
"""Default encoder configuration used in standard BatDetect2 models.
Specifies an architecture typically used in BatDetect2:
- Input: 1 channel, 128 frequency bins.
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
- Layer 2: FreqCoordConvDown -> 64 channels, H=32
- Layer 3: FreqCoordConvDown -> 128 channels, H=16
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck)
Assumes a 1-channel input with 128 frequency bins and produces the
following feature maps:
- Stage 1 (``FreqCoordConvDown``): 32 channels, height 64.
- Stage 2 (``FreqCoordConvDown``): 64 channels, height 32.
- Stage 3 (``LayerGroup``):
- ``FreqCoordConvDown``: 128 channels, height 16.
- ``ConvBlock``: 256 channels, height 16 (bottleneck input).
"""
@ -252,42 +254,38 @@ def build_encoder(
input_height: int,
config: EncoderConfig | None = None,
) -> Encoder:
"""Factory function to build an Encoder instance from configuration.
"""Build an ``Encoder`` from configuration.
Constructs a sequential `Encoder` module based on the layer sequence
defined in an `EncoderConfig` object and the provided input dimensions.
If no config is provided, uses the default layer sequence from
`DEFAULT_ENCODER_CONFIG`.
It iteratively builds the layers using the unified
`build_layer_from_config` factory (from `.blocks`), tracking the changing
number of channels and feature map height required for each subsequent
layer, especially for coordinate- aware blocks.
Constructs a sequential ``Encoder`` by iterating over the block
configurations in ``config.layers``, building each block with
``build_layer``, and tracking the channel count and feature-map height
as they change through the sequence.
Parameters
----------
in_channels : int
The number of channels expected in the input tensor to the encoder.
Must be > 0.
Number of channels in the input spectrogram tensor. Must be
positive.
input_height : int
The height (frequency bins) expected in the input tensor. Must be > 0.
Crucial for initializing coordinate-aware layers correctly.
Height (number of frequency bins) of the input spectrogram.
Must be positive and should be divisible by
``2 ** (number of downsampling stages)`` to avoid size mismatches
later in the network.
config : EncoderConfig, optional
The configuration object detailing the sequence of layers and their
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used.
Configuration specifying the layer sequence. Defaults to
``DEFAULT_ENCODER_CONFIG`` if not provided.
Returns
-------
Encoder
An initialized `Encoder` module.
An initialised ``Encoder`` module.
Raises
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `name`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `name`.
If ``in_channels`` or ``input_height`` are not positive.
KeyError
If a layer configuration specifies an unknown block type.
"""
if in_channels <= 0 or input_height <= 0:
raise ValueError("in_channels and input_height must be positive.")

View File

@ -1,20 +1,19 @@
"""Prediction Head modules for BatDetect2 models.
"""Prediction heads attached to the backbone feature map.
This module defines simple `torch.nn.Module` subclasses that serve as
prediction heads, typically attached to the output feature map of a backbone
network
Each head is a lightweight ``torch.nn.Module`` that applies a 1×1
convolution to map backbone feature channels to one specific type of
output required by BatDetect2:
Each head is responsible for generating one specific type of output required
by the BatDetect2 task:
- `DetectorHead`: Predicts the probability of sound event presence.
- `ClassifierHead`: Predicts the probability distribution over target classes.
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding
box.
- ``DetectorHead``: single-channel detection probability heatmap (sigmoid
activation).
- ``ClassifierHead``: multi-class probability map over the target bat
species / call types (softmax activation).
- ``BBoxHead``: two-channel map of predicted call duration (time axis) and
bandwidth (frequency axis) at each location (no activation; raw
regression output).
These heads use 1x1 convolutions to map the backbone feature channels
to the desired number of output channels for each prediction task at each
spatial location, followed by an appropriate activation function (e.g., sigmoid
for detection, softmax for classification, none for size regression).
All three heads share the same input feature map produced by the backbone,
so they can be evaluated in parallel in a single forward pass.
"""
import torch
@ -28,42 +27,35 @@ __all__ = [
class ClassifierHead(nn.Module):
"""Prediction head for multi-class classification probabilities.
"""Prediction head for species / call-type classification probabilities.
Takes an input feature map and produces a probability map where each
channel corresponds to a specific target class. It uses a 1x1 convolution
to map input channels to `num_classes + 1` outputs (one for each target
class plus an assumed background/generic class), applies softmax across the
channels, and returns the probabilities for the specific target classes
(excluding the last background/generic channel).
Takes a backbone feature map and produces a probability map where each
channel corresponds to a target class. Internally the 1×1 convolution
maps ``in_channels`` to ``num_classes + 1`` logits (the extra channel
represents a generic background / unknown category); a softmax is then
applied across the channel dimension and the background channel is
discarded before returning.
Parameters
----------
num_classes : int
The number of specific target classes the model should predict
(excluding any background or generic category). Must be positive.
Number of target classes (bat species or call types) to predict,
excluding the background category. Must be positive.
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Number of channels in the backbone feature map. Must be positive.
Attributes
----------
num_classes : int
Number of specific output classes.
Number of specific output classes (background excluded).
in_channels : int
Number of input channels expected.
classifier : nn.Conv2d
The 1x1 convolutional layer used for prediction.
Output channels = num_classes + 1.
Raises
------
ValueError
If `num_classes` or `in_channels` are not positive.
1×1 convolution with ``num_classes + 1`` output channels.
"""
def __init__(self, num_classes: int, in_channels: int):
"""Initialize the ClassifierHead."""
"""Initialise the ClassifierHead."""
super().__init__()
self.num_classes = num_classes
@ -76,20 +68,20 @@ class ClassifierHead(nn.Module):
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute class probabilities from input features.
"""Compute per-class probabilities from backbone features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Backbone feature map, shape ``(B, C_in, H, W)``.
Returns
-------
torch.Tensor
Class probability map tensor with shape `(B, num_classes, H, W)`.
Contains probabilities for the specific target classes after
softmax, excluding the implicit background/generic class channel.
Class probability map, shape ``(B, num_classes, H, W)``.
Values are softmax probabilities in the range [0, 1] and
sum to less than 1 per location (the background probability
is discarded).
"""
logits = self.classifier(features)
probs = torch.softmax(logits, dim=1)
@ -97,36 +89,30 @@ class ClassifierHead(nn.Module):
class DetectorHead(nn.Module):
"""Prediction head for sound event detection probability.
"""Prediction head for detection probability (is a call present here?).
Takes an input feature map and produces a single-channel heatmap where
each value represents the probability ([0, 1]) of a relevant sound event
(of any class) being present at that spatial location.
Produces a single-channel heatmap where each value indicates the
probability ([0, 1]) that a bat call of *any* species is present at
that timefrequency location in the spectrogram.
Uses a 1x1 convolution to map input channels to 1 output channel, followed
by a sigmoid activation function.
Applies a 1×1 convolution mapping ``in_channels`` 1, followed by
sigmoid activation.
Parameters
----------
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Number of channels in the backbone feature map. Must be positive.
Attributes
----------
in_channels : int
Number of input channels expected.
detector : nn.Conv2d
The 1x1 convolutional layer mapping to a single output channel.
Raises
------
ValueError
If `in_channels` is not positive.
1×1 convolution with a single output channel.
"""
def __init__(self, in_channels: int):
"""Initialize the DetectorHead."""
"""Initialise the DetectorHead."""
super().__init__()
self.in_channels = in_channels
@ -138,62 +124,49 @@ class DetectorHead(nn.Module):
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute detection probabilities from input features.
"""Compute detection probabilities from backbone features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Backbone feature map, shape ``(B, C_in, H, W)``.
Returns
-------
torch.Tensor
Detection probability heatmap tensor with shape `(B, 1, H, W)`.
Values are in the range [0, 1] due to the sigmoid activation.
Raises
------
RuntimeError
If input channel count does not match `self.in_channels`.
Detection probability heatmap, shape ``(B, 1, H, W)``.
Values are in the range [0, 1].
"""
return torch.sigmoid(self.detector(features))
class BBoxHead(nn.Module):
"""Prediction head for bounding box size dimensions.
"""Prediction head for bounding box size (duration and bandwidth).
Takes an input feature map and produces a two-channel map where each
channel represents a predicted size dimension (typically width/duration and
height/bandwidth) for a potential sound event at that spatial location.
Produces a two-channel map where channel 0 predicts the scaled duration
(time-axis extent) and channel 1 predicts the scaled bandwidth
(frequency-axis extent) of the call at each spectrogram location.
Uses a 1x1 convolution to map input channels to 2 output channels. No
activation function is typically applied, as size prediction is often
treated as a direct regression task. The output values usually represent
*scaled* dimensions that need to be un-scaled during postprocessing.
Applies a 1×1 convolution mapping ``in_channels`` 2 with no
activation function (raw regression output). The predicted values are
in a scaled space and must be converted to real units (seconds and Hz)
during postprocessing.
Parameters
----------
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Number of channels in the backbone feature map. Must be positive.
Attributes
----------
in_channels : int
Number of input channels expected.
bbox : nn.Conv2d
The 1x1 convolutional layer mapping to 2 output channels
(width, height).
Raises
------
ValueError
If `in_channels` is not positive.
1×1 convolution with 2 output channels (duration, bandwidth).
"""
def __init__(self, in_channels: int):
"""Initialize the BBoxHead."""
"""Initialise the BBoxHead."""
super().__init__()
self.in_channels = in_channels
@ -205,19 +178,19 @@ class BBoxHead(nn.Module):
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute predicted bounding box dimensions from input features.
"""Predict call duration and bandwidth from backbone features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Backbone feature map, shape ``(B, C_in, H, W)``.
Returns
-------
torch.Tensor
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually
represents scaled width, Channel 1 scaled height. These values
need to be un-scaled during postprocessing.
Size prediction tensor, shape ``(B, 2, H, W)``. Channel 0 is
the predicted scaled duration; channel 1 is the predicted
scaled bandwidth. Values must be rescaled to real units during
postprocessing.
"""
return self.bbox(features)

View File

@ -1,14 +1,4 @@
"""Tests for backbone configuration loading and the backbone registry.
Covers:
- UNetBackboneConfig default construction and field values.
- build_backbone with default and explicit configs.
- load_backbone_config loading from a YAML file.
- load_backbone_config with a nested field path.
- load_backbone_config round-trip: YAML config build_backbone.
- Registry registration and dispatch for UNetBackbone.
- BackboneConfig discriminated union validation.
"""
"""Tests for backbone configuration loading and the backbone registry."""
from pathlib import Path
from typing import Callable
@ -25,10 +15,6 @@ from batdetect2.models.backbones import (
)
from batdetect2.typing.models import BackboneModel
# ---------------------------------------------------------------------------
# UNetBackboneConfig
# ---------------------------------------------------------------------------
def test_unet_backbone_config_defaults():
"""Default config has expected field values."""
@ -57,11 +43,6 @@ def test_unet_backbone_config_extra_fields_ignored():
assert not hasattr(config, "unknown_field")
# ---------------------------------------------------------------------------
# build_backbone
# ---------------------------------------------------------------------------
def test_build_backbone_default():
"""Building with no config uses UNetBackbone defaults."""
backbone = build_backbone()
@ -83,15 +64,9 @@ def test_build_backbone_custom_config():
def test_build_backbone_returns_backbone_model():
"""build_backbone always returns a BackboneModel instance."""
backbone = build_backbone()
assert isinstance(backbone, BackboneModel)
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
def test_registry_has_unet_backbone():
"""The backbone registry has UNetBackbone registered."""
config_types = backbone_registry.get_config_types()
@ -124,11 +99,6 @@ def test_registry_build_unknown_name_raises():
backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# BackboneConfig discriminated union
# ---------------------------------------------------------------------------
def test_backbone_config_validates_unet_from_dict():
"""BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name."""
from pydantic import TypeAdapter
@ -151,11 +121,6 @@ def test_backbone_config_invalid_name_raises():
adapter.validate_python({"name": "NonExistentBackbone"})
# ---------------------------------------------------------------------------
# load_backbone_config
# ---------------------------------------------------------------------------
def test_load_backbone_config_from_yaml(
create_temp_yaml: Callable[[str], Path],
):
@ -218,11 +183,6 @@ deprecated_field: 99
assert config.input_height == 128
# ---------------------------------------------------------------------------
# Round-trip: YAML → config → build_backbone
# ---------------------------------------------------------------------------
def test_round_trip_yaml_to_build_backbone(
create_temp_yaml: Callable[[str], Path],
):