mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Logging is not just for training
This commit is contained in:
parent
8c80402f08
commit
6c25787123
@ -15,6 +15,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
Logger,
|
||||
@ -56,7 +57,7 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
|
||||
|
||||
class MLFlowLoggerConfig(BaseLoggerConfig):
|
||||
name: Literal["mlflow"] = "mlflow"
|
||||
tracking_uri: Optional[str] = None
|
||||
tracking_uri: Optional[str] = "http://localhost:5000"
|
||||
tags: Optional[dict[str, Any]] = None
|
||||
log_model: bool = False
|
||||
|
||||
@ -160,6 +161,9 @@ def create_tensorboard_logger(
|
||||
|
||||
name = run_name
|
||||
|
||||
if name is None:
|
||||
name = experiment_name
|
||||
|
||||
if run_name is not None and experiment_name is not None:
|
||||
name = str(Path(experiment_name) / run_name)
|
||||
|
||||
@ -239,10 +243,10 @@ def build_logger(
|
||||
)
|
||||
|
||||
|
||||
Plotter = Callable[[str, Figure, int], None]
|
||||
PlotLogger = Callable[[str, Figure, int], None]
|
||||
|
||||
|
||||
def get_image_plotter(logger: Logger) -> Optional[Plotter]:
|
||||
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return logger.experiment.add_figure
|
||||
|
||||
@ -250,6 +254,7 @@ def get_image_plotter(logger: Logger) -> Optional[Plotter]:
|
||||
|
||||
def plot_figure(name, figure, step):
|
||||
image = _convert_figure_to_array(figure)
|
||||
name = name.replace("/", "_")
|
||||
return logger.experiment.log_image(
|
||||
logger.run_id,
|
||||
image,
|
||||
@ -263,6 +268,37 @@ def get_image_plotter(logger: Logger) -> Optional[Plotter]:
|
||||
return partial(save_figure, dir=Path(logger.log_dir))
|
||||
|
||||
|
||||
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
||||
|
||||
|
||||
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return partial(save_table, dir=Path(logger.log_dir))
|
||||
|
||||
if isinstance(logger, MLFlowLogger):
|
||||
|
||||
def plot_figure(name: str, df: pd.DataFrame, step: int):
|
||||
return logger.experiment.log_table(
|
||||
logger.run_id,
|
||||
data=df,
|
||||
artifact_file=f"{name}_step_{step}.json",
|
||||
)
|
||||
|
||||
return plot_figure
|
||||
|
||||
if isinstance(logger, CSVLogger):
|
||||
return partial(save_table, dir=Path(logger.log_dir))
|
||||
|
||||
|
||||
def save_table(name: str, df: pd.DataFrame, step: int, dir: Path) -> None:
|
||||
path = dir / "tables" / f"{name}_step_{step}.csv"
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
df.to_csv(path, index=False)
|
||||
|
||||
|
||||
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
|
||||
path = dir / "plots" / f"{name}_step_{step}.png"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user