mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Update build trainer
This commit is contained in:
parent
0c8fae4a72
commit
15de168a20
@ -96,14 +96,26 @@ def build_trainer(
|
|||||||
conf: FullTrainingConfig,
|
conf: FullTrainingConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
logger = build_logger(conf.train.logger)
|
|
||||||
|
|
||||||
if logger and hasattr(logger, "log_hyperparams"):
|
|
||||||
logger.log_hyperparams(conf.model_dump(exclude_none=True))
|
|
||||||
|
|
||||||
return Trainer(
|
return Trainer(
|
||||||
accelerator=conf.train.accelerator,
|
accelerator=conf.train.accelerator,
|
||||||
logger=logger,
|
accumulate_grad_batches=conf.train.accumulate_grad_batches,
|
||||||
|
deterministic=conf.train.deterministic,
|
||||||
|
check_val_every_n_epoch=conf.train.check_val_every_n_epoch,
|
||||||
|
devices=conf.train.devices,
|
||||||
|
enable_checkpointing=conf.train.enable_checkpointing,
|
||||||
|
gradient_clip_val=conf.train.gradient_clip_val,
|
||||||
|
limit_train_batches=conf.train.limit_train_batches,
|
||||||
|
limit_test_batches=conf.train.limit_test_batches,
|
||||||
|
limit_val_batches=conf.train.limit_val_batches,
|
||||||
|
log_every_n_steps=conf.train.log_every_n_steps,
|
||||||
|
max_epochs=conf.train.max_epochs,
|
||||||
|
min_epochs=conf.train.min_epochs,
|
||||||
|
max_steps=conf.train.max_steps,
|
||||||
|
min_steps=conf.train.min_steps,
|
||||||
|
max_time=conf.train.max_time,
|
||||||
|
precision=conf.train.precision,
|
||||||
|
val_check_interval=conf.train.val_check_interval,
|
||||||
|
logger=build_logger(conf.train.logger),
|
||||||
callbacks=build_trainer_callbacks(targets),
|
callbacks=build_trainer_callbacks(targets),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user