mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added BlockGroups
This commit is contained in:
parent
bfcab0331e
commit
6498b6ca37
@ -27,7 +27,7 @@ A unified factory function `build_layer_from_config` allows creating instances
|
||||
of these blocks based on configuration objects.
|
||||
"""
|
||||
|
||||
from typing import Annotated, Literal, Tuple, Union
|
||||
from typing import Annotated, List, Literal, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -38,6 +38,7 @@ from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"ConvBlock",
|
||||
"BlockGroupConfig",
|
||||
"VerticalConv",
|
||||
"FreqCoordConvDownBlock",
|
||||
"StandardConvDownBlock",
|
||||
@ -653,12 +654,18 @@ LayerConfig = Annotated[
|
||||
StandardConvDownConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
"BlockGroupConfig",
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configuration models."""
|
||||
|
||||
|
||||
class BlockGroupConfig(BaseConfig):
|
||||
block_type: Literal["group"] = "group"
|
||||
blocks: List[LayerConfig]
|
||||
|
||||
|
||||
def build_layer_from_config(
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
@ -762,4 +769,20 @@ def build_layer_from_config(
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.block_type == "group":
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
blocks = []
|
||||
|
||||
for block_config in config.blocks:
|
||||
block, current_channels, current_height = build_layer_from_config(
|
||||
input_height=current_height,
|
||||
in_channels=current_channels,
|
||||
config=block_config,
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
return nn.Sequential(*blocks), current_channels, current_height
|
||||
|
||||
raise NotImplementedError(f"Unknown block type {config.block_type}")
|
||||
|
@ -26,6 +26,7 @@ from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
BlockGroupConfig,
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
@ -40,7 +41,12 @@ __all__ = [
|
||||
]
|
||||
|
||||
DecoderLayerConfig = Annotated[
|
||||
Union[ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig],
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
BlockGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
@ -191,8 +197,12 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
||||
layers=[
|
||||
FreqCoordConvUpConfig(out_channels=64),
|
||||
FreqCoordConvUpConfig(out_channels=32),
|
||||
BlockGroupConfig(
|
||||
blocks=[
|
||||
FreqCoordConvUpConfig(out_channels=32),
|
||||
ConvConfig(out_channels=32),
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
"""A default configuration for the Decoder's *layer sequence*.
|
||||
|
@ -28,6 +28,7 @@ from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
BlockGroupConfig,
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
@ -42,7 +43,12 @@ __all__ = [
|
||||
]
|
||||
|
||||
EncoderLayerConfig = Annotated[
|
||||
Union[ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig],
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
BlockGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||
@ -224,8 +230,12 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
||||
layers=[
|
||||
FreqCoordConvDownConfig(out_channels=32),
|
||||
FreqCoordConvDownConfig(out_channels=64),
|
||||
BlockGroupConfig(
|
||||
blocks=[
|
||||
FreqCoordConvDownConfig(out_channels=128),
|
||||
ConvConfig(out_channels=256),
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
"""Default configuration for the Encoder.
|
||||
|
Loading…
Reference in New Issue
Block a user