Fix trainer init

This commit is contained in:
mbsantiago 2025-06-26 19:35:23 -06:00
parent 16febed792
commit ce15a0f152

View File

@ -18,7 +18,11 @@ from batdetect2.targets import TargetProtocol
from batdetect2.train.augmentations import build_augmentations from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.config import (
FullTrainingConfig,
PLTrainerConfig,
TrainingConfig,
)
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource, RandomExampleSource,
@ -96,24 +100,12 @@ def build_trainer(
conf: FullTrainingConfig, conf: FullTrainingConfig,
targets: TargetProtocol, targets: TargetProtocol,
) -> Trainer: ) -> Trainer:
trainer_conf = PLTrainerConfig.model_validate(
conf.train,
from_attributes=True,
)
return Trainer( return Trainer(
accelerator=conf.train.accelerator, **trainer_conf.model_dump(exclude_none=True),
accumulate_grad_batches=conf.train.accumulate_grad_batches,
deterministic=conf.train.deterministic,
check_val_every_n_epoch=conf.train.check_val_every_n_epoch,
devices=conf.train.devices,
enable_checkpointing=conf.train.enable_checkpointing,
gradient_clip_val=conf.train.gradient_clip_val,
limit_train_batches=conf.train.limit_train_batches,
limit_test_batches=conf.train.limit_test_batches,
limit_val_batches=conf.train.limit_val_batches,
log_every_n_steps=conf.train.log_every_n_steps,
max_epochs=conf.train.max_epochs,
min_epochs=conf.train.min_epochs,
max_steps=conf.train.max_steps,
min_steps=conf.train.min_steps,
max_time=conf.train.max_time,
precision=conf.train.precision,
val_check_interval=conf.train.val_check_interval, val_check_interval=conf.train.val_check_interval,
logger=build_logger(conf.train.logger), logger=build_logger(conf.train.logger),
callbacks=build_trainer_callbacks(targets), callbacks=build_trainer_callbacks(targets),