From ef3348d651bdca33b00a012d8de52a340c89ad4f Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 8 Mar 2026 16:34:17 +0000 Subject: [PATCH] Update model docstrings --- src/batdetect2/models/__init__.py | 123 ++++++++--- src/batdetect2/models/backbones.py | 210 +++++++++++++------ src/batdetect2/models/blocks.py | 307 ++++++++++++++++++++-------- src/batdetect2/models/bottleneck.py | 157 ++++++++------ src/batdetect2/models/decoder.py | 179 ++++++++-------- src/batdetect2/models/detectors.py | 131 ++++++------ src/batdetect2/models/encoder.py | 186 +++++++++-------- src/batdetect2/models/heads.py | 151 ++++++-------- tests/test_models/test_backbones.py | 42 +--- 9 files changed, 873 insertions(+), 613 deletions(-) diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 195e4a2..ec6605d 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -1,29 +1,29 @@ -"""Defines and builds the neural network models used in BatDetect2. +"""Neural network model definitions and builders for BatDetect2. -This package (`batdetect2.models`) contains the PyTorch implementations of the -deep neural network architectures used for detecting and classifying bat calls -from spectrograms. It provides modular components and configuration-driven -assembly, allowing for experimentation and use of different architectural -variants. +This package contains the PyTorch implementations of the deep neural network +architectures used to detect and classify bat echolocation calls in +spectrograms. Components are designed to be combined through configuration +objects, making it easy to experiment with different architectures. -Key Submodules: -- `.types`: Defines core data structures (`ModelOutput`) and abstract base - classes (`BackboneModel`, `DetectionModel`) establishing interfaces. -- `.blocks`: Provides reusable neural network building blocks. -- `.encoder`: Defines and builds the downsampling path (encoder) of the network. -- `.bottleneck`: Defines and builds the central bottleneck component. -- `.decoder`: Defines and builds the upsampling path (decoder) of the network. -- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete - feature extraction backbone (e.g., a U-Net like structure). -- `.heads`: Defines simple prediction heads (detection, classification, size) - that attach to the backbone features. -- `.detectors`: Assembles the backbone and prediction heads into the final, - end-to-end `Detector` model. +Key submodules +-------------- +- ``blocks``: Reusable convolutional building blocks (downsampling, + upsampling, attention, coord-conv variants). +- ``encoder``: The downsampling path; reduces spatial resolution whilst + extracting increasingly abstract features. +- ``bottleneck``: The central component connecting encoder to decoder; + optionally applies self-attention along the time axis. +- ``decoder``: The upsampling path; reconstructs high-resolution feature + maps using bottleneck output and skip connections from the encoder. +- ``backbones``: Assembles encoder, bottleneck, and decoder into a complete + U-Net-style feature extraction backbone. +- ``heads``: Lightweight 1×1 convolutional heads that produce detection, + classification, and bounding-box size predictions from backbone features. +- ``detectors``: Combines a backbone with prediction heads into the final + end-to-end ``Detector`` model. -This module re-exports the most important classes, configurations, and builder -functions from these submodules for convenient access. The primary entry point -for creating a standard BatDetect2 model instance is the `build_model` function -provided here. +The primary entry point for building a full, ready-to-use BatDetect2 model +is the ``build_model`` factory function exported from this module. """ from typing import List @@ -99,6 +99,31 @@ __all__ = [ class Model(torch.nn.Module): + """End-to-end BatDetect2 model wrapping preprocessing and postprocessing. + + Combines a preprocessor, a detection model, and a postprocessor into a + single PyTorch module. Calling ``forward`` on a raw waveform tensor + returns a list of detection tensors ready for downstream use. + + This class is the top-level object produced by ``build_model``. Most + users will not need to construct it directly. + + Attributes + ---------- + detector : DetectionModel + The neural network that processes spectrograms and produces raw + detection, classification, and bounding-box outputs. + preprocessor : PreprocessorProtocol + Converts a raw waveform tensor into a spectrogram tensor accepted by + ``detector``. + postprocessor : PostprocessorProtocol + Converts the raw ``ModelOutput`` from ``detector`` into a list of + per-clip detection tensors. + targets : TargetProtocol + Describes the set of target classes; used when building heads and + during training target construction. + """ + detector: DetectionModel preprocessor: PreprocessorProtocol postprocessor: PostprocessorProtocol @@ -118,6 +143,25 @@ class Model(torch.nn.Module): self.targets = targets def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]: + """Run the full detection pipeline on a waveform tensor. + + Converts the waveform to a spectrogram, passes it through the + detector, and postprocesses the raw outputs into detection tensors. + + Parameters + ---------- + wav : torch.Tensor + Raw audio waveform tensor. The exact expected shape depends on + the preprocessor, but is typically ``(batch, samples)`` or + ``(batch, channels, samples)``. + + Returns + ------- + List[ClipDetectionsTensor] + One detection tensor per clip in the batch. Each tensor encodes + the detected events (locations, class scores, sizes) for that + clip. + """ spec = self.preprocessor(wav) outputs = self.detector(spec) return self.postprocessor(outputs) @@ -128,7 +172,38 @@ def build_model( targets: TargetProtocol | None = None, preprocessor: PreprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None, -): +) -> "Model": + """Build a complete, ready-to-use BatDetect2 model. + + Assembles a ``Model`` instance from optional configuration and component + overrides. Any argument left as ``None`` will be replaced by a sensible + default built with the project's own builder functions. + + Parameters + ---------- + config : BackboneConfig, optional + Configuration describing the backbone architecture (encoder, + bottleneck, decoder). Defaults to ``UNetBackboneConfig()`` if not + provided. + targets : TargetProtocol, optional + Describes the target bat species or call types to detect. Determines + the number of output classes. Defaults to the standard BatDetect2 + target set. + preprocessor : PreprocessorProtocol, optional + Converts raw audio waveforms to spectrograms. Defaults to the + standard BatDetect2 preprocessor. + postprocessor : PostprocessorProtocol, optional + Converts raw model outputs to detection tensors. Defaults to the + standard BatDetect2 postprocessor. If a custom ``preprocessor`` is + given without a matching ``postprocessor``, the default postprocessor + will be built using the provided preprocessor so that frequency and + time scaling remain consistent. + + Returns + ------- + Model + A fully assembled ``Model`` instance ready for inference or training. + """ from batdetect2.postprocess import build_postprocessor from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index a83555a..ca7c937 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -1,21 +1,26 @@ -"""Assembles a complete Encoder-Decoder Backbone network. +"""Assembles a complete encoder-decoder backbone network. -This module defines the configuration (`BackboneConfig`) and implementation -(`Backbone`) for a standard encoder-decoder style neural network backbone. +This module defines ``UNetBackboneConfig`` and the ``UNetBackbone`` +``nn.Module``, together with the ``build_backbone`` and +``load_backbone_config`` helpers. -It orchestrates the connection between three main components, built using their -respective configurations and factory functions from sibling modules: -1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features - at multiple resolutions and provides skip connections. -2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the - lowest resolution, optionally applying self-attention. -3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high- - resolution features using bottleneck features and skip connections. +A backbone combines three components built from the sibling modules: -The resulting `Backbone` module takes a spectrogram as input and outputs a -final feature map, typically used by subsequent prediction heads. It includes -automatic padding to handle input sizes not perfectly divisible by the -network's total downsampling factor. +1. **Encoder** (``batdetect2.models.encoder``) – reduces spatial resolution + while extracting hierarchical features and storing skip-connection tensors. +2. **Bottleneck** (``batdetect2.models.bottleneck``) – processes the + lowest-resolution features, optionally applying self-attention. +3. **Decoder** (``batdetect2.models.decoder``) – restores spatial resolution + using bottleneck features and skip connections from the encoder. + +The resulting ``UNetBackbone`` takes a spectrogram tensor as input and returns +a high-resolution feature map consumed by the prediction heads in +``batdetect2.models.detectors``. + +Input padding is handled automatically: the backbone pads the input to be +divisible by the total downsampling factor and strips the padding from the +output so that the output spatial dimensions always match the input spatial +dimensions. """ from typing import Annotated, Literal, Tuple, Union @@ -51,6 +56,34 @@ from batdetect2.typing.models import ( class UNetBackboneConfig(BaseConfig): + """Configuration for a U-Net-style encoder-decoder backbone. + + All fields have sensible defaults that reproduce the standard BatDetect2 + architecture, so you can start with ``UNetBackboneConfig()`` and override + only the fields you want to change. + + Attributes + ---------- + name : str + Discriminator field used by the backbone registry; always + ``"UNetBackbone"``. + input_height : int + Number of frequency bins in the input spectrogram. Defaults to + ``128``. + in_channels : int + Number of channels in the input spectrogram (e.g. ``1`` for a + standard mel-spectrogram). Defaults to ``1``. + encoder : EncoderConfig + Configuration for the downsampling path. Defaults to + ``DEFAULT_ENCODER_CONFIG``. + bottleneck : BottleneckConfig + Configuration for the bottleneck. Defaults to + ``DEFAULT_BOTTLENECK_CONFIG``. + decoder : DecoderConfig + Configuration for the upsampling path. Defaults to + ``DEFAULT_DECODER_CONFIG``. + """ + name: Literal["UNetBackbone"] = "UNetBackbone" input_height: int = 128 in_channels: int = 1 @@ -70,35 +103,36 @@ __all__ = [ class UNetBackbone(BackboneModel): - """Encoder-Decoder Backbone Network Implementation. + """U-Net-style encoder-decoder backbone network. - Combines an Encoder, Bottleneck, and Decoder module sequentially, using - skip connections between the Encoder and Decoder. Implements the standard - U-Net style forward pass. Includes automatic input padding to handle - various input sizes and a final convolutional block to adjust the output - channels. + Combines an encoder, a bottleneck, and a decoder into a single module + that produces a high-resolution feature map from an input spectrogram. + Skip connections from each encoder stage are added element-wise to the + corresponding decoder stage input. - This class inherits from `BackboneModel` and implements its `forward` - method. Instances are typically created using the `build_backbone` factory - function. + Input spectrograms of arbitrary width are handled automatically: the + backbone pads the input so that its dimensions are divisible by + ``divide_factor`` and removes the padding from the output. + + Instances are typically created via ``build_backbone``. Attributes ---------- input_height : int - Expected height of the input spectrogram. + Expected height (frequency bins) of the input spectrogram. out_channels : int - Number of channels in the final output feature map. + Number of channels in the output feature map (taken from the + decoder's output channel count). encoder : EncoderProtocol The instantiated encoder module. decoder : DecoderProtocol The instantiated decoder module. bottleneck : BottleneckProtocol The instantiated bottleneck module. - final_conv : ConvBlock - Final convolutional block applied after the decoder. divide_factor : int - The total downsampling factor (2^depth) applied by the encoder, - used for automatic input padding. + The total spatial downsampling factor applied by the encoder + (``input_height // encoder.output_height``). The input width is + padded to be a multiple of this value before processing. """ def __init__( @@ -108,25 +142,19 @@ class UNetBackbone(BackboneModel): decoder: DecoderProtocol, bottleneck: BottleneckProtocol, ): - """Initialize the Backbone network. + """Initialise the backbone network. Parameters ---------- input_height : int - Expected height of the input spectrogram. - out_channels : int - Desired number of output channels for the backbone's feature map. + Expected height (frequency bins) of the input spectrogram. encoder : EncoderProtocol - An initialized Encoder module. + An initialised encoder module. decoder : DecoderProtocol - An initialized Decoder module. + An initialised decoder module. Its ``output_height`` must equal + ``input_height``; a ``ValueError`` is raised otherwise. bottleneck : BottleneckProtocol - An initialized Bottleneck module. - - Raises - ------ - ValueError - If component output/input channels or heights are incompatible. + An initialised bottleneck module. """ super().__init__() self.input_height = input_height @@ -143,22 +171,25 @@ class UNetBackbone(BackboneModel): self.divide_factor = input_height // self.encoder.output_height def forward(self, spec: torch.Tensor) -> torch.Tensor: - """Perform the forward pass through the encoder-decoder backbone. + """Produce a feature map from an input spectrogram. - Applies padding, runs encoder, bottleneck, decoder (with skip - connections), removes padding, and applies a final convolution. + Pads the input if necessary, runs it through the encoder, then + the bottleneck, then the decoder (incorporating encoder skip + connections), and finally removes any padding added earlier. Parameters ---------- spec : torch.Tensor - Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match - `self.encoder.input_channels` and `self.input_height`. + Input spectrogram tensor, shape + ``(B, C_in, H_in, W_in)``. ``H_in`` must equal + ``self.input_height``; ``W_in`` can be any positive integer. Returns ------- torch.Tensor - Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where - `C_out` is `self.out_channels`. + Feature map tensor, shape ``(B, C_out, H_in, W_in)``, where + ``C_out`` is ``self.out_channels``. The spatial dimensions + always match those of the input. """ spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor) @@ -219,6 +250,24 @@ BackboneConfig = Annotated[ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel: + """Build a backbone network from configuration. + + Looks up the backbone class corresponding to ``config.name`` in the + backbone registry and calls its ``from_config`` method. If no + configuration is provided, a default ``UNetBackbone`` is returned. + + Parameters + ---------- + config : BackboneConfig, optional + A configuration object describing the desired backbone. Currently + ``UNetBackboneConfig`` is the only supported type. Defaults to + ``UNetBackboneConfig()`` if not provided. + + Returns + ------- + BackboneModel + An initialised backbone module. + """ config = config or UNetBackboneConfig() return backbone_registry.build(config) @@ -227,26 +276,25 @@ def _pad_adjust( spec: torch.Tensor, factor: int = 32, ) -> Tuple[torch.Tensor, int, int]: - """Pad tensor height and width to be divisible by a factor. + """Pad a tensor's height and width to be divisible by ``factor``. - Calculates the required padding for the last two dimensions (H, W) to make - them divisible by `factor` and applies right/bottom padding using - `torch.nn.functional.pad`. + Adds zero-padding to the bottom and right edges of the tensor so that + both dimensions are exact multiples of ``factor``. If both dimensions + are already divisible, the tensor is returned unchanged. Parameters ---------- spec : torch.Tensor - Input tensor, typically shape `(B, C, H, W)`. + Input tensor, typically shape ``(B, C, H, W)``. factor : int, default=32 - The factor to make height and width divisible by. + The factor that both H and W should be divisible by after padding. Returns ------- Tuple[torch.Tensor, int, int] - A tuple containing: - - The padded tensor. - - The amount of padding added to height (`h_pad`). - - The amount of padding added to width (`w_pad`). + - Padded tensor. + - Number of rows added to the height (``h_pad``). + - Number of columns added to the width (``w_pad``). """ h, w = spec.shape[-2:] h_pad = -h % factor @@ -261,23 +309,25 @@ def _pad_adjust( def _restore_pad( x: torch.Tensor, h_pad: int = 0, w_pad: int = 0 ) -> torch.Tensor: - """Remove padding added by _pad_adjust. + """Remove padding previously added by ``_pad_adjust``. - Removes padding from the bottom and right edges of the tensor. + Trims ``h_pad`` rows from the bottom and ``w_pad`` columns from the + right of the tensor, restoring its original spatial dimensions. Parameters ---------- x : torch.Tensor - Padded tensor, typically shape `(B, C, H_padded, W_padded)`. + Padded tensor, typically shape ``(B, C, H_padded, W_padded)``. h_pad : int, default=0 - Amount of padding previously added to the height (bottom). + Number of rows to remove from the bottom. w_pad : int, default=0 - Amount of padding previously added to the width (right). + Number of columns to remove from the right. Returns ------- torch.Tensor - Tensor with padding removed, shape `(B, C, H_original, W_original)`. + Tensor with padding removed, shape + ``(B, C, H_padded - h_pad, W_padded - w_pad)``. """ if h_pad > 0: x = x[..., :-h_pad, :] @@ -292,6 +342,36 @@ def load_backbone_config( path: data.PathLike, field: str | None = None, ) -> BackboneConfig: + """Load a backbone configuration from a YAML or JSON file. + + Reads the file at ``path``, optionally descends into a named sub-field, + and validates the result against the ``BackboneConfig`` discriminated + union. + + Parameters + ---------- + path : PathLike + Path to the configuration file. Both YAML and JSON formats are + supported. + field : str, optional + Dot-separated key path to the sub-field that contains the backbone + configuration (e.g. ``"model"``). If ``None``, the root of the + file is used. + + Returns + ------- + BackboneConfig + A validated backbone configuration object (currently always a + ``UNetBackboneConfig`` instance). + + Raises + ------ + FileNotFoundError + If ``path`` does not exist. + ValidationError + If the loaded data does not conform to a known ``BackboneConfig`` + schema. + """ return load_config( path, schema=TypeAdapter(BackboneConfig), diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index e21ce2f..71d8496 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -1,30 +1,49 @@ -"""Commonly used neural network building blocks for BatDetect2 models. +"""Reusable convolutional building blocks for BatDetect2 models. -This module provides various reusable `torch.nn.Module` subclasses that form -the fundamental building blocks for constructing convolutional neural network -architectures, particularly encoder-decoder backbones used in BatDetect2. +This module provides a collection of ``torch.nn.Module`` subclasses that form +the fundamental building blocks for the encoder-decoder backbone used in +BatDetect2. All blocks follow a consistent interface: they store +``in_channels`` and ``out_channels`` as attributes and implement a +``get_output_height`` method that reports how a given input height maps to an +output height (e.g., halved by downsampling blocks, doubled by upsampling +blocks). -It includes standard components like basic convolutional blocks (`ConvBlock`), -blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with -upsampling (`StandardConvUpBlock`). +Available block families +------------------------ +Standard blocks + ``ConvBlock`` – convolution + batch normalisation + ReLU, no change in + spatial resolution. -Additionally, it features specialized layers investigated in BatDetect2 -research: +Downsampling blocks + ``StandardConvDownBlock`` – convolution then 2×2 max-pooling, halves H + and W. + ``FreqCoordConvDownBlock`` – like ``StandardConvDownBlock`` but prepends + a normalised frequency-coordinate channel before the convolution + (CoordConv concept), helping filters learn frequency-position-dependent + patterns. -- `SelfAttention`: Applies self-attention along the time dimension, enabling - the model to weigh information across the entire temporal context, often - used in the bottleneck of an encoder-decoder. -- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv" - concept by concatenating normalized frequency coordinate information as an - extra channel to the input of convolutional layers. This explicitly provides - spatial frequency information to filters, potentially enabling them to learn - frequency-dependent patterns more effectively. +Upsampling blocks + ``StandardConvUpBlock`` – bilinear interpolation then convolution, + doubles H and W. + ``FreqCoordConvUpBlock`` – like ``StandardConvUpBlock`` but prepends a + frequency-coordinate channel after upsampling. -These blocks can be used directly in custom PyTorch model definitions or -assembled into larger architectures. +Bottleneck blocks + ``VerticalConv`` – 1-D convolution whose kernel spans the entire + frequency axis, collapsing H to 1 whilst preserving W. + ``SelfAttention`` – scaled dot-product self-attention along the time + axis; typically follows a ``VerticalConv``. -A unified factory function `build_layer` allows creating instances -of these blocks based on configuration objects. +Group block + ``LayerGroup`` – chains several blocks sequentially into one unit, + useful when a single encoder or decoder "stage" requires more than one + operation. + +Factory function +---------------- +``build_layer`` creates any of the above blocks from the matching +configuration object (one of the ``*Config`` classes exported here), using +a discriminated-union ``name`` field to dispatch to the correct class. """ from typing import Annotated, List, Literal, Tuple, Union @@ -57,10 +76,43 @@ __all__ = [ class Block(nn.Module): + """Abstract base class for all BatDetect2 building blocks. + + Subclasses must set ``in_channels`` and ``out_channels`` as integer + attributes so that factory functions can wire blocks together without + inspecting configuration objects at runtime. They may also override + ``get_output_height`` when the block changes the height dimension (e.g. + downsampling or upsampling blocks). + + Attributes + ---------- + in_channels : int + Number of channels expected in the input tensor. + out_channels : int + Number of channels produced in the output tensor. + """ + in_channels: int out_channels: int def get_output_height(self, input_height: int) -> int: + """Return the output height for a given input height. + + The default implementation returns ``input_height`` unchanged, + which is correct for blocks that do not alter spatial resolution. + Override this in downsampling (returns ``input_height // 2``) or + upsampling (returns ``input_height * 2``) subclasses. + + Parameters + ---------- + input_height : int + Height (number of frequency bins) of the input feature map. + + Returns + ------- + int + Height of the output feature map. + """ return input_height @@ -68,58 +120,65 @@ block_registry: Registry[Block, [int, int]] = Registry("block") class SelfAttentionConfig(BaseConfig): + """Configuration for a ``SelfAttention`` block. + + Attributes + ---------- + name : str + Discriminator field; always ``"SelfAttention"``. + attention_channels : int + Dimensionality of the query, key, and value projections. + temperature : float + Scaling factor applied to the weighted values before the final + linear projection. Defaults to ``1``. + """ + name: Literal["SelfAttention"] = "SelfAttention" attention_channels: int temperature: float = 1 class SelfAttention(Block): - """Self-Attention mechanism operating along the time dimension. + """Self-attention block operating along the time axis. - This module implements a scaled dot-product self-attention mechanism, - specifically designed here to operate across the time steps of an input - feature map, typically after spatial dimensions (like frequency) have been - condensed or squeezed. + Applies a scaled dot-product self-attention mechanism across the time + steps of an input feature map. Before attention is computed the height + dimension (frequency axis) is expected to have been reduced to 1, e.g. + by a preceding ``VerticalConv`` layer. - By calculating attention weights between all pairs of time steps, it allows - the model to capture long-range temporal dependencies and focus on relevant - parts of the sequence. It's often employed in the bottleneck or - intermediate layers of an encoder-decoder architecture to integrate global - temporal context. - - The implementation uses linear projections to create query, key, and value - representations, computes scaled dot-product attention scores, applies - softmax, and produces an output by weighting the values according to the - attention scores, followed by a final linear projection. Positional encoding - is not explicitly included in this block. + For each time step the block computes query, key, and value projections + with learned linear weights, then calculates attention weights from the + query–key dot products divided by ``temperature × attention_channels``. + The weighted sum of values is projected back to ``in_channels`` via a + final linear layer, and the height dimension is restored so that the + output shape matches the input shape. Parameters ---------- in_channels : int - Number of input channels (features per time step after spatial squeeze). + Number of input channels (features per time step). The output will + also have ``in_channels`` channels. attention_channels : int - Number of channels for the query, key, and value projections. Also the - dimension of the output projection's input. + Dimensionality of the query, key, and value projections. temperature : float, default=1.0 - Scaling factor applied *before* the final projection layer. Can be used - to adjust the sharpness or focus of the attention mechanism, although - scaling within the softmax (dividing by sqrt(dim)) is more common for - standard transformers. Here it scales the weighted values. + Divisor applied together with ``attention_channels`` when scaling + the dot-product scores before softmax. Larger values produce softer + (more uniform) attention distributions. Attributes ---------- key_fun : nn.Linear - Linear layer for key projection. + Linear projection for keys. value_fun : nn.Linear - Linear layer for value projection. + Linear projection for values. query_fun : nn.Linear - Linear layer for query projection. + Linear projection for queries. pro_fun : nn.Linear - Final linear projection layer applied after attention weighting. + Final linear projection applied to the attended values. temperature : float - Scaling factor applied before final projection. + Scaling divisor used when computing attention scores. att_dim : int - Dimensionality of the attention space (`attention_channels`). + Dimensionality of the attention space (``attention_channels``). """ def __init__( @@ -148,20 +207,16 @@ class SelfAttention(Block): Parameters ---------- x : torch.Tensor - Input tensor, expected shape `(B, C, H, W)`, where H is typically - squeezed (e.g., H=1 after a `VerticalConv` or pooling) before - applying attention along the W (time) dimension. + Input tensor with shape ``(B, C, 1, W)``. The height dimension + must be 1 (i.e. the frequency axis should already have been + collapsed by a preceding ``VerticalConv`` layer). Returns ------- torch.Tensor - Output tensor of the same shape as the input `(B, C, H, W)`, where - attention has been applied across the W dimension. - - Raises - ------ - RuntimeError - If input tensor dimensions are incompatible with operations. + Output tensor with the same shape ``(B, C, 1, W)`` as the + input, with each time step updated by attended context from all + other time steps. """ x = x.squeeze(2).permute(0, 2, 1) @@ -190,6 +245,22 @@ class SelfAttention(Block): return op def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor: + """Return the softmax attention weight matrix. + + Useful for visualising which time steps attend to which others. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape ``(B, C, 1, W)``. + + Returns + ------- + torch.Tensor + Attention weight matrix with shape ``(B, W, W)``. Entry + ``[b, i, j]`` is the attention weight that time step ``i`` + assigns to time step ``j`` in batch item ``b``. + """ x = x.squeeze(2).permute(0, 2, 1) key = torch.matmul( @@ -304,6 +375,16 @@ class ConvBlock(Block): class VerticalConvConfig(BaseConfig): + """Configuration for a ``VerticalConv`` block. + + Attributes + ---------- + name : str + Discriminator field; always ``"VerticalConv"``. + channels : int + Number of output channels produced by the vertical convolution. + """ + name: Literal["VerticalConv"] = "VerticalConv" channels: int @@ -844,12 +925,53 @@ LayerConfig = Annotated[ class LayerGroupConfig(BaseConfig): + """Configuration for a ``LayerGroup`` — a sequential chain of blocks. + + Use this when a single encoder or decoder stage needs more than one + block. The blocks are executed in the order they appear in ``layers``, + with channel counts and heights propagated automatically. + + Attributes + ---------- + name : str + Discriminator field; always ``"LayerGroup"``. + layers : List[LayerConfig] + Ordered list of block configurations to chain together. + """ + name: Literal["LayerGroup"] = "LayerGroup" layers: List[LayerConfig] class LayerGroup(nn.Module): - """Standard implementation of the `LayerGroup` architecture.""" + """Sequential chain of blocks that acts as a single composite block. + + Wraps multiple ``Block`` instances in an ``nn.Sequential`` container, + exposing the same ``in_channels``, ``out_channels``, and + ``get_output_height`` interface as a regular ``Block`` so it can be + used transparently wherever a single block is expected. + + Instances are typically constructed by ``build_layer`` when given a + ``LayerGroupConfig``; you rarely need to create them directly. + + Parameters + ---------- + layers : list[Block] + Pre-built block instances to chain, in execution order. + input_height : int + Height of the tensor entering the first block. + input_channels : int + Number of channels in the tensor entering the first block. + + Attributes + ---------- + in_channels : int + Number of input channels (taken from the first block). + out_channels : int + Number of output channels (taken from the last block). + layers : nn.Sequential + The wrapped sequence of block modules. + """ def __init__( self, @@ -865,9 +987,33 @@ class LayerGroup(nn.Module): self.layers = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Pass input through all blocks in sequence. + + Parameters + ---------- + x : torch.Tensor + Input feature map, shape ``(B, C_in, H, W)``. + + Returns + ------- + torch.Tensor + Output feature map after all blocks have been applied. + """ return self.layers(x) def get_output_height(self, input_height: int) -> int: + """Compute the output height by propagating through all blocks. + + Parameters + ---------- + input_height : int + Height of the input feature map. + + Returns + ------- + int + Height after all blocks in the group have been applied. + """ for block in self.layers: input_height = block.get_output_height(input_height) # type: ignore return input_height @@ -903,40 +1049,37 @@ def build_layer( in_channels: int, config: LayerConfig, ) -> Block: - """Factory function to build a specific nn.Module block from its config. + """Build a block from its configuration object. - Takes configuration object (one of the types included in the `LayerConfig` - union) and instantiates the corresponding nn.Module block with the correct - parameters derived from the config and the current pipeline state - (`input_height`, `in_channels`). - - It uses the `name` field within the `config` object to determine - which block class to instantiate. + Looks up the block class corresponding to ``config.name`` in the + internal block registry and instantiates it with the given input + dimensions. This is the standard way to construct blocks when + assembling an encoder or decoder from a configuration file. Parameters ---------- input_height : int - Height (frequency bins) of the input tensor *to this layer*. + Height (number of frequency bins) of the input tensor to this + block. Required for blocks whose kernel size depends on the input + height (e.g. ``VerticalConv``) and for coordinate-aware blocks. in_channels : int - Number of channels in the input tensor *to this layer*. + Number of channels in the input tensor to this block. config : LayerConfig - A Pydantic configuration object for the desired block (e.g., an - instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified - by its `name` field. + A configuration object for the desired block type. The ``name`` + field selects the block class; remaining fields supply its + parameters. Returns ------- - Tuple[nn.Module, int, int] - A tuple containing: - - The instantiated `nn.Module` block. - - The number of output channels produced by the block. - - The calculated height of the output produced by the block. + Block + An initialised block module ready to be added to an + ``nn.Sequential`` or ``nn.ModuleList``. Raises ------ - NotImplementedError - If the `config.name` does not correspond to a known block type. + KeyError + If ``config.name`` does not correspond to a registered block type. ValueError - If parameters derived from the config are invalid for the block. + If the configuration parameters are invalid for the chosen block. """ return block_registry.build(config, in_channels, input_height) diff --git a/src/batdetect2/models/bottleneck.py b/src/batdetect2/models/bottleneck.py index 420ec2c..55fe549 100644 --- a/src/batdetect2/models/bottleneck.py +++ b/src/batdetect2/models/bottleneck.py @@ -1,17 +1,21 @@ -"""Defines the Bottleneck component of an Encoder-Decoder architecture. +"""Bottleneck component for encoder-decoder network architectures. -This module provides the configuration (`BottleneckConfig`) and -`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the -bottleneck layer(s) that typically connect the Encoder (downsampling path) and -Decoder (upsampling path) in networks like U-Nets. +The bottleneck sits between the encoder (downsampling path) and the decoder +(upsampling path) and processes the lowest-resolution, highest-channel feature +map produced by the encoder. -The bottleneck processes the lowest-resolution, highest-dimensionality feature -map produced by the Encoder. This module offers a configurable option to include -a `SelfAttention` layer within the bottleneck, allowing the model to capture -global temporal context before features are passed to the Decoder. +This module provides: -A factory function `build_bottleneck` constructs the appropriate bottleneck -module based on the provided configuration. +- ``BottleneckConfig`` – configuration dataclass describing the number of + internal channels and an optional sequence of additional layers (currently + only ``SelfAttention`` is supported). +- ``Bottleneck`` – the ``torch.nn.Module`` implementation. It first applies a + ``VerticalConv`` to collapse the frequency axis to a single bin, optionally + runs one or more additional layers (e.g. self-attention along the time axis), + then repeats the output along the height dimension to restore the original + frequency resolution before passing features to the decoder. +- ``build_bottleneck`` – factory function that constructs a ``Bottleneck`` + instance from a ``BottleneckConfig`` and the encoder's output dimensions. """ from typing import Annotated, List @@ -37,42 +41,51 @@ __all__ = [ class Bottleneck(Block): - """Base Bottleneck module for Encoder-Decoder architectures. + """Bottleneck module for encoder-decoder architectures. - This implementation represents the simplest bottleneck structure - considered, primarily consisting of a `VerticalConv` layer. This layer - collapses the frequency dimension (height) to 1, summarizing information - across frequencies at each time step. The output is then repeated along the - height dimension to match the original bottleneck input height before being - passed to the decoder. + Processes the lowest-resolution feature map that links the encoder and + decoder. The sequence of operations is: - This base version does *not* include self-attention. + 1. ``VerticalConv`` – collapses the frequency axis (height) to a single + bin by applying a convolution whose kernel spans the full height. + 2. Optional additional layers (e.g. ``SelfAttention``) – applied while + the feature map has height 1, so they operate purely along the time + axis. + 3. Height restoration – the single-bin output is repeated along the + height axis to restore the original frequency resolution, producing + a tensor that the decoder can accept. Parameters ---------- input_height : int - Height (frequency bins) of the input tensor. Must be positive. + Height (number of frequency bins) of the input tensor. Must be + positive. in_channels : int Number of channels in the input tensor from the encoder. Must be positive. out_channels : int - Number of output channels. Must be positive. + Number of output channels after the bottleneck. Must be positive. + bottleneck_channels : int, optional + Number of internal channels used by the ``VerticalConv`` layer. + Defaults to ``out_channels`` if not provided. + layers : List[torch.nn.Module], optional + Additional modules (e.g. ``SelfAttention``) to apply after the + ``VerticalConv`` and before height restoration. Attributes ---------- in_channels : int Number of input channels accepted by the bottleneck. + out_channels : int + Number of output channels produced by the bottleneck. input_height : int Expected height of the input tensor. - channels : int - Number of output channels. + bottleneck_channels : int + Number of channels used internally by the vertical convolution. conv_vert : VerticalConv The vertical convolution layer. - - Raises - ------ - ValueError - If `input_height`, `in_channels`, or `out_channels` are not positive. + layers : nn.ModuleList + Additional layers applied after the vertical convolution. """ def __init__( @@ -83,7 +96,23 @@ class Bottleneck(Block): bottleneck_channels: int | None = None, layers: List[torch.nn.Module] | None = None, ) -> None: - """Initialize the base Bottleneck layer.""" + """Initialise the Bottleneck layer. + + Parameters + ---------- + input_height : int + Height (number of frequency bins) of the input tensor. + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels in the output tensor. + bottleneck_channels : int, optional + Number of internal channels for the ``VerticalConv``. Defaults + to ``out_channels``. + layers : List[torch.nn.Module], optional + Additional modules applied after the ``VerticalConv``, such as + a ``SelfAttention`` block. + """ super().__init__() self.in_channels = in_channels self.input_height = input_height @@ -103,23 +132,24 @@ class Bottleneck(Block): ) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Process input features through the bottleneck. + """Process the encoder's bottleneck features. - Applies vertical convolution and repeats the output height. + Applies vertical convolution, optional additional layers, then + restores the height dimension by repetition. Parameters ---------- x : torch.Tensor - Input tensor from the encoder bottleneck, shape - `(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`, - `H_in` must match `self.input_height`. + Input tensor from the encoder, shape + ``(B, C_in, H_in, W)``. ``C_in`` must match + ``self.in_channels`` and ``H_in`` must match + ``self.input_height``. Returns ------- torch.Tensor - Output tensor, shape `(B, C_out, H_in, W)`. Note that the height - dimension `H_in` is restored via repetition after the vertical - convolution. + Output tensor with shape ``(B, C_out, H_in, W)``. The height + ``H_in`` is restored by repeating the single-bin result. """ x = self.conv_vert(x) @@ -133,28 +163,22 @@ BottleneckLayerConfig = Annotated[ SelfAttentionConfig, Field(discriminator="name"), ] -"""Type alias for the discriminated union of block configs usable in Decoder.""" +"""Type alias for the discriminated union of block configs usable in the Bottleneck.""" class BottleneckConfig(BaseConfig): - """Configuration for the bottleneck layer(s). - - Defines the number of channels within the bottleneck and whether to include - a self-attention mechanism. + """Configuration for the bottleneck component. Attributes ---------- channels : int - The number of output channels produced by the main convolutional layer - within the bottleneck. This often matches the number of channels coming - from the last encoder stage, but can be different. Must be positive. - This also defines the channel dimensions used within the optional - `SelfAttention` layer. - self_attention : bool - If True, includes a `SelfAttention` layer operating on the time - dimension after an initial `VerticalConv` layer within the bottleneck. - If False, only the initial `VerticalConv` (and height repetition) is - performed. + Number of output channels produced by the bottleneck. This value + is also used as the dimensionality of any optional layers (e.g. + self-attention). Must be positive. + layers : List[BottleneckLayerConfig] + Ordered list of additional block configurations to apply after the + initial ``VerticalConv``. Currently only ``SelfAttentionConfig`` is + supported. Defaults to an empty list (no extra layers). """ channels: int @@ -174,30 +198,37 @@ def build_bottleneck( in_channels: int, config: BottleneckConfig | None = None, ) -> BottleneckProtocol: - """Factory function to build the Bottleneck module from configuration. + """Build a ``Bottleneck`` module from configuration. - Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based - on the `config.self_attention` flag. + Constructs a ``Bottleneck`` instance whose internal channel count and + optional extra layers (e.g. self-attention) are controlled by + ``config``. If no configuration is provided, the default + ``DEFAULT_BOTTLENECK_CONFIG`` is used, which includes a + ``SelfAttention`` layer. Parameters ---------- input_height : int - Height (frequency bins) of the input tensor. Must be positive. + Height (number of frequency bins) of the input tensor from the + encoder. Must be positive. in_channels : int - Number of channels in the input tensor. Must be positive. + Number of channels in the input tensor from the encoder. Must be + positive. config : BottleneckConfig, optional - Configuration object specifying the bottleneck channels and whether - to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None. + Configuration specifying the output channel count and any + additional layers. Uses ``DEFAULT_BOTTLENECK_CONFIG`` if ``None``. Returns ------- - nn.Module - An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`). + BottleneckProtocol + An initialised ``Bottleneck`` module. Raises ------ - ValueError - If `input_height` or `in_channels` are not positive. + AssertionError + If any configured layer changes the height of the feature map + (bottleneck layers must preserve height so that it can be restored + by repetition). """ config = config or DEFAULT_BOTTLENECK_CONFIG diff --git a/src/batdetect2/models/decoder.py b/src/batdetect2/models/decoder.py index 1673653..82e1150 100644 --- a/src/batdetect2/models/decoder.py +++ b/src/batdetect2/models/decoder.py @@ -1,21 +1,21 @@ -"""Constructs the Decoder part of an Encoder-Decoder neural network. +"""Decoder (upsampling path) for the BatDetect2 backbone. -This module defines the configuration structure (`DecoderConfig`) for the layer -sequence and provides the `Decoder` class (an `nn.Module`) along with a factory -function (`build_decoder`). Decoders typically form the upsampling path in -architectures like U-Nets, taking bottleneck features -(usually from an `Encoder`) and skip connections to reconstruct -higher-resolution feature maps. +This module defines ``DecoderConfig`` and the ``Decoder`` ``nn.Module``, +together with the ``build_decoder`` factory function. -The decoder is built dynamically by stacking neural network blocks based on a -list of configuration objects provided in `DecoderConfig.layers`. Each config -object specifies the type of block (e.g., standard convolution, -coordinate-feature convolution with upsampling) and its parameters. This allows -flexible definition of decoder architectures via configuration files. +In a U-Net-style network the decoder progressively restores the spatial +resolution of the feature map back towards the input resolution. At each +stage it combines the upsampled features with the corresponding skip-connection +tensor from the encoder (the residual) by element-wise addition before passing +the result to the upsampling block. -The `Decoder`'s `forward` method is designed to accept skip connection tensors -(`residuals`) from the encoder, merging them with the upsampled feature maps -at each stage. +The decoder is fully configurable: the type, number, and parameters of the +upsampling blocks are described by a ``DecoderConfig`` object containing an +ordered list of block configuration objects (see ``batdetect2.models.blocks`` +for available block types). + +A default configuration ``DEFAULT_DECODER_CONFIG`` is provided and used by +``build_decoder`` when no explicit configuration is supplied. """ from typing import Annotated, List @@ -51,51 +51,47 @@ DecoderLayerConfig = Annotated[ class DecoderConfig(BaseConfig): - """Configuration for the sequence of layers in the Decoder module. - - Defines the types and parameters of the neural network blocks that - constitute the decoder's upsampling path. + """Configuration for the sequential ``Decoder`` module. Attributes ---------- layers : List[DecoderLayerConfig] - An ordered list of configuration objects, each defining one layer or - block in the decoder sequence. Each item must be a valid block - config including a `name` field and necessary parameters like - `out_channels`. Input channels for each layer are inferred sequentially. - The list must contain at least one layer. + Ordered list of block configuration objects defining the decoder's + upsampling stages (from deepest to shallowest). Each entry + specifies the block type (via its ``name`` field) and any + block-specific parameters such as ``out_channels``. Input channels + for each block are inferred automatically from the output of the + previous block. Must contain at least one entry. """ layers: List[DecoderLayerConfig] = Field(min_length=1) class Decoder(nn.Module): - """Sequential Decoder module composed of configurable upsampling layers. + """Sequential decoder module composed of configurable upsampling layers. - Constructs the upsampling path of an encoder-decoder network by stacking - multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`) - based on a list of layer modules provided during initialization (typically - created by the `build_decoder` factory function). + Executes a series of upsampling blocks in order, adding the + corresponding encoder skip-connection tensor (residual) to the feature + map before each block. The residuals are consumed in reverse order (from + deepest encoder layer to shallowest) to match the spatial resolutions at + each decoder stage. - The `forward` method is designed to integrate skip connection tensors - (`residuals`) from the corresponding encoder stages, by adding them - element-wise to the input of each decoder layer before processing. + Instances are typically created by ``build_decoder``. Attributes ---------- in_channels : int - Number of channels expected in the input tensor. + Number of channels expected in the input tensor (bottleneck output). out_channels : int - Number of channels in the final output tensor produced by the last - layer. + Number of channels in the final output feature map. input_height : int - Height (frequency bins) expected in the input tensor. + Height (frequency bins) of the input tensor. output_height : int - Height (frequency bins) expected in the output tensor. + Height (frequency bins) of the output tensor. layers : nn.ModuleList - The sequence of instantiated upscaling layer modules. + Sequence of instantiated upsampling block modules. depth : int - The number of upscaling layers (depth) in the decoder. + Number of upsampling layers. """ def __init__( @@ -106,23 +102,24 @@ class Decoder(nn.Module): output_height: int, layers: List[nn.Module], ): - """Initialize the Decoder module. + """Initialise the Decoder module. - Note: This constructor is typically called internally by the - `build_decoder` factory function. + This constructor is typically called by the ``build_decoder`` + factory function. Parameters ---------- + in_channels : int + Number of channels in the input tensor (bottleneck output). out_channels : int Number of channels produced by the final layer. input_height : int - Expected height of the input tensor (bottleneck). - in_channels : int - Expected number of channels in the input tensor (bottleneck). + Height of the input tensor (bottleneck output height). + output_height : int + Height of the output tensor after all layers have been applied. layers : List[nn.Module] - A list of pre-instantiated upscaling layer modules (e.g., - `StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired - sequence (from bottleneck towards output resolution). + Pre-built upsampling block modules in execution order (deepest + stage first). """ super().__init__() @@ -140,43 +137,35 @@ class Decoder(nn.Module): x: torch.Tensor, residuals: List[torch.Tensor], ) -> torch.Tensor: - """Pass input through decoder layers, incorporating skip connections. + """Pass input through all decoder layers, incorporating skip connections. - Processes the input tensor `x` sequentially through the upscaling - layers. At each stage, the corresponding skip connection tensor from - the `residuals` list is added element-wise to the input before passing - it to the upscaling block. + At each stage the corresponding residual tensor is added + element-wise to ``x`` before it is passed to the upsampling block. + Residuals are consumed in reverse order — the last element of + ``residuals`` (the output of the shallowest encoder layer) is added + at the first decoder stage, and the first element (output of the + deepest encoder layer) is added at the last decoder stage. Parameters ---------- x : torch.Tensor - Input tensor from the previous stage (e.g., encoder bottleneck). - Shape `(B, C_in, H_in, W_in)`, where `C_in` matches - `self.in_channels`. + Bottleneck feature map, shape ``(B, C_in, H_in, W)``. residuals : List[torch.Tensor] - List containing the skip connection tensors from the corresponding - encoder stages. Should be ordered from the deepest encoder layer - output (lowest resolution) to the shallowest (highest resolution - near input). The number of tensors in this list must match the - number of decoder layers (`self.depth`). Each residual tensor's - channel count must be compatible with the input tensor `x` for - element-wise addition (or concatenation if the blocks were designed - for it). + Skip-connection tensors from the encoder, ordered from shallowest + (index 0) to deepest (index -1). Must contain exactly + ``self.depth`` tensors. Each tensor must have the same spatial + dimensions and channel count as ``x`` at the corresponding + decoder stage. Returns ------- torch.Tensor - The final decoded feature map tensor produced by the last layer. - Shape `(B, C_out, H_out, W_out)`. + Decoded feature map, shape ``(B, C_out, H_out, W)``. Raises ------ ValueError - If the number of `residuals` provided does not match the decoder - depth. - RuntimeError - If shapes mismatch during skip connection addition or layer - processing. + If the number of ``residuals`` does not equal ``self.depth``. """ if len(residuals) != len(self.layers): raise ValueError( @@ -203,11 +192,17 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig( ), ], ) -"""A default configuration for the Decoder's *layer sequence*. +"""Default decoder configuration used in standard BatDetect2 models. -Specifies an architecture often used in BatDetect2, consisting of three -frequency coordinate-aware upsampling blocks followed by a standard -convolutional block. +Mirrors ``DEFAULT_ENCODER_CONFIG`` in reverse. Assumes the bottleneck +output has 256 channels and height 16, and produces: + +- Stage 1 (``FreqCoordConvUp``): 64 channels, height 32. +- Stage 2 (``FreqCoordConvUp``): 32 channels, height 64. +- Stage 3 (``LayerGroup``): + + - ``FreqCoordConvUp``: 32 channels, height 128. + - ``ConvBlock``: 32 channels, height 128 (final feature map). """ @@ -216,40 +211,36 @@ def build_decoder( input_height: int, config: DecoderConfig | None = None, ) -> Decoder: - """Factory function to build a Decoder instance from configuration. + """Build a ``Decoder`` from configuration. - Constructs a sequential `Decoder` module based on the layer sequence - defined in a `DecoderConfig` object and the provided input dimensions - (bottleneck channels and height). If no config is provided, uses the - default layer sequence from `DEFAULT_DECODER_CONFIG`. - - It iteratively builds the layers using the unified `build_layer_from_config` - factory (from `.blocks`), tracking the changing number of channels and - feature map height required for each subsequent layer. + Constructs a sequential ``Decoder`` by iterating over the block + configurations in ``config.layers``, building each block with + ``build_layer``, and tracking the channel count and feature-map height + as they change through the sequence. Parameters ---------- in_channels : int - The number of channels in the input tensor to the decoder. Must be > 0. + Number of channels in the input tensor (bottleneck output). Must + be positive. input_height : int - The height (frequency bins) of the input tensor to the decoder. Must be - > 0. + Height (number of frequency bins) of the input tensor. Must be + positive. config : DecoderConfig, optional - The configuration object detailing the sequence of layers and their - parameters. If None, `DEFAULT_DECODER_CONFIG` is used. + Configuration specifying the layer sequence. Defaults to + ``DEFAULT_DECODER_CONFIG`` if not provided. Returns ------- Decoder - An initialized `Decoder` module. + An initialised ``Decoder`` module. Raises ------ ValueError - If `in_channels` or `input_height` are not positive, or if the layer - configuration is invalid (e.g., empty list, unknown `name`). - NotImplementedError - If `build_layer_from_config` encounters an unknown `name`. + If ``in_channels`` or ``input_height`` are not positive. + KeyError + If a layer configuration specifies an unknown block type. """ config = config or DEFAULT_DECODER_CONFIG diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index 5c22ba8..55fbabf 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -1,17 +1,20 @@ -"""Assembles the complete BatDetect2 Detection Model. +"""Assembles the complete BatDetect2 detection model. -This module defines the concrete `Detector` class, which implements the -`DetectionModel` interface defined in `.types`. It combines a feature -extraction backbone with specific prediction heads to create the end-to-end -neural network used for detecting bat calls, predicting their size, and -classifying them. +This module defines the ``Detector`` class, which combines a backbone +feature extractor with prediction heads for detection, classification, and +bounding-box size regression. -The primary components are: -- `Detector`: The `torch.nn.Module` subclass representing the complete model. +Components +---------- +- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone + (``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to + produce a ``ModelOutput`` tuple from an input spectrogram. +- ``build_detector`` – factory function that builds a ready-to-use + ``Detector`` from a backbone configuration and a target class count. -This module focuses purely on the neural network architecture definition. The -logic for preprocessing inputs and postprocessing/decoding outputs resides in -the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. +Note that ``Detector`` operates purely on spectrogram tensors; raw audio +preprocessing and output postprocessing are handled by +``batdetect2.preprocess`` and ``batdetect2.postprocess`` respectively. """ import torch @@ -32,25 +35,30 @@ __all__ = [ class Detector(DetectionModel): - """Concrete implementation of the BatDetect2 Detection Model. + """Complete BatDetect2 detection and classification model. - Assembles a complete detection and classification model by combining a - feature extraction backbone network with specific prediction heads for - detection probability, bounding box size regression, and class - probabilities. + Combines a backbone feature extractor with two prediction heads: + + - ``ClassifierHead``: predicts per-class probabilities at each + time–frequency location. + - ``BBoxHead``: predicts call duration and bandwidth at each location. + + The detection probability map is derived from the class probabilities by + summing across the class dimension (i.e. the probability that *any* class + is present), rather than from a separate detection head. + + Instances are typically created via ``build_detector``. Attributes ---------- backbone : BackboneModel - The feature extraction backbone network module. + The feature extraction backbone. num_classes : int - The number of specific target classes the model predicts (derived from - the `classifier_head`). + Number of target classes (inferred from the classifier head). classifier_head : ClassifierHead - The prediction head responsible for generating class probabilities. + Produces per-class probability maps from backbone features. bbox_head : BBoxHead - The prediction head responsible for generating bounding box size - predictions. + Produces duration and bandwidth predictions from backbone features. """ backbone: BackboneModel @@ -61,26 +69,21 @@ class Detector(DetectionModel): classifier_head: ClassifierHead, bbox_head: BBoxHead, ): - """Initialize the Detector model. + """Initialise the Detector model. - Note: Instances are typically created using the `build_detector` + This constructor is typically called by the ``build_detector`` factory function. Parameters ---------- backbone : BackboneModel - An initialized feature extraction backbone module (e.g., built by - `build_backbone` from the `.backbone` module). + An initialised backbone module (e.g. built by + ``build_backbone``). classifier_head : ClassifierHead - An initialized classification head module. The number of classes - is inferred from this head. + An initialised classification head. The ``num_classes`` + attribute is read from this head. bbox_head : BBoxHead - An initialized bounding box size prediction head module. - - Raises - ------ - TypeError - If the provided modules are not of the expected types. + An initialised bounding-box size prediction head. """ super().__init__() @@ -90,31 +93,34 @@ class Detector(DetectionModel): self.bbox_head = bbox_head def forward(self, spec: torch.Tensor) -> ModelOutput: - """Perform the forward pass of the complete detection model. + """Run the complete detection model on an input spectrogram. - Processes the input spectrogram through the backbone to extract - features, then passes these features through the separate prediction - heads to generate detection probabilities, class probabilities, and - size predictions. + Passes the spectrogram through the backbone to produce a feature + map, then applies the classifier and bounding-box heads. The + detection probability map is derived by summing the per-class + probability maps across the class dimension; no separate detection + head is used. Parameters ---------- spec : torch.Tensor - Input spectrogram tensor, typically with shape - `(batch_size, input_channels, frequency_bins, time_bins)`. The - shape must be compatible with the `self.backbone` input - requirements. + Input spectrogram tensor, shape + ``(batch_size, channels, frequency_bins, time_bins)``. Returns ------- ModelOutput - A NamedTuple containing the four output tensors: - - `detection_probs`: Detection probability heatmap `(B, 1, H, W)`. - - `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`. - - `class_probs`: Class probabilities (excluding background) - `(B, num_classes, H, W)`. - - `features`: Output feature map from the backbone - `(B, C_out, H, W)`. + A named tuple with four fields: + + - ``detection_probs`` – ``(B, 1, H, W)`` – probability that a + call of any class is present at each location. Derived by + summing ``class_probs`` over the class dimension. + - ``size_preds`` – ``(B, 2, H, W)`` – scaled duration (channel + 0) and bandwidth (channel 1) predictions at each location. + - ``class_probs`` – ``(B, num_classes, H, W)`` – per-class + probabilities at each location. + - ``features`` – ``(B, C_out, H, W)`` – raw backbone feature + map. """ features = self.backbone(spec) classification = self.classifier_head(features) @@ -131,30 +137,33 @@ class Detector(DetectionModel): def build_detector( num_classes: int, config: BackboneConfig | None = None ) -> DetectionModel: - """Build the complete BatDetect2 detection model. + """Build a complete BatDetect2 detection model. + + Constructs a backbone from ``config``, attaches a ``ClassifierHead`` + and a ``BBoxHead`` sized to the backbone's output channel count, and + returns them wrapped in a ``Detector``. Parameters ---------- num_classes : int - The number of specific target classes the model should predict - (required for the `ClassifierHead`). Must be positive. + Number of target bat species or call types to predict. Must be + positive. config : BackboneConfig, optional - Configuration object specifying the architecture of the backbone - (encoder, bottleneck, decoder). If None, default configurations defined - within the respective builder functions (`build_encoder`, etc.) will be - used to construct a default backbone architecture. + Backbone architecture configuration. Defaults to + ``UNetBackboneConfig()`` (the standard BatDetect2 architecture) if + not provided. Returns ------- DetectionModel - An initialized `Detector` model instance. + An initialised ``Detector`` instance ready for training or + inference. Raises ------ ValueError - If `num_classes` is not positive, or if errors occur during the - construction of the backbone or detector components (e.g., incompatible - configurations, invalid parameters). + If ``num_classes`` is not positive, or if the backbone + configuration is invalid. """ config = config or UNetBackboneConfig() diff --git a/src/batdetect2/models/encoder.py b/src/batdetect2/models/encoder.py index 142e9ab..24650e5 100644 --- a/src/batdetect2/models/encoder.py +++ b/src/batdetect2/models/encoder.py @@ -1,23 +1,24 @@ -"""Constructs the Encoder part of a configurable neural network backbone. +"""Encoder (downsampling path) for the BatDetect2 backbone. -This module defines the configuration structure (`EncoderConfig`) and provides -the `Encoder` class (an `nn.Module`) along with a factory function -(`build_encoder`) to create sequential encoders. Encoders typically form the -downsampling path in architectures like U-Nets, processing input feature maps -(like spectrograms) to produce lower-resolution, higher-dimensionality feature -representations (bottleneck features). +This module defines ``EncoderConfig`` and the ``Encoder`` ``nn.Module``, +together with the ``build_encoder`` factory function. -The encoder is built dynamically by stacking neural network blocks based on a -list of configuration objects provided in `EncoderConfig.layers`. Each -configuration object specifies the type of block (e.g., standard convolution, -coordinate-feature convolution with downsampling) and its parameters -(e.g., output channels). This allows for flexible definition of encoder -architectures via configuration files. +In a U-Net-style network the encoder progressively reduces the spatial +resolution of the spectrogram whilst increasing the number of feature +channels. Each layer in the encoder produces a feature map that is stored +for use as a skip connection in the corresponding decoder layer. -The `Encoder`'s `forward` method returns outputs from all intermediate layers, -suitable for skip connections, while the `encode` method returns only the final -bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also -provided. +The encoder is fully configurable: the type, number, and parameters of the +downsampling blocks are described by an ``EncoderConfig`` object containing +an ordered list of block configuration objects (see ``batdetect2.models.blocks`` +for available block types). + +``Encoder.forward`` returns the outputs of *all* encoder layers as a list, +so that skip connections are available to the decoder. +``Encoder.encode`` returns only the final output (the input to the bottleneck). + +A default configuration ``DEFAULT_ENCODER_CONFIG`` is provided and used by +``build_encoder`` when no explicit configuration is supplied. """ from typing import Annotated, List @@ -53,35 +54,32 @@ EncoderLayerConfig = Annotated[ class EncoderConfig(BaseConfig): - """Configuration for building the sequential Encoder module. - - Defines the sequence of neural network blocks that constitute the encoder - (downsampling path). + """Configuration for the sequential ``Encoder`` module. Attributes ---------- layers : List[EncoderLayerConfig] - An ordered list of configuration objects, each defining one layer or - block in the encoder sequence. Each item must be a valid block config - (e.g., `ConvConfig`, `FreqCoordConvDownConfig`, - `StandardConvDownConfig`) including a `name` field and necessary - parameters like `out_channels`. Input channels for each layer are - inferred sequentially. The list must contain at least one layer. + Ordered list of block configuration objects defining the encoder's + downsampling stages. Each entry specifies the block type (via its + ``name`` field) and any block-specific parameters such as + ``out_channels``. Input channels for each block are inferred + automatically from the output of the previous block. Must contain + at least one entry. """ layers: List[EncoderLayerConfig] = Field(min_length=1) class Encoder(nn.Module): - """Sequential Encoder module composed of configurable downscaling layers. + """Sequential encoder module composed of configurable downsampling layers. - Constructs the downsampling path of an encoder-decoder network by stacking - multiple downscaling blocks. + Executes a series of downsampling blocks in order, storing the output of + each block so that it can be passed as a skip connection to the + corresponding decoder layer. - The `forward` method executes the sequence and returns the output feature - map from *each* downscaling stage, facilitating the implementation of skip - connections in U-Net-like architectures. The `encode` method returns only - the final output tensor (bottleneck features). + ``forward`` returns the outputs of *all* layers (useful when skip + connections are needed). ``encode`` returns only the final output + (the input to the bottleneck). Attributes ---------- @@ -89,14 +87,14 @@ class Encoder(nn.Module): Number of channels expected in the input tensor. input_height : int Height (frequency bins) expected in the input tensor. - output_channels : int - Number of channels in the final output tensor (bottleneck). + out_channels : int + Number of channels in the final output tensor (bottleneck input). output_height : int - Height (frequency bins) expected in the output tensor. + Height (frequency bins) of the final output tensor. layers : nn.ModuleList - The sequence of instantiated downscaling layer modules. + Sequence of instantiated downsampling block modules. depth : int - The number of downscaling layers in the encoder. + Number of downsampling layers. """ def __init__( @@ -107,23 +105,22 @@ class Encoder(nn.Module): input_height: int = 128, in_channels: int = 1, ): - """Initialize the Encoder module. + """Initialise the Encoder module. - Note: This constructor is typically called internally by the - `build_encoder` factory function, which prepares the `layers` list. + This constructor is typically called by the ``build_encoder`` factory + function, which takes care of building the ``layers`` list from a + configuration object. Parameters ---------- output_channels : int Number of channels produced by the final layer. output_height : int - The expected height of the output tensor. + Height of the output tensor after all layers have been applied. layers : List[nn.Module] - A list of pre-instantiated downscaling layer modules (e.g., - `StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired - sequence. + Pre-built downsampling block modules in execution order. input_height : int, default=128 - Expected height of the input tensor. + Expected height of the input tensor (frequency bins). in_channels : int, default=1 Expected number of channels in the input tensor. """ @@ -138,29 +135,30 @@ class Encoder(nn.Module): self.depth = len(self.layers) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: - """Pass input through encoder layers, returns all intermediate outputs. + """Pass input through all encoder layers and return every output. - This method is typically used when the Encoder is part of a U-Net or - similar architecture requiring skip connections. + Used when skip connections are needed (e.g. in a U-Net decoder). Parameters ---------- x : torch.Tensor - Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match - `self.in_channels` and `H_in` must match `self.input_height`. + Input spectrogram feature map, shape ``(B, C_in, H_in, W)``. + ``C_in`` must match ``self.in_channels`` and ``H_in`` must + match ``self.input_height``. Returns ------- List[torch.Tensor] - A list containing the output tensors from *each* downscaling layer - in the sequence. `outputs[0]` is the output of the first layer, - `outputs[-1]` is the final output (bottleneck) of the encoder. + Output tensors from every layer in order. + ``outputs[0]`` is the output of the first (shallowest) layer; + ``outputs[-1]`` is the output of the last (deepest) layer, + which serves as the input to the bottleneck. Raises ------ ValueError - If input tensor channel count or height does not match expected - values. + If the input channel count or height does not match the + expected values. """ if x.shape[1] != self.in_channels: raise ValueError( @@ -183,28 +181,29 @@ class Encoder(nn.Module): return outputs def encode(self, x: torch.Tensor) -> torch.Tensor: - """Pass input through encoder layers, returning only the final output. + """Pass input through all encoder layers and return only the final output. - This method provides access to the bottleneck features produced after - the last downscaling layer. + Use this when skip connections are not needed and you only require + the bottleneck feature map. Parameters ---------- x : torch.Tensor - Input tensor, shape `(B, C_in, H_in, W)`. Must match expected - `in_channels` and `input_height`. + Input spectrogram feature map, shape ``(B, C_in, H_in, W)``. + Must satisfy the same shape requirements as ``forward``. Returns ------- torch.Tensor - The final output tensor (bottleneck features) from the last layer - of the encoder. Shape `(B, C_out, H_out, W_out)`. + Output of the last encoder layer, shape + ``(B, C_out, H_out, W)``, where ``C_out`` is + ``self.out_channels`` and ``H_out`` is ``self.output_height``. Raises ------ ValueError - If input tensor channel count or height does not match expected - values. + If the input channel count or height does not match the + expected values. """ if x.shape[1] != self.in_channels: raise ValueError( @@ -236,14 +235,17 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig( ), ], ) -"""Default configuration for the Encoder. +"""Default encoder configuration used in standard BatDetect2 models. -Specifies an architecture typically used in BatDetect2: -- Input: 1 channel, 128 frequency bins. -- Layer 1: FreqCoordConvDown -> 32 channels, H=64 -- Layer 2: FreqCoordConvDown -> 64 channels, H=32 -- Layer 3: FreqCoordConvDown -> 128 channels, H=16 -- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck) +Assumes a 1-channel input with 128 frequency bins and produces the +following feature maps: + +- Stage 1 (``FreqCoordConvDown``): 32 channels, height 64. +- Stage 2 (``FreqCoordConvDown``): 64 channels, height 32. +- Stage 3 (``LayerGroup``): + + - ``FreqCoordConvDown``: 128 channels, height 16. + - ``ConvBlock``: 256 channels, height 16 (bottleneck input). """ @@ -252,42 +254,38 @@ def build_encoder( input_height: int, config: EncoderConfig | None = None, ) -> Encoder: - """Factory function to build an Encoder instance from configuration. + """Build an ``Encoder`` from configuration. - Constructs a sequential `Encoder` module based on the layer sequence - defined in an `EncoderConfig` object and the provided input dimensions. - If no config is provided, uses the default layer sequence from - `DEFAULT_ENCODER_CONFIG`. - - It iteratively builds the layers using the unified - `build_layer_from_config` factory (from `.blocks`), tracking the changing - number of channels and feature map height required for each subsequent - layer, especially for coordinate- aware blocks. + Constructs a sequential ``Encoder`` by iterating over the block + configurations in ``config.layers``, building each block with + ``build_layer``, and tracking the channel count and feature-map height + as they change through the sequence. Parameters ---------- in_channels : int - The number of channels expected in the input tensor to the encoder. - Must be > 0. + Number of channels in the input spectrogram tensor. Must be + positive. input_height : int - The height (frequency bins) expected in the input tensor. Must be > 0. - Crucial for initializing coordinate-aware layers correctly. + Height (number of frequency bins) of the input spectrogram. + Must be positive and should be divisible by + ``2 ** (number of downsampling stages)`` to avoid size mismatches + later in the network. config : EncoderConfig, optional - The configuration object detailing the sequence of layers and their - parameters. If None, `DEFAULT_ENCODER_CONFIG` is used. + Configuration specifying the layer sequence. Defaults to + ``DEFAULT_ENCODER_CONFIG`` if not provided. Returns ------- Encoder - An initialized `Encoder` module. + An initialised ``Encoder`` module. Raises ------ ValueError - If `in_channels` or `input_height` are not positive, or if the layer - configuration is invalid (e.g., empty list, unknown `name`). - NotImplementedError - If `build_layer_from_config` encounters an unknown `name`. + If ``in_channels`` or ``input_height`` are not positive. + KeyError + If a layer configuration specifies an unknown block type. """ if in_channels <= 0 or input_height <= 0: raise ValueError("in_channels and input_height must be positive.") diff --git a/src/batdetect2/models/heads.py b/src/batdetect2/models/heads.py index 1fd9f8e..65a2a40 100644 --- a/src/batdetect2/models/heads.py +++ b/src/batdetect2/models/heads.py @@ -1,20 +1,19 @@ -"""Prediction Head modules for BatDetect2 models. +"""Prediction heads attached to the backbone feature map. -This module defines simple `torch.nn.Module` subclasses that serve as -prediction heads, typically attached to the output feature map of a backbone -network +Each head is a lightweight ``torch.nn.Module`` that applies a 1×1 +convolution to map backbone feature channels to one specific type of +output required by BatDetect2: -Each head is responsible for generating one specific type of output required -by the BatDetect2 task: -- `DetectorHead`: Predicts the probability of sound event presence. -- `ClassifierHead`: Predicts the probability distribution over target classes. -- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding -box. +- ``DetectorHead``: single-channel detection probability heatmap (sigmoid + activation). +- ``ClassifierHead``: multi-class probability map over the target bat + species / call types (softmax activation). +- ``BBoxHead``: two-channel map of predicted call duration (time axis) and + bandwidth (frequency axis) at each location (no activation; raw + regression output). -These heads use 1x1 convolutions to map the backbone feature channels -to the desired number of output channels for each prediction task at each -spatial location, followed by an appropriate activation function (e.g., sigmoid -for detection, softmax for classification, none for size regression). +All three heads share the same input feature map produced by the backbone, +so they can be evaluated in parallel in a single forward pass. """ import torch @@ -28,42 +27,35 @@ __all__ = [ class ClassifierHead(nn.Module): - """Prediction head for multi-class classification probabilities. + """Prediction head for species / call-type classification probabilities. - Takes an input feature map and produces a probability map where each - channel corresponds to a specific target class. It uses a 1x1 convolution - to map input channels to `num_classes + 1` outputs (one for each target - class plus an assumed background/generic class), applies softmax across the - channels, and returns the probabilities for the specific target classes - (excluding the last background/generic channel). + Takes a backbone feature map and produces a probability map where each + channel corresponds to a target class. Internally the 1×1 convolution + maps ``in_channels`` to ``num_classes + 1`` logits (the extra channel + represents a generic background / unknown category); a softmax is then + applied across the channel dimension and the background channel is + discarded before returning. Parameters ---------- num_classes : int - The number of specific target classes the model should predict - (excluding any background or generic category). Must be positive. + Number of target classes (bat species or call types) to predict, + excluding the background category. Must be positive. in_channels : int - Number of channels in the input feature map tensor from the backbone. - Must be positive. + Number of channels in the backbone feature map. Must be positive. Attributes ---------- num_classes : int - Number of specific output classes. + Number of specific output classes (background excluded). in_channels : int Number of input channels expected. classifier : nn.Conv2d - The 1x1 convolutional layer used for prediction. - Output channels = num_classes + 1. - - Raises - ------ - ValueError - If `num_classes` or `in_channels` are not positive. + 1×1 convolution with ``num_classes + 1`` output channels. """ def __init__(self, num_classes: int, in_channels: int): - """Initialize the ClassifierHead.""" + """Initialise the ClassifierHead.""" super().__init__() self.num_classes = num_classes @@ -76,20 +68,20 @@ class ClassifierHead(nn.Module): ) def forward(self, features: torch.Tensor) -> torch.Tensor: - """Compute class probabilities from input features. + """Compute per-class probabilities from backbone features. Parameters ---------- features : torch.Tensor - Input feature map tensor from the backbone, typically with shape - `(B, C_in, H, W)`. `C_in` must match `self.in_channels`. + Backbone feature map, shape ``(B, C_in, H, W)``. Returns ------- torch.Tensor - Class probability map tensor with shape `(B, num_classes, H, W)`. - Contains probabilities for the specific target classes after - softmax, excluding the implicit background/generic class channel. + Class probability map, shape ``(B, num_classes, H, W)``. + Values are softmax probabilities in the range [0, 1] and + sum to less than 1 per location (the background probability + is discarded). """ logits = self.classifier(features) probs = torch.softmax(logits, dim=1) @@ -97,36 +89,30 @@ class ClassifierHead(nn.Module): class DetectorHead(nn.Module): - """Prediction head for sound event detection probability. + """Prediction head for detection probability (is a call present here?). - Takes an input feature map and produces a single-channel heatmap where - each value represents the probability ([0, 1]) of a relevant sound event - (of any class) being present at that spatial location. + Produces a single-channel heatmap where each value indicates the + probability ([0, 1]) that a bat call of *any* species is present at + that time–frequency location in the spectrogram. - Uses a 1x1 convolution to map input channels to 1 output channel, followed - by a sigmoid activation function. + Applies a 1×1 convolution mapping ``in_channels`` → 1, followed by + sigmoid activation. Parameters ---------- in_channels : int - Number of channels in the input feature map tensor from the backbone. - Must be positive. + Number of channels in the backbone feature map. Must be positive. Attributes ---------- in_channels : int Number of input channels expected. detector : nn.Conv2d - The 1x1 convolutional layer mapping to a single output channel. - - Raises - ------ - ValueError - If `in_channels` is not positive. + 1×1 convolution with a single output channel. """ def __init__(self, in_channels: int): - """Initialize the DetectorHead.""" + """Initialise the DetectorHead.""" super().__init__() self.in_channels = in_channels @@ -138,62 +124,49 @@ class DetectorHead(nn.Module): ) def forward(self, features: torch.Tensor) -> torch.Tensor: - """Compute detection probabilities from input features. + """Compute detection probabilities from backbone features. Parameters ---------- features : torch.Tensor - Input feature map tensor from the backbone, typically with shape - `(B, C_in, H, W)`. `C_in` must match `self.in_channels`. + Backbone feature map, shape ``(B, C_in, H, W)``. Returns ------- torch.Tensor - Detection probability heatmap tensor with shape `(B, 1, H, W)`. - Values are in the range [0, 1] due to the sigmoid activation. - - Raises - ------ - RuntimeError - If input channel count does not match `self.in_channels`. + Detection probability heatmap, shape ``(B, 1, H, W)``. + Values are in the range [0, 1]. """ return torch.sigmoid(self.detector(features)) class BBoxHead(nn.Module): - """Prediction head for bounding box size dimensions. + """Prediction head for bounding box size (duration and bandwidth). - Takes an input feature map and produces a two-channel map where each - channel represents a predicted size dimension (typically width/duration and - height/bandwidth) for a potential sound event at that spatial location. + Produces a two-channel map where channel 0 predicts the scaled duration + (time-axis extent) and channel 1 predicts the scaled bandwidth + (frequency-axis extent) of the call at each spectrogram location. - Uses a 1x1 convolution to map input channels to 2 output channels. No - activation function is typically applied, as size prediction is often - treated as a direct regression task. The output values usually represent - *scaled* dimensions that need to be un-scaled during postprocessing. + Applies a 1×1 convolution mapping ``in_channels`` → 2 with no + activation function (raw regression output). The predicted values are + in a scaled space and must be converted to real units (seconds and Hz) + during postprocessing. Parameters ---------- in_channels : int - Number of channels in the input feature map tensor from the backbone. - Must be positive. + Number of channels in the backbone feature map. Must be positive. Attributes ---------- in_channels : int Number of input channels expected. bbox : nn.Conv2d - The 1x1 convolutional layer mapping to 2 output channels - (width, height). - - Raises - ------ - ValueError - If `in_channels` is not positive. + 1×1 convolution with 2 output channels (duration, bandwidth). """ def __init__(self, in_channels: int): - """Initialize the BBoxHead.""" + """Initialise the BBoxHead.""" super().__init__() self.in_channels = in_channels @@ -205,19 +178,19 @@ class BBoxHead(nn.Module): ) def forward(self, features: torch.Tensor) -> torch.Tensor: - """Compute predicted bounding box dimensions from input features. + """Predict call duration and bandwidth from backbone features. Parameters ---------- features : torch.Tensor - Input feature map tensor from the backbone, typically with shape - `(B, C_in, H, W)`. `C_in` must match `self.in_channels`. + Backbone feature map, shape ``(B, C_in, H, W)``. Returns ------- torch.Tensor - Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually - represents scaled width, Channel 1 scaled height. These values - need to be un-scaled during postprocessing. + Size prediction tensor, shape ``(B, 2, H, W)``. Channel 0 is + the predicted scaled duration; channel 1 is the predicted + scaled bandwidth. Values must be rescaled to real units during + postprocessing. """ return self.bbox(features) diff --git a/tests/test_models/test_backbones.py b/tests/test_models/test_backbones.py index f04fcfa..2f9c00a 100644 --- a/tests/test_models/test_backbones.py +++ b/tests/test_models/test_backbones.py @@ -1,14 +1,4 @@ -"""Tests for backbone configuration loading and the backbone registry. - -Covers: -- UNetBackboneConfig default construction and field values. -- build_backbone with default and explicit configs. -- load_backbone_config loading from a YAML file. -- load_backbone_config with a nested field path. -- load_backbone_config round-trip: YAML → config → build_backbone. -- Registry registration and dispatch for UNetBackbone. -- BackboneConfig discriminated union validation. -""" +"""Tests for backbone configuration loading and the backbone registry.""" from pathlib import Path from typing import Callable @@ -25,10 +15,6 @@ from batdetect2.models.backbones import ( ) from batdetect2.typing.models import BackboneModel -# --------------------------------------------------------------------------- -# UNetBackboneConfig -# --------------------------------------------------------------------------- - def test_unet_backbone_config_defaults(): """Default config has expected field values.""" @@ -57,11 +43,6 @@ def test_unet_backbone_config_extra_fields_ignored(): assert not hasattr(config, "unknown_field") -# --------------------------------------------------------------------------- -# build_backbone -# --------------------------------------------------------------------------- - - def test_build_backbone_default(): """Building with no config uses UNetBackbone defaults.""" backbone = build_backbone() @@ -83,15 +64,9 @@ def test_build_backbone_custom_config(): def test_build_backbone_returns_backbone_model(): """build_backbone always returns a BackboneModel instance.""" backbone = build_backbone() - assert isinstance(backbone, BackboneModel) -# --------------------------------------------------------------------------- -# Registry -# --------------------------------------------------------------------------- - - def test_registry_has_unet_backbone(): """The backbone registry has UNetBackbone registered.""" config_types = backbone_registry.get_config_types() @@ -124,11 +99,6 @@ def test_registry_build_unknown_name_raises(): backbone_registry.build(FakeConfig()) # type: ignore[arg-type] -# --------------------------------------------------------------------------- -# BackboneConfig discriminated union -# --------------------------------------------------------------------------- - - def test_backbone_config_validates_unet_from_dict(): """BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name.""" from pydantic import TypeAdapter @@ -151,11 +121,6 @@ def test_backbone_config_invalid_name_raises(): adapter.validate_python({"name": "NonExistentBackbone"}) -# --------------------------------------------------------------------------- -# load_backbone_config -# --------------------------------------------------------------------------- - - def test_load_backbone_config_from_yaml( create_temp_yaml: Callable[[str], Path], ): @@ -218,11 +183,6 @@ deprecated_field: 99 assert config.input_height == 128 -# --------------------------------------------------------------------------- -# Round-trip: YAML → config → build_backbone -# --------------------------------------------------------------------------- - - def test_round_trip_yaml_to_build_backbone( create_temp_yaml: Callable[[str], Path], ):