From 7689580a240623918a1ac1ef23f887e3d083eaea Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 3 Apr 2025 16:49:47 +0100 Subject: [PATCH] Create train config module --- batdetect2/train/config.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 batdetect2/train/config.py diff --git a/batdetect2/train/config.py b/batdetect2/train/config.py new file mode 100644 index 0000000..5663611 --- /dev/null +++ b/batdetect2/train/config.py @@ -0,0 +1,31 @@ +from typing import Optional + +from pydantic import Field +from soundevent.data import PathLike + +from batdetect2.configs import BaseConfig, load_config +from batdetect2.train.losses import LossConfig + +__all__ = [ + "OptimizerConfig", + "TrainingConfig", + "load_train_config", +] + + +class OptimizerConfig(BaseConfig): + learning_rate: float = 1e-3 + t_max: int = 100 + + +class TrainingConfig(BaseConfig): + batch_size: int = 32 + loss: LossConfig = Field(default_factory=LossConfig) + optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) + + +def load_train_config( + path: PathLike, + field: Optional[str] = None, +) -> TrainingConfig: + return load_config(path, schema=TrainingConfig, field=field)