From 136949c4e771b19f2fb4b9365dc5f2903b8f0e14 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 26 Jun 2025 07:55:24 -0600 Subject: [PATCH] Add logging config --- src/batdetect2/train/config.py | 3 + src/batdetect2/train/logging.py | 104 ++++++++++++++++++++++++++++++++ src/batdetect2/train/train.py | 7 ++- 3 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 src/batdetect2/train/logging.py diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index b75d0fe..d854d2b 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -9,6 +9,7 @@ from batdetect2.train.augmentations import ( AugmentationsConfig, ) from batdetect2.train.clips import ClipingConfig +from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig from batdetect2.train.losses import LossConfig __all__ = [ @@ -59,6 +60,8 @@ class TrainingConfig(BaseConfig): trainer: TrainerConfig = Field(default_factory=TrainerConfig) + logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) + def load_train_config( path: data.PathLike, diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py new file mode 100644 index 0000000..a6cdae5 --- /dev/null +++ b/src/batdetect2/train/logging.py @@ -0,0 +1,104 @@ +from typing import Annotated, Literal, Optional, Union + +from lightning.pytorch.loggers import Logger +from pydantic import Field +from soundevent.data import PathLike + +from batdetect2.configs import BaseConfig, load_config + +DEFAULT_LOGS_DIR: str = "logs" + + +class DVCLiveConfig(BaseConfig): + logger_type: Literal["dvclive"] = "dvclive" + dir: str = DEFAULT_LOGS_DIR + run_name: Optional[str] = None + prefix: str = "" + log_model: Union[bool, Literal["all"]] = False + monitor_system: bool = False + + +class CSVLoggerConfig(BaseConfig): + logger_type: Literal["csv"] = "csv" + save_dir: str = DEFAULT_LOGS_DIR + name: Optional[str] = "logs" + version: Optional[str] = None + flush_logs_every_n_steps: int = 100 + + +class TensorBoardLoggerConfig(BaseConfig): + logger_type: Literal["tensorboard"] = "tensorboard" + save_dir: str = DEFAULT_LOGS_DIR + name: Optional[str] = "default" + version: Optional[str] = None + log_graph: bool = False + flush_logs_every_n_steps: Optional[int] = None + + +LoggerConfig = Annotated[ + Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig], + Field(discriminator="logger_type"), +] + + +def create_dvclive_logger(config: DVCLiveConfig) -> Logger: + try: + from dvclive.lightning import DVCLiveLogger # type: ignore + except ImportError as error: + raise ValueError( + "DVCLive is not installed and cannot be used for logging" + "Make sure you have it installed by running `pip install dvclive`" + "or `uv add dvclive`" + ) from error + + return DVCLiveLogger( + dir=config.dir, + run_name=config.run_name, + prefix=config.prefix, + log_model=config.log_model, + monitor_system=config.monitor_system, + ) + + +def create_csv_logger(config: CSVLoggerConfig) -> Logger: + from lightning.pytorch.loggers import CSVLogger + + return CSVLogger( + save_dir=config.save_dir, + name=config.name, + version=config.version, + flush_logs_every_n_steps=config.flush_logs_every_n_steps, + ) + + +def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger: + from lightning.pytorch.loggers import TensorBoardLogger + + return TensorBoardLogger( + save_dir=config.save_dir, + name=config.name, + version=config.version, + log_graph=config.log_graph, + flush_logs_every_n_steps=config.flush_logs_every_n_steps, + ) + + +LOGGER_FACTORY = { + "dvclive": create_dvclive_logger, + "csv": create_csv_logger, + "tensorboard": create_tensorboard_logger, +} + + +def build_logger(config: LoggerConfig) -> Logger: + """ + Creates a logger instance from a validated Pydantic config object. + """ + logger_type = config.logger_type + + if logger_type not in LOGGER_FACTORY: + raise ValueError(f"Unknown logger type: {logger_type}") + + creation_func = LOGGER_FACTORY[logger_type] + + return creation_func(config) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 6a71204..8778498 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -23,6 +23,7 @@ from batdetect2.train.dataset import ( collate_fn, ) from batdetect2.train.lightning import TrainingModule +from batdetect2.train.logging import build_logger from batdetect2.train.losses import build_loss __all__ = [ @@ -47,6 +48,7 @@ def train( **trainer_kwargs, ) -> None: config = config or TrainingConfig() + if model_path is None: if preprocessor is None: preprocessor = build_preprocessor() @@ -81,9 +83,12 @@ def train( config=config, ) + logger = build_logger(config.logger) + trainer = Trainer( - **config.trainer.model_dump(exclude_none=True), + **config.trainer.model_dump(exclude_none=True, exclude={"logger"}), callbacks=callbacks, + logger=logger, **trainer_kwargs, )