Add seed option to train

This commit is contained in:
mbsantiago 2025-09-09 13:23:56 +01:00
parent 3376be06a4
commit 951dc59718
4 changed files with 10 additions and 3 deletions

View File

@ -109,7 +109,7 @@ train:
sigma: 3 sigma: 3
trainer: trainer:
max_epochs: 40 max_epochs: 5
dataloaders: dataloaders:
train: train:

View File

@ -27,6 +27,7 @@ __all__ = ["train_command"]
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int)
@click.option("--val-workers", type=int) @click.option("--val-workers", type=int)
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@click.option("--seed", type=int)
@click.option( @click.option(
"-v", "-v",
"--verbose", "--verbose",
@ -41,6 +42,7 @@ def train_command(
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
config: Optional[Path] = None, config: Optional[Path] = None,
config_field: Optional[str] = None, config_field: Optional[str] = None,
seed: Optional[int] = None,
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
@ -92,4 +94,5 @@ def train_command(
experiment_name=experiment_name, experiment_name=experiment_name,
log_dir=log_dir, log_dir=log_dir,
checkpoint_dir=ckpt_dir, checkpoint_dir=ckpt_dir,
seed=seed,
) )

View File

@ -9,7 +9,7 @@ from soundevent import data
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
DEFAULT_LOGS_DIR: str = "outputs" DEFAULT_LOGS_DIR: str = "outputs/logs"
class DVCLiveConfig(BaseConfig): class DVCLiveConfig(BaseConfig):

View File

@ -2,7 +2,7 @@ from collections.abc import Sequence
from typing import List, Optional from typing import List, Optional
import torch import torch
from lightning import Trainer from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
@ -56,7 +56,11 @@ def train(
checkpoint_dir: Optional[data.PathLike] = None, checkpoint_dir: Optional[data.PathLike] = None,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
seed: Optional[int] = None,
): ):
if seed is not None:
seed_everything(seed)
config = config or FullTrainingConfig() config = config or FullTrainingConfig()
targets = build_targets(config.targets) targets = build_targets(config.targets)