diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index b35b2a4..98b8808 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -152,6 +152,12 @@ def build_trainer_callbacks( if checkpoint_dir is None: checkpoint_dir = DEFAULT_CHECKPOINT_DIR + if experiment_name is not None: + checkpoint_dir = checkpoint_dir / experiment_name + + if run_name is not None: + checkpoint_dir = checkpoint_dir / run_name + filename = "best-{epoch:02d}-{val_loss:.0f}" if run_name is not None: