diff --git a/example_data/config.yaml b/example_data/config.yaml index 266746a..42c3cea 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -88,7 +88,9 @@ model: out_channels: 256 bottleneck: channels: 256 - self_attention: true + layers: + - block_type: SelfAttention + attention_channels: 256 decoder: layers: - block_type: FreqCoordConvUp diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 47d6328..5d35512 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -26,13 +26,14 @@ for creating a standard BatDetect2 model instance is the `build_model` function provided here. """ -from typing import Optional +from typing import List, Optional import torch from lightning import LightningModule from pydantic import Field +from soundevent.data import PathLike -from batdetect2.configs import BaseConfig +from batdetect2.configs import BaseConfig, load_config from batdetect2.models.backbones import ( Backbone, BackboneConfig, @@ -66,8 +67,8 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead from batdetect2.postprocess import PostprocessConfig, build_postprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.targets import TargetConfig, build_targets -from batdetect2.typing.models import DetectionModel, ModelOutput -from batdetect2.typing.postprocess import PostprocessorProtocol +from batdetect2.typing.models import DetectionModel +from batdetect2.typing.postprocess import Detections, PostprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.targets import TargetProtocol @@ -119,9 +120,12 @@ class Model(LightningModule): self.preprocessor = preprocessor self.postprocessor = postprocessor self.targets = targets + self.save_hyperparameters() - def forward(self, spec: torch.Tensor) -> ModelOutput: - return self.detector(spec) + def forward(self, wav: torch.Tensor) -> List[Detections]: + spec = self.preprocessor(wav) + outputs = self.detector(spec) + return self.postprocessor(outputs) class ModelConfig(BaseConfig): @@ -139,7 +143,6 @@ def build_model(config: Optional[ModelConfig] = None): targets = build_targets(config=config.targets) preprocessor = build_preprocessor(config=config.preprocess) postprocessor = build_postprocessor( - targets=targets, preprocessor=preprocessor, config=config.postprocess, ) @@ -153,3 +156,9 @@ def build_model(config: Optional[ModelConfig] = None): preprocessor=preprocessor, targets=targets, ) + + +def load_model_config( + path: PathLike, field: Optional[str] = None +) -> ModelConfig: + return load_config(path, schema=ModelConfig, field=field) diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index 39965b9..b1bbec1 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -55,6 +55,12 @@ __all__ = [ ] +class SelfAttentionConfig(BaseConfig): + block_type: Literal["SelfAttention"] = "SelfAttention" + attention_channels: int + temperature: float = 1 + + class SelfAttention(nn.Module): """Self-Attention mechanism operating along the time dimension. @@ -115,6 +121,7 @@ class SelfAttention(nn.Module): # Note, does not encode position information (absolute or relative) self.temperature = temperature self.att_dim = attention_channels + self.key_fun = nn.Linear(in_channels, attention_channels) self.value_fun = nn.Linear(in_channels, attention_channels) self.query_fun = nn.Linear(in_channels, attention_channels) @@ -654,6 +661,7 @@ LayerConfig = Annotated[ StandardConvDownConfig, FreqCoordConvUpConfig, StandardConvUpConfig, + SelfAttentionConfig, "LayerGroupConfig", ], Field(discriminator="block_type"), @@ -769,6 +777,17 @@ def build_layer_from_config( input_height * 2, ) + if config.block_type == "SelfAttention": + return ( + SelfAttention( + in_channels=in_channels, + attention_channels=config.attention_channels, + temperature=config.temperature, + ), + config.attention_channels, + input_height, + ) + if config.block_type == "LayerGroup": current_channels = in_channels current_height = input_height diff --git a/src/batdetect2/models/bottleneck.py b/src/batdetect2/models/bottleneck.py index d93ea55..7a38a11 100644 --- a/src/batdetect2/models/bottleneck.py +++ b/src/batdetect2/models/bottleneck.py @@ -14,47 +14,27 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck module based on the provided configuration. """ -from typing import Optional +from typing import Annotated, List, Optional, Union import torch +from pydantic import Field from torch import nn from batdetect2.configs import BaseConfig -from batdetect2.models.blocks import SelfAttention, VerticalConv +from batdetect2.models.blocks import ( + LayerConfig, + SelfAttentionConfig, + VerticalConv, + build_layer_from_config, +) __all__ = [ "BottleneckConfig", "Bottleneck", - "BottleneckAttn", "build_bottleneck", ] -class BottleneckConfig(BaseConfig): - """Configuration for the bottleneck layer(s). - - Defines the number of channels within the bottleneck and whether to include - a self-attention mechanism. - - Attributes - ---------- - channels : int - The number of output channels produced by the main convolutional layer - within the bottleneck. This often matches the number of channels coming - from the last encoder stage, but can be different. Must be positive. - This also defines the channel dimensions used within the optional - `SelfAttention` layer. - self_attention : bool - If True, includes a `SelfAttention` layer operating on the time - dimension after an initial `VerticalConv` layer within the bottleneck. - If False, only the initial `VerticalConv` (and height repetition) is - performed. - """ - - channels: int - self_attention: bool - - class Bottleneck(nn.Module): """Base Bottleneck module for Encoder-Decoder architectures. @@ -99,16 +79,24 @@ class Bottleneck(nn.Module): input_height: int, in_channels: int, out_channels: int, + bottleneck_channels: Optional[int] = None, + layers: Optional[List[torch.nn.Module]] = None, ) -> None: """Initialize the base Bottleneck layer.""" super().__init__() self.in_channels = in_channels self.input_height = input_height self.out_channels = out_channels + self.bottleneck_channels = ( + bottleneck_channels + if bottleneck_channels is not None + else out_channels + ) + self.layers = nn.ModuleList(layers or []) self.conv_vert = VerticalConv( in_channels=in_channels, - out_channels=out_channels, + out_channels=self.bottleneck_channels, input_height=input_height, ) @@ -132,73 +120,52 @@ class Bottleneck(nn.Module): convolution. """ x = self.conv_vert(x) + + for layer in self.layers: + x = layer(x) + return x.repeat([1, 1, self.input_height, 1]) -class BottleneckAttn(Bottleneck): - """Bottleneck module including a Self-Attention layer. +BottleneckLayerConfig = Annotated[ + Union[SelfAttentionConfig,], + Field(discriminator="block_type"), +] +"""Type alias for the discriminated union of block configs usable in Decoder.""" - Extends the base `Bottleneck` by inserting a `SelfAttention` layer after - the initial `VerticalConv`. This allows the bottleneck to capture global - temporal dependencies in the summarized frequency features before passing - them to the decoder. - Sequence: VerticalConv -> SelfAttention -> Repeat Height. +class BottleneckConfig(BaseConfig): + """Configuration for the bottleneck layer(s). - Parameters + Defines the number of channels within the bottleneck and whether to include + a self-attention mechanism. + + Attributes ---------- - input_height : int - Height (frequency bins) of the input tensor from the encoder. - in_channels : int - Number of channels in the input tensor from the encoder. - out_channels : int - Number of output channels produced by the `VerticalConv` and - subsequently processed and output by this bottleneck. Also determines - the input/output channels of the internal `SelfAttention` layer. - attention : nn.Module - An initialized `SelfAttention` module instance. - - Raises - ------ - ValueError - If `input_height`, `in_channels`, or `out_channels` are not positive. + channels : int + The number of output channels produced by the main convolutional layer + within the bottleneck. This often matches the number of channels coming + from the last encoder stage, but can be different. Must be positive. + This also defines the channel dimensions used within the optional + `SelfAttention` layer. + self_attention : bool + If True, includes a `SelfAttention` layer operating on the time + dimension after an initial `VerticalConv` layer within the bottleneck. + If False, only the initial `VerticalConv` (and height repetition) is + performed. """ - def __init__( - self, - input_height: int, - in_channels: int, - out_channels: int, - attention: nn.Module, - ) -> None: - """Initialize the Bottleneck with Self-Attention.""" - super().__init__(input_height, in_channels, out_channels) - self.attention = attention - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Process input tensor. - - Parameters - ---------- - x : torch.Tensor - Input tensor from the encoder bottleneck, shape - `(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`, - `H_in` must match `self.input_height`. - - Returns - ------- - torch.Tensor - Output tensor, shape `(B, C_out, H_in, W)`, after applying attention - and repeating the height dimension. - """ - x = self.conv_vert(x) - x = self.attention(x) - return x.repeat([1, 1, self.input_height, 1]) + channels: int + layers: List[BottleneckLayerConfig] = Field( + default_factory=list, + ) DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig( channels=256, - self_attention=True, + layers=[ + SelfAttentionConfig(attention_channels=256), + ], ) @@ -234,21 +201,25 @@ def build_bottleneck( """ config = config or DEFAULT_BOTTLENECK_CONFIG - if config.self_attention: - attention = SelfAttention( - in_channels=config.channels, - attention_channels=config.channels, - ) + current_channels = in_channels + current_height = input_height - return BottleneckAttn( - input_height=input_height, - in_channels=in_channels, - out_channels=config.channels, - attention=attention, + layers = [] + + for layer_config in config.layers: + layer, current_channels, current_height = build_layer_from_config( + input_height=current_height, + in_channels=current_channels, + config=layer_config, ) + assert current_height == input_height, ( + "Bottleneck layers should not change the spectrogram height" + ) + layers.append(layer) return Bottleneck( input_height=input_height, in_channels=in_channels, out_channels=config.channels, + layers=layers, )