diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 398adac..d48944c 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -62,6 +62,7 @@ class BatDetect2API: checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR, log_dir: Optional[Path] = DEFAULT_LOGS_DIR, experiment_name: Optional[str] = None, + num_epochs: Optional[int] = None, run_name: Optional[str] = None, seed: Optional[int] = None, ): @@ -76,6 +77,7 @@ class BatDetect2API: val_workers=val_workers, checkpoint_dir=checkpoint_dir, log_dir=log_dir, + num_epochs=num_epochs, experiment_name=experiment_name, run_name=run_name, seed=seed, diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 76105f9..7cb7d98 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -20,6 +20,7 @@ __all__ = ["train_command"] @click.option("--config-field", type=str) @click.option("--train-workers", type=int) @click.option("--val-workers", type=int) +@click.option("--num-epochs", type=int) @click.option("--experiment-name", type=str) @click.option("--run-name", type=str) @click.option("--seed", type=int) @@ -33,6 +34,7 @@ def train_command( targets_config: Optional[Path] = None, config_field: Optional[str] = None, seed: Optional[int] = None, + num_epochs: Optional[int] = None, train_workers: int = 0, val_workers: int = 0, experiment_name: Optional[str] = None, @@ -95,6 +97,7 @@ def train_command( val_workers=val_workers, checkpoint_dir=ckpt_dir, log_dir=log_dir, + num_epochs=num_epochs, experiment_name=experiment_name, run_name=run_name, seed=seed, diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index a5c5a43..5673d0d 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -47,6 +47,7 @@ def train( checkpoint_dir: Optional[Path] = None, log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, + num_epochs: Optional[int] = None, run_name: Optional[str] = None, seed: Optional[int] = None, ): @@ -107,6 +108,7 @@ def train( targets=targets, ), checkpoint_dir=checkpoint_dir, + num_epochs=num_epochs, log_dir=log_dir, experiment_name=experiment_name, run_name=run_name, @@ -128,6 +130,7 @@ def build_trainer( log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, run_name: Optional[str] = None, + num_epochs: Optional[int] = None, ) -> Trainer: trainer_conf = config.train.trainer logger.opt(lazy=True).debug( @@ -149,8 +152,13 @@ def build_trainer( ) ) + train_config = trainer_conf.model_dump(exclude_none=True) + + if num_epochs is not None: + train_config["max_epochs"] = num_epochs + return Trainer( - **trainer_conf.model_dump(exclude_none=True), + **train_config, logger=train_logger, callbacks=[ build_checkpoint_callback(