Expose model functions

This commit is contained in:
mbsantiago 2025-04-03 16:50:20 +01:00
parent bfa6049adc
commit 213b6dfd29
2 changed files with 17 additions and 6 deletions

View File

@ -1,7 +1,9 @@
from enum import Enum
from typing import Optional, Tuple
from batdetect2.configs import BaseConfig
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.backbones import (
Net2DFast,
Net2DFastNoAttn,
@ -12,13 +14,15 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.typing import BackboneModel
__all__ = [
"get_backbone",
"BBoxHead",
"ClassifierHead",
"ModelConfig",
"ModelType",
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
"ModelType",
"BBoxHead",
"ClassifierHead",
"build_architecture",
"load_model_config",
]
@ -38,7 +42,13 @@ class ModelConfig(BaseConfig):
out_channels: int = 32
def get_backbone(
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)
def build_architecture(
config: Optional[ModelConfig] = None,
) -> BackboneModel:
config = config or ModelConfig()

View File

@ -165,6 +165,7 @@ def pad_adjust(
spec: torch.Tensor,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
print(spec.shape)
h, w = spec.shape[2:]
h_pad = -h % factor
w_pad = -w % factor