diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index cf6b089..864dfda 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -26,6 +26,7 @@ __all__ = ["train_command"] @click.option("--config-field", type=str) @click.option("--train-workers", type=int) @click.option("--val-workers", type=int) +@click.option("--experiment-name", type=str) @click.option( "-v", "--verbose", @@ -42,6 +43,7 @@ def train_command( config_field: Optional[str] = None, train_workers: int = 0, val_workers: int = 0, + experiment_name: Optional[str] = None, verbose: int = 0, ): logger.remove() @@ -87,6 +89,7 @@ def train_command( model_path=model_path, train_workers=train_workers, val_workers=val_workers, + experiment_name=experiment_name, log_dir=log_dir, checkpoint_dir=ckpt_dir, ) diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index f482a1c..da576f6 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -61,6 +61,7 @@ LoggerConfig = Annotated[ def create_dvclive_logger( config: DVCLiveConfig, log_dir: Optional[data.PathLike] = None, + experiment_name: Optional[str] = None, ) -> Logger: try: from dvclive.lightning import DVCLiveLogger # type: ignore @@ -73,7 +74,9 @@ def create_dvclive_logger( return DVCLiveLogger( dir=log_dir if log_dir is not None else config.dir, - run_name=config.run_name, + run_name=experiment_name + if experiment_name is not None + else config.run_name, prefix=config.prefix, log_model=config.log_model, monitor_system=config.monitor_system, @@ -83,12 +86,13 @@ def create_dvclive_logger( def create_csv_logger( config: CSVLoggerConfig, log_dir: Optional[data.PathLike] = None, + experiment_name: Optional[str] = None, ) -> Logger: from lightning.pytorch.loggers import CSVLogger return CSVLogger( save_dir=str(log_dir) if log_dir is not None else config.save_dir, - name=config.name, + name=experiment_name if experiment_name is not None else config.name, version=config.version, flush_logs_every_n_steps=config.flush_logs_every_n_steps, ) @@ -97,12 +101,13 @@ def create_csv_logger( def create_tensorboard_logger( config: TensorBoardLoggerConfig, log_dir: Optional[data.PathLike] = None, + experiment_name: Optional[str] = None, ) -> Logger: from lightning.pytorch.loggers import TensorBoardLogger return TensorBoardLogger( save_dir=str(log_dir) if log_dir is not None else config.save_dir, - name=config.name, + name=experiment_name if experiment_name is not None else config.name, version=config.version, log_graph=config.log_graph, ) @@ -111,6 +116,7 @@ def create_tensorboard_logger( def create_mlflow_logger( config: MLFlowLoggerConfig, log_dir: Optional[data.PathLike] = None, + experiment_name: Optional[str] = None, ) -> Logger: try: from lightning.pytorch.loggers import MLFlowLogger @@ -122,7 +128,9 @@ def create_mlflow_logger( ) from error return MLFlowLogger( - experiment_name=config.experiment_name, + experiment_name=experiment_name + if experiment_name is not None + else config.experiment_name, run_name=config.run_name, save_dir=str(log_dir) if log_dir is not None else config.save_dir, tracking_uri=config.tracking_uri, @@ -142,6 +150,7 @@ LOGGER_FACTORY = { def build_logger( config: LoggerConfig, log_dir: Optional[data.PathLike] = None, + experiment_name: Optional[str] = None, ) -> Logger: """ Creates a logger instance from a validated Pydantic config object. @@ -157,7 +166,11 @@ def build_logger( creation_func = LOGGER_FACTORY[logger_type] - return creation_func(config, log_dir=log_dir) + return creation_func( + config, + log_dir=log_dir, + experiment_name=experiment_name, + ) def get_image_plotter(logger: Logger): diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 0eb2d6f..b4b3061 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -55,6 +55,7 @@ def train( val_workers: Optional[int] = None, checkpoint_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None, + experiment_name: Optional[str] = None, ): config = config or FullTrainingConfig() @@ -107,6 +108,7 @@ def train( targets=targets, checkpoint_dir=checkpoint_dir, log_dir=log_dir, + experiment_name=experiment_name, ) logger.info("Starting main training loop...") @@ -135,10 +137,14 @@ def build_trainer_callbacks( preprocessor: PreprocessorProtocol, config: EvaluationConfig, checkpoint_dir: Optional[data.PathLike] = None, + experiment_name: Optional[str] = None, ) -> List[Callback]: if checkpoint_dir is None: checkpoint_dir = "outputs/checkpoints" + if experiment_name is not None: + checkpoint_dir = f"{checkpoint_dir}/{experiment_name}" + return [ ModelCheckpoint( dirpath=str(checkpoint_dir), @@ -164,13 +170,18 @@ def build_trainer( targets: TargetProtocol, checkpoint_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None, + experiment_name: Optional[str] = None, ) -> Trainer: trainer_conf = conf.train.trainer logger.opt(lazy=True).debug( "Building trainer with config: \n{config}", config=lambda: trainer_conf.to_yaml_string(exclude_none=True), ) - train_logger = build_logger(conf.train.logger, log_dir=log_dir) + train_logger = build_logger( + conf.train.logger, + log_dir=log_dir, + experiment_name=experiment_name, + ) train_logger.log_hyperparams( conf.model_dump( @@ -187,6 +198,7 @@ def build_trainer( config=conf.evaluation, preprocessor=build_preprocessor(conf.preprocess), checkpoint_dir=checkpoint_dir, + experiment_name=train_logger.name, ), )