mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Expose model functions
This commit is contained in:
parent
bfa6049adc
commit
213b6dfd29
@ -1,7 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.backbones import (
|
||||
Net2DFast,
|
||||
Net2DFastNoAttn,
|
||||
@ -12,13 +14,15 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"get_backbone",
|
||||
"BBoxHead",
|
||||
"ClassifierHead",
|
||||
"ModelConfig",
|
||||
"ModelType",
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
"ModelType",
|
||||
"BBoxHead",
|
||||
"ClassifierHead",
|
||||
"build_architecture",
|
||||
"load_model_config",
|
||||
]
|
||||
|
||||
|
||||
@ -38,7 +42,13 @@ class ModelConfig(BaseConfig):
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def get_backbone(
|
||||
def load_model_config(
|
||||
path: PathLike, field: Optional[str] = None
|
||||
) -> ModelConfig:
|
||||
return load_config(path, schema=ModelConfig, field=field)
|
||||
|
||||
|
||||
def build_architecture(
|
||||
config: Optional[ModelConfig] = None,
|
||||
) -> BackboneModel:
|
||||
config = config or ModelConfig()
|
||||
|
@ -165,6 +165,7 @@ def pad_adjust(
|
||||
spec: torch.Tensor,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
print(spec.shape)
|
||||
h, w = spec.shape[2:]
|
||||
h_pad = -h % factor
|
||||
w_pad = -w % factor
|
||||
|
Loading…
Reference in New Issue
Block a user