Logging is not just for training

This commit is contained in:
mbsantiago 2025-09-18 09:27:40 +01:00
parent 8c80402f08
commit 6c25787123

View File

@ -15,6 +15,7 @@ from typing import (
)
import numpy as np
import pandas as pd
from lightning.pytorch.loggers import (
CSVLogger,
Logger,
@ -56,7 +57,7 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
class MLFlowLoggerConfig(BaseLoggerConfig):
name: Literal["mlflow"] = "mlflow"
tracking_uri: Optional[str] = None
tracking_uri: Optional[str] = "http://localhost:5000"
tags: Optional[dict[str, Any]] = None
log_model: bool = False
@ -160,6 +161,9 @@ def create_tensorboard_logger(
name = run_name
if name is None:
name = experiment_name
if run_name is not None and experiment_name is not None:
name = str(Path(experiment_name) / run_name)
@ -239,10 +243,10 @@ def build_logger(
)
Plotter = Callable[[str, Figure, int], None]
PlotLogger = Callable[[str, Figure, int], None]
def get_image_plotter(logger: Logger) -> Optional[Plotter]:
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
if isinstance(logger, TensorBoardLogger):
return logger.experiment.add_figure
@ -250,6 +254,7 @@ def get_image_plotter(logger: Logger) -> Optional[Plotter]:
def plot_figure(name, figure, step):
image = _convert_figure_to_array(figure)
name = name.replace("/", "_")
return logger.experiment.log_image(
logger.run_id,
image,
@ -263,6 +268,37 @@ def get_image_plotter(logger: Logger) -> Optional[Plotter]:
return partial(save_figure, dir=Path(logger.log_dir))
TableLogger = Callable[[str, pd.DataFrame, int], None]
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
if isinstance(logger, TensorBoardLogger):
return partial(save_table, dir=Path(logger.log_dir))
if isinstance(logger, MLFlowLogger):
def plot_figure(name: str, df: pd.DataFrame, step: int):
return logger.experiment.log_table(
logger.run_id,
data=df,
artifact_file=f"{name}_step_{step}.json",
)
return plot_figure
if isinstance(logger, CSVLogger):
return partial(save_table, dir=Path(logger.log_dir))
def save_table(name: str, df: pd.DataFrame, step: int, dir: Path) -> None:
path = dir / "tables" / f"{name}_step_{step}.csv"
if not path.parent.exists():
path.parent.mkdir(parents=True)
df.to_csv(path, index=False)
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
path = dir / "plots" / f"{name}_step_{step}.png"