mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Expose model functions
This commit is contained in:
parent
bfa6049adc
commit
213b6dfd29
@ -1,7 +1,9 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Net2DFast,
|
Net2DFast,
|
||||||
Net2DFastNoAttn,
|
Net2DFastNoAttn,
|
||||||
@ -12,13 +14,15 @@ from batdetect2.models.heads import BBoxHead, ClassifierHead
|
|||||||
from batdetect2.models.typing import BackboneModel
|
from batdetect2.models.typing import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_backbone",
|
"BBoxHead",
|
||||||
|
"ClassifierHead",
|
||||||
|
"ModelConfig",
|
||||||
|
"ModelType",
|
||||||
"Net2DFast",
|
"Net2DFast",
|
||||||
"Net2DFastNoAttn",
|
"Net2DFastNoAttn",
|
||||||
"Net2DFastNoCoordConv",
|
"Net2DFastNoCoordConv",
|
||||||
"ModelType",
|
"build_architecture",
|
||||||
"BBoxHead",
|
"load_model_config",
|
||||||
"ClassifierHead",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +42,13 @@ class ModelConfig(BaseConfig):
|
|||||||
out_channels: int = 32
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
def get_backbone(
|
def load_model_config(
|
||||||
|
path: PathLike, field: Optional[str] = None
|
||||||
|
) -> ModelConfig:
|
||||||
|
return load_config(path, schema=ModelConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def build_architecture(
|
||||||
config: Optional[ModelConfig] = None,
|
config: Optional[ModelConfig] = None,
|
||||||
) -> BackboneModel:
|
) -> BackboneModel:
|
||||||
config = config or ModelConfig()
|
config = config or ModelConfig()
|
||||||
|
@ -165,6 +165,7 @@ def pad_adjust(
|
|||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
factor: int = 32,
|
factor: int = 32,
|
||||||
) -> Tuple[torch.Tensor, int, int]:
|
) -> Tuple[torch.Tensor, int, int]:
|
||||||
|
print(spec.shape)
|
||||||
h, w = spec.shape[2:]
|
h, w = spec.shape[2:]
|
||||||
h_pad = -h % factor
|
h_pad = -h % factor
|
||||||
w_pad = -w % factor
|
w_pad = -w % factor
|
||||||
|
Loading…
Reference in New Issue
Block a user