From 652d076b463f7e485609825df7935cfc14659b35 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 8 Mar 2026 15:02:56 +0000 Subject: [PATCH] Add backbone registry --- src/batdetect2/config.py | 2 +- src/batdetect2/models/__init__.py | 13 ++- src/batdetect2/models/backbones.py | 141 +++++++++++++++------------- src/batdetect2/models/config.py | 96 ------------------- src/batdetect2/models/detectors.py | 8 +- tests/test_models/test_detectors.py | 2 +- 6 files changed, 92 insertions(+), 170 deletions(-) delete mode 100644 src/batdetect2/models/config.py diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py index 8412ede..f0274af 100644 --- a/src/batdetect2/config.py +++ b/src/batdetect2/config.py @@ -12,7 +12,7 @@ from batdetect2.evaluate.config import ( get_default_eval_config, ) from batdetect2.inference.config import InferenceConfig -from batdetect2.models.config import BackboneConfig +from batdetect2.models.backbones import BackboneConfig from batdetect2.postprocess.config import PostprocessConfig from batdetect2.preprocess.config import PreprocessingConfig from batdetect2.targets.config import TargetConfig diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 90ab1dc..80e00ba 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -30,7 +30,13 @@ from typing import List import torch -from batdetect2.models.backbones import Backbone, build_backbone +from batdetect2.models.backbones import ( + UNetBackbone, + UNetBackboneConfig, + BackboneConfig, + build_backbone, + load_backbone_config, +) from batdetect2.models.blocks import ( ConvConfig, FreqCoordConvDownConfig, @@ -43,7 +49,6 @@ from batdetect2.models.bottleneck import ( BottleneckConfig, build_bottleneck, ) -from batdetect2.models.config import BackboneConfig, load_backbone_config from batdetect2.models.decoder import ( DEFAULT_DECODER_CONFIG, DecoderConfig, @@ -66,7 +71,7 @@ from batdetect2.typing import ( __all__ = [ "BBoxHead", - "Backbone", + "UNetBackbone", "BackboneConfig", "Bottleneck", "BottleneckConfig", @@ -128,7 +133,7 @@ def build_model( from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets - config = config or BackboneConfig() + config = config or UNetBackboneConfig() targets = targets or build_targets() preprocessor = preprocessor or build_preprocessor() postprocessor = postprocessor or build_postprocessor( diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index 98548f8..bde9f64 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -18,16 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the network's total downsampling factor. """ -from typing import Tuple + +from typing import Annotated, Literal, Tuple import torch import torch.nn.functional as F -from torch import nn +from pydantic import Field + +from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.core.registries import Registry +from soundevent import data +from batdetect2.models.bottleneck import BottleneckConfig, DEFAULT_BOTTLENECK_CONFIG, build_bottleneck +from batdetect2.models.decoder import DecoderConfig, DEFAULT_DECODER_CONFIG, build_decoder +from batdetect2.models.encoder import EncoderConfig, DEFAULT_ENCODER_CONFIG, build_encoder -from batdetect2.models.bottleneck import build_bottleneck -from batdetect2.models.config import BackboneConfig -from batdetect2.models.decoder import build_decoder -from batdetect2.models.encoder import build_encoder from batdetect2.typing.models import ( BackboneModel, BottleneckProtocol, @@ -35,13 +39,38 @@ from batdetect2.typing.models import ( EncoderProtocol, ) + +class UNetBackboneConfig(BaseConfig): + name: Literal["UNetBackbone"] = "UNetBackbone" + input_height: int = 128 + in_channels: int = 1 + encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG + bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG + decoder: DecoderConfig = DEFAULT_DECODER_CONFIG + out_channels: int = 32 + +BackboneConfig = Annotated[ + UNetBackboneConfig, + Field(discriminator="name") +] + +def load_backbone_config( + path: data.PathLike, + field: str | None = None, +) -> BackboneConfig: + return load_config(path, schema=BackboneConfig, field=field) + +backbone_registry: Registry[BackboneModel, []] = Registry("backbone") + __all__ = [ - "Backbone", + "UNetBackbone", + "BackboneConfig", + "load_backbone_config", "build_backbone", ] -class Backbone(BackboneModel): +class UNetBackbone(BackboneModel): """Encoder-Decoder Backbone Network Implementation. Combines an Encoder, Bottleneck, and Decoder module sequentially, using @@ -149,66 +178,46 @@ class Backbone(BackboneModel): return x -def build_backbone(config: BackboneConfig) -> BackboneModel: - """Factory function to build a Backbone from configuration. - - Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on - the provided `BackboneConfig`, validates their compatibility, and assembles - them into a `Backbone` instance. - - Parameters - ---------- - config : BackboneConfig - The configuration object detailing the backbone architecture, including - input dimensions and configurations for encoder, bottleneck, and - decoder. - - Returns - ------- - BackboneModel - An initialized `Backbone` module ready for use. - - Raises - ------ - ValueError - If sub-component configurations are incompatible - (e.g., channel mismatches, decoder output height doesn't match backbone - input height). - NotImplementedError - If an unknown block type is specified in sub-configs. - """ - encoder = build_encoder( - in_channels=config.in_channels, - input_height=config.input_height, - config=config.encoder, - ) - - bottleneck = build_bottleneck( - input_height=encoder.output_height, - in_channels=encoder.out_channels, - config=config.bottleneck, - ) - - decoder = build_decoder( - in_channels=bottleneck.out_channels, - input_height=encoder.output_height, - config=config.decoder, - ) - - if decoder.output_height != config.input_height: - raise ValueError( - "Invalid configuration: Decoder output height " - f"({decoder.output_height}) must match the Backbone input height " - f"({config.input_height}). Check encoder/decoder layer " - "configurations and input/bottleneck heights." + @backbone_registry.register(UNetBackboneConfig) + @staticmethod + def from_config(config: UNetBackboneConfig) -> BackboneModel: + encoder = build_encoder( + in_channels=config.in_channels, + input_height=config.input_height, + config=config.encoder, ) - return Backbone( - input_height=config.input_height, - encoder=encoder, - decoder=decoder, - bottleneck=bottleneck, - ) + bottleneck = build_bottleneck( + input_height=encoder.output_height, + in_channels=encoder.out_channels, + config=config.bottleneck, + ) + + decoder = build_decoder( + in_channels=bottleneck.out_channels, + input_height=encoder.output_height, + config=config.decoder, + ) + + if decoder.output_height != config.input_height: + raise ValueError( + "Invalid configuration: Decoder output height " + f"({decoder.output_height}) must match the Backbone input height " + f"({config.input_height}). Check encoder/decoder layer " + "configurations and input/bottleneck heights." + ) + + return UNetBackbone( + input_height=config.input_height, + encoder=encoder, + decoder=decoder, + bottleneck=bottleneck, + ) + + +def build_backbone(config: BackboneConfig | None = None) -> BackboneModel: + config = config or UNetBackboneConfig() + return backbone_registry.build(config) def _pad_adjust( diff --git a/src/batdetect2/models/config.py b/src/batdetect2/models/config.py deleted file mode 100644 index 999b777..0000000 --- a/src/batdetect2/models/config.py +++ /dev/null @@ -1,96 +0,0 @@ -from soundevent import data - -from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.models.bottleneck import ( - DEFAULT_BOTTLENECK_CONFIG, - BottleneckConfig, -) -from batdetect2.models.decoder import ( - DEFAULT_DECODER_CONFIG, - DecoderConfig, -) -from batdetect2.models.encoder import ( - DEFAULT_ENCODER_CONFIG, - EncoderConfig, -) - -__all__ = [ - "BackboneConfig", - "load_backbone_config", -] - - -class BackboneConfig(BaseConfig): - """Configuration for the Encoder-Decoder Backbone network. - - Aggregates configurations for the encoder, bottleneck, and decoder - components, along with defining the input and final output dimensions - for the complete backbone. - - Attributes - ---------- - input_height : int, default=128 - Expected height (frequency bins) of the input spectrograms to the - backbone. Must be positive. - in_channels : int, default=1 - Expected number of channels in the input spectrograms (e.g., 1 for - mono). Must be positive. - encoder : EncoderConfig, optional - Configuration for the encoder. If None or omitted, - the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the - encoder module) will be used. - bottleneck : BottleneckConfig, optional - Configuration for the bottleneck layer connecting encoder and decoder. - If None or omitted, the default bottleneck configuration will be used. - decoder : DecoderConfig, optional - Configuration for the decoder. If None or omitted, - the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the - decoder module) will be used. - out_channels : int, default=32 - Desired number of channels in the final feature map output by the - backbone. Must be positive. - """ - - input_height: int = 128 - in_channels: int = 1 - encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG - bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG - decoder: DecoderConfig = DEFAULT_DECODER_CONFIG - out_channels: int = 32 - - -def load_backbone_config( - path: data.PathLike, - field: str | None = None, -) -> BackboneConfig: - """Load the backbone configuration from a file. - - Reads a configuration file (YAML) and validates it against the - `BackboneConfig` schema, potentially extracting data from a nested field. - - Parameters - ---------- - path : PathLike - Path to the configuration file. - field : str, optional - Dot-separated path to a nested section within the file containing the - backbone configuration (e.g., "model.backbone"). If None, the entire - file content is used. - - Returns - ------- - BackboneConfig - The loaded and validated backbone configuration object. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - yaml.YAMLError - If the file content is not valid YAML. - pydantic.ValidationError - If the loaded config data does not conform to `BackboneConfig`. - KeyError, TypeError - If `field` specifies an invalid path. - """ - return load_config(path, schema=BackboneConfig, field=field) diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index 37eeec6..5c22ba8 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -17,7 +17,11 @@ the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. import torch from loguru import logger -from batdetect2.models.backbones import BackboneConfig, build_backbone +from batdetect2.models.backbones import ( + BackboneConfig, + UNetBackboneConfig, + build_backbone, +) from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput @@ -152,7 +156,7 @@ def build_detector( construction of the backbone or detector components (e.g., incompatible configurations, invalid parameters). """ - config = config or BackboneConfig() + config = config or UNetBackboneConfig() logger.opt(lazy=True).debug( "Building model with config: \n{}", diff --git a/tests/test_models/test_detectors.py b/tests/test_models/test_detectors.py index 4b95f82..a2b4b69 100644 --- a/tests/test_models/test_detectors.py +++ b/tests/test_models/test_detectors.py @@ -2,7 +2,7 @@ import numpy as np import pytest import torch -from batdetect2.models.config import BackboneConfig +from batdetect2.models.backbones import BackboneConfig from batdetect2.models.detectors import Detector, build_detector from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.typing.models import ModelOutput