Expose model functions

This commit is contained in:
mbsantiago 2025-04-03 16:50:20 +01:00
parent bfa6049adc
commit 213b6dfd29
2 changed files with 17 additions and 6 deletions

View File

@ -1,7 +1,9 @@
from enum import Enum from enum import Enum
from typing import Optional, Tuple from typing import Optional, Tuple
from batdetect2.configs import BaseConfig from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.backbones import ( from batdetect2.models.backbones import (
Net2DFast, Net2DFast,
Net2DFastNoAttn, Net2DFastNoAttn,
@ -12,13 +14,15 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.typing import BackboneModel from batdetect2.models.typing import BackboneModel
__all__ = [ __all__ = [
"get_backbone", "BBoxHead",
"ClassifierHead",
"ModelConfig",
"ModelType",
"Net2DFast", "Net2DFast",
"Net2DFastNoAttn", "Net2DFastNoAttn",
"Net2DFastNoCoordConv", "Net2DFastNoCoordConv",
"ModelType", "build_architecture",
"BBoxHead", "load_model_config",
"ClassifierHead",
] ]
@ -38,7 +42,13 @@ class ModelConfig(BaseConfig):
out_channels: int = 32 out_channels: int = 32
def get_backbone( def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)
def build_architecture(
config: Optional[ModelConfig] = None, config: Optional[ModelConfig] = None,
) -> BackboneModel: ) -> BackboneModel:
config = config or ModelConfig() config = config or ModelConfig()

View File

@ -165,6 +165,7 @@ def pad_adjust(
spec: torch.Tensor, spec: torch.Tensor,
factor: int = 32, factor: int = 32,
) -> Tuple[torch.Tensor, int, int]: ) -> Tuple[torch.Tensor, int, int]:
print(spec.shape)
h, w = spec.shape[2:] h, w = spec.shape[2:]
h_pad = -h % factor h_pad = -h % factor
w_pad = -w % factor w_pad = -w % factor