mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add backbone registry
This commit is contained in:
parent
e393709258
commit
652d076b46
@ -12,7 +12,7 @@ from batdetect2.evaluate.config import (
|
||||
get_default_eval_config,
|
||||
)
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.models.backbones import BackboneConfig
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
|
||||
@ -30,7 +30,13 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.models.backbones import Backbone, build_backbone
|
||||
from batdetect2.models.backbones import (
|
||||
UNetBackbone,
|
||||
UNetBackboneConfig,
|
||||
BackboneConfig,
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
@ -43,7 +49,6 @@ from batdetect2.models.bottleneck import (
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.config import BackboneConfig, load_backbone_config
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
@ -66,7 +71,7 @@ from batdetect2.typing import (
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
"Backbone",
|
||||
"UNetBackbone",
|
||||
"BackboneConfig",
|
||||
"Bottleneck",
|
||||
"BottleneckConfig",
|
||||
@ -128,7 +133,7 @@ def build_model(
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
|
||||
config = config or BackboneConfig()
|
||||
config = config or UNetBackboneConfig()
|
||||
targets = targets or build_targets()
|
||||
preprocessor = preprocessor or build_preprocessor()
|
||||
postprocessor = postprocessor or build_postprocessor(
|
||||
|
||||
@ -18,16 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the
|
||||
network's total downsampling factor.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from typing import Annotated, Literal, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import Registry
|
||||
from soundevent import data
|
||||
from batdetect2.models.bottleneck import BottleneckConfig, DEFAULT_BOTTLENECK_CONFIG, build_bottleneck
|
||||
from batdetect2.models.decoder import DecoderConfig, DEFAULT_DECODER_CONFIG, build_decoder
|
||||
from batdetect2.models.encoder import EncoderConfig, DEFAULT_ENCODER_CONFIG, build_encoder
|
||||
|
||||
from batdetect2.models.bottleneck import build_bottleneck
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.models.decoder import build_decoder
|
||||
from batdetect2.models.encoder import build_encoder
|
||||
from batdetect2.typing.models import (
|
||||
BackboneModel,
|
||||
BottleneckProtocol,
|
||||
@ -35,13 +39,38 @@ from batdetect2.typing.models import (
|
||||
EncoderProtocol,
|
||||
)
|
||||
|
||||
|
||||
class UNetBackboneConfig(BaseConfig):
|
||||
name: Literal["UNetBackbone"] = "UNetBackbone"
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
BackboneConfig = Annotated[
|
||||
UNetBackboneConfig,
|
||||
Field(discriminator="name")
|
||||
]
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: str | None = None,
|
||||
) -> BackboneConfig:
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
|
||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
||||
|
||||
__all__ = [
|
||||
"Backbone",
|
||||
"UNetBackbone",
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
"build_backbone",
|
||||
]
|
||||
|
||||
|
||||
class Backbone(BackboneModel):
|
||||
class UNetBackbone(BackboneModel):
|
||||
"""Encoder-Decoder Backbone Network Implementation.
|
||||
|
||||
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
|
||||
@ -149,66 +178,46 @@ class Backbone(BackboneModel):
|
||||
return x
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||
"""Factory function to build a Backbone from configuration.
|
||||
|
||||
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on
|
||||
the provided `BackboneConfig`, validates their compatibility, and assembles
|
||||
them into a `Backbone` instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : BackboneConfig
|
||||
The configuration object detailing the backbone architecture, including
|
||||
input dimensions and configurations for encoder, bottleneck, and
|
||||
decoder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneModel
|
||||
An initialized `Backbone` module ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sub-component configurations are incompatible
|
||||
(e.g., channel mismatches, decoder output height doesn't match backbone
|
||||
input height).
|
||||
NotImplementedError
|
||||
If an unknown block type is specified in sub-configs.
|
||||
"""
|
||||
encoder = build_encoder(
|
||||
in_channels=config.in_channels,
|
||||
input_height=config.input_height,
|
||||
config=config.encoder,
|
||||
)
|
||||
|
||||
bottleneck = build_bottleneck(
|
||||
input_height=encoder.output_height,
|
||||
in_channels=encoder.out_channels,
|
||||
config=config.bottleneck,
|
||||
)
|
||||
|
||||
decoder = build_decoder(
|
||||
in_channels=bottleneck.out_channels,
|
||||
input_height=encoder.output_height,
|
||||
config=config.decoder,
|
||||
)
|
||||
|
||||
if decoder.output_height != config.input_height:
|
||||
raise ValueError(
|
||||
"Invalid configuration: Decoder output height "
|
||||
f"({decoder.output_height}) must match the Backbone input height "
|
||||
f"({config.input_height}). Check encoder/decoder layer "
|
||||
"configurations and input/bottleneck heights."
|
||||
@backbone_registry.register(UNetBackboneConfig)
|
||||
@staticmethod
|
||||
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
||||
encoder = build_encoder(
|
||||
in_channels=config.in_channels,
|
||||
input_height=config.input_height,
|
||||
config=config.encoder,
|
||||
)
|
||||
|
||||
return Backbone(
|
||||
input_height=config.input_height,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
bottleneck=bottleneck,
|
||||
)
|
||||
bottleneck = build_bottleneck(
|
||||
input_height=encoder.output_height,
|
||||
in_channels=encoder.out_channels,
|
||||
config=config.bottleneck,
|
||||
)
|
||||
|
||||
decoder = build_decoder(
|
||||
in_channels=bottleneck.out_channels,
|
||||
input_height=encoder.output_height,
|
||||
config=config.decoder,
|
||||
)
|
||||
|
||||
if decoder.output_height != config.input_height:
|
||||
raise ValueError(
|
||||
"Invalid configuration: Decoder output height "
|
||||
f"({decoder.output_height}) must match the Backbone input height "
|
||||
f"({config.input_height}). Check encoder/decoder layer "
|
||||
"configurations and input/bottleneck heights."
|
||||
)
|
||||
|
||||
return UNetBackbone(
|
||||
input_height=config.input_height,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
bottleneck=bottleneck,
|
||||
)
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||
config = config or UNetBackboneConfig()
|
||||
return backbone_registry.build(config)
|
||||
|
||||
|
||||
def _pad_adjust(
|
||||
|
||||
@ -1,96 +0,0 @@
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
EncoderConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
]
|
||||
|
||||
|
||||
class BackboneConfig(BaseConfig):
|
||||
"""Configuration for the Encoder-Decoder Backbone network.
|
||||
|
||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||
components, along with defining the input and final output dimensions
|
||||
for the complete backbone.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int, default=128
|
||||
Expected height (frequency bins) of the input spectrograms to the
|
||||
backbone. Must be positive.
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||
mono). Must be positive.
|
||||
encoder : EncoderConfig, optional
|
||||
Configuration for the encoder. If None or omitted,
|
||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||
encoder module) will be used.
|
||||
bottleneck : BottleneckConfig, optional
|
||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||
If None or omitted, the default bottleneck configuration will be used.
|
||||
decoder : DecoderConfig, optional
|
||||
Configuration for the decoder. If None or omitted,
|
||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||
decoder module) will be used.
|
||||
out_channels : int, default=32
|
||||
Desired number of channels in the final feature map output by the
|
||||
backbone. Must be positive.
|
||||
"""
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: str | None = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load the backbone configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||
file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
The loaded and validated backbone configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to `BackboneConfig`.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
@ -17,7 +17,11 @@ the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
||||
from batdetect2.models.backbones import (
|
||||
BackboneConfig,
|
||||
UNetBackboneConfig,
|
||||
build_backbone,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
|
||||
@ -152,7 +156,7 @@ def build_detector(
|
||||
construction of the backbone or detector components (e.g., incompatible
|
||||
configurations, invalid parameters).
|
||||
"""
|
||||
config = config or BackboneConfig()
|
||||
config = config or UNetBackboneConfig()
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
|
||||
@ -2,7 +2,7 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.models.backbones import BackboneConfig
|
||||
from batdetect2.models.detectors import Detector, build_detector
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
|
||||
Loading…
Reference in New Issue
Block a user