mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added better docstrings to types module
This commit is contained in:
parent
907e05ea48
commit
e00674f628
@ -1,5 +1,22 @@
|
|||||||
|
"""Defines shared interfaces (ABCs) and data structures for models.
|
||||||
|
|
||||||
|
This module centralizes the definitions of core data structures, like the
|
||||||
|
standard model output container (`ModelOutput`), and establishes abstract base
|
||||||
|
classes (ABCs) using `abc.ABC` and `torch.nn.Module`. These define contracts
|
||||||
|
for fundamental model components, ensuring modularity and consistent
|
||||||
|
interaction within the `batdetect2.models` package.
|
||||||
|
|
||||||
|
Key components:
|
||||||
|
- `ModelOutput`: Standard structure for outputs from detection models.
|
||||||
|
- `BackboneModel`: Generic interface for any feature extraction backbone.
|
||||||
|
- `EncoderDecoderModel`: Specialized interface for backbones with distinct
|
||||||
|
encoder-decoder stages (e.g., U-Net), providing access to intermediate
|
||||||
|
features.
|
||||||
|
- `DetectionModel`: Interface for the complete end-to-end detection model.
|
||||||
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import NamedTuple, Tuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -7,68 +24,202 @@ import torch.nn as nn
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
"BackboneModel",
|
"BackboneModel",
|
||||||
|
"DetectionModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ModelOutput(NamedTuple):
|
class ModelOutput(NamedTuple):
|
||||||
"""Output of the detection model.
|
"""Standard container for the outputs of a BatDetect2 detection model.
|
||||||
|
|
||||||
Each of the tensors has a shape of
|
This structure groups the different prediction tensors produced by the
|
||||||
|
model for a batch of input spectrograms. All tensors typically share the
|
||||||
|
same spatial dimensions (height H, width W) corresponding to the model's
|
||||||
|
output resolution, and the same batch size (N).
|
||||||
|
|
||||||
`(batch_size, num_channels, spec_height, spec_width)`.
|
Attributes
|
||||||
|
----------
|
||||||
Where `spec_height` and `spec_width` are the height and width of the
|
detection_probs : torch.Tensor
|
||||||
input spectrograms.
|
Tensor containing the probability of sound event presence at each
|
||||||
|
location in the output grid.
|
||||||
They contain localised information of:
|
Shape: `(N, 1, H, W)`
|
||||||
|
size_preds : torch.Tensor
|
||||||
1. The probability of a bounding box detection at the given location.
|
Tensor containing the predicted size dimensions
|
||||||
2. The predicted size of the bounding box at the given location.
|
(e.g., width and height) for a potential bounding box at each location.
|
||||||
3. The probabilities of each class at the given location before softmax.
|
Shape: `(N, 2, H, W)` (Channel 0 typically width, Channel 1 height)
|
||||||
4. Features used to make the predictions at the given location.
|
class_probs : torch.Tensor
|
||||||
|
Tensor containing the predicted probabilities (or logits, depending on
|
||||||
|
the final activation) for each target class at each location.
|
||||||
|
The number of channels corresponds to the number of specific classes
|
||||||
|
defined in the `Targets` configuration.
|
||||||
|
Shape: `(N, num_classes, H, W)`
|
||||||
|
features : torch.Tensor
|
||||||
|
Tensor containing features extracted by the model's backbone. These
|
||||||
|
might be used for downstream tasks or analysis. The number of channels
|
||||||
|
depends on the specific model architecture.
|
||||||
|
Shape: `(N, num_features, H, W)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
detection_probs: torch.Tensor
|
detection_probs: torch.Tensor
|
||||||
"""Tensor with predict detection probabilities."""
|
|
||||||
|
|
||||||
size_preds: torch.Tensor
|
size_preds: torch.Tensor
|
||||||
"""Tensor with predicted bounding box sizes."""
|
|
||||||
|
|
||||||
class_probs: torch.Tensor
|
class_probs: torch.Tensor
|
||||||
"""Tensor with predicted class probabilities."""
|
|
||||||
|
|
||||||
features: torch.Tensor
|
features: torch.Tensor
|
||||||
"""Tensor with intermediate features."""
|
|
||||||
|
|
||||||
|
|
||||||
class BackboneModel(ABC, nn.Module):
|
class BackboneModel(ABC, nn.Module):
|
||||||
|
"""Abstract Base Class for generic feature extraction backbone models.
|
||||||
|
|
||||||
|
Defines the minimal interface for a feature extractor network within a
|
||||||
|
BatDetect2 model. Its primary role is to process an input spectrogram
|
||||||
|
tensor and produce a spatially rich feature map tensor, which is then
|
||||||
|
typically consumed by separate prediction heads (for detection,
|
||||||
|
classification, size).
|
||||||
|
|
||||||
|
This base class is agnostic to the specific internal architecture (e.g.,
|
||||||
|
it could be a simple CNN, a U-Net, a Transformer, etc.). Concrete
|
||||||
|
implementations must inherit from this class and `torch.nn.Module`,
|
||||||
|
implement the `forward` method, and define the required attributes.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
Expected height (number of frequency bins) of the input spectrogram
|
||||||
|
tensor that the backbone is designed to process.
|
||||||
|
out_channels : int
|
||||||
|
Number of channels in the final feature map tensor produced by the
|
||||||
|
backbone's `forward` method.
|
||||||
|
"""
|
||||||
|
|
||||||
input_height: int
|
input_height: int
|
||||||
"""Height of the input spectrogram."""
|
"""Expected input spectrogram height (frequency bins)."""
|
||||||
|
|
||||||
encoder_channels: Tuple[int, ...]
|
|
||||||
"""Tuple specifying the number of channels for each convolutional layer
|
|
||||||
in the encoder. The length of this tuple determines the number of
|
|
||||||
encoder layers."""
|
|
||||||
|
|
||||||
decoder_channels: Tuple[int, ...]
|
|
||||||
"""Tuple specifying the number of channels for each convolutional layer
|
|
||||||
in the decoder. The length of this tuple determines the number of
|
|
||||||
decoder layers."""
|
|
||||||
|
|
||||||
bottleneck_channels: int
|
|
||||||
"""Number of channels in the bottleneck layer, which connects the
|
|
||||||
encoder and decoder."""
|
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
"""Number of channels in the final output feature map produced by the
|
"""Number of output channels in the final feature map."""
|
||||||
backbone model."""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
"""Forward pass of the model."""
|
"""Perform the forward pass to extract features from the spectrogram.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input spectrogram tensor, typically with shape
|
||||||
|
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||||
|
`frequency_bins` should match `self.input_height`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output feature map tensor, typically with shape
|
||||||
|
`(batch_size, self.out_channels, output_height, output_width)`.
|
||||||
|
The spatial dimensions (`output_height`, `output_width`) depend
|
||||||
|
on the specific backbone architecture (e.g., they might match the
|
||||||
|
input or be downsampled).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderDecoderModel(BackboneModel):
|
||||||
|
"""Abstract Base Class for Encoder-Decoder style backbone models.
|
||||||
|
|
||||||
|
This class specializes `BackboneModel` for architectures that have distinct
|
||||||
|
encoder stages (downsampling path), a bottleneck, and decoder stages
|
||||||
|
(upsampling path).
|
||||||
|
|
||||||
|
It provides separate abstract methods for the `encode` and `decode` steps,
|
||||||
|
allowing access to the intermediate "bottleneck" features produced by the
|
||||||
|
encoder. This can be useful for tasks like transfer learning or specialized
|
||||||
|
analyses.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
input_height : int
|
||||||
|
(Inherited from BackboneModel) Expected input spectrogram height.
|
||||||
|
out_channels : int
|
||||||
|
(Inherited from BackboneModel) Number of output channels in the final
|
||||||
|
feature map produced by the decoder/forward pass.
|
||||||
|
bottleneck_channels : int
|
||||||
|
Number of channels in the feature map produced by the encoder at its
|
||||||
|
deepest point (the bottleneck), before the decoder starts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
bottleneck_channels: int
|
||||||
|
"""Number of channels at the encoder's bottleneck."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def encode(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Process the input spectrogram through the encoder part.
|
||||||
|
|
||||||
|
Takes the input spectrogram and passes it through the downsampling path
|
||||||
|
of the network up to the bottleneck layer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input spectrogram tensor, typically with shape
|
||||||
|
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
The encoded feature map from the bottleneck layer, typically with
|
||||||
|
shape `(batch_size, self.bottleneck_channels, bottleneck_height,
|
||||||
|
bottleneck_width)`. The spatial dimensions are usually downsampled
|
||||||
|
relative to the input.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def decode(self, encoded: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Process the bottleneck features through the decoder part.
|
||||||
|
|
||||||
|
Takes the encoded feature map from the bottleneck and passes it through
|
||||||
|
the upsampling path (potentially using skip connections from the
|
||||||
|
encoder) to produce the final output feature map.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
encoded : torch.Tensor
|
||||||
|
The bottleneck feature map tensor produced by the `encode` method.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
The final output feature map tensor, typically with shape
|
||||||
|
`(batch_size, self.out_channels, output_height, output_width)`.
|
||||||
|
This should match the output shape of the `forward` method.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(ABC, nn.Module):
|
class DetectionModel(ABC, nn.Module):
|
||||||
|
"""Abstract Base Class for complete BatDetect2 detection models.
|
||||||
|
|
||||||
|
Defines the interface for the overall model that takes an input spectrogram
|
||||||
|
and produces all necessary outputs for detection, classification, and size
|
||||||
|
prediction, packaged within a `ModelOutput` object.
|
||||||
|
|
||||||
|
Concrete implementations typically combine a `BackboneModel` for feature
|
||||||
|
extraction with specific prediction heads for each output type. They must
|
||||||
|
inherit from this class and `torch.nn.Module`, and implement the `forward`
|
||||||
|
method.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
"""Forward pass of the detection model."""
|
"""Perform the forward pass of the full detection model.
|
||||||
|
|
||||||
|
Processes the input spectrogram through the backbone and prediction
|
||||||
|
heads to generate all required output tensors.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Input spectrogram tensor, typically with shape
|
||||||
|
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ModelOutput
|
||||||
|
A NamedTuple containing the prediction tensors: `detection_probs`,
|
||||||
|
`size_preds`, `class_probs`, and `features`.
|
||||||
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user