mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added BlockGroups
This commit is contained in:
parent
bfcab0331e
commit
6498b6ca37
@ -27,7 +27,7 @@ A unified factory function `build_layer_from_config` allows creating instances
|
|||||||
of these blocks based on configuration objects.
|
of these blocks based on configuration objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Literal, Tuple, Union
|
from typing import Annotated, List, Literal, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -38,6 +38,7 @@ from batdetect2.configs import BaseConfig
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConvBlock",
|
"ConvBlock",
|
||||||
|
"BlockGroupConfig",
|
||||||
"VerticalConv",
|
"VerticalConv",
|
||||||
"FreqCoordConvDownBlock",
|
"FreqCoordConvDownBlock",
|
||||||
"StandardConvDownBlock",
|
"StandardConvDownBlock",
|
||||||
@ -653,12 +654,18 @@ LayerConfig = Annotated[
|
|||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
|
"BlockGroupConfig",
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configuration models."""
|
"""Type alias for the discriminated union of block configuration models."""
|
||||||
|
|
||||||
|
|
||||||
|
class BlockGroupConfig(BaseConfig):
|
||||||
|
block_type: Literal["group"] = "group"
|
||||||
|
blocks: List[LayerConfig]
|
||||||
|
|
||||||
|
|
||||||
def build_layer_from_config(
|
def build_layer_from_config(
|
||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
@ -762,4 +769,20 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.block_type == "group":
|
||||||
|
current_channels = in_channels
|
||||||
|
current_height = input_height
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
|
||||||
|
for block_config in config.blocks:
|
||||||
|
block, current_channels, current_height = build_layer_from_config(
|
||||||
|
input_height=current_height,
|
||||||
|
in_channels=current_channels,
|
||||||
|
config=block_config,
|
||||||
|
)
|
||||||
|
blocks.append(block)
|
||||||
|
|
||||||
|
return nn.Sequential(*blocks), current_channels, current_height
|
||||||
|
|
||||||
raise NotImplementedError(f"Unknown block type {config.block_type}")
|
raise NotImplementedError(f"Unknown block type {config.block_type}")
|
||||||
|
@ -26,6 +26,7 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
|
BlockGroupConfig,
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
@ -40,7 +41,12 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
DecoderLayerConfig = Annotated[
|
DecoderLayerConfig = Annotated[
|
||||||
Union[ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig],
|
Union[
|
||||||
|
ConvConfig,
|
||||||
|
FreqCoordConvUpConfig,
|
||||||
|
StandardConvUpConfig,
|
||||||
|
BlockGroupConfig,
|
||||||
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
@ -191,8 +197,12 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
|||||||
layers=[
|
layers=[
|
||||||
FreqCoordConvUpConfig(out_channels=64),
|
FreqCoordConvUpConfig(out_channels=64),
|
||||||
FreqCoordConvUpConfig(out_channels=32),
|
FreqCoordConvUpConfig(out_channels=32),
|
||||||
FreqCoordConvUpConfig(out_channels=32),
|
BlockGroupConfig(
|
||||||
ConvConfig(out_channels=32),
|
blocks=[
|
||||||
|
FreqCoordConvUpConfig(out_channels=32),
|
||||||
|
ConvConfig(out_channels=32),
|
||||||
|
]
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
"""A default configuration for the Decoder's *layer sequence*.
|
"""A default configuration for the Decoder's *layer sequence*.
|
||||||
|
@ -28,6 +28,7 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
|
BlockGroupConfig,
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
@ -42,7 +43,12 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
EncoderLayerConfig = Annotated[
|
EncoderLayerConfig = Annotated[
|
||||||
Union[ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig],
|
Union[
|
||||||
|
ConvConfig,
|
||||||
|
FreqCoordConvDownConfig,
|
||||||
|
StandardConvDownConfig,
|
||||||
|
BlockGroupConfig,
|
||||||
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||||
@ -224,8 +230,12 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
|||||||
layers=[
|
layers=[
|
||||||
FreqCoordConvDownConfig(out_channels=32),
|
FreqCoordConvDownConfig(out_channels=32),
|
||||||
FreqCoordConvDownConfig(out_channels=64),
|
FreqCoordConvDownConfig(out_channels=64),
|
||||||
FreqCoordConvDownConfig(out_channels=128),
|
BlockGroupConfig(
|
||||||
ConvConfig(out_channels=256),
|
blocks=[
|
||||||
|
FreqCoordConvDownConfig(out_channels=128),
|
||||||
|
ConvConfig(out_channels=256),
|
||||||
|
]
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
"""Default configuration for the Encoder.
|
"""Default configuration for the Encoder.
|
||||||
|
Loading…
Reference in New Issue
Block a user