mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Add seed option to train
This commit is contained in:
parent
3376be06a4
commit
951dc59718
@ -109,7 +109,7 @@ train:
|
|||||||
sigma: 3
|
sigma: 3
|
||||||
|
|
||||||
trainer:
|
trainer:
|
||||||
max_epochs: 40
|
max_epochs: 5
|
||||||
|
|
||||||
dataloaders:
|
dataloaders:
|
||||||
train:
|
train:
|
||||||
|
|||||||
@ -27,6 +27,7 @@ __all__ = ["train_command"]
|
|||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int)
|
||||||
@click.option("--val-workers", type=int)
|
@click.option("--val-workers", type=int)
|
||||||
@click.option("--experiment-name", type=str)
|
@click.option("--experiment-name", type=str)
|
||||||
|
@click.option("--seed", type=int)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-v",
|
"-v",
|
||||||
"--verbose",
|
"--verbose",
|
||||||
@ -41,6 +42,7 @@ def train_command(
|
|||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
@ -92,4 +94,5 @@ def train_command(
|
|||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
checkpoint_dir=ckpt_dir,
|
checkpoint_dir=ckpt_dir,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: str = "outputs"
|
DEFAULT_LOGS_DIR: str = "outputs/logs"
|
||||||
|
|
||||||
|
|
||||||
class DVCLiveConfig(BaseConfig):
|
class DVCLiveConfig(BaseConfig):
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from collections.abc import Sequence
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lightning import Trainer
|
from lightning import Trainer, seed_everything
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -56,7 +56,11 @@ def train(
|
|||||||
checkpoint_dir: Optional[data.PathLike] = None,
|
checkpoint_dir: Optional[data.PathLike] = None,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[data.PathLike] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
if seed is not None:
|
||||||
|
seed_everything(seed)
|
||||||
|
|
||||||
config = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
targets = build_targets(config.targets)
|
targets = build_targets(config.targets)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user