Add image logging for mlflow

This commit is contained in:
mbsantiago 2025-08-14 10:50:00 +01:00
parent fdbb9c2b43
commit a7301bcdc8

View File

@ -1,8 +1,11 @@
import io
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import Logger, TensorBoardLogger
from lightning.pytorch.loggers.mlflow import MLFlowLogger
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -55,15 +58,17 @@ class ValidationMetrics(Callback):
pl_module: LightningModule, pl_module: LightningModule,
matches: List[MatchEvaluation], matches: List[MatchEvaluation],
): ):
if not isinstance(pl_module.logger, TensorBoardLogger): plotter = _get_image_plotter(pl_module.logger) # type: ignore
if plotter is None:
return return
for class_name, fig in plot_example_gallery( for class_name, fig in plot_example_gallery(
matches, matches,
preprocessor=pl_module.preprocessor, preprocessor=pl_module.preprocessor,
n_examples=5, n_examples=4,
): ):
pl_module.logger.experiment.add_figure( plotter(
f"images/{class_name}_examples", f"images/{class_name}_examples",
fig, fig,
pl_module.global_step, pl_module.global_step,
@ -210,3 +215,30 @@ def _get_subclip(
) )
], ],
) )
def _get_image_plotter(logger: Logger):
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
return logger.experiment.log_image(image, key=name, step=step)
return plot_figure
def _convert_figure_to_image(figure):
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im