mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Add arguments to train cli
This commit is contained in:
parent
3913d2d350
commit
5736421023
@ -62,6 +62,7 @@ class BatDetect2API:
|
|||||||
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
||||||
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
|
num_epochs: Optional[int] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
):
|
):
|
||||||
@ -76,6 +77,7 @@ class BatDetect2API:
|
|||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
|
num_epochs=num_epochs,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
|||||||
@ -20,6 +20,7 @@ __all__ = ["train_command"]
|
|||||||
@click.option("--config-field", type=str)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int)
|
||||||
@click.option("--val-workers", type=int)
|
@click.option("--val-workers", type=int)
|
||||||
|
@click.option("--num-epochs", type=int)
|
||||||
@click.option("--experiment-name", type=str)
|
@click.option("--experiment-name", type=str)
|
||||||
@click.option("--run-name", type=str)
|
@click.option("--run-name", type=str)
|
||||||
@click.option("--seed", type=int)
|
@click.option("--seed", type=int)
|
||||||
@ -33,6 +34,7 @@ def train_command(
|
|||||||
targets_config: Optional[Path] = None,
|
targets_config: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
num_epochs: Optional[int] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
@ -95,6 +97,7 @@ def train_command(
|
|||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
checkpoint_dir=ckpt_dir,
|
checkpoint_dir=ckpt_dir,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
|
num_epochs=num_epochs,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
|||||||
@ -47,6 +47,7 @@ def train(
|
|||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
|
num_epochs: Optional[int] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
):
|
):
|
||||||
@ -107,6 +108,7 @@ def train(
|
|||||||
targets=targets,
|
targets=targets,
|
||||||
),
|
),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
num_epochs=num_epochs,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
@ -128,6 +130,7 @@ def build_trainer(
|
|||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
|
num_epochs: Optional[int] = None,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = config.train.trainer
|
trainer_conf = config.train.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
@ -149,8 +152,13 @@ def build_trainer(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_config = trainer_conf.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
if num_epochs is not None:
|
||||||
|
train_config["max_epochs"] = num_epochs
|
||||||
|
|
||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**train_config,
|
||||||
logger=train_logger,
|
logger=train_logger,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
build_checkpoint_callback(
|
build_checkpoint_callback(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user