mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Add image logging for mlflow
This commit is contained in:
parent
fdbb9c2b43
commit
a7301bcdc8
@ -1,8 +1,11 @@
|
||||
import io
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from lightning.pytorch.loggers import Logger, TensorBoardLogger
|
||||
from lightning.pytorch.loggers.mlflow import MLFlowLogger
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
@ -55,15 +58,17 @@ class ValidationMetrics(Callback):
|
||||
pl_module: LightningModule,
|
||||
matches: List[MatchEvaluation],
|
||||
):
|
||||
if not isinstance(pl_module.logger, TensorBoardLogger):
|
||||
plotter = _get_image_plotter(pl_module.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
return
|
||||
|
||||
for class_name, fig in plot_example_gallery(
|
||||
matches,
|
||||
preprocessor=pl_module.preprocessor,
|
||||
n_examples=5,
|
||||
n_examples=4,
|
||||
):
|
||||
pl_module.logger.experiment.add_figure(
|
||||
plotter(
|
||||
f"images/{class_name}_examples",
|
||||
fig,
|
||||
pl_module.global_step,
|
||||
@ -210,3 +215,30 @@ def _get_subclip(
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _get_image_plotter(logger: Logger):
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
|
||||
def plot_figure(name, figure, step):
|
||||
return logger.experiment.add_figure(name, figure, step)
|
||||
|
||||
return plot_figure
|
||||
|
||||
if isinstance(logger, MLFlowLogger):
|
||||
|
||||
def plot_figure(name, figure, step):
|
||||
image = _convert_figure_to_image(figure)
|
||||
return logger.experiment.log_image(image, key=name, step=step)
|
||||
|
||||
return plot_figure
|
||||
|
||||
|
||||
def _convert_figure_to_image(figure):
|
||||
with io.BytesIO() as buff:
|
||||
figure.savefig(buff, format="raw")
|
||||
buff.seek(0)
|
||||
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
|
||||
w, h = figure.canvas.get_width_height()
|
||||
im = data.reshape((int(h), int(w), -1))
|
||||
return im
|
||||
|
||||
Loading…
Reference in New Issue
Block a user