Run formatter

This commit is contained in:
mbsantiago 2026-03-08 15:18:21 +00:00
parent 652d076b46
commit f2d5088bec

View File

@ -18,20 +18,31 @@ automatic padding to handle input sizes not perfectly divisible by the
network's total downsampling factor. network's total downsampling factor.
""" """
from typing import Annotated, Literal, Tuple from typing import Annotated, Literal, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from example import Union
from pydantic import Field from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry from batdetect2.core.registries import Registry
from soundevent import data from batdetect2.models.bottleneck import (
from batdetect2.models.bottleneck import BottleneckConfig, DEFAULT_BOTTLENECK_CONFIG, build_bottleneck DEFAULT_BOTTLENECK_CONFIG,
from batdetect2.models.decoder import DecoderConfig, DEFAULT_DECODER_CONFIG, build_decoder BottleneckConfig,
from batdetect2.models.encoder import EncoderConfig, DEFAULT_ENCODER_CONFIG, build_encoder build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
build_decoder,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
build_encoder,
)
from batdetect2.typing.models import ( from batdetect2.typing.models import (
BackboneModel, BackboneModel,
BottleneckProtocol, BottleneckProtocol,
@ -49,16 +60,6 @@ class UNetBackboneConfig(BaseConfig):
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32 out_channels: int = 32
BackboneConfig = Annotated[
UNetBackboneConfig,
Field(discriminator="name")
]
def load_backbone_config(
path: data.PathLike,
field: str | None = None,
) -> BackboneConfig:
return load_config(path, schema=BackboneConfig, field=field)
backbone_registry: Registry[BackboneModel, []] = Registry("backbone") backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
@ -177,7 +178,6 @@ class UNetBackbone(BackboneModel):
return x return x
@backbone_registry.register(UNetBackboneConfig) @backbone_registry.register(UNetBackboneConfig)
@staticmethod @staticmethod
def from_config(config: UNetBackboneConfig) -> BackboneModel: def from_config(config: UNetBackboneConfig) -> BackboneModel:
@ -215,6 +215,11 @@ class UNetBackbone(BackboneModel):
) )
BackboneConfig = Annotated[
Union[UNetBackboneConfig,], Field(discriminator="name")
]
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel: def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
config = config or UNetBackboneConfig() config = config or UNetBackboneConfig()
return backbone_registry.build(config) return backbone_registry.build(config)
@ -283,3 +288,10 @@ def _restore_pad(
x = x[..., :-w_pad] x = x[..., :-w_pad]
return x return x
def load_backbone_config(
path: data.PathLike,
field: str | None = None,
) -> BackboneConfig:
return load_config(path, schema=BackboneConfig, field=field)