mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Run formatter
This commit is contained in:
parent
652d076b46
commit
f2d5088bec
@ -18,20 +18,31 @@ automatic padding to handle input sizes not perfectly divisible by the
|
||||
network's total downsampling factor.
|
||||
"""
|
||||
|
||||
|
||||
from typing import Annotated, Literal, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from example import Union
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import Registry
|
||||
from soundevent import data
|
||||
from batdetect2.models.bottleneck import BottleneckConfig, DEFAULT_BOTTLENECK_CONFIG, build_bottleneck
|
||||
from batdetect2.models.decoder import DecoderConfig, DEFAULT_DECODER_CONFIG, build_decoder
|
||||
from batdetect2.models.encoder import EncoderConfig, DEFAULT_ENCODER_CONFIG, build_encoder
|
||||
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
build_decoder,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.typing.models import (
|
||||
BackboneModel,
|
||||
BottleneckProtocol,
|
||||
@ -49,16 +60,6 @@ class UNetBackboneConfig(BaseConfig):
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
BackboneConfig = Annotated[
|
||||
UNetBackboneConfig,
|
||||
Field(discriminator="name")
|
||||
]
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: str | None = None,
|
||||
) -> BackboneConfig:
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
|
||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
||||
|
||||
@ -177,7 +178,6 @@ class UNetBackbone(BackboneModel):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@backbone_registry.register(UNetBackboneConfig)
|
||||
@staticmethod
|
||||
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
||||
@ -215,6 +215,11 @@ class UNetBackbone(BackboneModel):
|
||||
)
|
||||
|
||||
|
||||
BackboneConfig = Annotated[
|
||||
Union[UNetBackboneConfig,], Field(discriminator="name")
|
||||
]
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||
config = config or UNetBackboneConfig()
|
||||
return backbone_registry.build(config)
|
||||
@ -283,3 +288,10 @@ def _restore_pad(
|
||||
x = x[..., :-w_pad]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: str | None = None,
|
||||
) -> BackboneConfig:
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user