Create dataloader config

This commit is contained in:
mbsantiago 2025-07-29 12:10:51 +01:00
parent 62f0c5c397
commit 9816985bb1
3 changed files with 41 additions and 40 deletions

View File

@ -22,8 +22,8 @@ __all__ = ["train_command"]
@click.option("--model-path", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int, default=0)
@click.option("--val-workers", type=int, default=0)
@click.option("--train-workers", type=int)
@click.option("--val-workers", type=int)
@click.option(
"-v",
"--verbose",

View File

@ -45,10 +45,29 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None
class TrainingConfig(PLTrainerConfig):
batch_size: int = 8
class DataLoaderConfig(BaseConfig):
batch_size: int
shuffle: bool
num_workers: int = 0
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=False)
class LoadersConfig(BaseConfig):
train: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy()
)
val: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy()
)
class TrainingConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
loss: LossConfig = Field(default_factory=LossConfig)
augmentations: Optional[AugmentationsConfig] = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG

View File

@ -1,7 +1,6 @@
from collections.abc import Sequence
from typing import List, Optional
import yaml
from lightning import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger
@ -20,11 +19,7 @@ from batdetect2.targets import TargetProtocol
from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import (
FullTrainingConfig,
PLTrainerConfig,
TrainingConfig,
)
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
@ -48,8 +43,8 @@ def train(
val_examples: Optional[Sequence[data.PathLike]] = None,
config: Optional[FullTrainingConfig] = None,
model_path: Optional[data.PathLike] = None,
train_workers: int = 0,
val_workers: int = 0,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
):
conf = config or FullTrainingConfig()
@ -110,16 +105,13 @@ def build_trainer(
conf: FullTrainingConfig,
targets: TargetProtocol,
) -> Trainer:
trainer_conf = PLTrainerConfig.model_validate(
conf.train.model_dump(mode="python")
)
trainer_conf = conf.train.trainer
logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
val_check_interval=conf.train.val_check_interval,
logger=build_logger(conf.train.logger),
callbacks=build_trainer_callbacks(targets),
)
@ -137,22 +129,17 @@ def build_train_loader(
preprocessor=preprocessor,
config=config,
)
loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug(
"Training data loader config: \n{}",
lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": True,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
"Training data loader config: \n{config}",
config=loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=num_workers or 0,
batch_size=loader_conf.batch_size,
shuffle=loader_conf.shuffle,
num_workers=num_workers,
collate_fn=collate_fn,
)
@ -167,22 +154,17 @@ def build_val_loader(
val_examples,
config=config,
)
loader_conf = config.dataloaders.val
logger.opt(lazy=True).debug(
"Validation data loader config: \n{}",
lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": False,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
"Validation data loader config: \n{config}",
config=loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=num_workers or 0,
batch_size=loader_conf.batch_size,
shuffle=loader_conf.shuffle,
num_workers=num_workers,
collate_fn=collate_fn,
)