mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add backbone registry
This commit is contained in:
parent
e393709258
commit
652d076b46
@ -12,7 +12,7 @@ from batdetect2.evaluate.config import (
|
|||||||
get_default_eval_config,
|
get_default_eval_config,
|
||||||
)
|
)
|
||||||
from batdetect2.inference.config import InferenceConfig
|
from batdetect2.inference.config import InferenceConfig
|
||||||
from batdetect2.models.config import BackboneConfig
|
from batdetect2.models.backbones import BackboneConfig
|
||||||
from batdetect2.postprocess.config import PostprocessConfig
|
from batdetect2.postprocess.config import PostprocessConfig
|
||||||
from batdetect2.preprocess.config import PreprocessingConfig
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
from batdetect2.targets.config import TargetConfig
|
from batdetect2.targets.config import TargetConfig
|
||||||
|
|||||||
@ -30,7 +30,13 @@ from typing import List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.models.backbones import Backbone, build_backbone
|
from batdetect2.models.backbones import (
|
||||||
|
UNetBackbone,
|
||||||
|
UNetBackboneConfig,
|
||||||
|
BackboneConfig,
|
||||||
|
build_backbone,
|
||||||
|
load_backbone_config,
|
||||||
|
)
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
@ -43,7 +49,6 @@ from batdetect2.models.bottleneck import (
|
|||||||
BottleneckConfig,
|
BottleneckConfig,
|
||||||
build_bottleneck,
|
build_bottleneck,
|
||||||
)
|
)
|
||||||
from batdetect2.models.config import BackboneConfig, load_backbone_config
|
|
||||||
from batdetect2.models.decoder import (
|
from batdetect2.models.decoder import (
|
||||||
DEFAULT_DECODER_CONFIG,
|
DEFAULT_DECODER_CONFIG,
|
||||||
DecoderConfig,
|
DecoderConfig,
|
||||||
@ -66,7 +71,7 @@ from batdetect2.typing import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxHead",
|
"BBoxHead",
|
||||||
"Backbone",
|
"UNetBackbone",
|
||||||
"BackboneConfig",
|
"BackboneConfig",
|
||||||
"Bottleneck",
|
"Bottleneck",
|
||||||
"BottleneckConfig",
|
"BottleneckConfig",
|
||||||
@ -128,7 +133,7 @@ def build_model(
|
|||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
|
|
||||||
config = config or BackboneConfig()
|
config = config or UNetBackboneConfig()
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
preprocessor = preprocessor or build_preprocessor()
|
preprocessor = preprocessor or build_preprocessor()
|
||||||
postprocessor = postprocessor or build_postprocessor(
|
postprocessor = postprocessor or build_postprocessor(
|
||||||
|
|||||||
@ -18,16 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the
|
|||||||
network's total downsampling factor.
|
network's total downsampling factor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
from typing import Annotated, Literal, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
|
from soundevent import data
|
||||||
|
from batdetect2.models.bottleneck import BottleneckConfig, DEFAULT_BOTTLENECK_CONFIG, build_bottleneck
|
||||||
|
from batdetect2.models.decoder import DecoderConfig, DEFAULT_DECODER_CONFIG, build_decoder
|
||||||
|
from batdetect2.models.encoder import EncoderConfig, DEFAULT_ENCODER_CONFIG, build_encoder
|
||||||
|
|
||||||
from batdetect2.models.bottleneck import build_bottleneck
|
|
||||||
from batdetect2.models.config import BackboneConfig
|
|
||||||
from batdetect2.models.decoder import build_decoder
|
|
||||||
from batdetect2.models.encoder import build_encoder
|
|
||||||
from batdetect2.typing.models import (
|
from batdetect2.typing.models import (
|
||||||
BackboneModel,
|
BackboneModel,
|
||||||
BottleneckProtocol,
|
BottleneckProtocol,
|
||||||
@ -35,13 +39,38 @@ from batdetect2.typing.models import (
|
|||||||
EncoderProtocol,
|
EncoderProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UNetBackboneConfig(BaseConfig):
|
||||||
|
name: Literal["UNetBackbone"] = "UNetBackbone"
|
||||||
|
input_height: int = 128
|
||||||
|
in_channels: int = 1
|
||||||
|
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||||
|
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||||
|
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||||
|
out_channels: int = 32
|
||||||
|
|
||||||
|
BackboneConfig = Annotated[
|
||||||
|
UNetBackboneConfig,
|
||||||
|
Field(discriminator="name")
|
||||||
|
]
|
||||||
|
|
||||||
|
def load_backbone_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: str | None = None,
|
||||||
|
) -> BackboneConfig:
|
||||||
|
return load_config(path, schema=BackboneConfig, field=field)
|
||||||
|
|
||||||
|
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Backbone",
|
"UNetBackbone",
|
||||||
|
"BackboneConfig",
|
||||||
|
"load_backbone_config",
|
||||||
"build_backbone",
|
"build_backbone",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Backbone(BackboneModel):
|
class UNetBackbone(BackboneModel):
|
||||||
"""Encoder-Decoder Backbone Network Implementation.
|
"""Encoder-Decoder Backbone Network Implementation.
|
||||||
|
|
||||||
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
|
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
|
||||||
@ -149,66 +178,46 @@ class Backbone(BackboneModel):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
@backbone_registry.register(UNetBackboneConfig)
|
||||||
"""Factory function to build a Backbone from configuration.
|
@staticmethod
|
||||||
|
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
||||||
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on
|
encoder = build_encoder(
|
||||||
the provided `BackboneConfig`, validates their compatibility, and assembles
|
in_channels=config.in_channels,
|
||||||
them into a `Backbone` instance.
|
input_height=config.input_height,
|
||||||
|
config=config.encoder,
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : BackboneConfig
|
|
||||||
The configuration object detailing the backbone architecture, including
|
|
||||||
input dimensions and configurations for encoder, bottleneck, and
|
|
||||||
decoder.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
BackboneModel
|
|
||||||
An initialized `Backbone` module ready for use.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If sub-component configurations are incompatible
|
|
||||||
(e.g., channel mismatches, decoder output height doesn't match backbone
|
|
||||||
input height).
|
|
||||||
NotImplementedError
|
|
||||||
If an unknown block type is specified in sub-configs.
|
|
||||||
"""
|
|
||||||
encoder = build_encoder(
|
|
||||||
in_channels=config.in_channels,
|
|
||||||
input_height=config.input_height,
|
|
||||||
config=config.encoder,
|
|
||||||
)
|
|
||||||
|
|
||||||
bottleneck = build_bottleneck(
|
|
||||||
input_height=encoder.output_height,
|
|
||||||
in_channels=encoder.out_channels,
|
|
||||||
config=config.bottleneck,
|
|
||||||
)
|
|
||||||
|
|
||||||
decoder = build_decoder(
|
|
||||||
in_channels=bottleneck.out_channels,
|
|
||||||
input_height=encoder.output_height,
|
|
||||||
config=config.decoder,
|
|
||||||
)
|
|
||||||
|
|
||||||
if decoder.output_height != config.input_height:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid configuration: Decoder output height "
|
|
||||||
f"({decoder.output_height}) must match the Backbone input height "
|
|
||||||
f"({config.input_height}). Check encoder/decoder layer "
|
|
||||||
"configurations and input/bottleneck heights."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return Backbone(
|
bottleneck = build_bottleneck(
|
||||||
input_height=config.input_height,
|
input_height=encoder.output_height,
|
||||||
encoder=encoder,
|
in_channels=encoder.out_channels,
|
||||||
decoder=decoder,
|
config=config.bottleneck,
|
||||||
bottleneck=bottleneck,
|
)
|
||||||
)
|
|
||||||
|
decoder = build_decoder(
|
||||||
|
in_channels=bottleneck.out_channels,
|
||||||
|
input_height=encoder.output_height,
|
||||||
|
config=config.decoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
if decoder.output_height != config.input_height:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid configuration: Decoder output height "
|
||||||
|
f"({decoder.output_height}) must match the Backbone input height "
|
||||||
|
f"({config.input_height}). Check encoder/decoder layer "
|
||||||
|
"configurations and input/bottleneck heights."
|
||||||
|
)
|
||||||
|
|
||||||
|
return UNetBackbone(
|
||||||
|
input_height=config.input_height,
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
bottleneck=bottleneck,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||||
|
config = config or UNetBackboneConfig()
|
||||||
|
return backbone_registry.build(config)
|
||||||
|
|
||||||
|
|
||||||
def _pad_adjust(
|
def _pad_adjust(
|
||||||
|
|||||||
@ -1,96 +0,0 @@
|
|||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.models.bottleneck import (
|
|
||||||
DEFAULT_BOTTLENECK_CONFIG,
|
|
||||||
BottleneckConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.models.decoder import (
|
|
||||||
DEFAULT_DECODER_CONFIG,
|
|
||||||
DecoderConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.models.encoder import (
|
|
||||||
DEFAULT_ENCODER_CONFIG,
|
|
||||||
EncoderConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BackboneConfig",
|
|
||||||
"load_backbone_config",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class BackboneConfig(BaseConfig):
|
|
||||||
"""Configuration for the Encoder-Decoder Backbone network.
|
|
||||||
|
|
||||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
|
||||||
components, along with defining the input and final output dimensions
|
|
||||||
for the complete backbone.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
input_height : int, default=128
|
|
||||||
Expected height (frequency bins) of the input spectrograms to the
|
|
||||||
backbone. Must be positive.
|
|
||||||
in_channels : int, default=1
|
|
||||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
|
||||||
mono). Must be positive.
|
|
||||||
encoder : EncoderConfig, optional
|
|
||||||
Configuration for the encoder. If None or omitted,
|
|
||||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
|
||||||
encoder module) will be used.
|
|
||||||
bottleneck : BottleneckConfig, optional
|
|
||||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
|
||||||
If None or omitted, the default bottleneck configuration will be used.
|
|
||||||
decoder : DecoderConfig, optional
|
|
||||||
Configuration for the decoder. If None or omitted,
|
|
||||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
|
||||||
decoder module) will be used.
|
|
||||||
out_channels : int, default=32
|
|
||||||
Desired number of channels in the final feature map output by the
|
|
||||||
backbone. Must be positive.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_height: int = 128
|
|
||||||
in_channels: int = 1
|
|
||||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
|
||||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
|
||||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
|
||||||
out_channels: int = 32
|
|
||||||
|
|
||||||
|
|
||||||
def load_backbone_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: str | None = None,
|
|
||||||
) -> BackboneConfig:
|
|
||||||
"""Load the backbone configuration from a file.
|
|
||||||
|
|
||||||
Reads a configuration file (YAML) and validates it against the
|
|
||||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the configuration file.
|
|
||||||
field : str, optional
|
|
||||||
Dot-separated path to a nested section within the file containing the
|
|
||||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
|
||||||
file content is used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
BackboneConfig
|
|
||||||
The loaded and validated backbone configuration object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
yaml.YAMLError
|
|
||||||
If the file content is not valid YAML.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the loaded config data does not conform to `BackboneConfig`.
|
|
||||||
KeyError, TypeError
|
|
||||||
If `field` specifies an invalid path.
|
|
||||||
"""
|
|
||||||
return load_config(path, schema=BackboneConfig, field=field)
|
|
||||||
@ -17,7 +17,11 @@ the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
|||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
from batdetect2.models.backbones import (
|
||||||
|
BackboneConfig,
|
||||||
|
UNetBackboneConfig,
|
||||||
|
build_backbone,
|
||||||
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||||
|
|
||||||
@ -152,7 +156,7 @@ def build_detector(
|
|||||||
construction of the backbone or detector components (e.g., incompatible
|
construction of the backbone or detector components (e.g., incompatible
|
||||||
configurations, invalid parameters).
|
configurations, invalid parameters).
|
||||||
"""
|
"""
|
||||||
config = config or BackboneConfig()
|
config = config or UNetBackboneConfig()
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building model with config: \n{}",
|
"Building model with config: \n{}",
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.models.config import BackboneConfig
|
from batdetect2.models.backbones import BackboneConfig
|
||||||
from batdetect2.models.detectors import Detector, build_detector
|
from batdetect2.models.detectors import Detector, build_detector
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.typing.models import ModelOutput
|
from batdetect2.typing.models import ModelOutput
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user