mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Add image logging for mlflow
This commit is contained in:
parent
fdbb9c2b43
commit
a7301bcdc8
@ -1,8 +1,11 @@
|
|||||||
|
import io
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from lightning import LightningModule, Trainer
|
from lightning import LightningModule, Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import Logger, TensorBoardLogger
|
||||||
|
from lightning.pytorch.loggers.mlflow import MLFlowLogger
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -55,15 +58,17 @@ class ValidationMetrics(Callback):
|
|||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
matches: List[MatchEvaluation],
|
matches: List[MatchEvaluation],
|
||||||
):
|
):
|
||||||
if not isinstance(pl_module.logger, TensorBoardLogger):
|
plotter = _get_image_plotter(pl_module.logger) # type: ignore
|
||||||
|
|
||||||
|
if plotter is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
for class_name, fig in plot_example_gallery(
|
for class_name, fig in plot_example_gallery(
|
||||||
matches,
|
matches,
|
||||||
preprocessor=pl_module.preprocessor,
|
preprocessor=pl_module.preprocessor,
|
||||||
n_examples=5,
|
n_examples=4,
|
||||||
):
|
):
|
||||||
pl_module.logger.experiment.add_figure(
|
plotter(
|
||||||
f"images/{class_name}_examples",
|
f"images/{class_name}_examples",
|
||||||
fig,
|
fig,
|
||||||
pl_module.global_step,
|
pl_module.global_step,
|
||||||
@ -210,3 +215,30 @@ def _get_subclip(
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_image_plotter(logger: Logger):
|
||||||
|
if isinstance(logger, TensorBoardLogger):
|
||||||
|
|
||||||
|
def plot_figure(name, figure, step):
|
||||||
|
return logger.experiment.add_figure(name, figure, step)
|
||||||
|
|
||||||
|
return plot_figure
|
||||||
|
|
||||||
|
if isinstance(logger, MLFlowLogger):
|
||||||
|
|
||||||
|
def plot_figure(name, figure, step):
|
||||||
|
image = _convert_figure_to_image(figure)
|
||||||
|
return logger.experiment.log_image(image, key=name, step=step)
|
||||||
|
|
||||||
|
return plot_figure
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_figure_to_image(figure):
|
||||||
|
with io.BytesIO() as buff:
|
||||||
|
figure.savefig(buff, format="raw")
|
||||||
|
buff.seek(0)
|
||||||
|
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
|
||||||
|
w, h = figure.canvas.get_width_height()
|
||||||
|
im = data.reshape((int(h), int(w), -1))
|
||||||
|
return im
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user