From ce15a0f1525bf6402c206a7bb2505136497a3008 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 26 Jun 2025 19:35:23 -0600 Subject: [PATCH] Fix trainer init --- src/batdetect2/train/train.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index f101d93..c2410a0 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -18,7 +18,11 @@ from batdetect2.targets import TargetProtocol from batdetect2.train.augmentations import build_augmentations from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.clips import build_clipper -from batdetect2.train.config import FullTrainingConfig, TrainingConfig +from batdetect2.train.config import ( + FullTrainingConfig, + PLTrainerConfig, + TrainingConfig, +) from batdetect2.train.dataset import ( LabeledDataset, RandomExampleSource, @@ -96,24 +100,12 @@ def build_trainer( conf: FullTrainingConfig, targets: TargetProtocol, ) -> Trainer: + trainer_conf = PLTrainerConfig.model_validate( + conf.train, + from_attributes=True, + ) return Trainer( - accelerator=conf.train.accelerator, - accumulate_grad_batches=conf.train.accumulate_grad_batches, - deterministic=conf.train.deterministic, - check_val_every_n_epoch=conf.train.check_val_every_n_epoch, - devices=conf.train.devices, - enable_checkpointing=conf.train.enable_checkpointing, - gradient_clip_val=conf.train.gradient_clip_val, - limit_train_batches=conf.train.limit_train_batches, - limit_test_batches=conf.train.limit_test_batches, - limit_val_batches=conf.train.limit_val_batches, - log_every_n_steps=conf.train.log_every_n_steps, - max_epochs=conf.train.max_epochs, - min_epochs=conf.train.min_epochs, - max_steps=conf.train.max_steps, - min_steps=conf.train.min_steps, - max_time=conf.train.max_time, - precision=conf.train.precision, + **trainer_conf.model_dump(exclude_none=True), val_check_interval=conf.train.val_check_interval, logger=build_logger(conf.train.logger), callbacks=build_trainer_callbacks(targets),