mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Update model config
This commit is contained in:
parent
dba6d2d918
commit
0b5ac96fe8
@ -88,7 +88,9 @@ model:
|
|||||||
out_channels: 256
|
out_channels: 256
|
||||||
bottleneck:
|
bottleneck:
|
||||||
channels: 256
|
channels: 256
|
||||||
self_attention: true
|
layers:
|
||||||
|
- block_type: SelfAttention
|
||||||
|
attention_channels: 256
|
||||||
decoder:
|
decoder:
|
||||||
layers:
|
layers:
|
||||||
- block_type: FreqCoordConvUp
|
- block_type: FreqCoordConvUp
|
||||||
|
|||||||
@ -26,13 +26,14 @@ for creating a standard BatDetect2 model instance is the `build_model` function
|
|||||||
provided here.
|
provided here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lightning import LightningModule
|
from lightning import LightningModule
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Backbone,
|
Backbone,
|
||||||
BackboneConfig,
|
BackboneConfig,
|
||||||
@ -66,8 +67,8 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
|||||||
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
from batdetect2.typing.models import DetectionModel, ModelOutput
|
from batdetect2.typing.models import DetectionModel
|
||||||
from batdetect2.typing.postprocess import PostprocessorProtocol
|
from batdetect2.typing.postprocess import Detections, PostprocessorProtocol
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
@ -119,9 +120,12 @@ class Model(LightningModule):
|
|||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, wav: torch.Tensor) -> List[Detections]:
|
||||||
return self.detector(spec)
|
spec = self.preprocessor(wav)
|
||||||
|
outputs = self.detector(spec)
|
||||||
|
return self.postprocessor(outputs)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseConfig):
|
class ModelConfig(BaseConfig):
|
||||||
@ -139,7 +143,6 @@ def build_model(config: Optional[ModelConfig] = None):
|
|||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
preprocessor = build_preprocessor(config=config.preprocess)
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
targets=targets,
|
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
@ -153,3 +156,9 @@ def build_model(config: Optional[ModelConfig] = None):
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_config(
|
||||||
|
path: PathLike, field: Optional[str] = None
|
||||||
|
) -> ModelConfig:
|
||||||
|
return load_config(path, schema=ModelConfig, field=field)
|
||||||
|
|||||||
@ -55,6 +55,12 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttentionConfig(BaseConfig):
|
||||||
|
block_type: Literal["SelfAttention"] = "SelfAttention"
|
||||||
|
attention_channels: int
|
||||||
|
temperature: float = 1
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
"""Self-Attention mechanism operating along the time dimension.
|
"""Self-Attention mechanism operating along the time dimension.
|
||||||
|
|
||||||
@ -115,6 +121,7 @@ class SelfAttention(nn.Module):
|
|||||||
# Note, does not encode position information (absolute or relative)
|
# Note, does not encode position information (absolute or relative)
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.att_dim = attention_channels
|
self.att_dim = attention_channels
|
||||||
|
|
||||||
self.key_fun = nn.Linear(in_channels, attention_channels)
|
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.value_fun = nn.Linear(in_channels, attention_channels)
|
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.query_fun = nn.Linear(in_channels, attention_channels)
|
self.query_fun = nn.Linear(in_channels, attention_channels)
|
||||||
@ -654,6 +661,7 @@ LayerConfig = Annotated[
|
|||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
|
SelfAttentionConfig,
|
||||||
"LayerGroupConfig",
|
"LayerGroupConfig",
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
@ -769,6 +777,17 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.block_type == "SelfAttention":
|
||||||
|
return (
|
||||||
|
SelfAttention(
|
||||||
|
in_channels=in_channels,
|
||||||
|
attention_channels=config.attention_channels,
|
||||||
|
temperature=config.temperature,
|
||||||
|
),
|
||||||
|
config.attention_channels,
|
||||||
|
input_height,
|
||||||
|
)
|
||||||
|
|
||||||
if config.block_type == "LayerGroup":
|
if config.block_type == "LayerGroup":
|
||||||
current_channels = in_channels
|
current_channels = in_channels
|
||||||
current_height = input_height
|
current_height = input_height
|
||||||
|
|||||||
@ -14,47 +14,27 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
|
|||||||
module based on the provided configuration.
|
module based on the provided configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Annotated, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import SelfAttention, VerticalConv
|
from batdetect2.models.blocks import (
|
||||||
|
LayerConfig,
|
||||||
|
SelfAttentionConfig,
|
||||||
|
VerticalConv,
|
||||||
|
build_layer_from_config,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BottleneckConfig",
|
"BottleneckConfig",
|
||||||
"Bottleneck",
|
"Bottleneck",
|
||||||
"BottleneckAttn",
|
|
||||||
"build_bottleneck",
|
"build_bottleneck",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class BottleneckConfig(BaseConfig):
|
|
||||||
"""Configuration for the bottleneck layer(s).
|
|
||||||
|
|
||||||
Defines the number of channels within the bottleneck and whether to include
|
|
||||||
a self-attention mechanism.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
channels : int
|
|
||||||
The number of output channels produced by the main convolutional layer
|
|
||||||
within the bottleneck. This often matches the number of channels coming
|
|
||||||
from the last encoder stage, but can be different. Must be positive.
|
|
||||||
This also defines the channel dimensions used within the optional
|
|
||||||
`SelfAttention` layer.
|
|
||||||
self_attention : bool
|
|
||||||
If True, includes a `SelfAttention` layer operating on the time
|
|
||||||
dimension after an initial `VerticalConv` layer within the bottleneck.
|
|
||||||
If False, only the initial `VerticalConv` (and height repetition) is
|
|
||||||
performed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
channels: int
|
|
||||||
self_attention: bool
|
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck(nn.Module):
|
class Bottleneck(nn.Module):
|
||||||
"""Base Bottleneck module for Encoder-Decoder architectures.
|
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||||
|
|
||||||
@ -99,16 +79,24 @@ class Bottleneck(nn.Module):
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
|
bottleneck_channels: Optional[int] = None,
|
||||||
|
layers: Optional[List[torch.nn.Module]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the base Bottleneck layer."""
|
"""Initialize the base Bottleneck layer."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
self.bottleneck_channels = (
|
||||||
|
bottleneck_channels
|
||||||
|
if bottleneck_channels is not None
|
||||||
|
else out_channels
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(layers or [])
|
||||||
|
|
||||||
self.conv_vert = VerticalConv(
|
self.conv_vert = VerticalConv(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=self.bottleneck_channels,
|
||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -132,73 +120,52 @@ class Bottleneck(nn.Module):
|
|||||||
convolution.
|
convolution.
|
||||||
"""
|
"""
|
||||||
x = self.conv_vert(x)
|
x = self.conv_vert(x)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
return x.repeat([1, 1, self.input_height, 1])
|
return x.repeat([1, 1, self.input_height, 1])
|
||||||
|
|
||||||
|
|
||||||
class BottleneckAttn(Bottleneck):
|
BottleneckLayerConfig = Annotated[
|
||||||
"""Bottleneck module including a Self-Attention layer.
|
Union[SelfAttentionConfig,],
|
||||||
|
Field(discriminator="block_type"),
|
||||||
|
]
|
||||||
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|
||||||
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
|
|
||||||
the initial `VerticalConv`. This allows the bottleneck to capture global
|
|
||||||
temporal dependencies in the summarized frequency features before passing
|
|
||||||
them to the decoder.
|
|
||||||
|
|
||||||
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
|
class BottleneckConfig(BaseConfig):
|
||||||
|
"""Configuration for the bottleneck layer(s).
|
||||||
|
|
||||||
Parameters
|
Defines the number of channels within the bottleneck and whether to include
|
||||||
|
a self-attention mechanism.
|
||||||
|
|
||||||
|
Attributes
|
||||||
----------
|
----------
|
||||||
input_height : int
|
channels : int
|
||||||
Height (frequency bins) of the input tensor from the encoder.
|
The number of output channels produced by the main convolutional layer
|
||||||
in_channels : int
|
within the bottleneck. This often matches the number of channels coming
|
||||||
Number of channels in the input tensor from the encoder.
|
from the last encoder stage, but can be different. Must be positive.
|
||||||
out_channels : int
|
This also defines the channel dimensions used within the optional
|
||||||
Number of output channels produced by the `VerticalConv` and
|
`SelfAttention` layer.
|
||||||
subsequently processed and output by this bottleneck. Also determines
|
self_attention : bool
|
||||||
the input/output channels of the internal `SelfAttention` layer.
|
If True, includes a `SelfAttention` layer operating on the time
|
||||||
attention : nn.Module
|
dimension after an initial `VerticalConv` layer within the bottleneck.
|
||||||
An initialized `SelfAttention` module instance.
|
If False, only the initial `VerticalConv` (and height repetition) is
|
||||||
|
performed.
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
channels: int
|
||||||
self,
|
layers: List[BottleneckLayerConfig] = Field(
|
||||||
input_height: int,
|
default_factory=list,
|
||||||
in_channels: int,
|
)
|
||||||
out_channels: int,
|
|
||||||
attention: nn.Module,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the Bottleneck with Self-Attention."""
|
|
||||||
super().__init__(input_height, in_channels, out_channels)
|
|
||||||
self.attention = attention
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Process input tensor.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x : torch.Tensor
|
|
||||||
Input tensor from the encoder bottleneck, shape
|
|
||||||
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
|
||||||
`H_in` must match `self.input_height`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
|
|
||||||
and repeating the height dimension.
|
|
||||||
"""
|
|
||||||
x = self.conv_vert(x)
|
|
||||||
x = self.attention(x)
|
|
||||||
return x.repeat([1, 1, self.input_height, 1])
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||||
channels=256,
|
channels=256,
|
||||||
self_attention=True,
|
layers=[
|
||||||
|
SelfAttentionConfig(attention_channels=256),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -234,21 +201,25 @@ def build_bottleneck(
|
|||||||
"""
|
"""
|
||||||
config = config or DEFAULT_BOTTLENECK_CONFIG
|
config = config or DEFAULT_BOTTLENECK_CONFIG
|
||||||
|
|
||||||
if config.self_attention:
|
current_channels = in_channels
|
||||||
attention = SelfAttention(
|
current_height = input_height
|
||||||
in_channels=config.channels,
|
|
||||||
attention_channels=config.channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
return BottleneckAttn(
|
layers = []
|
||||||
input_height=input_height,
|
|
||||||
in_channels=in_channels,
|
for layer_config in config.layers:
|
||||||
out_channels=config.channels,
|
layer, current_channels, current_height = build_layer_from_config(
|
||||||
attention=attention,
|
input_height=current_height,
|
||||||
|
in_channels=current_channels,
|
||||||
|
config=layer_config,
|
||||||
)
|
)
|
||||||
|
assert current_height == input_height, (
|
||||||
|
"Bottleneck layers should not change the spectrogram height"
|
||||||
|
)
|
||||||
|
layers.append(layer)
|
||||||
|
|
||||||
return Bottleneck(
|
return Bottleneck(
|
||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=config.channels,
|
out_channels=config.channels,
|
||||||
|
layers=layers,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user