Compare commits

..

No commits in common. "78a09758645ebfd2cdb5f5db0beaaa2a19d0e793" and "fdbb9c2b43b41629ea949a5c489da8cfeed4066f" have entirely different histories.

2 changed files with 6 additions and 38 deletions

View File

@ -1,11 +1,8 @@
import io
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger, TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loggers.mlflow import MLFlowLogger
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -58,17 +55,15 @@ class ValidationMetrics(Callback):
pl_module: LightningModule, pl_module: LightningModule,
matches: List[MatchEvaluation], matches: List[MatchEvaluation],
): ):
plotter = _get_image_plotter(pl_module.logger) # type: ignore if not isinstance(pl_module.logger, TensorBoardLogger):
if plotter is None:
return return
for class_name, fig in plot_example_gallery( for class_name, fig in plot_example_gallery(
matches, matches,
preprocessor=pl_module.preprocessor, preprocessor=pl_module.preprocessor,
n_examples=4, n_examples=5,
): ):
plotter( pl_module.logger.experiment.add_figure(
f"images/{class_name}_examples", f"images/{class_name}_examples",
fig, fig,
pl_module.global_step, pl_module.global_step,
@ -215,30 +210,3 @@ def _get_subclip(
) )
], ],
) )
def _get_image_plotter(logger: Logger):
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
return logger.experiment.log_image(image, key=name, step=step)
return plot_figure
def _convert_figure_to_image(figure):
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im

View File

@ -136,7 +136,7 @@ def build_train_loader(
loader_conf = config.dataloaders.train loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Training data loader config: \n{config}", "Training data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True), config=loader_conf.to_yaml_string(exclude_none=True),
) )
num_workers = num_workers or loader_conf.num_workers num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
@ -161,7 +161,7 @@ def build_val_loader(
loader_conf = config.dataloaders.val loader_conf = config.dataloaders.val
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}", "Validation data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True), config=loader_conf.to_yaml_string(exclude_none=True),
) )
num_workers = num_workers or loader_conf.num_workers num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(