diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 8d5836a..af2c1b6 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -7,6 +7,7 @@ from loguru import logger from batdetect2.cli.base import cli from batdetect2.data import load_dataset_from_config +from batdetect2.targets import load_target_config from batdetect2.train import ( FullTrainingConfig, load_full_training_config, @@ -20,6 +21,7 @@ __all__ = ["train_command"] @click.argument("train_dataset", type=click.Path(exists=True)) @click.option("--val-dataset", type=click.Path(exists=True)) @click.option("--model-path", type=click.Path(exists=True)) +@click.option("--targets", type=click.Path(exists=True)) @click.option("--ckpt-dir", type=click.Path(exists=True)) @click.option("--log-dir", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True)) @@ -42,6 +44,7 @@ def train_command( ckpt_dir: Optional[Path] = None, log_dir: Optional[Path] = None, config: Optional[Path] = None, + targets: Optional[Path] = None, config_field: Optional[str] = None, seed: Optional[int] = None, train_workers: int = 0, @@ -62,12 +65,18 @@ def train_command( logger.info("Initiating training process...") logger.info("Loading training configuration...") + conf = ( load_full_training_config(config, field=config_field) if config is not None else FullTrainingConfig() ) + if targets is not None: + logger.info("Loading targets configuration...") + targets_config = load_target_config(targets) + conf = conf.model_copy(update=dict(targets=targets_config)) + logger.info("Loading training dataset...") train_annotations = load_dataset_from_config(train_dataset) logger.debug(