Add arguments to train cli

This commit is contained in:
mbsantiago 2025-10-14 18:19:33 +01:00
parent 3913d2d350
commit 5736421023
3 changed files with 14 additions and 1 deletions

View File

@ -62,6 +62,7 @@ class BatDetect2API:
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR, checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR, log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
): ):
@ -76,6 +77,7 @@ class BatDetect2API:
val_workers=val_workers, val_workers=val_workers,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
log_dir=log_dir, log_dir=log_dir,
num_epochs=num_epochs,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
seed=seed, seed=seed,

View File

@ -20,6 +20,7 @@ __all__ = ["train_command"]
@click.option("--config-field", type=str) @click.option("--config-field", type=str)
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int)
@click.option("--val-workers", type=int) @click.option("--val-workers", type=int)
@click.option("--num-epochs", type=int)
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@click.option("--run-name", type=str) @click.option("--run-name", type=str)
@click.option("--seed", type=int) @click.option("--seed", type=int)
@ -33,6 +34,7 @@ def train_command(
targets_config: Optional[Path] = None, targets_config: Optional[Path] = None,
config_field: Optional[str] = None, config_field: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
num_epochs: Optional[int] = None,
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
@ -95,6 +97,7 @@ def train_command(
val_workers=val_workers, val_workers=val_workers,
checkpoint_dir=ckpt_dir, checkpoint_dir=ckpt_dir,
log_dir=log_dir, log_dir=log_dir,
num_epochs=num_epochs,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
seed=seed, seed=seed,

View File

@ -47,6 +47,7 @@ def train(
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
): ):
@ -107,6 +108,7 @@ def train(
targets=targets, targets=targets,
), ),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
@ -128,6 +130,7 @@ def build_trainer(
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
num_epochs: Optional[int] = None,
) -> Trainer: ) -> Trainer:
trainer_conf = config.train.trainer trainer_conf = config.train.trainer
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
@ -149,8 +152,13 @@ def build_trainer(
) )
) )
train_config = trainer_conf.model_dump(exclude_none=True)
if num_epochs is not None:
train_config["max_epochs"] = num_epochs
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **train_config,
logger=train_logger, logger=train_logger,
callbacks=[ callbacks=[
build_checkpoint_callback( build_checkpoint_callback(