diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 7b339b7..f101d93 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -96,14 +96,26 @@ def build_trainer( conf: FullTrainingConfig, targets: TargetProtocol, ) -> Trainer: - logger = build_logger(conf.train.logger) - - if logger and hasattr(logger, "log_hyperparams"): - logger.log_hyperparams(conf.model_dump(exclude_none=True)) - return Trainer( accelerator=conf.train.accelerator, - logger=logger, + accumulate_grad_batches=conf.train.accumulate_grad_batches, + deterministic=conf.train.deterministic, + check_val_every_n_epoch=conf.train.check_val_every_n_epoch, + devices=conf.train.devices, + enable_checkpointing=conf.train.enable_checkpointing, + gradient_clip_val=conf.train.gradient_clip_val, + limit_train_batches=conf.train.limit_train_batches, + limit_test_batches=conf.train.limit_test_batches, + limit_val_batches=conf.train.limit_val_batches, + log_every_n_steps=conf.train.log_every_n_steps, + max_epochs=conf.train.max_epochs, + min_epochs=conf.train.min_epochs, + max_steps=conf.train.max_steps, + min_steps=conf.train.min_steps, + max_time=conf.train.max_time, + precision=conf.train.precision, + val_check_interval=conf.train.val_check_interval, + logger=build_logger(conf.train.logger), callbacks=build_trainer_callbacks(targets), )