mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Rename BlockGroupConfig to LayerGroupConfig
This commit is contained in:
parent
cbb02cf69e
commit
166dad20bd
@ -27,9 +27,23 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.blocks import ConvBlock
|
from batdetect2.models.blocks import ConvBlock
|
||||||
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck
|
from batdetect2.models.bottleneck import (
|
||||||
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder
|
DEFAULT_BOTTLENECK_CONFIG,
|
||||||
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder
|
BottleneckConfig,
|
||||||
|
build_bottleneck,
|
||||||
|
)
|
||||||
|
from batdetect2.models.decoder import (
|
||||||
|
DEFAULT_DECODER_CONFIG,
|
||||||
|
Decoder,
|
||||||
|
DecoderConfig,
|
||||||
|
build_decoder,
|
||||||
|
)
|
||||||
|
from batdetect2.models.encoder import (
|
||||||
|
DEFAULT_ENCODER_CONFIG,
|
||||||
|
Encoder,
|
||||||
|
EncoderConfig,
|
||||||
|
build_encoder,
|
||||||
|
)
|
||||||
from batdetect2.models.types import BackboneModel
|
from batdetect2.models.types import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -186,9 +200,9 @@ class BackboneConfig(BaseConfig):
|
|||||||
|
|
||||||
input_height: int = 128
|
input_height: int = 128
|
||||||
in_channels: int = 1
|
in_channels: int = 1
|
||||||
encoder: Optional[EncoderConfig] = None
|
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||||
bottleneck: Optional[BottleneckConfig] = None
|
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||||
decoder: Optional[DecoderConfig] = None
|
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||||
out_channels: int = 32
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ from batdetect2.configs import BaseConfig
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConvBlock",
|
"ConvBlock",
|
||||||
"BlockGroupConfig",
|
"LayerGroupConfig",
|
||||||
"VerticalConv",
|
"VerticalConv",
|
||||||
"FreqCoordConvDownBlock",
|
"FreqCoordConvDownBlock",
|
||||||
"StandardConvDownBlock",
|
"StandardConvDownBlock",
|
||||||
@ -654,16 +654,16 @@ LayerConfig = Annotated[
|
|||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
"BlockGroupConfig",
|
"LayerGroupConfig",
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configuration models."""
|
"""Type alias for the discriminated union of block configuration models."""
|
||||||
|
|
||||||
|
|
||||||
class BlockGroupConfig(BaseConfig):
|
class LayerGroupConfig(BaseConfig):
|
||||||
block_type: Literal["group"] = "group"
|
block_type: Literal["LayerGroup"] = "LayerGroup"
|
||||||
blocks: List[LayerConfig]
|
layers: List[LayerConfig]
|
||||||
|
|
||||||
|
|
||||||
def build_layer_from_config(
|
def build_layer_from_config(
|
||||||
@ -769,13 +769,13 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.block_type == "group":
|
if config.block_type == "LayerGroup":
|
||||||
current_channels = in_channels
|
current_channels = in_channels
|
||||||
current_height = input_height
|
current_height = input_height
|
||||||
|
|
||||||
blocks = []
|
blocks = []
|
||||||
|
|
||||||
for block_config in config.blocks:
|
for block_config in config.layers:
|
||||||
block, current_channels, current_height = build_layer_from_config(
|
block, current_channels, current_height = build_layer_from_config(
|
||||||
input_height=current_height,
|
input_height=current_height,
|
||||||
in_channels=current_channels,
|
in_channels=current_channels,
|
||||||
|
@ -26,7 +26,7 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
BlockGroupConfig,
|
LayerGroupConfig,
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
@ -45,7 +45,7 @@ DecoderLayerConfig = Annotated[
|
|||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
BlockGroupConfig,
|
LayerGroupConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
@ -197,8 +197,8 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
|||||||
layers=[
|
layers=[
|
||||||
FreqCoordConvUpConfig(out_channels=64),
|
FreqCoordConvUpConfig(out_channels=64),
|
||||||
FreqCoordConvUpConfig(out_channels=32),
|
FreqCoordConvUpConfig(out_channels=32),
|
||||||
BlockGroupConfig(
|
LayerGroupConfig(
|
||||||
blocks=[
|
layers=[
|
||||||
FreqCoordConvUpConfig(out_channels=32),
|
FreqCoordConvUpConfig(out_channels=32),
|
||||||
ConvConfig(out_channels=32),
|
ConvConfig(out_channels=32),
|
||||||
]
|
]
|
||||||
|
@ -28,7 +28,7 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
BlockGroupConfig,
|
LayerGroupConfig,
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
@ -47,7 +47,7 @@ EncoderLayerConfig = Annotated[
|
|||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
BlockGroupConfig,
|
LayerGroupConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="block_type"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
@ -230,8 +230,8 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
|||||||
layers=[
|
layers=[
|
||||||
FreqCoordConvDownConfig(out_channels=32),
|
FreqCoordConvDownConfig(out_channels=32),
|
||||||
FreqCoordConvDownConfig(out_channels=64),
|
FreqCoordConvDownConfig(out_channels=64),
|
||||||
BlockGroupConfig(
|
LayerGroupConfig(
|
||||||
blocks=[
|
layers=[
|
||||||
FreqCoordConvDownConfig(out_channels=128),
|
FreqCoordConvDownConfig(out_channels=128),
|
||||||
ConvConfig(out_channels=256),
|
ConvConfig(out_channels=256),
|
||||||
]
|
]
|
||||||
|
@ -2,9 +2,8 @@ from typing import Annotated, Literal, Optional, Union
|
|||||||
|
|
||||||
from lightning.pytorch.loggers import Logger
|
from lightning.pytorch.loggers import Logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: str = "logs"
|
DEFAULT_LOGS_DIR: str = "logs"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user