mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Add seed option to train
This commit is contained in:
parent
3376be06a4
commit
951dc59718
@ -109,7 +109,7 @@ train:
|
||||
sigma: 3
|
||||
|
||||
trainer:
|
||||
max_epochs: 40
|
||||
max_epochs: 5
|
||||
|
||||
dataloaders:
|
||||
train:
|
||||
|
||||
@ -27,6 +27,7 @@ __all__ = ["train_command"]
|
||||
@click.option("--train-workers", type=int)
|
||||
@click.option("--val-workers", type=int)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--seed", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
@ -41,6 +42,7 @@ def train_command(
|
||||
log_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
experiment_name: Optional[str] = None,
|
||||
@ -92,4 +94,5 @@ def train_command(
|
||||
experiment_name=experiment_name,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
DEFAULT_LOGS_DIR: str = "outputs"
|
||||
DEFAULT_LOGS_DIR: str = "outputs/logs"
|
||||
|
||||
|
||||
class DVCLiveConfig(BaseConfig):
|
||||
|
||||
@ -2,7 +2,7 @@ from collections.abc import Sequence
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from lightning import Trainer
|
||||
from lightning import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
@ -56,7 +56,11 @@ def train(
|
||||
checkpoint_dir: Optional[data.PathLike] = None,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
config = config or FullTrainingConfig()
|
||||
|
||||
targets = build_targets(config.targets)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user