Update model config

This commit is contained in:
mbsantiago 2025-08-27 23:58:07 +01:00
parent dba6d2d918
commit 0b5ac96fe8
4 changed files with 102 additions and 101 deletions

View File

@ -88,7 +88,9 @@ model:
out_channels: 256
bottleneck:
channels: 256
self_attention: true
layers:
- block_type: SelfAttention
attention_channels: 256
decoder:
layers:
- block_type: FreqCoordConvUp

View File

@ -26,13 +26,14 @@ for creating a standard BatDetect2 model instance is the `build_model` function
provided here.
"""
from typing import Optional
from typing import List, Optional
import torch
from lightning import LightningModule
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.backbones import (
Backbone,
BackboneConfig,
@ -66,8 +67,8 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.typing.models import DetectionModel, ModelOutput
from batdetect2.typing.postprocess import PostprocessorProtocol
from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import Detections, PostprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
@ -119,9 +120,12 @@ class Model(LightningModule):
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
self.save_hyperparameters()
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec)
def forward(self, wav: torch.Tensor) -> List[Detections]:
spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)
class ModelConfig(BaseConfig):
@ -139,7 +143,6 @@ def build_model(config: Optional[ModelConfig] = None):
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor(
targets=targets,
preprocessor=preprocessor,
config=config.postprocess,
)
@ -153,3 +156,9 @@ def build_model(config: Optional[ModelConfig] = None):
preprocessor=preprocessor,
targets=targets,
)
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)

View File

@ -55,6 +55,12 @@ __all__ = [
]
class SelfAttentionConfig(BaseConfig):
block_type: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int
temperature: float = 1
class SelfAttention(nn.Module):
"""Self-Attention mechanism operating along the time dimension.
@ -115,6 +121,7 @@ class SelfAttention(nn.Module):
# Note, does not encode position information (absolute or relative)
self.temperature = temperature
self.att_dim = attention_channels
self.key_fun = nn.Linear(in_channels, attention_channels)
self.value_fun = nn.Linear(in_channels, attention_channels)
self.query_fun = nn.Linear(in_channels, attention_channels)
@ -654,6 +661,7 @@ LayerConfig = Annotated[
StandardConvDownConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
SelfAttentionConfig,
"LayerGroupConfig",
],
Field(discriminator="block_type"),
@ -769,6 +777,17 @@ def build_layer_from_config(
input_height * 2,
)
if config.block_type == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
),
config.attention_channels,
input_height,
)
if config.block_type == "LayerGroup":
current_channels = in_channels
current_height = input_height

View File

@ -14,47 +14,27 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration.
"""
from typing import Optional
from typing import Annotated, List, Optional, Union
import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import SelfAttention, VerticalConv
from batdetect2.models.blocks import (
LayerConfig,
SelfAttentionConfig,
VerticalConv,
build_layer_from_config,
)
__all__ = [
"BottleneckConfig",
"Bottleneck",
"BottleneckAttn",
"build_bottleneck",
]
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
----------
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
"""
channels: int
self_attention: bool
class Bottleneck(nn.Module):
"""Base Bottleneck module for Encoder-Decoder architectures.
@ -99,16 +79,24 @@ class Bottleneck(nn.Module):
input_height: int,
in_channels: int,
out_channels: int,
bottleneck_channels: Optional[int] = None,
layers: Optional[List[torch.nn.Module]] = None,
) -> None:
"""Initialize the base Bottleneck layer."""
super().__init__()
self.in_channels = in_channels
self.input_height = input_height
self.out_channels = out_channels
self.bottleneck_channels = (
bottleneck_channels
if bottleneck_channels is not None
else out_channels
)
self.layers = nn.ModuleList(layers or [])
self.conv_vert = VerticalConv(
in_channels=in_channels,
out_channels=out_channels,
out_channels=self.bottleneck_channels,
input_height=input_height,
)
@ -132,73 +120,52 @@ class Bottleneck(nn.Module):
convolution.
"""
x = self.conv_vert(x)
for layer in self.layers:
x = layer(x)
return x.repeat([1, 1, self.input_height, 1])
class BottleneckAttn(Bottleneck):
"""Bottleneck module including a Self-Attention layer.
BottleneckLayerConfig = Annotated[
Union[SelfAttentionConfig,],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
the initial `VerticalConv`. This allows the bottleneck to capture global
temporal dependencies in the summarized frequency features before passing
them to the decoder.
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Parameters
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
----------
input_height : int
Height (frequency bins) of the input tensor from the encoder.
in_channels : int
Number of channels in the input tensor from the encoder.
out_channels : int
Number of output channels produced by the `VerticalConv` and
subsequently processed and output by this bottleneck. Also determines
the input/output channels of the internal `SelfAttention` layer.
attention : nn.Module
An initialized `SelfAttention` module instance.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
"""
def __init__(
self,
input_height: int,
in_channels: int,
out_channels: int,
attention: nn.Module,
) -> None:
"""Initialize the Bottleneck with Self-Attention."""
super().__init__(input_height, in_channels, out_channels)
self.attention = attention
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input tensor.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
and repeating the height dimension.
"""
x = self.conv_vert(x)
x = self.attention(x)
return x.repeat([1, 1, self.input_height, 1])
channels: int
layers: List[BottleneckLayerConfig] = Field(
default_factory=list,
)
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
channels=256,
self_attention=True,
layers=[
SelfAttentionConfig(attention_channels=256),
],
)
@ -234,21 +201,25 @@ def build_bottleneck(
"""
config = config or DEFAULT_BOTTLENECK_CONFIG
if config.self_attention:
attention = SelfAttention(
in_channels=config.channels,
attention_channels=config.channels,
)
current_channels = in_channels
current_height = input_height
return BottleneckAttn(
input_height=input_height,
in_channels=in_channels,
out_channels=config.channels,
attention=attention,
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,
config=layer_config,
)
assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height"
)
layers.append(layer)
return Bottleneck(
input_height=input_height,
in_channels=in_channels,
out_channels=config.channels,
layers=layers,
)