Compare commits

..

3 Commits

Author SHA1 Message Date
mbsantiago
78a0975864 Same issue 2025-08-14 11:07:27 +01:00
mbsantiago
c9848deebf Fix logging issue 2025-08-14 11:05:21 +01:00
mbsantiago
a7301bcdc8 Add image logging for mlflow 2025-08-14 10:50:00 +01:00
2 changed files with 38 additions and 6 deletions

View File

@ -1,8 +1,11 @@
import io
from typing import List, Optional, Tuple
import numpy as np
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loggers import Logger, TensorBoardLogger
from lightning.pytorch.loggers.mlflow import MLFlowLogger
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
@ -55,15 +58,17 @@ class ValidationMetrics(Callback):
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
if not isinstance(pl_module.logger, TensorBoardLogger):
plotter = _get_image_plotter(pl_module.logger) # type: ignore
if plotter is None:
return
for class_name, fig in plot_example_gallery(
matches,
preprocessor=pl_module.preprocessor,
n_examples=5,
n_examples=4,
):
pl_module.logger.experiment.add_figure(
plotter(
f"images/{class_name}_examples",
fig,
pl_module.global_step,
@ -210,3 +215,30 @@ def _get_subclip(
)
],
)
def _get_image_plotter(logger: Logger):
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
return logger.experiment.log_image(image, key=name, step=step)
return plot_figure
def _convert_figure_to_image(figure):
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im

View File

@ -136,7 +136,7 @@ def build_train_loader(
loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=loader_conf.to_yaml_string(exclude_none=True),
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader(
@ -161,7 +161,7 @@ def build_val_loader(
loader_conf = config.dataloaders.val
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=loader_conf.to_yaml_string(exclude_none=True),
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader(