Add targets to train cli

This commit is contained in:
mbsantiago 2025-09-09 15:45:00 +01:00
parent 115084fd2b
commit 16a0fa7b75

View File

@ -7,6 +7,7 @@ from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.train import ( from batdetect2.train import (
FullTrainingConfig, FullTrainingConfig,
load_full_training_config, load_full_training_config,
@ -20,6 +21,7 @@ __all__ = ["train_command"]
@click.argument("train_dataset", type=click.Path(exists=True)) @click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dataset", type=click.Path(exists=True)) @click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True)) @click.option("--model-path", type=click.Path(exists=True))
@click.option("--targets", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True)) @click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True)) @click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True))
@ -42,6 +44,7 @@ def train_command(
ckpt_dir: Optional[Path] = None, ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
config: Optional[Path] = None, config: Optional[Path] = None,
targets: Optional[Path] = None,
config_field: Optional[str] = None, config_field: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
train_workers: int = 0, train_workers: int = 0,
@ -62,12 +65,18 @@ def train_command(
logger.info("Initiating training process...") logger.info("Initiating training process...")
logger.info("Loading training configuration...") logger.info("Loading training configuration...")
conf = ( conf = (
load_full_training_config(config, field=config_field) load_full_training_config(config, field=config_field)
if config is not None if config is not None
else FullTrainingConfig() else FullTrainingConfig()
) )
if targets is not None:
logger.info("Loading targets configuration...")
targets_config = load_target_config(targets)
conf = conf.model_copy(update=dict(targets=targets_config))
logger.info("Loading training dataset...") logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset) train_annotations = load_dataset_from_config(train_dataset)
logger.debug( logger.debug(