mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Add arguments to train cli
This commit is contained in:
parent
3913d2d350
commit
5736421023
@ -62,6 +62,7 @@ class BatDetect2API:
|
||||
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
@ -76,6 +77,7 @@ class BatDetect2API:
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
num_epochs=num_epochs,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
|
||||
@ -20,6 +20,7 @@ __all__ = ["train_command"]
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int)
|
||||
@click.option("--val-workers", type=int)
|
||||
@click.option("--num-epochs", type=int)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--run-name", type=str)
|
||||
@click.option("--seed", type=int)
|
||||
@ -33,6 +34,7 @@ def train_command(
|
||||
targets_config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
experiment_name: Optional[str] = None,
|
||||
@ -95,6 +97,7 @@ def train_command(
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
log_dir=log_dir,
|
||||
num_epochs=num_epochs,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
|
||||
@ -47,6 +47,7 @@ def train(
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
@ -107,6 +108,7 @@ def train(
|
||||
targets=targets,
|
||||
),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
num_epochs=num_epochs,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
@ -128,6 +130,7 @@ def build_trainer(
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = config.train.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
@ -149,8 +152,13 @@ def build_trainer(
|
||||
)
|
||||
)
|
||||
|
||||
train_config = trainer_conf.model_dump(exclude_none=True)
|
||||
|
||||
if num_epochs is not None:
|
||||
train_config["max_epochs"] = num_epochs
|
||||
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
**train_config,
|
||||
logger=train_logger,
|
||||
callbacks=[
|
||||
build_checkpoint_callback(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user