mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Update model docstrings
This commit is contained in:
parent
4207661da4
commit
ef3348d651
@ -1,29 +1,29 @@
|
|||||||
"""Defines and builds the neural network models used in BatDetect2.
|
"""Neural network model definitions and builders for BatDetect2.
|
||||||
|
|
||||||
This package (`batdetect2.models`) contains the PyTorch implementations of the
|
This package contains the PyTorch implementations of the deep neural network
|
||||||
deep neural network architectures used for detecting and classifying bat calls
|
architectures used to detect and classify bat echolocation calls in
|
||||||
from spectrograms. It provides modular components and configuration-driven
|
spectrograms. Components are designed to be combined through configuration
|
||||||
assembly, allowing for experimentation and use of different architectural
|
objects, making it easy to experiment with different architectures.
|
||||||
variants.
|
|
||||||
|
|
||||||
Key Submodules:
|
Key submodules
|
||||||
- `.types`: Defines core data structures (`ModelOutput`) and abstract base
|
--------------
|
||||||
classes (`BackboneModel`, `DetectionModel`) establishing interfaces.
|
- ``blocks``: Reusable convolutional building blocks (downsampling,
|
||||||
- `.blocks`: Provides reusable neural network building blocks.
|
upsampling, attention, coord-conv variants).
|
||||||
- `.encoder`: Defines and builds the downsampling path (encoder) of the network.
|
- ``encoder``: The downsampling path; reduces spatial resolution whilst
|
||||||
- `.bottleneck`: Defines and builds the central bottleneck component.
|
extracting increasingly abstract features.
|
||||||
- `.decoder`: Defines and builds the upsampling path (decoder) of the network.
|
- ``bottleneck``: The central component connecting encoder to decoder;
|
||||||
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete
|
optionally applies self-attention along the time axis.
|
||||||
feature extraction backbone (e.g., a U-Net like structure).
|
- ``decoder``: The upsampling path; reconstructs high-resolution feature
|
||||||
- `.heads`: Defines simple prediction heads (detection, classification, size)
|
maps using bottleneck output and skip connections from the encoder.
|
||||||
that attach to the backbone features.
|
- ``backbones``: Assembles encoder, bottleneck, and decoder into a complete
|
||||||
- `.detectors`: Assembles the backbone and prediction heads into the final,
|
U-Net-style feature extraction backbone.
|
||||||
end-to-end `Detector` model.
|
- ``heads``: Lightweight 1×1 convolutional heads that produce detection,
|
||||||
|
classification, and bounding-box size predictions from backbone features.
|
||||||
|
- ``detectors``: Combines a backbone with prediction heads into the final
|
||||||
|
end-to-end ``Detector`` model.
|
||||||
|
|
||||||
This module re-exports the most important classes, configurations, and builder
|
The primary entry point for building a full, ready-to-use BatDetect2 model
|
||||||
functions from these submodules for convenient access. The primary entry point
|
is the ``build_model`` factory function exported from this module.
|
||||||
for creating a standard BatDetect2 model instance is the `build_model` function
|
|
||||||
provided here.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
@ -99,6 +99,31 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
|
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
|
||||||
|
|
||||||
|
Combines a preprocessor, a detection model, and a postprocessor into a
|
||||||
|
single PyTorch module. Calling ``forward`` on a raw waveform tensor
|
||||||
|
returns a list of detection tensors ready for downstream use.
|
||||||
|
|
||||||
|
This class is the top-level object produced by ``build_model``. Most
|
||||||
|
users will not need to construct it directly.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
detector : DetectionModel
|
||||||
|
The neural network that processes spectrograms and produces raw
|
||||||
|
detection, classification, and bounding-box outputs.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
Converts a raw waveform tensor into a spectrogram tensor accepted by
|
||||||
|
``detector``.
|
||||||
|
postprocessor : PostprocessorProtocol
|
||||||
|
Converts the raw ``ModelOutput`` from ``detector`` into a list of
|
||||||
|
per-clip detection tensors.
|
||||||
|
targets : TargetProtocol
|
||||||
|
Describes the set of target classes; used when building heads and
|
||||||
|
during training target construction.
|
||||||
|
"""
|
||||||
|
|
||||||
detector: DetectionModel
|
detector: DetectionModel
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
@ -118,6 +143,25 @@ class Model(torch.nn.Module):
|
|||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
||||||
|
"""Run the full detection pipeline on a waveform tensor.
|
||||||
|
|
||||||
|
Converts the waveform to a spectrogram, passes it through the
|
||||||
|
detector, and postprocesses the raw outputs into detection tensors.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
wav : torch.Tensor
|
||||||
|
Raw audio waveform tensor. The exact expected shape depends on
|
||||||
|
the preprocessor, but is typically ``(batch, samples)`` or
|
||||||
|
``(batch, channels, samples)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[ClipDetectionsTensor]
|
||||||
|
One detection tensor per clip in the batch. Each tensor encodes
|
||||||
|
the detected events (locations, class scores, sizes) for that
|
||||||
|
clip.
|
||||||
|
"""
|
||||||
spec = self.preprocessor(wav)
|
spec = self.preprocessor(wav)
|
||||||
outputs = self.detector(spec)
|
outputs = self.detector(spec)
|
||||||
return self.postprocessor(outputs)
|
return self.postprocessor(outputs)
|
||||||
@ -128,7 +172,38 @@ def build_model(
|
|||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
postprocessor: PostprocessorProtocol | None = None,
|
postprocessor: PostprocessorProtocol | None = None,
|
||||||
):
|
) -> "Model":
|
||||||
|
"""Build a complete, ready-to-use BatDetect2 model.
|
||||||
|
|
||||||
|
Assembles a ``Model`` instance from optional configuration and component
|
||||||
|
overrides. Any argument left as ``None`` will be replaced by a sensible
|
||||||
|
default built with the project's own builder functions.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : BackboneConfig, optional
|
||||||
|
Configuration describing the backbone architecture (encoder,
|
||||||
|
bottleneck, decoder). Defaults to ``UNetBackboneConfig()`` if not
|
||||||
|
provided.
|
||||||
|
targets : TargetProtocol, optional
|
||||||
|
Describes the target bat species or call types to detect. Determines
|
||||||
|
the number of output classes. Defaults to the standard BatDetect2
|
||||||
|
target set.
|
||||||
|
preprocessor : PreprocessorProtocol, optional
|
||||||
|
Converts raw audio waveforms to spectrograms. Defaults to the
|
||||||
|
standard BatDetect2 preprocessor.
|
||||||
|
postprocessor : PostprocessorProtocol, optional
|
||||||
|
Converts raw model outputs to detection tensors. Defaults to the
|
||||||
|
standard BatDetect2 postprocessor. If a custom ``preprocessor`` is
|
||||||
|
given without a matching ``postprocessor``, the default postprocessor
|
||||||
|
will be built using the provided preprocessor so that frequency and
|
||||||
|
time scaling remain consistent.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Model
|
||||||
|
A fully assembled ``Model`` instance ready for inference or training.
|
||||||
|
"""
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
|
|||||||
@ -1,21 +1,26 @@
|
|||||||
"""Assembles a complete Encoder-Decoder Backbone network.
|
"""Assembles a complete encoder-decoder backbone network.
|
||||||
|
|
||||||
This module defines the configuration (`BackboneConfig`) and implementation
|
This module defines ``UNetBackboneConfig`` and the ``UNetBackbone``
|
||||||
(`Backbone`) for a standard encoder-decoder style neural network backbone.
|
``nn.Module``, together with the ``build_backbone`` and
|
||||||
|
``load_backbone_config`` helpers.
|
||||||
|
|
||||||
It orchestrates the connection between three main components, built using their
|
A backbone combines three components built from the sibling modules:
|
||||||
respective configurations and factory functions from sibling modules:
|
|
||||||
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
|
|
||||||
at multiple resolutions and provides skip connections.
|
|
||||||
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
|
|
||||||
lowest resolution, optionally applying self-attention.
|
|
||||||
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
|
|
||||||
resolution features using bottleneck features and skip connections.
|
|
||||||
|
|
||||||
The resulting `Backbone` module takes a spectrogram as input and outputs a
|
1. **Encoder** (``batdetect2.models.encoder``) – reduces spatial resolution
|
||||||
final feature map, typically used by subsequent prediction heads. It includes
|
while extracting hierarchical features and storing skip-connection tensors.
|
||||||
automatic padding to handle input sizes not perfectly divisible by the
|
2. **Bottleneck** (``batdetect2.models.bottleneck``) – processes the
|
||||||
network's total downsampling factor.
|
lowest-resolution features, optionally applying self-attention.
|
||||||
|
3. **Decoder** (``batdetect2.models.decoder``) – restores spatial resolution
|
||||||
|
using bottleneck features and skip connections from the encoder.
|
||||||
|
|
||||||
|
The resulting ``UNetBackbone`` takes a spectrogram tensor as input and returns
|
||||||
|
a high-resolution feature map consumed by the prediction heads in
|
||||||
|
``batdetect2.models.detectors``.
|
||||||
|
|
||||||
|
Input padding is handled automatically: the backbone pads the input to be
|
||||||
|
divisible by the total downsampling factor and strips the padding from the
|
||||||
|
output so that the output spatial dimensions always match the input spatial
|
||||||
|
dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Literal, Tuple, Union
|
from typing import Annotated, Literal, Tuple, Union
|
||||||
@ -51,6 +56,34 @@ from batdetect2.typing.models import (
|
|||||||
|
|
||||||
|
|
||||||
class UNetBackboneConfig(BaseConfig):
|
class UNetBackboneConfig(BaseConfig):
|
||||||
|
"""Configuration for a U-Net-style encoder-decoder backbone.
|
||||||
|
|
||||||
|
All fields have sensible defaults that reproduce the standard BatDetect2
|
||||||
|
architecture, so you can start with ``UNetBackboneConfig()`` and override
|
||||||
|
only the fields you want to change.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Discriminator field used by the backbone registry; always
|
||||||
|
``"UNetBackbone"``.
|
||||||
|
input_height : int
|
||||||
|
Number of frequency bins in the input spectrogram. Defaults to
|
||||||
|
``128``.
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input spectrogram (e.g. ``1`` for a
|
||||||
|
standard mel-spectrogram). Defaults to ``1``.
|
||||||
|
encoder : EncoderConfig
|
||||||
|
Configuration for the downsampling path. Defaults to
|
||||||
|
``DEFAULT_ENCODER_CONFIG``.
|
||||||
|
bottleneck : BottleneckConfig
|
||||||
|
Configuration for the bottleneck. Defaults to
|
||||||
|
``DEFAULT_BOTTLENECK_CONFIG``.
|
||||||
|
decoder : DecoderConfig
|
||||||
|
Configuration for the upsampling path. Defaults to
|
||||||
|
``DEFAULT_DECODER_CONFIG``.
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["UNetBackbone"] = "UNetBackbone"
|
name: Literal["UNetBackbone"] = "UNetBackbone"
|
||||||
input_height: int = 128
|
input_height: int = 128
|
||||||
in_channels: int = 1
|
in_channels: int = 1
|
||||||
@ -70,35 +103,36 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class UNetBackbone(BackboneModel):
|
class UNetBackbone(BackboneModel):
|
||||||
"""Encoder-Decoder Backbone Network Implementation.
|
"""U-Net-style encoder-decoder backbone network.
|
||||||
|
|
||||||
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
|
Combines an encoder, a bottleneck, and a decoder into a single module
|
||||||
skip connections between the Encoder and Decoder. Implements the standard
|
that produces a high-resolution feature map from an input spectrogram.
|
||||||
U-Net style forward pass. Includes automatic input padding to handle
|
Skip connections from each encoder stage are added element-wise to the
|
||||||
various input sizes and a final convolutional block to adjust the output
|
corresponding decoder stage input.
|
||||||
channels.
|
|
||||||
|
|
||||||
This class inherits from `BackboneModel` and implements its `forward`
|
Input spectrograms of arbitrary width are handled automatically: the
|
||||||
method. Instances are typically created using the `build_backbone` factory
|
backbone pads the input so that its dimensions are divisible by
|
||||||
function.
|
``divide_factor`` and removes the padding from the output.
|
||||||
|
|
||||||
|
Instances are typically created via ``build_backbone``.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
input_height : int
|
input_height : int
|
||||||
Expected height of the input spectrogram.
|
Expected height (frequency bins) of the input spectrogram.
|
||||||
out_channels : int
|
out_channels : int
|
||||||
Number of channels in the final output feature map.
|
Number of channels in the output feature map (taken from the
|
||||||
|
decoder's output channel count).
|
||||||
encoder : EncoderProtocol
|
encoder : EncoderProtocol
|
||||||
The instantiated encoder module.
|
The instantiated encoder module.
|
||||||
decoder : DecoderProtocol
|
decoder : DecoderProtocol
|
||||||
The instantiated decoder module.
|
The instantiated decoder module.
|
||||||
bottleneck : BottleneckProtocol
|
bottleneck : BottleneckProtocol
|
||||||
The instantiated bottleneck module.
|
The instantiated bottleneck module.
|
||||||
final_conv : ConvBlock
|
|
||||||
Final convolutional block applied after the decoder.
|
|
||||||
divide_factor : int
|
divide_factor : int
|
||||||
The total downsampling factor (2^depth) applied by the encoder,
|
The total spatial downsampling factor applied by the encoder
|
||||||
used for automatic input padding.
|
(``input_height // encoder.output_height``). The input width is
|
||||||
|
padded to be a multiple of this value before processing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -108,25 +142,19 @@ class UNetBackbone(BackboneModel):
|
|||||||
decoder: DecoderProtocol,
|
decoder: DecoderProtocol,
|
||||||
bottleneck: BottleneckProtocol,
|
bottleneck: BottleneckProtocol,
|
||||||
):
|
):
|
||||||
"""Initialize the Backbone network.
|
"""Initialise the backbone network.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
input_height : int
|
input_height : int
|
||||||
Expected height of the input spectrogram.
|
Expected height (frequency bins) of the input spectrogram.
|
||||||
out_channels : int
|
|
||||||
Desired number of output channels for the backbone's feature map.
|
|
||||||
encoder : EncoderProtocol
|
encoder : EncoderProtocol
|
||||||
An initialized Encoder module.
|
An initialised encoder module.
|
||||||
decoder : DecoderProtocol
|
decoder : DecoderProtocol
|
||||||
An initialized Decoder module.
|
An initialised decoder module. Its ``output_height`` must equal
|
||||||
|
``input_height``; a ``ValueError`` is raised otherwise.
|
||||||
bottleneck : BottleneckProtocol
|
bottleneck : BottleneckProtocol
|
||||||
An initialized Bottleneck module.
|
An initialised bottleneck module.
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If component output/input channels or heights are incompatible.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
@ -143,22 +171,25 @@ class UNetBackbone(BackboneModel):
|
|||||||
self.divide_factor = input_height // self.encoder.output_height
|
self.divide_factor = input_height // self.encoder.output_height
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
"""Perform the forward pass through the encoder-decoder backbone.
|
"""Produce a feature map from an input spectrogram.
|
||||||
|
|
||||||
Applies padding, runs encoder, bottleneck, decoder (with skip
|
Pads the input if necessary, runs it through the encoder, then
|
||||||
connections), removes padding, and applies a final convolution.
|
the bottleneck, then the decoder (incorporating encoder skip
|
||||||
|
connections), and finally removes any padding added earlier.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match
|
Input spectrogram tensor, shape
|
||||||
`self.encoder.input_channels` and `self.input_height`.
|
``(B, C_in, H_in, W_in)``. ``H_in`` must equal
|
||||||
|
``self.input_height``; ``W_in`` can be any positive integer.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where
|
Feature map tensor, shape ``(B, C_out, H_in, W_in)``, where
|
||||||
`C_out` is `self.out_channels`.
|
``C_out`` is ``self.out_channels``. The spatial dimensions
|
||||||
|
always match those of the input.
|
||||||
"""
|
"""
|
||||||
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
|
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
|
||||||
|
|
||||||
@ -219,6 +250,24 @@ BackboneConfig = Annotated[
|
|||||||
|
|
||||||
|
|
||||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||||
|
"""Build a backbone network from configuration.
|
||||||
|
|
||||||
|
Looks up the backbone class corresponding to ``config.name`` in the
|
||||||
|
backbone registry and calls its ``from_config`` method. If no
|
||||||
|
configuration is provided, a default ``UNetBackbone`` is returned.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : BackboneConfig, optional
|
||||||
|
A configuration object describing the desired backbone. Currently
|
||||||
|
``UNetBackboneConfig`` is the only supported type. Defaults to
|
||||||
|
``UNetBackboneConfig()`` if not provided.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BackboneModel
|
||||||
|
An initialised backbone module.
|
||||||
|
"""
|
||||||
config = config or UNetBackboneConfig()
|
config = config or UNetBackboneConfig()
|
||||||
return backbone_registry.build(config)
|
return backbone_registry.build(config)
|
||||||
|
|
||||||
@ -227,26 +276,25 @@ def _pad_adjust(
|
|||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
factor: int = 32,
|
factor: int = 32,
|
||||||
) -> Tuple[torch.Tensor, int, int]:
|
) -> Tuple[torch.Tensor, int, int]:
|
||||||
"""Pad tensor height and width to be divisible by a factor.
|
"""Pad a tensor's height and width to be divisible by ``factor``.
|
||||||
|
|
||||||
Calculates the required padding for the last two dimensions (H, W) to make
|
Adds zero-padding to the bottom and right edges of the tensor so that
|
||||||
them divisible by `factor` and applies right/bottom padding using
|
both dimensions are exact multiples of ``factor``. If both dimensions
|
||||||
`torch.nn.functional.pad`.
|
are already divisible, the tensor is returned unchanged.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Input tensor, typically shape `(B, C, H, W)`.
|
Input tensor, typically shape ``(B, C, H, W)``.
|
||||||
factor : int, default=32
|
factor : int, default=32
|
||||||
The factor to make height and width divisible by.
|
The factor that both H and W should be divisible by after padding.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Tuple[torch.Tensor, int, int]
|
Tuple[torch.Tensor, int, int]
|
||||||
A tuple containing:
|
- Padded tensor.
|
||||||
- The padded tensor.
|
- Number of rows added to the height (``h_pad``).
|
||||||
- The amount of padding added to height (`h_pad`).
|
- Number of columns added to the width (``w_pad``).
|
||||||
- The amount of padding added to width (`w_pad`).
|
|
||||||
"""
|
"""
|
||||||
h, w = spec.shape[-2:]
|
h, w = spec.shape[-2:]
|
||||||
h_pad = -h % factor
|
h_pad = -h % factor
|
||||||
@ -261,23 +309,25 @@ def _pad_adjust(
|
|||||||
def _restore_pad(
|
def _restore_pad(
|
||||||
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Remove padding added by _pad_adjust.
|
"""Remove padding previously added by ``_pad_adjust``.
|
||||||
|
|
||||||
Removes padding from the bottom and right edges of the tensor.
|
Trims ``h_pad`` rows from the bottom and ``w_pad`` columns from the
|
||||||
|
right of the tensor, restoring its original spatial dimensions.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Padded tensor, typically shape `(B, C, H_padded, W_padded)`.
|
Padded tensor, typically shape ``(B, C, H_padded, W_padded)``.
|
||||||
h_pad : int, default=0
|
h_pad : int, default=0
|
||||||
Amount of padding previously added to the height (bottom).
|
Number of rows to remove from the bottom.
|
||||||
w_pad : int, default=0
|
w_pad : int, default=0
|
||||||
Amount of padding previously added to the width (right).
|
Number of columns to remove from the right.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
|
Tensor with padding removed, shape
|
||||||
|
``(B, C, H_padded - h_pad, W_padded - w_pad)``.
|
||||||
"""
|
"""
|
||||||
if h_pad > 0:
|
if h_pad > 0:
|
||||||
x = x[..., :-h_pad, :]
|
x = x[..., :-h_pad, :]
|
||||||
@ -292,6 +342,36 @@ def load_backbone_config(
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: str | None = None,
|
field: str | None = None,
|
||||||
) -> BackboneConfig:
|
) -> BackboneConfig:
|
||||||
|
"""Load a backbone configuration from a YAML or JSON file.
|
||||||
|
|
||||||
|
Reads the file at ``path``, optionally descends into a named sub-field,
|
||||||
|
and validates the result against the ``BackboneConfig`` discriminated
|
||||||
|
union.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file. Both YAML and JSON formats are
|
||||||
|
supported.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated key path to the sub-field that contains the backbone
|
||||||
|
configuration (e.g. ``"model"``). If ``None``, the root of the
|
||||||
|
file is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BackboneConfig
|
||||||
|
A validated backbone configuration object (currently always a
|
||||||
|
``UNetBackboneConfig`` instance).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If ``path`` does not exist.
|
||||||
|
ValidationError
|
||||||
|
If the loaded data does not conform to a known ``BackboneConfig``
|
||||||
|
schema.
|
||||||
|
"""
|
||||||
return load_config(
|
return load_config(
|
||||||
path,
|
path,
|
||||||
schema=TypeAdapter(BackboneConfig),
|
schema=TypeAdapter(BackboneConfig),
|
||||||
|
|||||||
@ -1,30 +1,49 @@
|
|||||||
"""Commonly used neural network building blocks for BatDetect2 models.
|
"""Reusable convolutional building blocks for BatDetect2 models.
|
||||||
|
|
||||||
This module provides various reusable `torch.nn.Module` subclasses that form
|
This module provides a collection of ``torch.nn.Module`` subclasses that form
|
||||||
the fundamental building blocks for constructing convolutional neural network
|
the fundamental building blocks for the encoder-decoder backbone used in
|
||||||
architectures, particularly encoder-decoder backbones used in BatDetect2.
|
BatDetect2. All blocks follow a consistent interface: they store
|
||||||
|
``in_channels`` and ``out_channels`` as attributes and implement a
|
||||||
|
``get_output_height`` method that reports how a given input height maps to an
|
||||||
|
output height (e.g., halved by downsampling blocks, doubled by upsampling
|
||||||
|
blocks).
|
||||||
|
|
||||||
It includes standard components like basic convolutional blocks (`ConvBlock`),
|
Available block families
|
||||||
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
|
------------------------
|
||||||
upsampling (`StandardConvUpBlock`).
|
Standard blocks
|
||||||
|
``ConvBlock`` – convolution + batch normalisation + ReLU, no change in
|
||||||
|
spatial resolution.
|
||||||
|
|
||||||
Additionally, it features specialized layers investigated in BatDetect2
|
Downsampling blocks
|
||||||
research:
|
``StandardConvDownBlock`` – convolution then 2×2 max-pooling, halves H
|
||||||
|
and W.
|
||||||
|
``FreqCoordConvDownBlock`` – like ``StandardConvDownBlock`` but prepends
|
||||||
|
a normalised frequency-coordinate channel before the convolution
|
||||||
|
(CoordConv concept), helping filters learn frequency-position-dependent
|
||||||
|
patterns.
|
||||||
|
|
||||||
- `SelfAttention`: Applies self-attention along the time dimension, enabling
|
Upsampling blocks
|
||||||
the model to weigh information across the entire temporal context, often
|
``StandardConvUpBlock`` – bilinear interpolation then convolution,
|
||||||
used in the bottleneck of an encoder-decoder.
|
doubles H and W.
|
||||||
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
|
``FreqCoordConvUpBlock`` – like ``StandardConvUpBlock`` but prepends a
|
||||||
concept by concatenating normalized frequency coordinate information as an
|
frequency-coordinate channel after upsampling.
|
||||||
extra channel to the input of convolutional layers. This explicitly provides
|
|
||||||
spatial frequency information to filters, potentially enabling them to learn
|
|
||||||
frequency-dependent patterns more effectively.
|
|
||||||
|
|
||||||
These blocks can be used directly in custom PyTorch model definitions or
|
Bottleneck blocks
|
||||||
assembled into larger architectures.
|
``VerticalConv`` – 1-D convolution whose kernel spans the entire
|
||||||
|
frequency axis, collapsing H to 1 whilst preserving W.
|
||||||
|
``SelfAttention`` – scaled dot-product self-attention along the time
|
||||||
|
axis; typically follows a ``VerticalConv``.
|
||||||
|
|
||||||
A unified factory function `build_layer` allows creating instances
|
Group block
|
||||||
of these blocks based on configuration objects.
|
``LayerGroup`` – chains several blocks sequentially into one unit,
|
||||||
|
useful when a single encoder or decoder "stage" requires more than one
|
||||||
|
operation.
|
||||||
|
|
||||||
|
Factory function
|
||||||
|
----------------
|
||||||
|
``build_layer`` creates any of the above blocks from the matching
|
||||||
|
configuration object (one of the ``*Config`` classes exported here), using
|
||||||
|
a discriminated-union ``name`` field to dispatch to the correct class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List, Literal, Tuple, Union
|
from typing import Annotated, List, Literal, Tuple, Union
|
||||||
@ -57,10 +76,43 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
|
"""Abstract base class for all BatDetect2 building blocks.
|
||||||
|
|
||||||
|
Subclasses must set ``in_channels`` and ``out_channels`` as integer
|
||||||
|
attributes so that factory functions can wire blocks together without
|
||||||
|
inspecting configuration objects at runtime. They may also override
|
||||||
|
``get_output_height`` when the block changes the height dimension (e.g.
|
||||||
|
downsampling or upsampling blocks).
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of channels expected in the input tensor.
|
||||||
|
out_channels : int
|
||||||
|
Number of channels produced in the output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
in_channels: int
|
in_channels: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
def get_output_height(self, input_height: int) -> int:
|
||||||
|
"""Return the output height for a given input height.
|
||||||
|
|
||||||
|
The default implementation returns ``input_height`` unchanged,
|
||||||
|
which is correct for blocks that do not alter spatial resolution.
|
||||||
|
Override this in downsampling (returns ``input_height // 2``) or
|
||||||
|
upsampling (returns ``input_height * 2``) subclasses.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
Height (number of frequency bins) of the input feature map.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
Height of the output feature map.
|
||||||
|
"""
|
||||||
return input_height
|
return input_height
|
||||||
|
|
||||||
|
|
||||||
@ -68,58 +120,65 @@ block_registry: Registry[Block, [int, int]] = Registry("block")
|
|||||||
|
|
||||||
|
|
||||||
class SelfAttentionConfig(BaseConfig):
|
class SelfAttentionConfig(BaseConfig):
|
||||||
|
"""Configuration for a ``SelfAttention`` block.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Discriminator field; always ``"SelfAttention"``.
|
||||||
|
attention_channels : int
|
||||||
|
Dimensionality of the query, key, and value projections.
|
||||||
|
temperature : float
|
||||||
|
Scaling factor applied to the weighted values before the final
|
||||||
|
linear projection. Defaults to ``1``.
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["SelfAttention"] = "SelfAttention"
|
name: Literal["SelfAttention"] = "SelfAttention"
|
||||||
attention_channels: int
|
attention_channels: int
|
||||||
temperature: float = 1
|
temperature: float = 1
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(Block):
|
class SelfAttention(Block):
|
||||||
"""Self-Attention mechanism operating along the time dimension.
|
"""Self-attention block operating along the time axis.
|
||||||
|
|
||||||
This module implements a scaled dot-product self-attention mechanism,
|
Applies a scaled dot-product self-attention mechanism across the time
|
||||||
specifically designed here to operate across the time steps of an input
|
steps of an input feature map. Before attention is computed the height
|
||||||
feature map, typically after spatial dimensions (like frequency) have been
|
dimension (frequency axis) is expected to have been reduced to 1, e.g.
|
||||||
condensed or squeezed.
|
by a preceding ``VerticalConv`` layer.
|
||||||
|
|
||||||
By calculating attention weights between all pairs of time steps, it allows
|
For each time step the block computes query, key, and value projections
|
||||||
the model to capture long-range temporal dependencies and focus on relevant
|
with learned linear weights, then calculates attention weights from the
|
||||||
parts of the sequence. It's often employed in the bottleneck or
|
query–key dot products divided by ``temperature × attention_channels``.
|
||||||
intermediate layers of an encoder-decoder architecture to integrate global
|
The weighted sum of values is projected back to ``in_channels`` via a
|
||||||
temporal context.
|
final linear layer, and the height dimension is restored so that the
|
||||||
|
output shape matches the input shape.
|
||||||
The implementation uses linear projections to create query, key, and value
|
|
||||||
representations, computes scaled dot-product attention scores, applies
|
|
||||||
softmax, and produces an output by weighting the values according to the
|
|
||||||
attention scores, followed by a final linear projection. Positional encoding
|
|
||||||
is not explicitly included in this block.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of input channels (features per time step after spatial squeeze).
|
Number of input channels (features per time step). The output will
|
||||||
|
also have ``in_channels`` channels.
|
||||||
attention_channels : int
|
attention_channels : int
|
||||||
Number of channels for the query, key, and value projections. Also the
|
Dimensionality of the query, key, and value projections.
|
||||||
dimension of the output projection's input.
|
|
||||||
temperature : float, default=1.0
|
temperature : float, default=1.0
|
||||||
Scaling factor applied *before* the final projection layer. Can be used
|
Divisor applied together with ``attention_channels`` when scaling
|
||||||
to adjust the sharpness or focus of the attention mechanism, although
|
the dot-product scores before softmax. Larger values produce softer
|
||||||
scaling within the softmax (dividing by sqrt(dim)) is more common for
|
(more uniform) attention distributions.
|
||||||
standard transformers. Here it scales the weighted values.
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
key_fun : nn.Linear
|
key_fun : nn.Linear
|
||||||
Linear layer for key projection.
|
Linear projection for keys.
|
||||||
value_fun : nn.Linear
|
value_fun : nn.Linear
|
||||||
Linear layer for value projection.
|
Linear projection for values.
|
||||||
query_fun : nn.Linear
|
query_fun : nn.Linear
|
||||||
Linear layer for query projection.
|
Linear projection for queries.
|
||||||
pro_fun : nn.Linear
|
pro_fun : nn.Linear
|
||||||
Final linear projection layer applied after attention weighting.
|
Final linear projection applied to the attended values.
|
||||||
temperature : float
|
temperature : float
|
||||||
Scaling factor applied before final projection.
|
Scaling divisor used when computing attention scores.
|
||||||
att_dim : int
|
att_dim : int
|
||||||
Dimensionality of the attention space (`attention_channels`).
|
Dimensionality of the attention space (``attention_channels``).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -148,20 +207,16 @@ class SelfAttention(Block):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input tensor, expected shape `(B, C, H, W)`, where H is typically
|
Input tensor with shape ``(B, C, 1, W)``. The height dimension
|
||||||
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
|
must be 1 (i.e. the frequency axis should already have been
|
||||||
applying attention along the W (time) dimension.
|
collapsed by a preceding ``VerticalConv`` layer).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Output tensor of the same shape as the input `(B, C, H, W)`, where
|
Output tensor with the same shape ``(B, C, 1, W)`` as the
|
||||||
attention has been applied across the W dimension.
|
input, with each time step updated by attended context from all
|
||||||
|
other time steps.
|
||||||
Raises
|
|
||||||
------
|
|
||||||
RuntimeError
|
|
||||||
If input tensor dimensions are incompatible with operations.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x = x.squeeze(2).permute(0, 2, 1)
|
x = x.squeeze(2).permute(0, 2, 1)
|
||||||
@ -190,6 +245,22 @@ class SelfAttention(Block):
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Return the softmax attention weight matrix.
|
||||||
|
|
||||||
|
Useful for visualising which time steps attend to which others.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor with shape ``(B, C, 1, W)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Attention weight matrix with shape ``(B, W, W)``. Entry
|
||||||
|
``[b, i, j]`` is the attention weight that time step ``i``
|
||||||
|
assigns to time step ``j`` in batch item ``b``.
|
||||||
|
"""
|
||||||
x = x.squeeze(2).permute(0, 2, 1)
|
x = x.squeeze(2).permute(0, 2, 1)
|
||||||
|
|
||||||
key = torch.matmul(
|
key = torch.matmul(
|
||||||
@ -304,6 +375,16 @@ class ConvBlock(Block):
|
|||||||
|
|
||||||
|
|
||||||
class VerticalConvConfig(BaseConfig):
|
class VerticalConvConfig(BaseConfig):
|
||||||
|
"""Configuration for a ``VerticalConv`` block.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Discriminator field; always ``"VerticalConv"``.
|
||||||
|
channels : int
|
||||||
|
Number of output channels produced by the vertical convolution.
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["VerticalConv"] = "VerticalConv"
|
name: Literal["VerticalConv"] = "VerticalConv"
|
||||||
channels: int
|
channels: int
|
||||||
|
|
||||||
@ -844,12 +925,53 @@ LayerConfig = Annotated[
|
|||||||
|
|
||||||
|
|
||||||
class LayerGroupConfig(BaseConfig):
|
class LayerGroupConfig(BaseConfig):
|
||||||
|
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
|
||||||
|
|
||||||
|
Use this when a single encoder or decoder stage needs more than one
|
||||||
|
block. The blocks are executed in the order they appear in ``layers``,
|
||||||
|
with channel counts and heights propagated automatically.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Discriminator field; always ``"LayerGroup"``.
|
||||||
|
layers : List[LayerConfig]
|
||||||
|
Ordered list of block configurations to chain together.
|
||||||
|
"""
|
||||||
|
|
||||||
name: Literal["LayerGroup"] = "LayerGroup"
|
name: Literal["LayerGroup"] = "LayerGroup"
|
||||||
layers: List[LayerConfig]
|
layers: List[LayerConfig]
|
||||||
|
|
||||||
|
|
||||||
class LayerGroup(nn.Module):
|
class LayerGroup(nn.Module):
|
||||||
"""Standard implementation of the `LayerGroup` architecture."""
|
"""Sequential chain of blocks that acts as a single composite block.
|
||||||
|
|
||||||
|
Wraps multiple ``Block`` instances in an ``nn.Sequential`` container,
|
||||||
|
exposing the same ``in_channels``, ``out_channels``, and
|
||||||
|
``get_output_height`` interface as a regular ``Block`` so it can be
|
||||||
|
used transparently wherever a single block is expected.
|
||||||
|
|
||||||
|
Instances are typically constructed by ``build_layer`` when given a
|
||||||
|
``LayerGroupConfig``; you rarely need to create them directly.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
layers : list[Block]
|
||||||
|
Pre-built block instances to chain, in execution order.
|
||||||
|
input_height : int
|
||||||
|
Height of the tensor entering the first block.
|
||||||
|
input_channels : int
|
||||||
|
Number of channels in the tensor entering the first block.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of input channels (taken from the first block).
|
||||||
|
out_channels : int
|
||||||
|
Number of output channels (taken from the last block).
|
||||||
|
layers : nn.Sequential
|
||||||
|
The wrapped sequence of block modules.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -865,9 +987,33 @@ class LayerGroup(nn.Module):
|
|||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Pass input through all blocks in sequence.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input feature map, shape ``(B, C_in, H, W)``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output feature map after all blocks have been applied.
|
||||||
|
"""
|
||||||
return self.layers(x)
|
return self.layers(x)
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
def get_output_height(self, input_height: int) -> int:
|
||||||
|
"""Compute the output height by propagating through all blocks.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
Height of the input feature map.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
Height after all blocks in the group have been applied.
|
||||||
|
"""
|
||||||
for block in self.layers:
|
for block in self.layers:
|
||||||
input_height = block.get_output_height(input_height) # type: ignore
|
input_height = block.get_output_height(input_height) # type: ignore
|
||||||
return input_height
|
return input_height
|
||||||
@ -903,40 +1049,37 @@ def build_layer(
|
|||||||
in_channels: int,
|
in_channels: int,
|
||||||
config: LayerConfig,
|
config: LayerConfig,
|
||||||
) -> Block:
|
) -> Block:
|
||||||
"""Factory function to build a specific nn.Module block from its config.
|
"""Build a block from its configuration object.
|
||||||
|
|
||||||
Takes configuration object (one of the types included in the `LayerConfig`
|
Looks up the block class corresponding to ``config.name`` in the
|
||||||
union) and instantiates the corresponding nn.Module block with the correct
|
internal block registry and instantiates it with the given input
|
||||||
parameters derived from the config and the current pipeline state
|
dimensions. This is the standard way to construct blocks when
|
||||||
(`input_height`, `in_channels`).
|
assembling an encoder or decoder from a configuration file.
|
||||||
|
|
||||||
It uses the `name` field within the `config` object to determine
|
|
||||||
which block class to instantiate.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (frequency bins) of the input tensor *to this layer*.
|
Height (number of frequency bins) of the input tensor to this
|
||||||
|
block. Required for blocks whose kernel size depends on the input
|
||||||
|
height (e.g. ``VerticalConv``) and for coordinate-aware blocks.
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input tensor *to this layer*.
|
Number of channels in the input tensor to this block.
|
||||||
config : LayerConfig
|
config : LayerConfig
|
||||||
A Pydantic configuration object for the desired block (e.g., an
|
A configuration object for the desired block type. The ``name``
|
||||||
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
field selects the block class; remaining fields supply its
|
||||||
by its `name` field.
|
parameters.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Tuple[nn.Module, int, int]
|
Block
|
||||||
A tuple containing:
|
An initialised block module ready to be added to an
|
||||||
- The instantiated `nn.Module` block.
|
``nn.Sequential`` or ``nn.ModuleList``.
|
||||||
- The number of output channels produced by the block.
|
|
||||||
- The calculated height of the output produced by the block.
|
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
NotImplementedError
|
KeyError
|
||||||
If the `config.name` does not correspond to a known block type.
|
If ``config.name`` does not correspond to a registered block type.
|
||||||
ValueError
|
ValueError
|
||||||
If parameters derived from the config are invalid for the block.
|
If the configuration parameters are invalid for the chosen block.
|
||||||
"""
|
"""
|
||||||
return block_registry.build(config, in_channels, input_height)
|
return block_registry.build(config, in_channels, input_height)
|
||||||
|
|||||||
@ -1,17 +1,21 @@
|
|||||||
"""Defines the Bottleneck component of an Encoder-Decoder architecture.
|
"""Bottleneck component for encoder-decoder network architectures.
|
||||||
|
|
||||||
This module provides the configuration (`BottleneckConfig`) and
|
The bottleneck sits between the encoder (downsampling path) and the decoder
|
||||||
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the
|
(upsampling path) and processes the lowest-resolution, highest-channel feature
|
||||||
bottleneck layer(s) that typically connect the Encoder (downsampling path) and
|
map produced by the encoder.
|
||||||
Decoder (upsampling path) in networks like U-Nets.
|
|
||||||
|
|
||||||
The bottleneck processes the lowest-resolution, highest-dimensionality feature
|
This module provides:
|
||||||
map produced by the Encoder. This module offers a configurable option to include
|
|
||||||
a `SelfAttention` layer within the bottleneck, allowing the model to capture
|
|
||||||
global temporal context before features are passed to the Decoder.
|
|
||||||
|
|
||||||
A factory function `build_bottleneck` constructs the appropriate bottleneck
|
- ``BottleneckConfig`` – configuration dataclass describing the number of
|
||||||
module based on the provided configuration.
|
internal channels and an optional sequence of additional layers (currently
|
||||||
|
only ``SelfAttention`` is supported).
|
||||||
|
- ``Bottleneck`` – the ``torch.nn.Module`` implementation. It first applies a
|
||||||
|
``VerticalConv`` to collapse the frequency axis to a single bin, optionally
|
||||||
|
runs one or more additional layers (e.g. self-attention along the time axis),
|
||||||
|
then repeats the output along the height dimension to restore the original
|
||||||
|
frequency resolution before passing features to the decoder.
|
||||||
|
- ``build_bottleneck`` – factory function that constructs a ``Bottleneck``
|
||||||
|
instance from a ``BottleneckConfig`` and the encoder's output dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List
|
from typing import Annotated, List
|
||||||
@ -37,42 +41,51 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class Bottleneck(Block):
|
class Bottleneck(Block):
|
||||||
"""Base Bottleneck module for Encoder-Decoder architectures.
|
"""Bottleneck module for encoder-decoder architectures.
|
||||||
|
|
||||||
This implementation represents the simplest bottleneck structure
|
Processes the lowest-resolution feature map that links the encoder and
|
||||||
considered, primarily consisting of a `VerticalConv` layer. This layer
|
decoder. The sequence of operations is:
|
||||||
collapses the frequency dimension (height) to 1, summarizing information
|
|
||||||
across frequencies at each time step. The output is then repeated along the
|
|
||||||
height dimension to match the original bottleneck input height before being
|
|
||||||
passed to the decoder.
|
|
||||||
|
|
||||||
This base version does *not* include self-attention.
|
1. ``VerticalConv`` – collapses the frequency axis (height) to a single
|
||||||
|
bin by applying a convolution whose kernel spans the full height.
|
||||||
|
2. Optional additional layers (e.g. ``SelfAttention``) – applied while
|
||||||
|
the feature map has height 1, so they operate purely along the time
|
||||||
|
axis.
|
||||||
|
3. Height restoration – the single-bin output is repeated along the
|
||||||
|
height axis to restore the original frequency resolution, producing
|
||||||
|
a tensor that the decoder can accept.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (frequency bins) of the input tensor. Must be positive.
|
Height (number of frequency bins) of the input tensor. Must be
|
||||||
|
positive.
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input tensor from the encoder. Must be
|
Number of channels in the input tensor from the encoder. Must be
|
||||||
positive.
|
positive.
|
||||||
out_channels : int
|
out_channels : int
|
||||||
Number of output channels. Must be positive.
|
Number of output channels after the bottleneck. Must be positive.
|
||||||
|
bottleneck_channels : int, optional
|
||||||
|
Number of internal channels used by the ``VerticalConv`` layer.
|
||||||
|
Defaults to ``out_channels`` if not provided.
|
||||||
|
layers : List[torch.nn.Module], optional
|
||||||
|
Additional modules (e.g. ``SelfAttention``) to apply after the
|
||||||
|
``VerticalConv`` and before height restoration.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of input channels accepted by the bottleneck.
|
Number of input channels accepted by the bottleneck.
|
||||||
|
out_channels : int
|
||||||
|
Number of output channels produced by the bottleneck.
|
||||||
input_height : int
|
input_height : int
|
||||||
Expected height of the input tensor.
|
Expected height of the input tensor.
|
||||||
channels : int
|
bottleneck_channels : int
|
||||||
Number of output channels.
|
Number of channels used internally by the vertical convolution.
|
||||||
conv_vert : VerticalConv
|
conv_vert : VerticalConv
|
||||||
The vertical convolution layer.
|
The vertical convolution layer.
|
||||||
|
layers : nn.ModuleList
|
||||||
Raises
|
Additional layers applied after the vertical convolution.
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -83,7 +96,23 @@ class Bottleneck(Block):
|
|||||||
bottleneck_channels: int | None = None,
|
bottleneck_channels: int | None = None,
|
||||||
layers: List[torch.nn.Module] | None = None,
|
layers: List[torch.nn.Module] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the base Bottleneck layer."""
|
"""Initialise the Bottleneck layer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
Height (number of frequency bins) of the input tensor.
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor.
|
||||||
|
out_channels : int
|
||||||
|
Number of channels in the output tensor.
|
||||||
|
bottleneck_channels : int, optional
|
||||||
|
Number of internal channels for the ``VerticalConv``. Defaults
|
||||||
|
to ``out_channels``.
|
||||||
|
layers : List[torch.nn.Module], optional
|
||||||
|
Additional modules applied after the ``VerticalConv``, such as
|
||||||
|
a ``SelfAttention`` block.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
@ -103,23 +132,24 @@ class Bottleneck(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Process input features through the bottleneck.
|
"""Process the encoder's bottleneck features.
|
||||||
|
|
||||||
Applies vertical convolution and repeats the output height.
|
Applies vertical convolution, optional additional layers, then
|
||||||
|
restores the height dimension by repetition.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input tensor from the encoder bottleneck, shape
|
Input tensor from the encoder, shape
|
||||||
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
``(B, C_in, H_in, W)``. ``C_in`` must match
|
||||||
`H_in` must match `self.input_height`.
|
``self.in_channels`` and ``H_in`` must match
|
||||||
|
``self.input_height``.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height
|
Output tensor with shape ``(B, C_out, H_in, W)``. The height
|
||||||
dimension `H_in` is restored via repetition after the vertical
|
``H_in`` is restored by repeating the single-bin result.
|
||||||
convolution.
|
|
||||||
"""
|
"""
|
||||||
x = self.conv_vert(x)
|
x = self.conv_vert(x)
|
||||||
|
|
||||||
@ -133,28 +163,22 @@ BottleneckLayerConfig = Annotated[
|
|||||||
SelfAttentionConfig,
|
SelfAttentionConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in the Bottleneck."""
|
||||||
|
|
||||||
|
|
||||||
class BottleneckConfig(BaseConfig):
|
class BottleneckConfig(BaseConfig):
|
||||||
"""Configuration for the bottleneck layer(s).
|
"""Configuration for the bottleneck component.
|
||||||
|
|
||||||
Defines the number of channels within the bottleneck and whether to include
|
|
||||||
a self-attention mechanism.
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
channels : int
|
channels : int
|
||||||
The number of output channels produced by the main convolutional layer
|
Number of output channels produced by the bottleneck. This value
|
||||||
within the bottleneck. This often matches the number of channels coming
|
is also used as the dimensionality of any optional layers (e.g.
|
||||||
from the last encoder stage, but can be different. Must be positive.
|
self-attention). Must be positive.
|
||||||
This also defines the channel dimensions used within the optional
|
layers : List[BottleneckLayerConfig]
|
||||||
`SelfAttention` layer.
|
Ordered list of additional block configurations to apply after the
|
||||||
self_attention : bool
|
initial ``VerticalConv``. Currently only ``SelfAttentionConfig`` is
|
||||||
If True, includes a `SelfAttention` layer operating on the time
|
supported. Defaults to an empty list (no extra layers).
|
||||||
dimension after an initial `VerticalConv` layer within the bottleneck.
|
|
||||||
If False, only the initial `VerticalConv` (and height repetition) is
|
|
||||||
performed.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
channels: int
|
channels: int
|
||||||
@ -174,30 +198,37 @@ def build_bottleneck(
|
|||||||
in_channels: int,
|
in_channels: int,
|
||||||
config: BottleneckConfig | None = None,
|
config: BottleneckConfig | None = None,
|
||||||
) -> BottleneckProtocol:
|
) -> BottleneckProtocol:
|
||||||
"""Factory function to build the Bottleneck module from configuration.
|
"""Build a ``Bottleneck`` module from configuration.
|
||||||
|
|
||||||
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
Constructs a ``Bottleneck`` instance whose internal channel count and
|
||||||
on the `config.self_attention` flag.
|
optional extra layers (e.g. self-attention) are controlled by
|
||||||
|
``config``. If no configuration is provided, the default
|
||||||
|
``DEFAULT_BOTTLENECK_CONFIG`` is used, which includes a
|
||||||
|
``SelfAttention`` layer.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (frequency bins) of the input tensor. Must be positive.
|
Height (number of frequency bins) of the input tensor from the
|
||||||
|
encoder. Must be positive.
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input tensor. Must be positive.
|
Number of channels in the input tensor from the encoder. Must be
|
||||||
|
positive.
|
||||||
config : BottleneckConfig, optional
|
config : BottleneckConfig, optional
|
||||||
Configuration object specifying the bottleneck channels and whether
|
Configuration specifying the output channel count and any
|
||||||
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None.
|
additional layers. Uses ``DEFAULT_BOTTLENECK_CONFIG`` if ``None``.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
nn.Module
|
BottleneckProtocol
|
||||||
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`).
|
An initialised ``Bottleneck`` module.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
AssertionError
|
||||||
If `input_height` or `in_channels` are not positive.
|
If any configured layer changes the height of the feature map
|
||||||
|
(bottleneck layers must preserve height so that it can be restored
|
||||||
|
by repetition).
|
||||||
"""
|
"""
|
||||||
config = config or DEFAULT_BOTTLENECK_CONFIG
|
config = config or DEFAULT_BOTTLENECK_CONFIG
|
||||||
|
|
||||||
|
|||||||
@ -1,21 +1,21 @@
|
|||||||
"""Constructs the Decoder part of an Encoder-Decoder neural network.
|
"""Decoder (upsampling path) for the BatDetect2 backbone.
|
||||||
|
|
||||||
This module defines the configuration structure (`DecoderConfig`) for the layer
|
This module defines ``DecoderConfig`` and the ``Decoder`` ``nn.Module``,
|
||||||
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory
|
together with the ``build_decoder`` factory function.
|
||||||
function (`build_decoder`). Decoders typically form the upsampling path in
|
|
||||||
architectures like U-Nets, taking bottleneck features
|
|
||||||
(usually from an `Encoder`) and skip connections to reconstruct
|
|
||||||
higher-resolution feature maps.
|
|
||||||
|
|
||||||
The decoder is built dynamically by stacking neural network blocks based on a
|
In a U-Net-style network the decoder progressively restores the spatial
|
||||||
list of configuration objects provided in `DecoderConfig.layers`. Each config
|
resolution of the feature map back towards the input resolution. At each
|
||||||
object specifies the type of block (e.g., standard convolution,
|
stage it combines the upsampled features with the corresponding skip-connection
|
||||||
coordinate-feature convolution with upsampling) and its parameters. This allows
|
tensor from the encoder (the residual) by element-wise addition before passing
|
||||||
flexible definition of decoder architectures via configuration files.
|
the result to the upsampling block.
|
||||||
|
|
||||||
The `Decoder`'s `forward` method is designed to accept skip connection tensors
|
The decoder is fully configurable: the type, number, and parameters of the
|
||||||
(`residuals`) from the encoder, merging them with the upsampled feature maps
|
upsampling blocks are described by a ``DecoderConfig`` object containing an
|
||||||
at each stage.
|
ordered list of block configuration objects (see ``batdetect2.models.blocks``
|
||||||
|
for available block types).
|
||||||
|
|
||||||
|
A default configuration ``DEFAULT_DECODER_CONFIG`` is provided and used by
|
||||||
|
``build_decoder`` when no explicit configuration is supplied.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List
|
from typing import Annotated, List
|
||||||
@ -51,51 +51,47 @@ DecoderLayerConfig = Annotated[
|
|||||||
|
|
||||||
|
|
||||||
class DecoderConfig(BaseConfig):
|
class DecoderConfig(BaseConfig):
|
||||||
"""Configuration for the sequence of layers in the Decoder module.
|
"""Configuration for the sequential ``Decoder`` module.
|
||||||
|
|
||||||
Defines the types and parameters of the neural network blocks that
|
|
||||||
constitute the decoder's upsampling path.
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
layers : List[DecoderLayerConfig]
|
layers : List[DecoderLayerConfig]
|
||||||
An ordered list of configuration objects, each defining one layer or
|
Ordered list of block configuration objects defining the decoder's
|
||||||
block in the decoder sequence. Each item must be a valid block
|
upsampling stages (from deepest to shallowest). Each entry
|
||||||
config including a `name` field and necessary parameters like
|
specifies the block type (via its ``name`` field) and any
|
||||||
`out_channels`. Input channels for each layer are inferred sequentially.
|
block-specific parameters such as ``out_channels``. Input channels
|
||||||
The list must contain at least one layer.
|
for each block are inferred automatically from the output of the
|
||||||
|
previous block. Must contain at least one entry.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
layers: List[DecoderLayerConfig] = Field(min_length=1)
|
layers: List[DecoderLayerConfig] = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""Sequential Decoder module composed of configurable upsampling layers.
|
"""Sequential decoder module composed of configurable upsampling layers.
|
||||||
|
|
||||||
Constructs the upsampling path of an encoder-decoder network by stacking
|
Executes a series of upsampling blocks in order, adding the
|
||||||
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`)
|
corresponding encoder skip-connection tensor (residual) to the feature
|
||||||
based on a list of layer modules provided during initialization (typically
|
map before each block. The residuals are consumed in reverse order (from
|
||||||
created by the `build_decoder` factory function).
|
deepest encoder layer to shallowest) to match the spatial resolutions at
|
||||||
|
each decoder stage.
|
||||||
|
|
||||||
The `forward` method is designed to integrate skip connection tensors
|
Instances are typically created by ``build_decoder``.
|
||||||
(`residuals`) from the corresponding encoder stages, by adding them
|
|
||||||
element-wise to the input of each decoder layer before processing.
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels expected in the input tensor.
|
Number of channels expected in the input tensor (bottleneck output).
|
||||||
out_channels : int
|
out_channels : int
|
||||||
Number of channels in the final output tensor produced by the last
|
Number of channels in the final output feature map.
|
||||||
layer.
|
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (frequency bins) expected in the input tensor.
|
Height (frequency bins) of the input tensor.
|
||||||
output_height : int
|
output_height : int
|
||||||
Height (frequency bins) expected in the output tensor.
|
Height (frequency bins) of the output tensor.
|
||||||
layers : nn.ModuleList
|
layers : nn.ModuleList
|
||||||
The sequence of instantiated upscaling layer modules.
|
Sequence of instantiated upsampling block modules.
|
||||||
depth : int
|
depth : int
|
||||||
The number of upscaling layers (depth) in the decoder.
|
Number of upsampling layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -106,23 +102,24 @@ class Decoder(nn.Module):
|
|||||||
output_height: int,
|
output_height: int,
|
||||||
layers: List[nn.Module],
|
layers: List[nn.Module],
|
||||||
):
|
):
|
||||||
"""Initialize the Decoder module.
|
"""Initialise the Decoder module.
|
||||||
|
|
||||||
Note: This constructor is typically called internally by the
|
This constructor is typically called by the ``build_decoder``
|
||||||
`build_decoder` factory function.
|
factory function.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
in_channels : int
|
||||||
|
Number of channels in the input tensor (bottleneck output).
|
||||||
out_channels : int
|
out_channels : int
|
||||||
Number of channels produced by the final layer.
|
Number of channels produced by the final layer.
|
||||||
input_height : int
|
input_height : int
|
||||||
Expected height of the input tensor (bottleneck).
|
Height of the input tensor (bottleneck output height).
|
||||||
in_channels : int
|
output_height : int
|
||||||
Expected number of channels in the input tensor (bottleneck).
|
Height of the output tensor after all layers have been applied.
|
||||||
layers : List[nn.Module]
|
layers : List[nn.Module]
|
||||||
A list of pre-instantiated upscaling layer modules (e.g.,
|
Pre-built upsampling block modules in execution order (deepest
|
||||||
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired
|
stage first).
|
||||||
sequence (from bottleneck towards output resolution).
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -140,43 +137,35 @@ class Decoder(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residuals: List[torch.Tensor],
|
residuals: List[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass input through decoder layers, incorporating skip connections.
|
"""Pass input through all decoder layers, incorporating skip connections.
|
||||||
|
|
||||||
Processes the input tensor `x` sequentially through the upscaling
|
At each stage the corresponding residual tensor is added
|
||||||
layers. At each stage, the corresponding skip connection tensor from
|
element-wise to ``x`` before it is passed to the upsampling block.
|
||||||
the `residuals` list is added element-wise to the input before passing
|
Residuals are consumed in reverse order — the last element of
|
||||||
it to the upscaling block.
|
``residuals`` (the output of the shallowest encoder layer) is added
|
||||||
|
at the first decoder stage, and the first element (output of the
|
||||||
|
deepest encoder layer) is added at the last decoder stage.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input tensor from the previous stage (e.g., encoder bottleneck).
|
Bottleneck feature map, shape ``(B, C_in, H_in, W)``.
|
||||||
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
|
|
||||||
`self.in_channels`.
|
|
||||||
residuals : List[torch.Tensor]
|
residuals : List[torch.Tensor]
|
||||||
List containing the skip connection tensors from the corresponding
|
Skip-connection tensors from the encoder, ordered from shallowest
|
||||||
encoder stages. Should be ordered from the deepest encoder layer
|
(index 0) to deepest (index -1). Must contain exactly
|
||||||
output (lowest resolution) to the shallowest (highest resolution
|
``self.depth`` tensors. Each tensor must have the same spatial
|
||||||
near input). The number of tensors in this list must match the
|
dimensions and channel count as ``x`` at the corresponding
|
||||||
number of decoder layers (`self.depth`). Each residual tensor's
|
decoder stage.
|
||||||
channel count must be compatible with the input tensor `x` for
|
|
||||||
element-wise addition (or concatenation if the blocks were designed
|
|
||||||
for it).
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
The final decoded feature map tensor produced by the last layer.
|
Decoded feature map, shape ``(B, C_out, H_out, W)``.
|
||||||
Shape `(B, C_out, H_out, W_out)`.
|
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If the number of `residuals` provided does not match the decoder
|
If the number of ``residuals`` does not equal ``self.depth``.
|
||||||
depth.
|
|
||||||
RuntimeError
|
|
||||||
If shapes mismatch during skip connection addition or layer
|
|
||||||
processing.
|
|
||||||
"""
|
"""
|
||||||
if len(residuals) != len(self.layers):
|
if len(residuals) != len(self.layers):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -203,11 +192,17 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
"""A default configuration for the Decoder's *layer sequence*.
|
"""Default decoder configuration used in standard BatDetect2 models.
|
||||||
|
|
||||||
Specifies an architecture often used in BatDetect2, consisting of three
|
Mirrors ``DEFAULT_ENCODER_CONFIG`` in reverse. Assumes the bottleneck
|
||||||
frequency coordinate-aware upsampling blocks followed by a standard
|
output has 256 channels and height 16, and produces:
|
||||||
convolutional block.
|
|
||||||
|
- Stage 1 (``FreqCoordConvUp``): 64 channels, height 32.
|
||||||
|
- Stage 2 (``FreqCoordConvUp``): 32 channels, height 64.
|
||||||
|
- Stage 3 (``LayerGroup``):
|
||||||
|
|
||||||
|
- ``FreqCoordConvUp``: 32 channels, height 128.
|
||||||
|
- ``ConvBlock``: 32 channels, height 128 (final feature map).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -216,40 +211,36 @@ def build_decoder(
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
config: DecoderConfig | None = None,
|
config: DecoderConfig | None = None,
|
||||||
) -> Decoder:
|
) -> Decoder:
|
||||||
"""Factory function to build a Decoder instance from configuration.
|
"""Build a ``Decoder`` from configuration.
|
||||||
|
|
||||||
Constructs a sequential `Decoder` module based on the layer sequence
|
Constructs a sequential ``Decoder`` by iterating over the block
|
||||||
defined in a `DecoderConfig` object and the provided input dimensions
|
configurations in ``config.layers``, building each block with
|
||||||
(bottleneck channels and height). If no config is provided, uses the
|
``build_layer``, and tracking the channel count and feature-map height
|
||||||
default layer sequence from `DEFAULT_DECODER_CONFIG`.
|
as they change through the sequence.
|
||||||
|
|
||||||
It iteratively builds the layers using the unified `build_layer_from_config`
|
|
||||||
factory (from `.blocks`), tracking the changing number of channels and
|
|
||||||
feature map height required for each subsequent layer.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
The number of channels in the input tensor to the decoder. Must be > 0.
|
Number of channels in the input tensor (bottleneck output). Must
|
||||||
|
be positive.
|
||||||
input_height : int
|
input_height : int
|
||||||
The height (frequency bins) of the input tensor to the decoder. Must be
|
Height (number of frequency bins) of the input tensor. Must be
|
||||||
> 0.
|
positive.
|
||||||
config : DecoderConfig, optional
|
config : DecoderConfig, optional
|
||||||
The configuration object detailing the sequence of layers and their
|
Configuration specifying the layer sequence. Defaults to
|
||||||
parameters. If None, `DEFAULT_DECODER_CONFIG` is used.
|
``DEFAULT_DECODER_CONFIG`` if not provided.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Decoder
|
Decoder
|
||||||
An initialized `Decoder` module.
|
An initialised ``Decoder`` module.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `in_channels` or `input_height` are not positive, or if the layer
|
If ``in_channels`` or ``input_height`` are not positive.
|
||||||
configuration is invalid (e.g., empty list, unknown `name`).
|
KeyError
|
||||||
NotImplementedError
|
If a layer configuration specifies an unknown block type.
|
||||||
If `build_layer_from_config` encounters an unknown `name`.
|
|
||||||
"""
|
"""
|
||||||
config = config or DEFAULT_DECODER_CONFIG
|
config = config or DEFAULT_DECODER_CONFIG
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,20 @@
|
|||||||
"""Assembles the complete BatDetect2 Detection Model.
|
"""Assembles the complete BatDetect2 detection model.
|
||||||
|
|
||||||
This module defines the concrete `Detector` class, which implements the
|
This module defines the ``Detector`` class, which combines a backbone
|
||||||
`DetectionModel` interface defined in `.types`. It combines a feature
|
feature extractor with prediction heads for detection, classification, and
|
||||||
extraction backbone with specific prediction heads to create the end-to-end
|
bounding-box size regression.
|
||||||
neural network used for detecting bat calls, predicting their size, and
|
|
||||||
classifying them.
|
|
||||||
|
|
||||||
The primary components are:
|
Components
|
||||||
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
----------
|
||||||
|
- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone
|
||||||
|
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
||||||
|
produce a ``ModelOutput`` tuple from an input spectrogram.
|
||||||
|
- ``build_detector`` – factory function that builds a ready-to-use
|
||||||
|
``Detector`` from a backbone configuration and a target class count.
|
||||||
|
|
||||||
This module focuses purely on the neural network architecture definition. The
|
Note that ``Detector`` operates purely on spectrogram tensors; raw audio
|
||||||
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
preprocessing and output postprocessing are handled by
|
||||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
``batdetect2.preprocess`` and ``batdetect2.postprocess`` respectively.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -32,25 +35,30 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class Detector(DetectionModel):
|
class Detector(DetectionModel):
|
||||||
"""Concrete implementation of the BatDetect2 Detection Model.
|
"""Complete BatDetect2 detection and classification model.
|
||||||
|
|
||||||
Assembles a complete detection and classification model by combining a
|
Combines a backbone feature extractor with two prediction heads:
|
||||||
feature extraction backbone network with specific prediction heads for
|
|
||||||
detection probability, bounding box size regression, and class
|
- ``ClassifierHead``: predicts per-class probabilities at each
|
||||||
probabilities.
|
time–frequency location.
|
||||||
|
- ``BBoxHead``: predicts call duration and bandwidth at each location.
|
||||||
|
|
||||||
|
The detection probability map is derived from the class probabilities by
|
||||||
|
summing across the class dimension (i.e. the probability that *any* class
|
||||||
|
is present), rather than from a separate detection head.
|
||||||
|
|
||||||
|
Instances are typically created via ``build_detector``.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
backbone : BackboneModel
|
backbone : BackboneModel
|
||||||
The feature extraction backbone network module.
|
The feature extraction backbone.
|
||||||
num_classes : int
|
num_classes : int
|
||||||
The number of specific target classes the model predicts (derived from
|
Number of target classes (inferred from the classifier head).
|
||||||
the `classifier_head`).
|
|
||||||
classifier_head : ClassifierHead
|
classifier_head : ClassifierHead
|
||||||
The prediction head responsible for generating class probabilities.
|
Produces per-class probability maps from backbone features.
|
||||||
bbox_head : BBoxHead
|
bbox_head : BBoxHead
|
||||||
The prediction head responsible for generating bounding box size
|
Produces duration and bandwidth predictions from backbone features.
|
||||||
predictions.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
backbone: BackboneModel
|
backbone: BackboneModel
|
||||||
@ -61,26 +69,21 @@ class Detector(DetectionModel):
|
|||||||
classifier_head: ClassifierHead,
|
classifier_head: ClassifierHead,
|
||||||
bbox_head: BBoxHead,
|
bbox_head: BBoxHead,
|
||||||
):
|
):
|
||||||
"""Initialize the Detector model.
|
"""Initialise the Detector model.
|
||||||
|
|
||||||
Note: Instances are typically created using the `build_detector`
|
This constructor is typically called by the ``build_detector``
|
||||||
factory function.
|
factory function.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
backbone : BackboneModel
|
backbone : BackboneModel
|
||||||
An initialized feature extraction backbone module (e.g., built by
|
An initialised backbone module (e.g. built by
|
||||||
`build_backbone` from the `.backbone` module).
|
``build_backbone``).
|
||||||
classifier_head : ClassifierHead
|
classifier_head : ClassifierHead
|
||||||
An initialized classification head module. The number of classes
|
An initialised classification head. The ``num_classes``
|
||||||
is inferred from this head.
|
attribute is read from this head.
|
||||||
bbox_head : BBoxHead
|
bbox_head : BBoxHead
|
||||||
An initialized bounding box size prediction head module.
|
An initialised bounding-box size prediction head.
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
TypeError
|
|
||||||
If the provided modules are not of the expected types.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -90,31 +93,34 @@ class Detector(DetectionModel):
|
|||||||
self.bbox_head = bbox_head
|
self.bbox_head = bbox_head
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
"""Perform the forward pass of the complete detection model.
|
"""Run the complete detection model on an input spectrogram.
|
||||||
|
|
||||||
Processes the input spectrogram through the backbone to extract
|
Passes the spectrogram through the backbone to produce a feature
|
||||||
features, then passes these features through the separate prediction
|
map, then applies the classifier and bounding-box heads. The
|
||||||
heads to generate detection probabilities, class probabilities, and
|
detection probability map is derived by summing the per-class
|
||||||
size predictions.
|
probability maps across the class dimension; no separate detection
|
||||||
|
head is used.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Input spectrogram tensor, typically with shape
|
Input spectrogram tensor, shape
|
||||||
`(batch_size, input_channels, frequency_bins, time_bins)`. The
|
``(batch_size, channels, frequency_bins, time_bins)``.
|
||||||
shape must be compatible with the `self.backbone` input
|
|
||||||
requirements.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
ModelOutput
|
ModelOutput
|
||||||
A NamedTuple containing the four output tensors:
|
A named tuple with four fields:
|
||||||
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
|
|
||||||
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`.
|
- ``detection_probs`` – ``(B, 1, H, W)`` – probability that a
|
||||||
- `class_probs`: Class probabilities (excluding background)
|
call of any class is present at each location. Derived by
|
||||||
`(B, num_classes, H, W)`.
|
summing ``class_probs`` over the class dimension.
|
||||||
- `features`: Output feature map from the backbone
|
- ``size_preds`` – ``(B, 2, H, W)`` – scaled duration (channel
|
||||||
`(B, C_out, H, W)`.
|
0) and bandwidth (channel 1) predictions at each location.
|
||||||
|
- ``class_probs`` – ``(B, num_classes, H, W)`` – per-class
|
||||||
|
probabilities at each location.
|
||||||
|
- ``features`` – ``(B, C_out, H, W)`` – raw backbone feature
|
||||||
|
map.
|
||||||
"""
|
"""
|
||||||
features = self.backbone(spec)
|
features = self.backbone(spec)
|
||||||
classification = self.classifier_head(features)
|
classification = self.classifier_head(features)
|
||||||
@ -131,30 +137,33 @@ class Detector(DetectionModel):
|
|||||||
def build_detector(
|
def build_detector(
|
||||||
num_classes: int, config: BackboneConfig | None = None
|
num_classes: int, config: BackboneConfig | None = None
|
||||||
) -> DetectionModel:
|
) -> DetectionModel:
|
||||||
"""Build the complete BatDetect2 detection model.
|
"""Build a complete BatDetect2 detection model.
|
||||||
|
|
||||||
|
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
|
||||||
|
and a ``BBoxHead`` sized to the backbone's output channel count, and
|
||||||
|
returns them wrapped in a ``Detector``.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
num_classes : int
|
num_classes : int
|
||||||
The number of specific target classes the model should predict
|
Number of target bat species or call types to predict. Must be
|
||||||
(required for the `ClassifierHead`). Must be positive.
|
positive.
|
||||||
config : BackboneConfig, optional
|
config : BackboneConfig, optional
|
||||||
Configuration object specifying the architecture of the backbone
|
Backbone architecture configuration. Defaults to
|
||||||
(encoder, bottleneck, decoder). If None, default configurations defined
|
``UNetBackboneConfig()`` (the standard BatDetect2 architecture) if
|
||||||
within the respective builder functions (`build_encoder`, etc.) will be
|
not provided.
|
||||||
used to construct a default backbone architecture.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
DetectionModel
|
DetectionModel
|
||||||
An initialized `Detector` model instance.
|
An initialised ``Detector`` instance ready for training or
|
||||||
|
inference.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `num_classes` is not positive, or if errors occur during the
|
If ``num_classes`` is not positive, or if the backbone
|
||||||
construction of the backbone or detector components (e.g., incompatible
|
configuration is invalid.
|
||||||
configurations, invalid parameters).
|
|
||||||
"""
|
"""
|
||||||
config = config or UNetBackboneConfig()
|
config = config or UNetBackboneConfig()
|
||||||
|
|
||||||
|
|||||||
@ -1,23 +1,24 @@
|
|||||||
"""Constructs the Encoder part of a configurable neural network backbone.
|
"""Encoder (downsampling path) for the BatDetect2 backbone.
|
||||||
|
|
||||||
This module defines the configuration structure (`EncoderConfig`) and provides
|
This module defines ``EncoderConfig`` and the ``Encoder`` ``nn.Module``,
|
||||||
the `Encoder` class (an `nn.Module`) along with a factory function
|
together with the ``build_encoder`` factory function.
|
||||||
(`build_encoder`) to create sequential encoders. Encoders typically form the
|
|
||||||
downsampling path in architectures like U-Nets, processing input feature maps
|
|
||||||
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
|
|
||||||
representations (bottleneck features).
|
|
||||||
|
|
||||||
The encoder is built dynamically by stacking neural network blocks based on a
|
In a U-Net-style network the encoder progressively reduces the spatial
|
||||||
list of configuration objects provided in `EncoderConfig.layers`. Each
|
resolution of the spectrogram whilst increasing the number of feature
|
||||||
configuration object specifies the type of block (e.g., standard convolution,
|
channels. Each layer in the encoder produces a feature map that is stored
|
||||||
coordinate-feature convolution with downsampling) and its parameters
|
for use as a skip connection in the corresponding decoder layer.
|
||||||
(e.g., output channels). This allows for flexible definition of encoder
|
|
||||||
architectures via configuration files.
|
|
||||||
|
|
||||||
The `Encoder`'s `forward` method returns outputs from all intermediate layers,
|
The encoder is fully configurable: the type, number, and parameters of the
|
||||||
suitable for skip connections, while the `encode` method returns only the final
|
downsampling blocks are described by an ``EncoderConfig`` object containing
|
||||||
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
|
an ordered list of block configuration objects (see ``batdetect2.models.blocks``
|
||||||
provided.
|
for available block types).
|
||||||
|
|
||||||
|
``Encoder.forward`` returns the outputs of *all* encoder layers as a list,
|
||||||
|
so that skip connections are available to the decoder.
|
||||||
|
``Encoder.encode`` returns only the final output (the input to the bottleneck).
|
||||||
|
|
||||||
|
A default configuration ``DEFAULT_ENCODER_CONFIG`` is provided and used by
|
||||||
|
``build_encoder`` when no explicit configuration is supplied.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List
|
from typing import Annotated, List
|
||||||
@ -53,35 +54,32 @@ EncoderLayerConfig = Annotated[
|
|||||||
|
|
||||||
|
|
||||||
class EncoderConfig(BaseConfig):
|
class EncoderConfig(BaseConfig):
|
||||||
"""Configuration for building the sequential Encoder module.
|
"""Configuration for the sequential ``Encoder`` module.
|
||||||
|
|
||||||
Defines the sequence of neural network blocks that constitute the encoder
|
|
||||||
(downsampling path).
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
layers : List[EncoderLayerConfig]
|
layers : List[EncoderLayerConfig]
|
||||||
An ordered list of configuration objects, each defining one layer or
|
Ordered list of block configuration objects defining the encoder's
|
||||||
block in the encoder sequence. Each item must be a valid block config
|
downsampling stages. Each entry specifies the block type (via its
|
||||||
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
``name`` field) and any block-specific parameters such as
|
||||||
`StandardConvDownConfig`) including a `name` field and necessary
|
``out_channels``. Input channels for each block are inferred
|
||||||
parameters like `out_channels`. Input channels for each layer are
|
automatically from the output of the previous block. Must contain
|
||||||
inferred sequentially. The list must contain at least one layer.
|
at least one entry.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
layers: List[EncoderLayerConfig] = Field(min_length=1)
|
layers: List[EncoderLayerConfig] = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
"""Sequential Encoder module composed of configurable downscaling layers.
|
"""Sequential encoder module composed of configurable downsampling layers.
|
||||||
|
|
||||||
Constructs the downsampling path of an encoder-decoder network by stacking
|
Executes a series of downsampling blocks in order, storing the output of
|
||||||
multiple downscaling blocks.
|
each block so that it can be passed as a skip connection to the
|
||||||
|
corresponding decoder layer.
|
||||||
|
|
||||||
The `forward` method executes the sequence and returns the output feature
|
``forward`` returns the outputs of *all* layers (useful when skip
|
||||||
map from *each* downscaling stage, facilitating the implementation of skip
|
connections are needed). ``encode`` returns only the final output
|
||||||
connections in U-Net-like architectures. The `encode` method returns only
|
(the input to the bottleneck).
|
||||||
the final output tensor (bottleneck features).
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
@ -89,14 +87,14 @@ class Encoder(nn.Module):
|
|||||||
Number of channels expected in the input tensor.
|
Number of channels expected in the input tensor.
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (frequency bins) expected in the input tensor.
|
Height (frequency bins) expected in the input tensor.
|
||||||
output_channels : int
|
out_channels : int
|
||||||
Number of channels in the final output tensor (bottleneck).
|
Number of channels in the final output tensor (bottleneck input).
|
||||||
output_height : int
|
output_height : int
|
||||||
Height (frequency bins) expected in the output tensor.
|
Height (frequency bins) of the final output tensor.
|
||||||
layers : nn.ModuleList
|
layers : nn.ModuleList
|
||||||
The sequence of instantiated downscaling layer modules.
|
Sequence of instantiated downsampling block modules.
|
||||||
depth : int
|
depth : int
|
||||||
The number of downscaling layers in the encoder.
|
Number of downsampling layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -107,23 +105,22 @@ class Encoder(nn.Module):
|
|||||||
input_height: int = 128,
|
input_height: int = 128,
|
||||||
in_channels: int = 1,
|
in_channels: int = 1,
|
||||||
):
|
):
|
||||||
"""Initialize the Encoder module.
|
"""Initialise the Encoder module.
|
||||||
|
|
||||||
Note: This constructor is typically called internally by the
|
This constructor is typically called by the ``build_encoder`` factory
|
||||||
`build_encoder` factory function, which prepares the `layers` list.
|
function, which takes care of building the ``layers`` list from a
|
||||||
|
configuration object.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
output_channels : int
|
output_channels : int
|
||||||
Number of channels produced by the final layer.
|
Number of channels produced by the final layer.
|
||||||
output_height : int
|
output_height : int
|
||||||
The expected height of the output tensor.
|
Height of the output tensor after all layers have been applied.
|
||||||
layers : List[nn.Module]
|
layers : List[nn.Module]
|
||||||
A list of pre-instantiated downscaling layer modules (e.g.,
|
Pre-built downsampling block modules in execution order.
|
||||||
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
|
|
||||||
sequence.
|
|
||||||
input_height : int, default=128
|
input_height : int, default=128
|
||||||
Expected height of the input tensor.
|
Expected height of the input tensor (frequency bins).
|
||||||
in_channels : int, default=1
|
in_channels : int, default=1
|
||||||
Expected number of channels in the input tensor.
|
Expected number of channels in the input tensor.
|
||||||
"""
|
"""
|
||||||
@ -138,29 +135,30 @@ class Encoder(nn.Module):
|
|||||||
self.depth = len(self.layers)
|
self.depth = len(self.layers)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
"""Pass input through encoder layers, returns all intermediate outputs.
|
"""Pass input through all encoder layers and return every output.
|
||||||
|
|
||||||
This method is typically used when the Encoder is part of a U-Net or
|
Used when skip connections are needed (e.g. in a U-Net decoder).
|
||||||
similar architecture requiring skip connections.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
|
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
|
||||||
`self.in_channels` and `H_in` must match `self.input_height`.
|
``C_in`` must match ``self.in_channels`` and ``H_in`` must
|
||||||
|
match ``self.input_height``.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
List[torch.Tensor]
|
List[torch.Tensor]
|
||||||
A list containing the output tensors from *each* downscaling layer
|
Output tensors from every layer in order.
|
||||||
in the sequence. `outputs[0]` is the output of the first layer,
|
``outputs[0]`` is the output of the first (shallowest) layer;
|
||||||
`outputs[-1]` is the final output (bottleneck) of the encoder.
|
``outputs[-1]`` is the output of the last (deepest) layer,
|
||||||
|
which serves as the input to the bottleneck.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If input tensor channel count or height does not match expected
|
If the input channel count or height does not match the
|
||||||
values.
|
expected values.
|
||||||
"""
|
"""
|
||||||
if x.shape[1] != self.in_channels:
|
if x.shape[1] != self.in_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -183,28 +181,29 @@ class Encoder(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Pass input through encoder layers, returning only the final output.
|
"""Pass input through all encoder layers and return only the final output.
|
||||||
|
|
||||||
This method provides access to the bottleneck features produced after
|
Use this when skip connections are not needed and you only require
|
||||||
the last downscaling layer.
|
the bottleneck feature map.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
|
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
|
||||||
`in_channels` and `input_height`.
|
Must satisfy the same shape requirements as ``forward``.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
The final output tensor (bottleneck features) from the last layer
|
Output of the last encoder layer, shape
|
||||||
of the encoder. Shape `(B, C_out, H_out, W_out)`.
|
``(B, C_out, H_out, W)``, where ``C_out`` is
|
||||||
|
``self.out_channels`` and ``H_out`` is ``self.output_height``.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If input tensor channel count or height does not match expected
|
If the input channel count or height does not match the
|
||||||
values.
|
expected values.
|
||||||
"""
|
"""
|
||||||
if x.shape[1] != self.in_channels:
|
if x.shape[1] != self.in_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -236,14 +235,17 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
"""Default configuration for the Encoder.
|
"""Default encoder configuration used in standard BatDetect2 models.
|
||||||
|
|
||||||
Specifies an architecture typically used in BatDetect2:
|
Assumes a 1-channel input with 128 frequency bins and produces the
|
||||||
- Input: 1 channel, 128 frequency bins.
|
following feature maps:
|
||||||
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
|
|
||||||
- Layer 2: FreqCoordConvDown -> 64 channels, H=32
|
- Stage 1 (``FreqCoordConvDown``): 32 channels, height 64.
|
||||||
- Layer 3: FreqCoordConvDown -> 128 channels, H=16
|
- Stage 2 (``FreqCoordConvDown``): 64 channels, height 32.
|
||||||
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck)
|
- Stage 3 (``LayerGroup``):
|
||||||
|
|
||||||
|
- ``FreqCoordConvDown``: 128 channels, height 16.
|
||||||
|
- ``ConvBlock``: 256 channels, height 16 (bottleneck input).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -252,42 +254,38 @@ def build_encoder(
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
config: EncoderConfig | None = None,
|
config: EncoderConfig | None = None,
|
||||||
) -> Encoder:
|
) -> Encoder:
|
||||||
"""Factory function to build an Encoder instance from configuration.
|
"""Build an ``Encoder`` from configuration.
|
||||||
|
|
||||||
Constructs a sequential `Encoder` module based on the layer sequence
|
Constructs a sequential ``Encoder`` by iterating over the block
|
||||||
defined in an `EncoderConfig` object and the provided input dimensions.
|
configurations in ``config.layers``, building each block with
|
||||||
If no config is provided, uses the default layer sequence from
|
``build_layer``, and tracking the channel count and feature-map height
|
||||||
`DEFAULT_ENCODER_CONFIG`.
|
as they change through the sequence.
|
||||||
|
|
||||||
It iteratively builds the layers using the unified
|
|
||||||
`build_layer_from_config` factory (from `.blocks`), tracking the changing
|
|
||||||
number of channels and feature map height required for each subsequent
|
|
||||||
layer, especially for coordinate- aware blocks.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
The number of channels expected in the input tensor to the encoder.
|
Number of channels in the input spectrogram tensor. Must be
|
||||||
Must be > 0.
|
positive.
|
||||||
input_height : int
|
input_height : int
|
||||||
The height (frequency bins) expected in the input tensor. Must be > 0.
|
Height (number of frequency bins) of the input spectrogram.
|
||||||
Crucial for initializing coordinate-aware layers correctly.
|
Must be positive and should be divisible by
|
||||||
|
``2 ** (number of downsampling stages)`` to avoid size mismatches
|
||||||
|
later in the network.
|
||||||
config : EncoderConfig, optional
|
config : EncoderConfig, optional
|
||||||
The configuration object detailing the sequence of layers and their
|
Configuration specifying the layer sequence. Defaults to
|
||||||
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used.
|
``DEFAULT_ENCODER_CONFIG`` if not provided.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Encoder
|
Encoder
|
||||||
An initialized `Encoder` module.
|
An initialised ``Encoder`` module.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `in_channels` or `input_height` are not positive, or if the layer
|
If ``in_channels`` or ``input_height`` are not positive.
|
||||||
configuration is invalid (e.g., empty list, unknown `name`).
|
KeyError
|
||||||
NotImplementedError
|
If a layer configuration specifies an unknown block type.
|
||||||
If `build_layer_from_config` encounters an unknown `name`.
|
|
||||||
"""
|
"""
|
||||||
if in_channels <= 0 or input_height <= 0:
|
if in_channels <= 0 or input_height <= 0:
|
||||||
raise ValueError("in_channels and input_height must be positive.")
|
raise ValueError("in_channels and input_height must be positive.")
|
||||||
|
|||||||
@ -1,20 +1,19 @@
|
|||||||
"""Prediction Head modules for BatDetect2 models.
|
"""Prediction heads attached to the backbone feature map.
|
||||||
|
|
||||||
This module defines simple `torch.nn.Module` subclasses that serve as
|
Each head is a lightweight ``torch.nn.Module`` that applies a 1×1
|
||||||
prediction heads, typically attached to the output feature map of a backbone
|
convolution to map backbone feature channels to one specific type of
|
||||||
network
|
output required by BatDetect2:
|
||||||
|
|
||||||
Each head is responsible for generating one specific type of output required
|
- ``DetectorHead``: single-channel detection probability heatmap (sigmoid
|
||||||
by the BatDetect2 task:
|
activation).
|
||||||
- `DetectorHead`: Predicts the probability of sound event presence.
|
- ``ClassifierHead``: multi-class probability map over the target bat
|
||||||
- `ClassifierHead`: Predicts the probability distribution over target classes.
|
species / call types (softmax activation).
|
||||||
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding
|
- ``BBoxHead``: two-channel map of predicted call duration (time axis) and
|
||||||
box.
|
bandwidth (frequency axis) at each location (no activation; raw
|
||||||
|
regression output).
|
||||||
|
|
||||||
These heads use 1x1 convolutions to map the backbone feature channels
|
All three heads share the same input feature map produced by the backbone,
|
||||||
to the desired number of output channels for each prediction task at each
|
so they can be evaluated in parallel in a single forward pass.
|
||||||
spatial location, followed by an appropriate activation function (e.g., sigmoid
|
|
||||||
for detection, softmax for classification, none for size regression).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -28,42 +27,35 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class ClassifierHead(nn.Module):
|
class ClassifierHead(nn.Module):
|
||||||
"""Prediction head for multi-class classification probabilities.
|
"""Prediction head for species / call-type classification probabilities.
|
||||||
|
|
||||||
Takes an input feature map and produces a probability map where each
|
Takes a backbone feature map and produces a probability map where each
|
||||||
channel corresponds to a specific target class. It uses a 1x1 convolution
|
channel corresponds to a target class. Internally the 1×1 convolution
|
||||||
to map input channels to `num_classes + 1` outputs (one for each target
|
maps ``in_channels`` to ``num_classes + 1`` logits (the extra channel
|
||||||
class plus an assumed background/generic class), applies softmax across the
|
represents a generic background / unknown category); a softmax is then
|
||||||
channels, and returns the probabilities for the specific target classes
|
applied across the channel dimension and the background channel is
|
||||||
(excluding the last background/generic channel).
|
discarded before returning.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
num_classes : int
|
num_classes : int
|
||||||
The number of specific target classes the model should predict
|
Number of target classes (bat species or call types) to predict,
|
||||||
(excluding any background or generic category). Must be positive.
|
excluding the background category. Must be positive.
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input feature map tensor from the backbone.
|
Number of channels in the backbone feature map. Must be positive.
|
||||||
Must be positive.
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
num_classes : int
|
num_classes : int
|
||||||
Number of specific output classes.
|
Number of specific output classes (background excluded).
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of input channels expected.
|
Number of input channels expected.
|
||||||
classifier : nn.Conv2d
|
classifier : nn.Conv2d
|
||||||
The 1x1 convolutional layer used for prediction.
|
1×1 convolution with ``num_classes + 1`` output channels.
|
||||||
Output channels = num_classes + 1.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `num_classes` or `in_channels` are not positive.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_classes: int, in_channels: int):
|
def __init__(self, num_classes: int, in_channels: int):
|
||||||
"""Initialize the ClassifierHead."""
|
"""Initialise the ClassifierHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
@ -76,20 +68,20 @@ class ClassifierHead(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
"""Compute class probabilities from input features.
|
"""Compute per-class probabilities from backbone features.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
Input feature map tensor from the backbone, typically with shape
|
Backbone feature map, shape ``(B, C_in, H, W)``.
|
||||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Class probability map tensor with shape `(B, num_classes, H, W)`.
|
Class probability map, shape ``(B, num_classes, H, W)``.
|
||||||
Contains probabilities for the specific target classes after
|
Values are softmax probabilities in the range [0, 1] and
|
||||||
softmax, excluding the implicit background/generic class channel.
|
sum to less than 1 per location (the background probability
|
||||||
|
is discarded).
|
||||||
"""
|
"""
|
||||||
logits = self.classifier(features)
|
logits = self.classifier(features)
|
||||||
probs = torch.softmax(logits, dim=1)
|
probs = torch.softmax(logits, dim=1)
|
||||||
@ -97,36 +89,30 @@ class ClassifierHead(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DetectorHead(nn.Module):
|
class DetectorHead(nn.Module):
|
||||||
"""Prediction head for sound event detection probability.
|
"""Prediction head for detection probability (is a call present here?).
|
||||||
|
|
||||||
Takes an input feature map and produces a single-channel heatmap where
|
Produces a single-channel heatmap where each value indicates the
|
||||||
each value represents the probability ([0, 1]) of a relevant sound event
|
probability ([0, 1]) that a bat call of *any* species is present at
|
||||||
(of any class) being present at that spatial location.
|
that time–frequency location in the spectrogram.
|
||||||
|
|
||||||
Uses a 1x1 convolution to map input channels to 1 output channel, followed
|
Applies a 1×1 convolution mapping ``in_channels`` → 1, followed by
|
||||||
by a sigmoid activation function.
|
sigmoid activation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input feature map tensor from the backbone.
|
Number of channels in the backbone feature map. Must be positive.
|
||||||
Must be positive.
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of input channels expected.
|
Number of input channels expected.
|
||||||
detector : nn.Conv2d
|
detector : nn.Conv2d
|
||||||
The 1x1 convolutional layer mapping to a single output channel.
|
1×1 convolution with a single output channel.
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `in_channels` is not positive.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int):
|
def __init__(self, in_channels: int):
|
||||||
"""Initialize the DetectorHead."""
|
"""Initialise the DetectorHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
@ -138,62 +124,49 @@ class DetectorHead(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
"""Compute detection probabilities from input features.
|
"""Compute detection probabilities from backbone features.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
Input feature map tensor from the backbone, typically with shape
|
Backbone feature map, shape ``(B, C_in, H, W)``.
|
||||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Detection probability heatmap tensor with shape `(B, 1, H, W)`.
|
Detection probability heatmap, shape ``(B, 1, H, W)``.
|
||||||
Values are in the range [0, 1] due to the sigmoid activation.
|
Values are in the range [0, 1].
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
RuntimeError
|
|
||||||
If input channel count does not match `self.in_channels`.
|
|
||||||
"""
|
"""
|
||||||
return torch.sigmoid(self.detector(features))
|
return torch.sigmoid(self.detector(features))
|
||||||
|
|
||||||
|
|
||||||
class BBoxHead(nn.Module):
|
class BBoxHead(nn.Module):
|
||||||
"""Prediction head for bounding box size dimensions.
|
"""Prediction head for bounding box size (duration and bandwidth).
|
||||||
|
|
||||||
Takes an input feature map and produces a two-channel map where each
|
Produces a two-channel map where channel 0 predicts the scaled duration
|
||||||
channel represents a predicted size dimension (typically width/duration and
|
(time-axis extent) and channel 1 predicts the scaled bandwidth
|
||||||
height/bandwidth) for a potential sound event at that spatial location.
|
(frequency-axis extent) of the call at each spectrogram location.
|
||||||
|
|
||||||
Uses a 1x1 convolution to map input channels to 2 output channels. No
|
Applies a 1×1 convolution mapping ``in_channels`` → 2 with no
|
||||||
activation function is typically applied, as size prediction is often
|
activation function (raw regression output). The predicted values are
|
||||||
treated as a direct regression task. The output values usually represent
|
in a scaled space and must be converted to real units (seconds and Hz)
|
||||||
*scaled* dimensions that need to be un-scaled during postprocessing.
|
during postprocessing.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input feature map tensor from the backbone.
|
Number of channels in the backbone feature map. Must be positive.
|
||||||
Must be positive.
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of input channels expected.
|
Number of input channels expected.
|
||||||
bbox : nn.Conv2d
|
bbox : nn.Conv2d
|
||||||
The 1x1 convolutional layer mapping to 2 output channels
|
1×1 convolution with 2 output channels (duration, bandwidth).
|
||||||
(width, height).
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `in_channels` is not positive.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int):
|
def __init__(self, in_channels: int):
|
||||||
"""Initialize the BBoxHead."""
|
"""Initialise the BBoxHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
@ -205,19 +178,19 @@ class BBoxHead(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
"""Compute predicted bounding box dimensions from input features.
|
"""Predict call duration and bandwidth from backbone features.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
Input feature map tensor from the backbone, typically with shape
|
Backbone feature map, shape ``(B, C_in, H, W)``.
|
||||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually
|
Size prediction tensor, shape ``(B, 2, H, W)``. Channel 0 is
|
||||||
represents scaled width, Channel 1 scaled height. These values
|
the predicted scaled duration; channel 1 is the predicted
|
||||||
need to be un-scaled during postprocessing.
|
scaled bandwidth. Values must be rescaled to real units during
|
||||||
|
postprocessing.
|
||||||
"""
|
"""
|
||||||
return self.bbox(features)
|
return self.bbox(features)
|
||||||
|
|||||||
@ -1,14 +1,4 @@
|
|||||||
"""Tests for backbone configuration loading and the backbone registry.
|
"""Tests for backbone configuration loading and the backbone registry."""
|
||||||
|
|
||||||
Covers:
|
|
||||||
- UNetBackboneConfig default construction and field values.
|
|
||||||
- build_backbone with default and explicit configs.
|
|
||||||
- load_backbone_config loading from a YAML file.
|
|
||||||
- load_backbone_config with a nested field path.
|
|
||||||
- load_backbone_config round-trip: YAML → config → build_backbone.
|
|
||||||
- Registry registration and dispatch for UNetBackbone.
|
|
||||||
- BackboneConfig discriminated union validation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
@ -25,10 +15,6 @@ from batdetect2.models.backbones import (
|
|||||||
)
|
)
|
||||||
from batdetect2.typing.models import BackboneModel
|
from batdetect2.typing.models import BackboneModel
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# UNetBackboneConfig
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_unet_backbone_config_defaults():
|
def test_unet_backbone_config_defaults():
|
||||||
"""Default config has expected field values."""
|
"""Default config has expected field values."""
|
||||||
@ -57,11 +43,6 @@ def test_unet_backbone_config_extra_fields_ignored():
|
|||||||
assert not hasattr(config, "unknown_field")
|
assert not hasattr(config, "unknown_field")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# build_backbone
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_backbone_default():
|
def test_build_backbone_default():
|
||||||
"""Building with no config uses UNetBackbone defaults."""
|
"""Building with no config uses UNetBackbone defaults."""
|
||||||
backbone = build_backbone()
|
backbone = build_backbone()
|
||||||
@ -83,15 +64,9 @@ def test_build_backbone_custom_config():
|
|||||||
def test_build_backbone_returns_backbone_model():
|
def test_build_backbone_returns_backbone_model():
|
||||||
"""build_backbone always returns a BackboneModel instance."""
|
"""build_backbone always returns a BackboneModel instance."""
|
||||||
backbone = build_backbone()
|
backbone = build_backbone()
|
||||||
|
|
||||||
assert isinstance(backbone, BackboneModel)
|
assert isinstance(backbone, BackboneModel)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Registry
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_registry_has_unet_backbone():
|
def test_registry_has_unet_backbone():
|
||||||
"""The backbone registry has UNetBackbone registered."""
|
"""The backbone registry has UNetBackbone registered."""
|
||||||
config_types = backbone_registry.get_config_types()
|
config_types = backbone_registry.get_config_types()
|
||||||
@ -124,11 +99,6 @@ def test_registry_build_unknown_name_raises():
|
|||||||
backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
|
backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# BackboneConfig discriminated union
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_backbone_config_validates_unet_from_dict():
|
def test_backbone_config_validates_unet_from_dict():
|
||||||
"""BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name."""
|
"""BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name."""
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
@ -151,11 +121,6 @@ def test_backbone_config_invalid_name_raises():
|
|||||||
adapter.validate_python({"name": "NonExistentBackbone"})
|
adapter.validate_python({"name": "NonExistentBackbone"})
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# load_backbone_config
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_backbone_config_from_yaml(
|
def test_load_backbone_config_from_yaml(
|
||||||
create_temp_yaml: Callable[[str], Path],
|
create_temp_yaml: Callable[[str], Path],
|
||||||
):
|
):
|
||||||
@ -218,11 +183,6 @@ deprecated_field: 99
|
|||||||
assert config.input_height == 128
|
assert config.input_height == 128
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Round-trip: YAML → config → build_backbone
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_round_trip_yaml_to_build_backbone(
|
def test_round_trip_yaml_to_build_backbone(
|
||||||
create_temp_yaml: Callable[[str], Path],
|
create_temp_yaml: Callable[[str], Path],
|
||||||
):
|
):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user