mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Create train config module
This commit is contained in:
parent
1338ae7431
commit
7689580a24
31
batdetect2/train/config.py
Normal file
31
batdetect2/train/config.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.train.losses import LossConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OptimizerConfig",
|
||||||
|
"TrainingConfig",
|
||||||
|
"load_train_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerConfig(BaseConfig):
|
||||||
|
learning_rate: float = 1e-3
|
||||||
|
t_max: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingConfig(BaseConfig):
|
||||||
|
batch_size: int = 32
|
||||||
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
|
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_train_config(
|
||||||
|
path: PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> TrainingConfig:
|
||||||
|
return load_config(path, schema=TrainingConfig, field=field)
|
Loading…
Reference in New Issue
Block a user