diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index bde9f64..31feee0 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -18,20 +18,31 @@ automatic padding to handle input sizes not perfectly divisible by the network's total downsampling factor. """ - from typing import Annotated, Literal, Tuple import torch import torch.nn.functional as F +from example import Union from pydantic import Field +from soundevent import data from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.registries import Registry -from soundevent import data -from batdetect2.models.bottleneck import BottleneckConfig, DEFAULT_BOTTLENECK_CONFIG, build_bottleneck -from batdetect2.models.decoder import DecoderConfig, DEFAULT_DECODER_CONFIG, build_decoder -from batdetect2.models.encoder import EncoderConfig, DEFAULT_ENCODER_CONFIG, build_encoder - +from batdetect2.models.bottleneck import ( + DEFAULT_BOTTLENECK_CONFIG, + BottleneckConfig, + build_bottleneck, +) +from batdetect2.models.decoder import ( + DEFAULT_DECODER_CONFIG, + DecoderConfig, + build_decoder, +) +from batdetect2.models.encoder import ( + DEFAULT_ENCODER_CONFIG, + EncoderConfig, + build_encoder, +) from batdetect2.typing.models import ( BackboneModel, BottleneckProtocol, @@ -49,16 +60,6 @@ class UNetBackboneConfig(BaseConfig): decoder: DecoderConfig = DEFAULT_DECODER_CONFIG out_channels: int = 32 -BackboneConfig = Annotated[ - UNetBackboneConfig, - Field(discriminator="name") -] - -def load_backbone_config( - path: data.PathLike, - field: str | None = None, -) -> BackboneConfig: - return load_config(path, schema=BackboneConfig, field=field) backbone_registry: Registry[BackboneModel, []] = Registry("backbone") @@ -177,7 +178,6 @@ class UNetBackbone(BackboneModel): return x - @backbone_registry.register(UNetBackboneConfig) @staticmethod def from_config(config: UNetBackboneConfig) -> BackboneModel: @@ -215,6 +215,11 @@ class UNetBackbone(BackboneModel): ) +BackboneConfig = Annotated[ + Union[UNetBackboneConfig,], Field(discriminator="name") +] + + def build_backbone(config: BackboneConfig | None = None) -> BackboneModel: config = config or UNetBackboneConfig() return backbone_registry.build(config) @@ -283,3 +288,10 @@ def _restore_pad( x = x[..., :-w_pad] return x + + +def load_backbone_config( + path: data.PathLike, + field: str | None = None, +) -> BackboneConfig: + return load_config(path, schema=BackboneConfig, field=field)