From 6c25787123e40050e08f217c7051e230bbfe0f1e Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 18 Sep 2025 09:27:40 +0100 Subject: [PATCH] Logging is not just for training --- src/batdetect2/{train => }/logging.py | 42 +++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) rename src/batdetect2/{train => }/logging.py (86%) diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/logging.py similarity index 86% rename from src/batdetect2/train/logging.py rename to src/batdetect2/logging.py index 5b1b8c6..eb96d44 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/logging.py @@ -15,6 +15,7 @@ from typing import ( ) import numpy as np +import pandas as pd from lightning.pytorch.loggers import ( CSVLogger, Logger, @@ -56,7 +57,7 @@ class TensorBoardLoggerConfig(BaseLoggerConfig): class MLFlowLoggerConfig(BaseLoggerConfig): name: Literal["mlflow"] = "mlflow" - tracking_uri: Optional[str] = None + tracking_uri: Optional[str] = "http://localhost:5000" tags: Optional[dict[str, Any]] = None log_model: bool = False @@ -160,6 +161,9 @@ def create_tensorboard_logger( name = run_name + if name is None: + name = experiment_name + if run_name is not None and experiment_name is not None: name = str(Path(experiment_name) / run_name) @@ -239,10 +243,10 @@ def build_logger( ) -Plotter = Callable[[str, Figure, int], None] +PlotLogger = Callable[[str, Figure, int], None] -def get_image_plotter(logger: Logger) -> Optional[Plotter]: +def get_image_logger(logger: Logger) -> Optional[PlotLogger]: if isinstance(logger, TensorBoardLogger): return logger.experiment.add_figure @@ -250,6 +254,7 @@ def get_image_plotter(logger: Logger) -> Optional[Plotter]: def plot_figure(name, figure, step): image = _convert_figure_to_array(figure) + name = name.replace("/", "_") return logger.experiment.log_image( logger.run_id, image, @@ -263,6 +268,37 @@ def get_image_plotter(logger: Logger) -> Optional[Plotter]: return partial(save_figure, dir=Path(logger.log_dir)) +TableLogger = Callable[[str, pd.DataFrame, int], None] + + +def get_table_logger(logger: Logger) -> Optional[TableLogger]: + if isinstance(logger, TensorBoardLogger): + return partial(save_table, dir=Path(logger.log_dir)) + + if isinstance(logger, MLFlowLogger): + + def plot_figure(name: str, df: pd.DataFrame, step: int): + return logger.experiment.log_table( + logger.run_id, + data=df, + artifact_file=f"{name}_step_{step}.json", + ) + + return plot_figure + + if isinstance(logger, CSVLogger): + return partial(save_table, dir=Path(logger.log_dir)) + + +def save_table(name: str, df: pd.DataFrame, step: int, dir: Path) -> None: + path = dir / "tables" / f"{name}_step_{step}.csv" + + if not path.parent.exists(): + path.parent.mkdir(parents=True) + + df.to_csv(path, index=False) + + def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None: path = dir / "plots" / f"{name}_step_{step}.png"