Add targets to train cli

This commit is contained in:
mbsantiago 2025-09-09 15:45:00 +01:00
parent 115084fd2b
commit 16a0fa7b75

View File

@ -7,6 +7,7 @@ from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.train import (
FullTrainingConfig,
load_full_training_config,
@ -20,6 +21,7 @@ __all__ = ["train_command"]
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True))
@click.option("--targets", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@ -42,6 +44,7 @@ def train_command(
ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
config: Optional[Path] = None,
targets: Optional[Path] = None,
config_field: Optional[str] = None,
seed: Optional[int] = None,
train_workers: int = 0,
@ -62,12 +65,18 @@ def train_command(
logger.info("Initiating training process...")
logger.info("Loading training configuration...")
conf = (
load_full_training_config(config, field=config_field)
if config is not None
else FullTrainingConfig()
)
if targets is not None:
logger.info("Loading targets configuration...")
targets_config = load_target_config(targets)
conf = conf.model_copy(update=dict(targets=targets_config))
logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset)
logger.debug(