Update build trainer

This commit is contained in:
mbsantiago 2025-06-26 17:43:56 -06:00
parent 0c8fae4a72
commit 15de168a20

View File

@ -96,14 +96,26 @@ def build_trainer(
conf: FullTrainingConfig, conf: FullTrainingConfig,
targets: TargetProtocol, targets: TargetProtocol,
) -> Trainer: ) -> Trainer:
logger = build_logger(conf.train.logger)
if logger and hasattr(logger, "log_hyperparams"):
logger.log_hyperparams(conf.model_dump(exclude_none=True))
return Trainer( return Trainer(
accelerator=conf.train.accelerator, accelerator=conf.train.accelerator,
logger=logger, accumulate_grad_batches=conf.train.accumulate_grad_batches,
deterministic=conf.train.deterministic,
check_val_every_n_epoch=conf.train.check_val_every_n_epoch,
devices=conf.train.devices,
enable_checkpointing=conf.train.enable_checkpointing,
gradient_clip_val=conf.train.gradient_clip_val,
limit_train_batches=conf.train.limit_train_batches,
limit_test_batches=conf.train.limit_test_batches,
limit_val_batches=conf.train.limit_val_batches,
log_every_n_steps=conf.train.log_every_n_steps,
max_epochs=conf.train.max_epochs,
min_epochs=conf.train.min_epochs,
max_steps=conf.train.max_steps,
min_steps=conf.train.min_steps,
max_time=conf.train.max_time,
precision=conf.train.precision,
val_check_interval=conf.train.val_check_interval,
logger=build_logger(conf.train.logger),
callbacks=build_trainer_callbacks(targets), callbacks=build_trainer_callbacks(targets),
) )