diff --git a/example_data/config.yaml b/example_data/config.yaml index 46f34c3..08a98f2 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -117,7 +117,8 @@ train: size: weight: 0.1 logger: - logger_type: dvclive + logger_type: mlflow + tracking_uri: http://localhost:5000 augmentations: steps: - augmentation_type: mix_audio diff --git a/pyproject.toml b/pyproject.toml index 3bee649..8170836 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,9 @@ dev = [ dvclive = [ "dvclive>=3.48.2", ] +mlflow = [ + "mlflow>=3.1.1", +] [tool.ruff] line-length = 79 diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index fb70c78..26fa64d 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -1,4 +1,4 @@ -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Any, Literal, Optional, Union from lightning.pytorch.loggers import Logger from pydantic import Field @@ -34,8 +34,23 @@ class TensorBoardLoggerConfig(BaseConfig): flush_logs_every_n_steps: Optional[int] = None +class MLFlowLoggerConfig(BaseConfig): + logger_type: Literal["mlflow"] = "mlflow" + experiment_name: str = "default" + run_name: Optional[str] = None + save_dir: Optional[str] = "./mlruns" + tracking_uri: Optional[str] = None + tags: Optional[dict[str, Any]] = None + log_model: bool = False + + LoggerConfig = Annotated[ - Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig], + Union[ + DVCLiveConfig, + CSVLoggerConfig, + TensorBoardLoggerConfig, + MLFlowLoggerConfig, + ], Field(discriminator="logger_type"), ] @@ -82,10 +97,31 @@ def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger: ) +def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger: + try: + from lightning.pytorch.loggers import MLFlowLogger + except ImportError as error: + raise ValueError( + "MLFlow is not installed and cannot be used for logging. " + "Make sure you have it installed by running `pip install mlflow` " + "or `uv add mlflow`" + ) from error + + return MLFlowLogger( + experiment_name=config.experiment_name, + run_name=config.run_name, + save_dir=config.save_dir, + tracking_uri=config.tracking_uri, + tags=config.tags, + log_model=config.log_model, + ) + + LOGGER_FACTORY = { "dvclive": create_dvclive_logger, "csv": create_csv_logger, "tensorboard": create_tensorboard_logger, + "mlflow": create_mlflow_logger, } @@ -101,3 +137,4 @@ def build_logger(config: LoggerConfig) -> Logger: creation_func = LOGGER_FACTORY[logger_type] return creation_func(config) +