mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Create train config module
This commit is contained in:
parent
1338ae7431
commit
7689580a24
31
batdetect2/train/config.py
Normal file
31
batdetect2/train/config.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.losses import LossConfig
|
||||
|
||||
__all__ = [
|
||||
"OptimizerConfig",
|
||||
"TrainingConfig",
|
||||
"load_train_config",
|
||||
]
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
batch_size: int = 32
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TrainingConfig:
|
||||
return load_config(path, schema=TrainingConfig, field=field)
|
Loading…
Reference in New Issue
Block a user