mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Add experiment name
This commit is contained in:
parent
cd4955d4f3
commit
3376be06a4
@ -26,6 +26,7 @@ __all__ = ["train_command"]
|
|||||||
@click.option("--config-field", type=str)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int)
|
||||||
@click.option("--val-workers", type=int)
|
@click.option("--val-workers", type=int)
|
||||||
|
@click.option("--experiment-name", type=str)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-v",
|
"-v",
|
||||||
"--verbose",
|
"--verbose",
|
||||||
@ -42,6 +43,7 @@ def train_command(
|
|||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
logger.remove()
|
logger.remove()
|
||||||
@ -87,6 +89,7 @@ def train_command(
|
|||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
|
experiment_name=experiment_name,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
checkpoint_dir=ckpt_dir,
|
checkpoint_dir=ckpt_dir,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -61,6 +61,7 @@ LoggerConfig = Annotated[
|
|||||||
def create_dvclive_logger(
|
def create_dvclive_logger(
|
||||||
config: DVCLiveConfig,
|
config: DVCLiveConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||||
@ -73,7 +74,9 @@ def create_dvclive_logger(
|
|||||||
|
|
||||||
return DVCLiveLogger(
|
return DVCLiveLogger(
|
||||||
dir=log_dir if log_dir is not None else config.dir,
|
dir=log_dir if log_dir is not None else config.dir,
|
||||||
run_name=config.run_name,
|
run_name=experiment_name
|
||||||
|
if experiment_name is not None
|
||||||
|
else config.run_name,
|
||||||
prefix=config.prefix,
|
prefix=config.prefix,
|
||||||
log_model=config.log_model,
|
log_model=config.log_model,
|
||||||
monitor_system=config.monitor_system,
|
monitor_system=config.monitor_system,
|
||||||
@ -83,12 +86,13 @@ def create_dvclive_logger(
|
|||||||
def create_csv_logger(
|
def create_csv_logger(
|
||||||
config: CSVLoggerConfig,
|
config: CSVLoggerConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import CSVLogger
|
from lightning.pytorch.loggers import CSVLogger
|
||||||
|
|
||||||
return CSVLogger(
|
return CSVLogger(
|
||||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||||
name=config.name,
|
name=experiment_name if experiment_name is not None else config.name,
|
||||||
version=config.version,
|
version=config.version,
|
||||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||||
)
|
)
|
||||||
@ -97,12 +101,13 @@ def create_csv_logger(
|
|||||||
def create_tensorboard_logger(
|
def create_tensorboard_logger(
|
||||||
config: TensorBoardLoggerConfig,
|
config: TensorBoardLoggerConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
return TensorBoardLogger(
|
return TensorBoardLogger(
|
||||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||||
name=config.name,
|
name=experiment_name if experiment_name is not None else config.name,
|
||||||
version=config.version,
|
version=config.version,
|
||||||
log_graph=config.log_graph,
|
log_graph=config.log_graph,
|
||||||
)
|
)
|
||||||
@ -111,6 +116,7 @@ def create_tensorboard_logger(
|
|||||||
def create_mlflow_logger(
|
def create_mlflow_logger(
|
||||||
config: MLFlowLoggerConfig,
|
config: MLFlowLoggerConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from lightning.pytorch.loggers import MLFlowLogger
|
from lightning.pytorch.loggers import MLFlowLogger
|
||||||
@ -122,7 +128,9 @@ def create_mlflow_logger(
|
|||||||
) from error
|
) from error
|
||||||
|
|
||||||
return MLFlowLogger(
|
return MLFlowLogger(
|
||||||
experiment_name=config.experiment_name,
|
experiment_name=experiment_name
|
||||||
|
if experiment_name is not None
|
||||||
|
else config.experiment_name,
|
||||||
run_name=config.run_name,
|
run_name=config.run_name,
|
||||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||||
tracking_uri=config.tracking_uri,
|
tracking_uri=config.tracking_uri,
|
||||||
@ -142,6 +150,7 @@ LOGGER_FACTORY = {
|
|||||||
def build_logger(
|
def build_logger(
|
||||||
config: LoggerConfig,
|
config: LoggerConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
"""
|
"""
|
||||||
Creates a logger instance from a validated Pydantic config object.
|
Creates a logger instance from a validated Pydantic config object.
|
||||||
@ -157,7 +166,11 @@ def build_logger(
|
|||||||
|
|
||||||
creation_func = LOGGER_FACTORY[logger_type]
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
return creation_func(config, log_dir=log_dir)
|
return creation_func(
|
||||||
|
config,
|
||||||
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_image_plotter(logger: Logger):
|
def get_image_plotter(logger: Logger):
|
||||||
|
|||||||
@ -55,6 +55,7 @@ def train(
|
|||||||
val_workers: Optional[int] = None,
|
val_workers: Optional[int] = None,
|
||||||
checkpoint_dir: Optional[data.PathLike] = None,
|
checkpoint_dir: Optional[data.PathLike] = None,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
config = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
@ -107,6 +108,7 @@ def train(
|
|||||||
targets=targets,
|
targets=targets,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Starting main training loop...")
|
logger.info("Starting main training loop...")
|
||||||
@ -135,10 +137,14 @@ def build_trainer_callbacks(
|
|||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: EvaluationConfig,
|
config: EvaluationConfig,
|
||||||
checkpoint_dir: Optional[data.PathLike] = None,
|
checkpoint_dir: Optional[data.PathLike] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
) -> List[Callback]:
|
) -> List[Callback]:
|
||||||
if checkpoint_dir is None:
|
if checkpoint_dir is None:
|
||||||
checkpoint_dir = "outputs/checkpoints"
|
checkpoint_dir = "outputs/checkpoints"
|
||||||
|
|
||||||
|
if experiment_name is not None:
|
||||||
|
checkpoint_dir = f"{checkpoint_dir}/{experiment_name}"
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
dirpath=str(checkpoint_dir),
|
dirpath=str(checkpoint_dir),
|
||||||
@ -164,13 +170,18 @@ def build_trainer(
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
checkpoint_dir: Optional[data.PathLike] = None,
|
checkpoint_dir: Optional[data.PathLike] = None,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = conf.train.trainer
|
trainer_conf = conf.train.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
train_logger = build_logger(conf.train.logger, log_dir=log_dir)
|
train_logger = build_logger(
|
||||||
|
conf.train.logger,
|
||||||
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
)
|
||||||
|
|
||||||
train_logger.log_hyperparams(
|
train_logger.log_hyperparams(
|
||||||
conf.model_dump(
|
conf.model_dump(
|
||||||
@ -187,6 +198,7 @@ def build_trainer(
|
|||||||
config=conf.evaluation,
|
config=conf.evaluation,
|
||||||
preprocessor=build_preprocessor(conf.preprocess),
|
preprocessor=build_preprocessor(conf.preprocess),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
experiment_name=train_logger.name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user