Update model config

This commit is contained in:
mbsantiago 2025-08-27 23:58:07 +01:00
parent dba6d2d918
commit 0b5ac96fe8
4 changed files with 102 additions and 101 deletions

View File

@ -88,7 +88,9 @@ model:
out_channels: 256 out_channels: 256
bottleneck: bottleneck:
channels: 256 channels: 256
self_attention: true layers:
- block_type: SelfAttention
attention_channels: 256
decoder: decoder:
layers: layers:
- block_type: FreqCoordConvUp - block_type: FreqCoordConvUp

View File

@ -26,13 +26,14 @@ for creating a standard BatDetect2 model instance is the `build_model` function
provided here. provided here.
""" """
from typing import Optional from typing import List, Optional
import torch import torch
from lightning import LightningModule from lightning import LightningModule
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.backbones import ( from batdetect2.models.backbones import (
Backbone, Backbone,
BackboneConfig, BackboneConfig,
@ -66,8 +67,8 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.postprocess import PostprocessConfig, build_postprocessor from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets from batdetect2.targets import TargetConfig, build_targets
from batdetect2.typing.models import DetectionModel, ModelOutput from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import PostprocessorProtocol from batdetect2.typing.postprocess import Detections, PostprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol from batdetect2.typing.targets import TargetProtocol
@ -119,9 +120,12 @@ class Model(LightningModule):
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.postprocessor = postprocessor self.postprocessor = postprocessor
self.targets = targets self.targets = targets
self.save_hyperparameters()
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, wav: torch.Tensor) -> List[Detections]:
return self.detector(spec) spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)
class ModelConfig(BaseConfig): class ModelConfig(BaseConfig):
@ -139,7 +143,6 @@ def build_model(config: Optional[ModelConfig] = None):
targets = build_targets(config=config.targets) targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess) preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
targets=targets,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.postprocess, config=config.postprocess,
) )
@ -153,3 +156,9 @@ def build_model(config: Optional[ModelConfig] = None):
preprocessor=preprocessor, preprocessor=preprocessor,
targets=targets, targets=targets,
) )
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)

View File

@ -55,6 +55,12 @@ __all__ = [
] ]
class SelfAttentionConfig(BaseConfig):
block_type: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int
temperature: float = 1
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
"""Self-Attention mechanism operating along the time dimension. """Self-Attention mechanism operating along the time dimension.
@ -115,6 +121,7 @@ class SelfAttention(nn.Module):
# Note, does not encode position information (absolute or relative) # Note, does not encode position information (absolute or relative)
self.temperature = temperature self.temperature = temperature
self.att_dim = attention_channels self.att_dim = attention_channels
self.key_fun = nn.Linear(in_channels, attention_channels) self.key_fun = nn.Linear(in_channels, attention_channels)
self.value_fun = nn.Linear(in_channels, attention_channels) self.value_fun = nn.Linear(in_channels, attention_channels)
self.query_fun = nn.Linear(in_channels, attention_channels) self.query_fun = nn.Linear(in_channels, attention_channels)
@ -654,6 +661,7 @@ LayerConfig = Annotated[
StandardConvDownConfig, StandardConvDownConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
StandardConvUpConfig, StandardConvUpConfig,
SelfAttentionConfig,
"LayerGroupConfig", "LayerGroupConfig",
], ],
Field(discriminator="block_type"), Field(discriminator="block_type"),
@ -769,6 +777,17 @@ def build_layer_from_config(
input_height * 2, input_height * 2,
) )
if config.block_type == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
),
config.attention_channels,
input_height,
)
if config.block_type == "LayerGroup": if config.block_type == "LayerGroup":
current_channels = in_channels current_channels = in_channels
current_height = input_height current_height = input_height

View File

@ -14,47 +14,27 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration. module based on the provided configuration.
""" """
from typing import Optional from typing import Annotated, List, Optional, Union
import torch import torch
from pydantic import Field
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import SelfAttention, VerticalConv from batdetect2.models.blocks import (
LayerConfig,
SelfAttentionConfig,
VerticalConv,
build_layer_from_config,
)
__all__ = [ __all__ = [
"BottleneckConfig", "BottleneckConfig",
"Bottleneck", "Bottleneck",
"BottleneckAttn",
"build_bottleneck", "build_bottleneck",
] ]
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
----------
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
"""
channels: int
self_attention: bool
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
"""Base Bottleneck module for Encoder-Decoder architectures. """Base Bottleneck module for Encoder-Decoder architectures.
@ -99,16 +79,24 @@ class Bottleneck(nn.Module):
input_height: int, input_height: int,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
bottleneck_channels: Optional[int] = None,
layers: Optional[List[torch.nn.Module]] = None,
) -> None: ) -> None:
"""Initialize the base Bottleneck layer.""" """Initialize the base Bottleneck layer."""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.input_height = input_height self.input_height = input_height
self.out_channels = out_channels self.out_channels = out_channels
self.bottleneck_channels = (
bottleneck_channels
if bottleneck_channels is not None
else out_channels
)
self.layers = nn.ModuleList(layers or [])
self.conv_vert = VerticalConv( self.conv_vert = VerticalConv(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=self.bottleneck_channels,
input_height=input_height, input_height=input_height,
) )
@ -132,73 +120,52 @@ class Bottleneck(nn.Module):
convolution. convolution.
""" """
x = self.conv_vert(x) x = self.conv_vert(x)
for layer in self.layers:
x = layer(x)
return x.repeat([1, 1, self.input_height, 1]) return x.repeat([1, 1, self.input_height, 1])
class BottleneckAttn(Bottleneck): BottleneckLayerConfig = Annotated[
"""Bottleneck module including a Self-Attention layer. Union[SelfAttentionConfig,],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
the initial `VerticalConv`. This allows the bottleneck to capture global
temporal dependencies in the summarized frequency features before passing
them to the decoder.
Sequence: VerticalConv -> SelfAttention -> Repeat Height. class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Parameters Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
---------- ----------
input_height : int channels : int
Height (frequency bins) of the input tensor from the encoder. The number of output channels produced by the main convolutional layer
in_channels : int within the bottleneck. This often matches the number of channels coming
Number of channels in the input tensor from the encoder. from the last encoder stage, but can be different. Must be positive.
out_channels : int This also defines the channel dimensions used within the optional
Number of output channels produced by the `VerticalConv` and `SelfAttention` layer.
subsequently processed and output by this bottleneck. Also determines self_attention : bool
the input/output channels of the internal `SelfAttention` layer. If True, includes a `SelfAttention` layer operating on the time
attention : nn.Module dimension after an initial `VerticalConv` layer within the bottleneck.
An initialized `SelfAttention` module instance. If False, only the initial `VerticalConv` (and height repetition) is
performed.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
""" """
def __init__( channels: int
self, layers: List[BottleneckLayerConfig] = Field(
input_height: int, default_factory=list,
in_channels: int, )
out_channels: int,
attention: nn.Module,
) -> None:
"""Initialize the Bottleneck with Self-Attention."""
super().__init__(input_height, in_channels, out_channels)
self.attention = attention
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input tensor.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
and repeating the height dimension.
"""
x = self.conv_vert(x)
x = self.attention(x)
return x.repeat([1, 1, self.input_height, 1])
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig( DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
channels=256, channels=256,
self_attention=True, layers=[
SelfAttentionConfig(attention_channels=256),
],
) )
@ -234,21 +201,25 @@ def build_bottleneck(
""" """
config = config or DEFAULT_BOTTLENECK_CONFIG config = config or DEFAULT_BOTTLENECK_CONFIG
if config.self_attention: current_channels = in_channels
attention = SelfAttention( current_height = input_height
in_channels=config.channels,
attention_channels=config.channels,
)
return BottleneckAttn( layers = []
input_height=input_height,
in_channels=in_channels, for layer_config in config.layers:
out_channels=config.channels, layer, current_channels, current_height = build_layer_from_config(
attention=attention, input_height=current_height,
in_channels=current_channels,
config=layer_config,
) )
assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height"
)
layers.append(layer)
return Bottleneck( return Bottleneck(
input_height=input_height, input_height=input_height,
in_channels=in_channels, in_channels=in_channels,
out_channels=config.channels, out_channels=config.channels,
layers=layers,
) )