mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Fix trainer init
This commit is contained in:
parent
16febed792
commit
ce15a0f152
@ -18,7 +18,11 @@ from batdetect2.targets import TargetProtocol
|
||||
from batdetect2.train.augmentations import build_augmentations
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
||||
from batdetect2.train.config import (
|
||||
FullTrainingConfig,
|
||||
PLTrainerConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
RandomExampleSource,
|
||||
@ -96,24 +100,12 @@ def build_trainer(
|
||||
conf: FullTrainingConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> Trainer:
|
||||
trainer_conf = PLTrainerConfig.model_validate(
|
||||
conf.train,
|
||||
from_attributes=True,
|
||||
)
|
||||
return Trainer(
|
||||
accelerator=conf.train.accelerator,
|
||||
accumulate_grad_batches=conf.train.accumulate_grad_batches,
|
||||
deterministic=conf.train.deterministic,
|
||||
check_val_every_n_epoch=conf.train.check_val_every_n_epoch,
|
||||
devices=conf.train.devices,
|
||||
enable_checkpointing=conf.train.enable_checkpointing,
|
||||
gradient_clip_val=conf.train.gradient_clip_val,
|
||||
limit_train_batches=conf.train.limit_train_batches,
|
||||
limit_test_batches=conf.train.limit_test_batches,
|
||||
limit_val_batches=conf.train.limit_val_batches,
|
||||
log_every_n_steps=conf.train.log_every_n_steps,
|
||||
max_epochs=conf.train.max_epochs,
|
||||
min_epochs=conf.train.min_epochs,
|
||||
max_steps=conf.train.max_steps,
|
||||
min_steps=conf.train.min_steps,
|
||||
max_time=conf.train.max_time,
|
||||
precision=conf.train.precision,
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
val_check_interval=conf.train.val_check_interval,
|
||||
logger=build_logger(conf.train.logger),
|
||||
callbacks=build_trainer_callbacks(targets),
|
||||
|
Loading…
Reference in New Issue
Block a user