Add mlflow logger

This commit is contained in:
mbsantiago 2025-06-28 11:08:19 -06:00
parent e8db1d4050
commit bafb9a3622
3 changed files with 44 additions and 3 deletions

View File

@ -117,7 +117,8 @@ train:
size:
weight: 0.1
logger:
logger_type: dvclive
logger_type: mlflow
tracking_uri: http://localhost:5000
augmentations:
steps:
- augmentation_type: mix_audio

View File

@ -91,6 +91,9 @@ dev = [
dvclive = [
"dvclive>=3.48.2",
]
mlflow = [
"mlflow>=3.1.1",
]
[tool.ruff]
line-length = 79

View File

@ -1,4 +1,4 @@
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Any, Literal, Optional, Union
from lightning.pytorch.loggers import Logger
from pydantic import Field
@ -34,8 +34,23 @@ class TensorBoardLoggerConfig(BaseConfig):
flush_logs_every_n_steps: Optional[int] = None
class MLFlowLoggerConfig(BaseConfig):
logger_type: Literal["mlflow"] = "mlflow"
experiment_name: str = "default"
run_name: Optional[str] = None
save_dir: Optional[str] = "./mlruns"
tracking_uri: Optional[str] = None
tags: Optional[dict[str, Any]] = None
log_model: bool = False
LoggerConfig = Annotated[
Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig],
Union[
DVCLiveConfig,
CSVLoggerConfig,
TensorBoardLoggerConfig,
MLFlowLoggerConfig,
],
Field(discriminator="logger_type"),
]
@ -82,10 +97,31 @@ def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
)
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
try:
from lightning.pytorch.loggers import MLFlowLogger
except ImportError as error:
raise ValueError(
"MLFlow is not installed and cannot be used for logging. "
"Make sure you have it installed by running `pip install mlflow` "
"or `uv add mlflow`"
) from error
return MLFlowLogger(
experiment_name=config.experiment_name,
run_name=config.run_name,
save_dir=config.save_dir,
tracking_uri=config.tracking_uri,
tags=config.tags,
log_model=config.log_model,
)
LOGGER_FACTORY = {
"dvclive": create_dvclive_logger,
"csv": create_csv_logger,
"tensorboard": create_tensorboard_logger,
"mlflow": create_mlflow_logger,
}
@ -101,3 +137,4 @@ def build_logger(config: LoggerConfig) -> Logger:
creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config)