mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added better docstrings to types module
This commit is contained in:
parent
907e05ea48
commit
e00674f628
@ -1,5 +1,22 @@
|
||||
"""Defines shared interfaces (ABCs) and data structures for models.
|
||||
|
||||
This module centralizes the definitions of core data structures, like the
|
||||
standard model output container (`ModelOutput`), and establishes abstract base
|
||||
classes (ABCs) using `abc.ABC` and `torch.nn.Module`. These define contracts
|
||||
for fundamental model components, ensuring modularity and consistent
|
||||
interaction within the `batdetect2.models` package.
|
||||
|
||||
Key components:
|
||||
- `ModelOutput`: Standard structure for outputs from detection models.
|
||||
- `BackboneModel`: Generic interface for any feature extraction backbone.
|
||||
- `EncoderDecoderModel`: Specialized interface for backbones with distinct
|
||||
encoder-decoder stages (e.g., U-Net), providing access to intermediate
|
||||
features.
|
||||
- `DetectionModel`: Interface for the complete end-to-end detection model.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple, Tuple
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -7,68 +24,202 @@ import torch.nn as nn
|
||||
__all__ = [
|
||||
"ModelOutput",
|
||||
"BackboneModel",
|
||||
"DetectionModel",
|
||||
]
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Output of the detection model.
|
||||
"""Standard container for the outputs of a BatDetect2 detection model.
|
||||
|
||||
Each of the tensors has a shape of
|
||||
This structure groups the different prediction tensors produced by the
|
||||
model for a batch of input spectrograms. All tensors typically share the
|
||||
same spatial dimensions (height H, width W) corresponding to the model's
|
||||
output resolution, and the same batch size (N).
|
||||
|
||||
`(batch_size, num_channels, spec_height, spec_width)`.
|
||||
|
||||
Where `spec_height` and `spec_width` are the height and width of the
|
||||
input spectrograms.
|
||||
|
||||
They contain localised information of:
|
||||
|
||||
1. The probability of a bounding box detection at the given location.
|
||||
2. The predicted size of the bounding box at the given location.
|
||||
3. The probabilities of each class at the given location before softmax.
|
||||
4. Features used to make the predictions at the given location.
|
||||
Attributes
|
||||
----------
|
||||
detection_probs : torch.Tensor
|
||||
Tensor containing the probability of sound event presence at each
|
||||
location in the output grid.
|
||||
Shape: `(N, 1, H, W)`
|
||||
size_preds : torch.Tensor
|
||||
Tensor containing the predicted size dimensions
|
||||
(e.g., width and height) for a potential bounding box at each location.
|
||||
Shape: `(N, 2, H, W)` (Channel 0 typically width, Channel 1 height)
|
||||
class_probs : torch.Tensor
|
||||
Tensor containing the predicted probabilities (or logits, depending on
|
||||
the final activation) for each target class at each location.
|
||||
The number of channels corresponds to the number of specific classes
|
||||
defined in the `Targets` configuration.
|
||||
Shape: `(N, num_classes, H, W)`
|
||||
features : torch.Tensor
|
||||
Tensor containing features extracted by the model's backbone. These
|
||||
might be used for downstream tasks or analysis. The number of channels
|
||||
depends on the specific model architecture.
|
||||
Shape: `(N, num_features, H, W)`
|
||||
"""
|
||||
|
||||
detection_probs: torch.Tensor
|
||||
"""Tensor with predict detection probabilities."""
|
||||
|
||||
size_preds: torch.Tensor
|
||||
"""Tensor with predicted bounding box sizes."""
|
||||
|
||||
class_probs: torch.Tensor
|
||||
"""Tensor with predicted class probabilities."""
|
||||
|
||||
features: torch.Tensor
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class BackboneModel(ABC, nn.Module):
|
||||
"""Abstract Base Class for generic feature extraction backbone models.
|
||||
|
||||
Defines the minimal interface for a feature extractor network within a
|
||||
BatDetect2 model. Its primary role is to process an input spectrogram
|
||||
tensor and produce a spatially rich feature map tensor, which is then
|
||||
typically consumed by separate prediction heads (for detection,
|
||||
classification, size).
|
||||
|
||||
This base class is agnostic to the specific internal architecture (e.g.,
|
||||
it could be a simple CNN, a U-Net, a Transformer, etc.). Concrete
|
||||
implementations must inherit from this class and `torch.nn.Module`,
|
||||
implement the `forward` method, and define the required attributes.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
Expected height (number of frequency bins) of the input spectrogram
|
||||
tensor that the backbone is designed to process.
|
||||
out_channels : int
|
||||
Number of channels in the final feature map tensor produced by the
|
||||
backbone's `forward` method.
|
||||
"""
|
||||
|
||||
input_height: int
|
||||
"""Height of the input spectrogram."""
|
||||
|
||||
encoder_channels: Tuple[int, ...]
|
||||
"""Tuple specifying the number of channels for each convolutional layer
|
||||
in the encoder. The length of this tuple determines the number of
|
||||
encoder layers."""
|
||||
|
||||
decoder_channels: Tuple[int, ...]
|
||||
"""Tuple specifying the number of channels for each convolutional layer
|
||||
in the decoder. The length of this tuple determines the number of
|
||||
decoder layers."""
|
||||
|
||||
bottleneck_channels: int
|
||||
"""Number of channels in the bottleneck layer, which connects the
|
||||
encoder and decoder."""
|
||||
"""Expected input spectrogram height (frequency bins)."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of channels in the final output feature map produced by the
|
||||
backbone model."""
|
||||
"""Number of output channels in the final feature map."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the model."""
|
||||
"""Perform the forward pass to extract features from the spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||
`frequency_bins` should match `self.input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output feature map tensor, typically with shape
|
||||
`(batch_size, self.out_channels, output_height, output_width)`.
|
||||
The spatial dimensions (`output_height`, `output_width`) depend
|
||||
on the specific backbone architecture (e.g., they might match the
|
||||
input or be downsampled).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EncoderDecoderModel(BackboneModel):
|
||||
"""Abstract Base Class for Encoder-Decoder style backbone models.
|
||||
|
||||
This class specializes `BackboneModel` for architectures that have distinct
|
||||
encoder stages (downsampling path), a bottleneck, and decoder stages
|
||||
(upsampling path).
|
||||
|
||||
It provides separate abstract methods for the `encode` and `decode` steps,
|
||||
allowing access to the intermediate "bottleneck" features produced by the
|
||||
encoder. This can be useful for tasks like transfer learning or specialized
|
||||
analyses.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
(Inherited from BackboneModel) Expected input spectrogram height.
|
||||
out_channels : int
|
||||
(Inherited from BackboneModel) Number of output channels in the final
|
||||
feature map produced by the decoder/forward pass.
|
||||
bottleneck_channels : int
|
||||
Number of channels in the feature map produced by the encoder at its
|
||||
deepest point (the bottleneck), before the decoder starts.
|
||||
"""
|
||||
|
||||
bottleneck_channels: int
|
||||
"""Number of channels at the encoder's bottleneck."""
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Process the input spectrogram through the encoder part.
|
||||
|
||||
Takes the input spectrogram and passes it through the downsampling path
|
||||
of the network up to the bottleneck layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The encoded feature map from the bottleneck layer, typically with
|
||||
shape `(batch_size, self.bottleneck_channels, bottleneck_height,
|
||||
bottleneck_width)`. The spatial dimensions are usually downsampled
|
||||
relative to the input.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor:
|
||||
"""Process the bottleneck features through the decoder part.
|
||||
|
||||
Takes the encoded feature map from the bottleneck and passes it through
|
||||
the upsampling path (potentially using skip connections from the
|
||||
encoder) to produce the final output feature map.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encoded : torch.Tensor
|
||||
The bottleneck feature map tensor produced by the `encode` method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The final output feature map tensor, typically with shape
|
||||
`(batch_size, self.out_channels, output_height, output_width)`.
|
||||
This should match the output shape of the `forward` method.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DetectionModel(ABC, nn.Module):
|
||||
"""Abstract Base Class for complete BatDetect2 detection models.
|
||||
|
||||
Defines the interface for the overall model that takes an input spectrogram
|
||||
and produces all necessary outputs for detection, classification, and size
|
||||
prediction, packaged within a `ModelOutput` object.
|
||||
|
||||
Concrete implementations typically combine a `BackboneModel` for feature
|
||||
extraction with specific prediction heads for each output type. They must
|
||||
inherit from this class and `torch.nn.Module`, and implement the `forward`
|
||||
method.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the detection model."""
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
"""Perform the forward pass of the full detection model.
|
||||
|
||||
Processes the input spectrogram through the backbone and prediction
|
||||
heads to generate all required output tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ModelOutput
|
||||
A NamedTuple containing the prediction tensors: `detection_probs`,
|
||||
`size_preds`, `class_probs`, and `features`.
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user