diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index a9d91d9..83d8060 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -215,18 +215,15 @@ def build_trainer( run_name=run_name, ) + if num_epochs is not None: + trainer_conf.max_epochs = num_epochs + train_logger.log_hyperparams( - config.model_dump( - mode="json", - exclude_none=True, - ) + config.model_dump(mode="json", exclude_none=True) ) train_config = trainer_conf.model_dump(exclude_none=True) - if num_epochs is not None: - train_config["max_epochs"] = num_epochs - return Trainer( **train_config, logger=train_logger,