Add backbone registry

This commit is contained in:
mbsantiago 2026-03-08 15:02:56 +00:00
parent e393709258
commit 652d076b46
6 changed files with 92 additions and 170 deletions

View File

@ -12,7 +12,7 @@ from batdetect2.evaluate.config import (
get_default_eval_config, get_default_eval_config,
) )
from batdetect2.inference.config import InferenceConfig from batdetect2.inference.config import InferenceConfig
from batdetect2.models.config import BackboneConfig from batdetect2.models.backbones import BackboneConfig
from batdetect2.postprocess.config import PostprocessConfig from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig from batdetect2.targets.config import TargetConfig

View File

@ -30,7 +30,13 @@ from typing import List
import torch import torch
from batdetect2.models.backbones import Backbone, build_backbone from batdetect2.models.backbones import (
UNetBackbone,
UNetBackboneConfig,
BackboneConfig,
build_backbone,
load_backbone_config,
)
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
ConvConfig, ConvConfig,
FreqCoordConvDownConfig, FreqCoordConvDownConfig,
@ -43,7 +49,6 @@ from batdetect2.models.bottleneck import (
BottleneckConfig, BottleneckConfig,
build_bottleneck, build_bottleneck,
) )
from batdetect2.models.config import BackboneConfig, load_backbone_config
from batdetect2.models.decoder import ( from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG, DEFAULT_DECODER_CONFIG,
DecoderConfig, DecoderConfig,
@ -66,7 +71,7 @@ from batdetect2.typing import (
__all__ = [ __all__ = [
"BBoxHead", "BBoxHead",
"Backbone", "UNetBackbone",
"BackboneConfig", "BackboneConfig",
"Bottleneck", "Bottleneck",
"BottleneckConfig", "BottleneckConfig",
@ -128,7 +133,7 @@ def build_model(
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
config = config or BackboneConfig() config = config or UNetBackboneConfig()
targets = targets or build_targets() targets = targets or build_targets()
preprocessor = preprocessor or build_preprocessor() preprocessor = preprocessor or build_preprocessor()
postprocessor = postprocessor or build_postprocessor( postprocessor = postprocessor or build_postprocessor(

View File

@ -18,16 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the
network's total downsampling factor. network's total downsampling factor.
""" """
from typing import Tuple
from typing import Annotated, Literal, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from pydantic import Field
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry
from soundevent import data
from batdetect2.models.bottleneck import BottleneckConfig, DEFAULT_BOTTLENECK_CONFIG, build_bottleneck
from batdetect2.models.decoder import DecoderConfig, DEFAULT_DECODER_CONFIG, build_decoder
from batdetect2.models.encoder import EncoderConfig, DEFAULT_ENCODER_CONFIG, build_encoder
from batdetect2.models.bottleneck import build_bottleneck
from batdetect2.models.config import BackboneConfig
from batdetect2.models.decoder import build_decoder
from batdetect2.models.encoder import build_encoder
from batdetect2.typing.models import ( from batdetect2.typing.models import (
BackboneModel, BackboneModel,
BottleneckProtocol, BottleneckProtocol,
@ -35,13 +39,38 @@ from batdetect2.typing.models import (
EncoderProtocol, EncoderProtocol,
) )
class UNetBackboneConfig(BaseConfig):
name: Literal["UNetBackbone"] = "UNetBackbone"
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
BackboneConfig = Annotated[
UNetBackboneConfig,
Field(discriminator="name")
]
def load_backbone_config(
path: data.PathLike,
field: str | None = None,
) -> BackboneConfig:
return load_config(path, schema=BackboneConfig, field=field)
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
__all__ = [ __all__ = [
"Backbone", "UNetBackbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone", "build_backbone",
] ]
class Backbone(BackboneModel): class UNetBackbone(BackboneModel):
"""Encoder-Decoder Backbone Network Implementation. """Encoder-Decoder Backbone Network Implementation.
Combines an Encoder, Bottleneck, and Decoder module sequentially, using Combines an Encoder, Bottleneck, and Decoder module sequentially, using
@ -149,66 +178,46 @@ class Backbone(BackboneModel):
return x return x
def build_backbone(config: BackboneConfig) -> BackboneModel: @backbone_registry.register(UNetBackboneConfig)
"""Factory function to build a Backbone from configuration. @staticmethod
def from_config(config: UNetBackboneConfig) -> BackboneModel:
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on encoder = build_encoder(
the provided `BackboneConfig`, validates their compatibility, and assembles in_channels=config.in_channels,
them into a `Backbone` instance. input_height=config.input_height,
config=config.encoder,
Parameters
----------
config : BackboneConfig
The configuration object detailing the backbone architecture, including
input dimensions and configurations for encoder, bottleneck, and
decoder.
Returns
-------
BackboneModel
An initialized `Backbone` module ready for use.
Raises
------
ValueError
If sub-component configurations are incompatible
(e.g., channel mismatches, decoder output height doesn't match backbone
input height).
NotImplementedError
If an unknown block type is specified in sub-configs.
"""
encoder = build_encoder(
in_channels=config.in_channels,
input_height=config.input_height,
config=config.encoder,
)
bottleneck = build_bottleneck(
input_height=encoder.output_height,
in_channels=encoder.out_channels,
config=config.bottleneck,
)
decoder = build_decoder(
in_channels=bottleneck.out_channels,
input_height=encoder.output_height,
config=config.decoder,
)
if decoder.output_height != config.input_height:
raise ValueError(
"Invalid configuration: Decoder output height "
f"({decoder.output_height}) must match the Backbone input height "
f"({config.input_height}). Check encoder/decoder layer "
"configurations and input/bottleneck heights."
) )
return Backbone( bottleneck = build_bottleneck(
input_height=config.input_height, input_height=encoder.output_height,
encoder=encoder, in_channels=encoder.out_channels,
decoder=decoder, config=config.bottleneck,
bottleneck=bottleneck, )
)
decoder = build_decoder(
in_channels=bottleneck.out_channels,
input_height=encoder.output_height,
config=config.decoder,
)
if decoder.output_height != config.input_height:
raise ValueError(
"Invalid configuration: Decoder output height "
f"({decoder.output_height}) must match the Backbone input height "
f"({config.input_height}). Check encoder/decoder layer "
"configurations and input/bottleneck heights."
)
return UNetBackbone(
input_height=config.input_height,
encoder=encoder,
decoder=decoder,
bottleneck=bottleneck,
)
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
config = config or UNetBackboneConfig()
return backbone_registry.build(config)
def _pad_adjust( def _pad_adjust(

View File

@ -1,96 +0,0 @@
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
)
__all__ = [
"BackboneConfig",
"load_backbone_config",
]
class BackboneConfig(BaseConfig):
"""Configuration for the Encoder-Decoder Backbone network.
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: str | None = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)

View File

@ -17,7 +17,11 @@ the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
import torch import torch
from loguru import logger from loguru import logger
from batdetect2.models.backbones import BackboneConfig, build_backbone from batdetect2.models.backbones import (
BackboneConfig,
UNetBackboneConfig,
build_backbone,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
@ -152,7 +156,7 @@ def build_detector(
construction of the backbone or detector components (e.g., incompatible construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters). configurations, invalid parameters).
""" """
config = config or BackboneConfig() config = config or UNetBackboneConfig()
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building model with config: \n{}", "Building model with config: \n{}",

View File

@ -2,7 +2,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from batdetect2.models.config import BackboneConfig from batdetect2.models.backbones import BackboneConfig
from batdetect2.models.detectors import Detector, build_detector from batdetect2.models.detectors import Detector, build_detector
from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.typing.models import ModelOutput from batdetect2.typing.models import ModelOutput