Compare commits

...

3 Commits

Author SHA1 Message Date
mbsantiago
78a0975864 Same issue 2025-08-14 11:07:27 +01:00
mbsantiago
c9848deebf Fix logging issue 2025-08-14 11:05:21 +01:00
mbsantiago
a7301bcdc8 Add image logging for mlflow 2025-08-14 10:50:00 +01:00
2 changed files with 38 additions and 6 deletions

View File

@ -1,8 +1,11 @@
import io
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import Logger, TensorBoardLogger
from lightning.pytorch.loggers.mlflow import MLFlowLogger
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -55,15 +58,17 @@ class ValidationMetrics(Callback):
pl_module: LightningModule, pl_module: LightningModule,
matches: List[MatchEvaluation], matches: List[MatchEvaluation],
): ):
if not isinstance(pl_module.logger, TensorBoardLogger): plotter = _get_image_plotter(pl_module.logger) # type: ignore
if plotter is None:
return return
for class_name, fig in plot_example_gallery( for class_name, fig in plot_example_gallery(
matches, matches,
preprocessor=pl_module.preprocessor, preprocessor=pl_module.preprocessor,
n_examples=5, n_examples=4,
): ):
pl_module.logger.experiment.add_figure( plotter(
f"images/{class_name}_examples", f"images/{class_name}_examples",
fig, fig,
pl_module.global_step, pl_module.global_step,
@ -210,3 +215,30 @@ def _get_subclip(
) )
], ],
) )
def _get_image_plotter(logger: Logger):
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
return logger.experiment.log_image(image, key=name, step=step)
return plot_figure
def _convert_figure_to_image(figure):
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im

View File

@ -136,7 +136,7 @@ def build_train_loader(
loader_conf = config.dataloaders.train loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Training data loader config: \n{config}", "Training data loader config: \n{config}",
config=loader_conf.to_yaml_string(exclude_none=True), config=lambda: loader_conf.to_yaml_string(exclude_none=True),
) )
num_workers = num_workers or loader_conf.num_workers num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
@ -161,7 +161,7 @@ def build_val_loader(
loader_conf = config.dataloaders.val loader_conf = config.dataloaders.val
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}", "Validation data loader config: \n{config}",
config=loader_conf.to_yaml_string(exclude_none=True), config=lambda: loader_conf.to_yaml_string(exclude_none=True),
) )
num_workers = num_workers or loader_conf.num_workers num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(