Compare commits

..

No commits in common. "78a09758645ebfd2cdb5f5db0beaaa2a19d0e793" and "fdbb9c2b43b41629ea949a5c489da8cfeed4066f" have entirely different histories.

2 changed files with 6 additions and 38 deletions

View File

@ -1,11 +1,8 @@
import io
from typing import List, Optional, Tuple
import numpy as np
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger, TensorBoardLogger
from lightning.pytorch.loggers.mlflow import MLFlowLogger
from lightning.pytorch.loggers import TensorBoardLogger
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
@ -58,17 +55,15 @@ class ValidationMetrics(Callback):
pl_module: LightningModule,
matches: List[MatchEvaluation],
):
plotter = _get_image_plotter(pl_module.logger) # type: ignore
if plotter is None:
if not isinstance(pl_module.logger, TensorBoardLogger):
return
for class_name, fig in plot_example_gallery(
matches,
preprocessor=pl_module.preprocessor,
n_examples=4,
n_examples=5,
):
plotter(
pl_module.logger.experiment.add_figure(
f"images/{class_name}_examples",
fig,
pl_module.global_step,
@ -215,30 +210,3 @@ def _get_subclip(
)
],
)
def _get_image_plotter(logger: Logger):
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
return logger.experiment.log_image(image, key=name, step=step)
return plot_figure
def _convert_figure_to_image(figure):
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im

View File

@ -136,7 +136,7 @@ def build_train_loader(
loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
config=loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader(
@ -161,7 +161,7 @@ def build_val_loader(
loader_conf = config.dataloaders.val
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
config=loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader(