From 166dad20bd0a8c2f821f9e097e867cc3eeb3f625 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 26 Jun 2025 10:04:42 -0600 Subject: [PATCH] Rename BlockGroupConfig to LayerGroupConfig --- src/batdetect2/models/backbones.py | 26 ++++++++++++++++++++------ src/batdetect2/models/blocks.py | 14 +++++++------- src/batdetect2/models/decoder.py | 8 ++++---- src/batdetect2/models/encoder.py | 8 ++++---- src/batdetect2/train/logging.py | 3 +-- 5 files changed, 36 insertions(+), 23 deletions(-) diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index 1728b58..e7f1ada 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -27,9 +27,23 @@ from torch import nn from batdetect2.configs import BaseConfig, load_config from batdetect2.models.blocks import ConvBlock -from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck -from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder -from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder +from batdetect2.models.bottleneck import ( + DEFAULT_BOTTLENECK_CONFIG, + BottleneckConfig, + build_bottleneck, +) +from batdetect2.models.decoder import ( + DEFAULT_DECODER_CONFIG, + Decoder, + DecoderConfig, + build_decoder, +) +from batdetect2.models.encoder import ( + DEFAULT_ENCODER_CONFIG, + Encoder, + EncoderConfig, + build_encoder, +) from batdetect2.models.types import BackboneModel __all__ = [ @@ -186,9 +200,9 @@ class BackboneConfig(BaseConfig): input_height: int = 128 in_channels: int = 1 - encoder: Optional[EncoderConfig] = None - bottleneck: Optional[BottleneckConfig] = None - decoder: Optional[DecoderConfig] = None + encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG + bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG + decoder: DecoderConfig = DEFAULT_DECODER_CONFIG out_channels: int = 32 diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index 4df17f1..39965b9 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -38,7 +38,7 @@ from batdetect2.configs import BaseConfig __all__ = [ "ConvBlock", - "BlockGroupConfig", + "LayerGroupConfig", "VerticalConv", "FreqCoordConvDownBlock", "StandardConvDownBlock", @@ -654,16 +654,16 @@ LayerConfig = Annotated[ StandardConvDownConfig, FreqCoordConvUpConfig, StandardConvUpConfig, - "BlockGroupConfig", + "LayerGroupConfig", ], Field(discriminator="block_type"), ] """Type alias for the discriminated union of block configuration models.""" -class BlockGroupConfig(BaseConfig): - block_type: Literal["group"] = "group" - blocks: List[LayerConfig] +class LayerGroupConfig(BaseConfig): + block_type: Literal["LayerGroup"] = "LayerGroup" + layers: List[LayerConfig] def build_layer_from_config( @@ -769,13 +769,13 @@ def build_layer_from_config( input_height * 2, ) - if config.block_type == "group": + if config.block_type == "LayerGroup": current_channels = in_channels current_height = input_height blocks = [] - for block_config in config.blocks: + for block_config in config.layers: block, current_channels, current_height = build_layer_from_config( input_height=current_height, in_channels=current_channels, diff --git a/src/batdetect2/models/decoder.py b/src/batdetect2/models/decoder.py index 91c1089..c3deed7 100644 --- a/src/batdetect2/models/decoder.py +++ b/src/batdetect2/models/decoder.py @@ -26,7 +26,7 @@ from torch import nn from batdetect2.configs import BaseConfig from batdetect2.models.blocks import ( - BlockGroupConfig, + LayerGroupConfig, ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig, @@ -45,7 +45,7 @@ DecoderLayerConfig = Annotated[ ConvConfig, FreqCoordConvUpConfig, StandardConvUpConfig, - BlockGroupConfig, + LayerGroupConfig, ], Field(discriminator="block_type"), ] @@ -197,8 +197,8 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig( layers=[ FreqCoordConvUpConfig(out_channels=64), FreqCoordConvUpConfig(out_channels=32), - BlockGroupConfig( - blocks=[ + LayerGroupConfig( + layers=[ FreqCoordConvUpConfig(out_channels=32), ConvConfig(out_channels=32), ] diff --git a/src/batdetect2/models/encoder.py b/src/batdetect2/models/encoder.py index 3a0c609..91087cb 100644 --- a/src/batdetect2/models/encoder.py +++ b/src/batdetect2/models/encoder.py @@ -28,7 +28,7 @@ from torch import nn from batdetect2.configs import BaseConfig from batdetect2.models.blocks import ( - BlockGroupConfig, + LayerGroupConfig, ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig, @@ -47,7 +47,7 @@ EncoderLayerConfig = Annotated[ ConvConfig, FreqCoordConvDownConfig, StandardConvDownConfig, - BlockGroupConfig, + LayerGroupConfig, ], Field(discriminator="block_type"), ] @@ -230,8 +230,8 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig( layers=[ FreqCoordConvDownConfig(out_channels=32), FreqCoordConvDownConfig(out_channels=64), - BlockGroupConfig( - blocks=[ + LayerGroupConfig( + layers=[ FreqCoordConvDownConfig(out_channels=128), ConvConfig(out_channels=256), ] diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index a6cdae5..fb70c78 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -2,9 +2,8 @@ from typing import Annotated, Literal, Optional, Union from lightning.pytorch.loggers import Logger from pydantic import Field -from soundevent.data import PathLike -from batdetect2.configs import BaseConfig, load_config +from batdetect2.configs import BaseConfig DEFAULT_LOGS_DIR: str = "logs"