Create train config module

This commit is contained in:
mbsantiago 2025-04-03 16:49:47 +01:00
parent 1338ae7431
commit 7689580a24

View File

@ -0,0 +1,31 @@
from typing import Optional
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.losses import LossConfig
__all__ = [
"OptimizerConfig",
"TrainingConfig",
"load_train_config",
]
class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
class TrainingConfig(BaseConfig):
batch_size: int = 32
loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
def load_train_config(
path: PathLike,
field: Optional[str] = None,
) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field)