diff --git a/batdetect2/models/__init__.py b/batdetect2/models/__init__.py index 71caad1..5a88a09 100644 --- a/batdetect2/models/__init__.py +++ b/batdetect2/models/__init__.py @@ -1,7 +1,9 @@ from enum import Enum from typing import Optional, Tuple -from batdetect2.configs import BaseConfig +from soundevent.data import PathLike + +from batdetect2.configs import BaseConfig, load_config from batdetect2.models.backbones import ( Net2DFast, Net2DFastNoAttn, @@ -12,13 +14,15 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.models.typing import BackboneModel __all__ = [ - "get_backbone", + "BBoxHead", + "ClassifierHead", + "ModelConfig", + "ModelType", "Net2DFast", "Net2DFastNoAttn", "Net2DFastNoCoordConv", - "ModelType", - "BBoxHead", - "ClassifierHead", + "build_architecture", + "load_model_config", ] @@ -38,7 +42,13 @@ class ModelConfig(BaseConfig): out_channels: int = 32 -def get_backbone( +def load_model_config( + path: PathLike, field: Optional[str] = None +) -> ModelConfig: + return load_config(path, schema=ModelConfig, field=field) + + +def build_architecture( config: Optional[ModelConfig] = None, ) -> BackboneModel: config = config or ModelConfig() diff --git a/batdetect2/models/backbones.py b/batdetect2/models/backbones.py index 5083016..2e53f3b 100644 --- a/batdetect2/models/backbones.py +++ b/batdetect2/models/backbones.py @@ -165,6 +165,7 @@ def pad_adjust( spec: torch.Tensor, factor: int = 32, ) -> Tuple[torch.Tensor, int, int]: + print(spec.shape) h, w = spec.shape[2:] h_pad = -h % factor w_pad = -w % factor