From 9816985bb17ddff3cfd082a0f2dfc4f92b8023b8 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 29 Jul 2025 12:10:51 +0100 Subject: [PATCH] Create dataloader config --- src/batdetect2/cli/train.py | 4 +-- src/batdetect2/train/config.py | 23 +++++++++++++-- src/batdetect2/train/train.py | 54 ++++++++++++---------------------- 3 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index c735576..3dcc96b 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -22,8 +22,8 @@ __all__ = ["train_command"] @click.option("--model-path", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True)) @click.option("--config-field", type=str) -@click.option("--train-workers", type=int, default=0) -@click.option("--val-workers", type=int, default=0) +@click.option("--train-workers", type=int) +@click.option("--val-workers", type=int) @click.option( "-v", "--verbose", diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index c6c17af..0a013aa 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -45,10 +45,29 @@ class PLTrainerConfig(BaseConfig): val_check_interval: Optional[Union[int, float]] = None -class TrainingConfig(PLTrainerConfig): - batch_size: int = 8 +class DataLoaderConfig(BaseConfig): + batch_size: int + shuffle: bool + num_workers: int = 0 + + +DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True) +DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=False) + + +class LoadersConfig(BaseConfig): + train: DataLoaderConfig = Field( + default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy() + ) + val: DataLoaderConfig = Field( + default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy() + ) + + +class TrainingConfig(BaseConfig): learning_rate: float = 1e-3 t_max: int = 100 + dataloaders: LoadersConfig = Field(default_factory=LoadersConfig) loss: LossConfig = Field(default_factory=LossConfig) augmentations: Optional[AugmentationsConfig] = Field( default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 18ad393..caaa768 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from typing import List, Optional -import yaml from lightning import Trainer from lightning.pytorch.callbacks import Callback, ModelCheckpoint from loguru import logger @@ -20,11 +19,7 @@ from batdetect2.targets import TargetProtocol from batdetect2.train.augmentations import build_augmentations from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.clips import build_clipper -from batdetect2.train.config import ( - FullTrainingConfig, - PLTrainerConfig, - TrainingConfig, -) +from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.dataset import ( LabeledDataset, RandomExampleSource, @@ -48,8 +43,8 @@ def train( val_examples: Optional[Sequence[data.PathLike]] = None, config: Optional[FullTrainingConfig] = None, model_path: Optional[data.PathLike] = None, - train_workers: int = 0, - val_workers: int = 0, + train_workers: Optional[int] = None, + val_workers: Optional[int] = None, ): conf = config or FullTrainingConfig() @@ -110,16 +105,13 @@ def build_trainer( conf: FullTrainingConfig, targets: TargetProtocol, ) -> Trainer: - trainer_conf = PLTrainerConfig.model_validate( - conf.train.model_dump(mode="python") - ) + trainer_conf = conf.train.trainer logger.opt(lazy=True).debug( "Building trainer with config: \n{config}", config=lambda: trainer_conf.to_yaml_string(exclude_none=True), ) return Trainer( **trainer_conf.model_dump(exclude_none=True), - val_check_interval=conf.train.val_check_interval, logger=build_logger(conf.train.logger), callbacks=build_trainer_callbacks(targets), ) @@ -137,22 +129,17 @@ def build_train_loader( preprocessor=preprocessor, config=config, ) + loader_conf = config.dataloaders.train logger.opt(lazy=True).debug( - "Training data loader config: \n{}", - lambda: yaml.dump( - { - "batch_size": config.batch_size, - "shuffle": True, - "num_workers": num_workers or 0, - "collate_fn": str(collate_fn), - } - ), + "Training data loader config: \n{config}", + config=loader_conf.to_yaml_string(exclude_none=True), ) + num_workers = num_workers or loader_conf.num_workers return DataLoader( train_dataset, - batch_size=config.batch_size, - shuffle=True, - num_workers=num_workers or 0, + batch_size=loader_conf.batch_size, + shuffle=loader_conf.shuffle, + num_workers=num_workers, collate_fn=collate_fn, ) @@ -167,22 +154,17 @@ def build_val_loader( val_examples, config=config, ) + loader_conf = config.dataloaders.val logger.opt(lazy=True).debug( - "Validation data loader config: \n{}", - lambda: yaml.dump( - { - "batch_size": config.batch_size, - "shuffle": False, - "num_workers": num_workers or 0, - "collate_fn": str(collate_fn), - } - ), + "Validation data loader config: \n{config}", + config=loader_conf.to_yaml_string(exclude_none=True), ) + num_workers = num_workers or loader_conf.num_workers return DataLoader( val_dataset, - batch_size=config.batch_size, - shuffle=False, - num_workers=num_workers or 0, + batch_size=loader_conf.batch_size, + shuffle=loader_conf.shuffle, + num_workers=num_workers, collate_fn=collate_fn, )