From a7301bcdc849b64a810435e64eb42319cdbe3509 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 14 Aug 2025 10:50:00 +0100 Subject: [PATCH] Add image logging for mlflow --- src/batdetect2/train/callbacks.py | 40 +++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index d973c37..74d93ab 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -1,8 +1,11 @@ +import io from typing import List, Optional, Tuple +import numpy as np from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import Callback -from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.loggers import Logger, TensorBoardLogger +from lightning.pytorch.loggers.mlflow import MLFlowLogger from loguru import logger from soundevent import data from torch.utils.data import DataLoader @@ -55,15 +58,17 @@ class ValidationMetrics(Callback): pl_module: LightningModule, matches: List[MatchEvaluation], ): - if not isinstance(pl_module.logger, TensorBoardLogger): + plotter = _get_image_plotter(pl_module.logger) # type: ignore + + if plotter is None: return for class_name, fig in plot_example_gallery( matches, preprocessor=pl_module.preprocessor, - n_examples=5, + n_examples=4, ): - pl_module.logger.experiment.add_figure( + plotter( f"images/{class_name}_examples", fig, pl_module.global_step, @@ -210,3 +215,30 @@ def _get_subclip( ) ], ) + + +def _get_image_plotter(logger: Logger): + if isinstance(logger, TensorBoardLogger): + + def plot_figure(name, figure, step): + return logger.experiment.add_figure(name, figure, step) + + return plot_figure + + if isinstance(logger, MLFlowLogger): + + def plot_figure(name, figure, step): + image = _convert_figure_to_image(figure) + return logger.experiment.log_image(image, key=name, step=step) + + return plot_figure + + +def _convert_figure_to_image(figure): + with io.BytesIO() as buff: + figure.savefig(buff, format="raw") + buff.seek(0) + data = np.frombuffer(buff.getvalue(), dtype=np.uint8) + w, h = figure.canvas.get_width_height() + im = data.reshape((int(h), int(w), -1)) + return im