diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 8a6c8ef..e169e3e 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -24,6 +24,14 @@ __all__ = ["train_command"] "training starts from a fresh model config." ), ) +@click.option( + "--base-dir", + type=click.Path(exists=True), + help=( + "Base directory used to resolve relative paths inside the training " + "and validation dataset configs." + ), +) @click.option( "--targets", "targets_config", @@ -111,6 +119,7 @@ def train_command( model_path: Path | None = None, ckpt_dir: Path | None = None, log_dir: Path | None = None, + base_dir: Path | None = None, targets_config: Path | None = None, model_config: Path | None = None, training_config: Path | None = None, @@ -191,7 +200,10 @@ def train_command( model_conf = model_conf.model_copy(update={"targets": target_conf}) logger.info("Loading training dataset...") - train_annotations = load_dataset_from_config(train_dataset) + train_annotations = load_dataset_from_config( + train_dataset, + base_dir=base_dir, + ) logger.debug( "Loaded {num_annotations} training examples", num_annotations=len(train_annotations), @@ -199,7 +211,10 @@ def train_command( val_annotations = None if val_dataset is not None: - val_annotations = load_dataset_from_config(val_dataset) + val_annotations = load_dataset_from_config( + val_dataset, + base_dir=base_dir, + ) logger.debug( "Loaded {num_annotations} validation examples", num_annotations=len(val_annotations),