mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Run formatter
This commit is contained in:
parent
652d076b46
commit
f2d5088bec
@ -18,20 +18,31 @@ automatic padding to handle input sizes not perfectly divisible by the
|
|||||||
network's total downsampling factor.
|
network's total downsampling factor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
from typing import Annotated, Literal, Tuple
|
from typing import Annotated, Literal, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from example import Union
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.core.registries import Registry
|
from batdetect2.core.registries import Registry
|
||||||
from soundevent import data
|
from batdetect2.models.bottleneck import (
|
||||||
from batdetect2.models.bottleneck import BottleneckConfig, DEFAULT_BOTTLENECK_CONFIG, build_bottleneck
|
DEFAULT_BOTTLENECK_CONFIG,
|
||||||
from batdetect2.models.decoder import DecoderConfig, DEFAULT_DECODER_CONFIG, build_decoder
|
BottleneckConfig,
|
||||||
from batdetect2.models.encoder import EncoderConfig, DEFAULT_ENCODER_CONFIG, build_encoder
|
build_bottleneck,
|
||||||
|
)
|
||||||
|
from batdetect2.models.decoder import (
|
||||||
|
DEFAULT_DECODER_CONFIG,
|
||||||
|
DecoderConfig,
|
||||||
|
build_decoder,
|
||||||
|
)
|
||||||
|
from batdetect2.models.encoder import (
|
||||||
|
DEFAULT_ENCODER_CONFIG,
|
||||||
|
EncoderConfig,
|
||||||
|
build_encoder,
|
||||||
|
)
|
||||||
from batdetect2.typing.models import (
|
from batdetect2.typing.models import (
|
||||||
BackboneModel,
|
BackboneModel,
|
||||||
BottleneckProtocol,
|
BottleneckProtocol,
|
||||||
@ -49,16 +60,6 @@ class UNetBackboneConfig(BaseConfig):
|
|||||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||||
out_channels: int = 32
|
out_channels: int = 32
|
||||||
|
|
||||||
BackboneConfig = Annotated[
|
|
||||||
UNetBackboneConfig,
|
|
||||||
Field(discriminator="name")
|
|
||||||
]
|
|
||||||
|
|
||||||
def load_backbone_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: str | None = None,
|
|
||||||
) -> BackboneConfig:
|
|
||||||
return load_config(path, schema=BackboneConfig, field=field)
|
|
||||||
|
|
||||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
||||||
|
|
||||||
@ -177,7 +178,6 @@ class UNetBackbone(BackboneModel):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@backbone_registry.register(UNetBackboneConfig)
|
@backbone_registry.register(UNetBackboneConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
||||||
@ -215,6 +215,11 @@ class UNetBackbone(BackboneModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BackboneConfig = Annotated[
|
||||||
|
Union[UNetBackboneConfig,], Field(discriminator="name")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||||
config = config or UNetBackboneConfig()
|
config = config or UNetBackboneConfig()
|
||||||
return backbone_registry.build(config)
|
return backbone_registry.build(config)
|
||||||
@ -283,3 +288,10 @@ def _restore_pad(
|
|||||||
x = x[..., :-w_pad]
|
x = x[..., :-w_pad]
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def load_backbone_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: str | None = None,
|
||||||
|
) -> BackboneConfig:
|
||||||
|
return load_config(path, schema=BackboneConfig, field=field)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user