Add experiment name

This commit is contained in:
mbsantiago 2025-09-09 09:02:25 +01:00
parent cd4955d4f3
commit 3376be06a4
3 changed files with 34 additions and 6 deletions

View File

@ -26,6 +26,7 @@ __all__ = ["train_command"]
@click.option("--config-field", type=str) @click.option("--config-field", type=str)
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int)
@click.option("--val-workers", type=int) @click.option("--val-workers", type=int)
@click.option("--experiment-name", type=str)
@click.option( @click.option(
"-v", "-v",
"--verbose", "--verbose",
@ -42,6 +43,7 @@ def train_command(
config_field: Optional[str] = None, config_field: Optional[str] = None,
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
experiment_name: Optional[str] = None,
verbose: int = 0, verbose: int = 0,
): ):
logger.remove() logger.remove()
@ -87,6 +89,7 @@ def train_command(
model_path=model_path, model_path=model_path,
train_workers=train_workers, train_workers=train_workers,
val_workers=val_workers, val_workers=val_workers,
experiment_name=experiment_name,
log_dir=log_dir, log_dir=log_dir,
checkpoint_dir=ckpt_dir, checkpoint_dir=ckpt_dir,
) )

View File

@ -61,6 +61,7 @@ LoggerConfig = Annotated[
def create_dvclive_logger( def create_dvclive_logger(
config: DVCLiveConfig, config: DVCLiveConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
) -> Logger: ) -> Logger:
try: try:
from dvclive.lightning import DVCLiveLogger # type: ignore from dvclive.lightning import DVCLiveLogger # type: ignore
@ -73,7 +74,9 @@ def create_dvclive_logger(
return DVCLiveLogger( return DVCLiveLogger(
dir=log_dir if log_dir is not None else config.dir, dir=log_dir if log_dir is not None else config.dir,
run_name=config.run_name, run_name=experiment_name
if experiment_name is not None
else config.run_name,
prefix=config.prefix, prefix=config.prefix,
log_model=config.log_model, log_model=config.log_model,
monitor_system=config.monitor_system, monitor_system=config.monitor_system,
@ -83,12 +86,13 @@ def create_dvclive_logger(
def create_csv_logger( def create_csv_logger(
config: CSVLoggerConfig, config: CSVLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
) -> Logger: ) -> Logger:
from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers import CSVLogger
return CSVLogger( return CSVLogger(
save_dir=str(log_dir) if log_dir is not None else config.save_dir, save_dir=str(log_dir) if log_dir is not None else config.save_dir,
name=config.name, name=experiment_name if experiment_name is not None else config.name,
version=config.version, version=config.version,
flush_logs_every_n_steps=config.flush_logs_every_n_steps, flush_logs_every_n_steps=config.flush_logs_every_n_steps,
) )
@ -97,12 +101,13 @@ def create_csv_logger(
def create_tensorboard_logger( def create_tensorboard_logger(
config: TensorBoardLoggerConfig, config: TensorBoardLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
) -> Logger: ) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
return TensorBoardLogger( return TensorBoardLogger(
save_dir=str(log_dir) if log_dir is not None else config.save_dir, save_dir=str(log_dir) if log_dir is not None else config.save_dir,
name=config.name, name=experiment_name if experiment_name is not None else config.name,
version=config.version, version=config.version,
log_graph=config.log_graph, log_graph=config.log_graph,
) )
@ -111,6 +116,7 @@ def create_tensorboard_logger(
def create_mlflow_logger( def create_mlflow_logger(
config: MLFlowLoggerConfig, config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
) -> Logger: ) -> Logger:
try: try:
from lightning.pytorch.loggers import MLFlowLogger from lightning.pytorch.loggers import MLFlowLogger
@ -122,7 +128,9 @@ def create_mlflow_logger(
) from error ) from error
return MLFlowLogger( return MLFlowLogger(
experiment_name=config.experiment_name, experiment_name=experiment_name
if experiment_name is not None
else config.experiment_name,
run_name=config.run_name, run_name=config.run_name,
save_dir=str(log_dir) if log_dir is not None else config.save_dir, save_dir=str(log_dir) if log_dir is not None else config.save_dir,
tracking_uri=config.tracking_uri, tracking_uri=config.tracking_uri,
@ -142,6 +150,7 @@ LOGGER_FACTORY = {
def build_logger( def build_logger(
config: LoggerConfig, config: LoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
) -> Logger: ) -> Logger:
""" """
Creates a logger instance from a validated Pydantic config object. Creates a logger instance from a validated Pydantic config object.
@ -157,7 +166,11 @@ def build_logger(
creation_func = LOGGER_FACTORY[logger_type] creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config, log_dir=log_dir) return creation_func(
config,
log_dir=log_dir,
experiment_name=experiment_name,
)
def get_image_plotter(logger: Logger): def get_image_plotter(logger: Logger):

View File

@ -55,6 +55,7 @@ def train(
val_workers: Optional[int] = None, val_workers: Optional[int] = None,
checkpoint_dir: Optional[data.PathLike] = None, checkpoint_dir: Optional[data.PathLike] = None,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
): ):
config = config or FullTrainingConfig() config = config or FullTrainingConfig()
@ -107,6 +108,7 @@ def train(
targets=targets, targets=targets,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name,
) )
logger.info("Starting main training loop...") logger.info("Starting main training loop...")
@ -135,10 +137,14 @@ def build_trainer_callbacks(
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: EvaluationConfig, config: EvaluationConfig,
checkpoint_dir: Optional[data.PathLike] = None, checkpoint_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
) -> List[Callback]: ) -> List[Callback]:
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = "outputs/checkpoints" checkpoint_dir = "outputs/checkpoints"
if experiment_name is not None:
checkpoint_dir = f"{checkpoint_dir}/{experiment_name}"
return [ return [
ModelCheckpoint( ModelCheckpoint(
dirpath=str(checkpoint_dir), dirpath=str(checkpoint_dir),
@ -164,13 +170,18 @@ def build_trainer(
targets: TargetProtocol, targets: TargetProtocol,
checkpoint_dir: Optional[data.PathLike] = None, checkpoint_dir: Optional[data.PathLike] = None,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
) -> Trainer: ) -> Trainer:
trainer_conf = conf.train.trainer trainer_conf = conf.train.trainer
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}", "Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True), config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
) )
train_logger = build_logger(conf.train.logger, log_dir=log_dir) train_logger = build_logger(
conf.train.logger,
log_dir=log_dir,
experiment_name=experiment_name,
)
train_logger.log_hyperparams( train_logger.log_hyperparams(
conf.model_dump( conf.model_dump(
@ -187,6 +198,7 @@ def build_trainer(
config=conf.evaluation, config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess), preprocessor=build_preprocessor(conf.preprocess),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
experiment_name=train_logger.name,
), ),
) )