diff --git a/example_data/config.yaml b/example_data/config.yaml index 90e1e42..e3cd7ff 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -109,7 +109,7 @@ train: sigma: 3 trainer: - max_epochs: 40 + max_epochs: 5 dataloaders: train: diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 864dfda..2911db3 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -27,6 +27,7 @@ __all__ = ["train_command"] @click.option("--train-workers", type=int) @click.option("--val-workers", type=int) @click.option("--experiment-name", type=str) +@click.option("--seed", type=int) @click.option( "-v", "--verbose", @@ -41,6 +42,7 @@ def train_command( log_dir: Optional[Path] = None, config: Optional[Path] = None, config_field: Optional[str] = None, + seed: Optional[int] = None, train_workers: int = 0, val_workers: int = 0, experiment_name: Optional[str] = None, @@ -92,4 +94,5 @@ def train_command( experiment_name=experiment_name, log_dir=log_dir, checkpoint_dir=ckpt_dir, + seed=seed, ) diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index da576f6..517acb7 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -9,7 +9,7 @@ from soundevent import data from batdetect2.configs import BaseConfig -DEFAULT_LOGS_DIR: str = "outputs" +DEFAULT_LOGS_DIR: str = "outputs/logs" class DVCLiveConfig(BaseConfig): diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index b4b3061..24c54de 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from typing import List, Optional import torch -from lightning import Trainer +from lightning import Trainer, seed_everything from lightning.pytorch.callbacks import Callback, ModelCheckpoint from loguru import logger from soundevent import data @@ -56,7 +56,11 @@ def train( checkpoint_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None, experiment_name: Optional[str] = None, + seed: Optional[int] = None, ): + if seed is not None: + seed_everything(seed) + config = config or FullTrainingConfig() targets = build_targets(config.targets)