Run formatter

This commit is contained in:
mbsantiago 2026-03-08 15:18:21 +00:00
parent 652d076b46
commit f2d5088bec

View File

@ -18,20 +18,31 @@ automatic padding to handle input sizes not perfectly divisible by the
network's total downsampling factor.
"""
from typing import Annotated, Literal, Tuple
import torch
import torch.nn.functional as F
from example import Union
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry
from soundevent import data
from batdetect2.models.bottleneck import BottleneckConfig, DEFAULT_BOTTLENECK_CONFIG, build_bottleneck
from batdetect2.models.decoder import DecoderConfig, DEFAULT_DECODER_CONFIG, build_decoder
from batdetect2.models.encoder import EncoderConfig, DEFAULT_ENCODER_CONFIG, build_encoder
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
build_decoder,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
build_encoder,
)
from batdetect2.typing.models import (
BackboneModel,
BottleneckProtocol,
@ -49,16 +60,6 @@ class UNetBackboneConfig(BaseConfig):
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
BackboneConfig = Annotated[
UNetBackboneConfig,
Field(discriminator="name")
]
def load_backbone_config(
path: data.PathLike,
field: str | None = None,
) -> BackboneConfig:
return load_config(path, schema=BackboneConfig, field=field)
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
@ -177,7 +178,6 @@ class UNetBackbone(BackboneModel):
return x
@backbone_registry.register(UNetBackboneConfig)
@staticmethod
def from_config(config: UNetBackboneConfig) -> BackboneModel:
@ -215,6 +215,11 @@ class UNetBackbone(BackboneModel):
)
BackboneConfig = Annotated[
Union[UNetBackboneConfig,], Field(discriminator="name")
]
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
config = config or UNetBackboneConfig()
return backbone_registry.build(config)
@ -283,3 +288,10 @@ def _restore_pad(
x = x[..., :-w_pad]
return x
def load_backbone_config(
path: data.PathLike,
field: str | None = None,
) -> BackboneConfig:
return load_config(path, schema=BackboneConfig, field=field)