Add mlflow logger

This commit is contained in:
mbsantiago 2025-06-28 11:08:19 -06:00
parent e8db1d4050
commit bafb9a3622
3 changed files with 44 additions and 3 deletions

View File

@ -117,7 +117,8 @@ train:
size: size:
weight: 0.1 weight: 0.1
logger: logger:
logger_type: dvclive logger_type: mlflow
tracking_uri: http://localhost:5000
augmentations: augmentations:
steps: steps:
- augmentation_type: mix_audio - augmentation_type: mix_audio

View File

@ -91,6 +91,9 @@ dev = [
dvclive = [ dvclive = [
"dvclive>=3.48.2", "dvclive>=3.48.2",
] ]
mlflow = [
"mlflow>=3.1.1",
]
[tool.ruff] [tool.ruff]
line-length = 79 line-length = 79

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Optional, Union from typing import Annotated, Any, Literal, Optional, Union
from lightning.pytorch.loggers import Logger from lightning.pytorch.loggers import Logger
from pydantic import Field from pydantic import Field
@ -34,8 +34,23 @@ class TensorBoardLoggerConfig(BaseConfig):
flush_logs_every_n_steps: Optional[int] = None flush_logs_every_n_steps: Optional[int] = None
class MLFlowLoggerConfig(BaseConfig):
logger_type: Literal["mlflow"] = "mlflow"
experiment_name: str = "default"
run_name: Optional[str] = None
save_dir: Optional[str] = "./mlruns"
tracking_uri: Optional[str] = None
tags: Optional[dict[str, Any]] = None
log_model: bool = False
LoggerConfig = Annotated[ LoggerConfig = Annotated[
Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig], Union[
DVCLiveConfig,
CSVLoggerConfig,
TensorBoardLoggerConfig,
MLFlowLoggerConfig,
],
Field(discriminator="logger_type"), Field(discriminator="logger_type"),
] ]
@ -82,10 +97,31 @@ def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
) )
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
try:
from lightning.pytorch.loggers import MLFlowLogger
except ImportError as error:
raise ValueError(
"MLFlow is not installed and cannot be used for logging. "
"Make sure you have it installed by running `pip install mlflow` "
"or `uv add mlflow`"
) from error
return MLFlowLogger(
experiment_name=config.experiment_name,
run_name=config.run_name,
save_dir=config.save_dir,
tracking_uri=config.tracking_uri,
tags=config.tags,
log_model=config.log_model,
)
LOGGER_FACTORY = { LOGGER_FACTORY = {
"dvclive": create_dvclive_logger, "dvclive": create_dvclive_logger,
"csv": create_csv_logger, "csv": create_csv_logger,
"tensorboard": create_tensorboard_logger, "tensorboard": create_tensorboard_logger,
"mlflow": create_mlflow_logger,
} }
@ -101,3 +137,4 @@ def build_logger(config: LoggerConfig) -> Logger:
creation_func = LOGGER_FACTORY[logger_type] creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config) return creation_func(config)