Add arguments to train cli

This commit is contained in:
mbsantiago 2025-10-14 18:19:33 +01:00
parent 3913d2d350
commit 5736421023
3 changed files with 14 additions and 1 deletions

View File

@ -62,6 +62,7 @@ class BatDetect2API:
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
@ -76,6 +77,7 @@ class BatDetect2API:
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
num_epochs=num_epochs,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,

View File

@ -20,6 +20,7 @@ __all__ = ["train_command"]
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int)
@click.option("--val-workers", type=int)
@click.option("--num-epochs", type=int)
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--seed", type=int)
@ -33,6 +34,7 @@ def train_command(
targets_config: Optional[Path] = None,
config_field: Optional[str] = None,
seed: Optional[int] = None,
num_epochs: Optional[int] = None,
train_workers: int = 0,
val_workers: int = 0,
experiment_name: Optional[str] = None,
@ -95,6 +97,7 @@ def train_command(
val_workers=val_workers,
checkpoint_dir=ckpt_dir,
log_dir=log_dir,
num_epochs=num_epochs,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,

View File

@ -47,6 +47,7 @@ def train(
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
@ -107,6 +108,7 @@ def train(
targets=targets,
),
checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
@ -128,6 +130,7 @@ def build_trainer(
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
num_epochs: Optional[int] = None,
) -> Trainer:
trainer_conf = config.train.trainer
logger.opt(lazy=True).debug(
@ -149,8 +152,13 @@ def build_trainer(
)
)
train_config = trainer_conf.model_dump(exclude_none=True)
if num_epochs is not None:
train_config["max_epochs"] = num_epochs
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
**train_config,
logger=train_logger,
callbacks=[
build_checkpoint_callback(