mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Add mlflow logger
This commit is contained in:
parent
e8db1d4050
commit
bafb9a3622
@ -117,7 +117,8 @@ train:
|
|||||||
size:
|
size:
|
||||||
weight: 0.1
|
weight: 0.1
|
||||||
logger:
|
logger:
|
||||||
logger_type: dvclive
|
logger_type: mlflow
|
||||||
|
tracking_uri: http://localhost:5000
|
||||||
augmentations:
|
augmentations:
|
||||||
steps:
|
steps:
|
||||||
- augmentation_type: mix_audio
|
- augmentation_type: mix_audio
|
||||||
|
|||||||
@ -91,6 +91,9 @@ dev = [
|
|||||||
dvclive = [
|
dvclive = [
|
||||||
"dvclive>=3.48.2",
|
"dvclive>=3.48.2",
|
||||||
]
|
]
|
||||||
|
mlflow = [
|
||||||
|
"mlflow>=3.1.1",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 79
|
line-length = 79
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, Literal, Optional, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
from lightning.pytorch.loggers import Logger
|
from lightning.pytorch.loggers import Logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -34,8 +34,23 @@ class TensorBoardLoggerConfig(BaseConfig):
|
|||||||
flush_logs_every_n_steps: Optional[int] = None
|
flush_logs_every_n_steps: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class MLFlowLoggerConfig(BaseConfig):
|
||||||
|
logger_type: Literal["mlflow"] = "mlflow"
|
||||||
|
experiment_name: str = "default"
|
||||||
|
run_name: Optional[str] = None
|
||||||
|
save_dir: Optional[str] = "./mlruns"
|
||||||
|
tracking_uri: Optional[str] = None
|
||||||
|
tags: Optional[dict[str, Any]] = None
|
||||||
|
log_model: bool = False
|
||||||
|
|
||||||
|
|
||||||
LoggerConfig = Annotated[
|
LoggerConfig = Annotated[
|
||||||
Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig],
|
Union[
|
||||||
|
DVCLiveConfig,
|
||||||
|
CSVLoggerConfig,
|
||||||
|
TensorBoardLoggerConfig,
|
||||||
|
MLFlowLoggerConfig,
|
||||||
|
],
|
||||||
Field(discriminator="logger_type"),
|
Field(discriminator="logger_type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -82,10 +97,31 @@ def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
|
||||||
|
try:
|
||||||
|
from lightning.pytorch.loggers import MLFlowLogger
|
||||||
|
except ImportError as error:
|
||||||
|
raise ValueError(
|
||||||
|
"MLFlow is not installed and cannot be used for logging. "
|
||||||
|
"Make sure you have it installed by running `pip install mlflow` "
|
||||||
|
"or `uv add mlflow`"
|
||||||
|
) from error
|
||||||
|
|
||||||
|
return MLFlowLogger(
|
||||||
|
experiment_name=config.experiment_name,
|
||||||
|
run_name=config.run_name,
|
||||||
|
save_dir=config.save_dir,
|
||||||
|
tracking_uri=config.tracking_uri,
|
||||||
|
tags=config.tags,
|
||||||
|
log_model=config.log_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
LOGGER_FACTORY = {
|
LOGGER_FACTORY = {
|
||||||
"dvclive": create_dvclive_logger,
|
"dvclive": create_dvclive_logger,
|
||||||
"csv": create_csv_logger,
|
"csv": create_csv_logger,
|
||||||
"tensorboard": create_tensorboard_logger,
|
"tensorboard": create_tensorboard_logger,
|
||||||
|
"mlflow": create_mlflow_logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -101,3 +137,4 @@ def build_logger(config: LoggerConfig) -> Logger:
|
|||||||
creation_func = LOGGER_FACTORY[logger_type]
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
return creation_func(config)
|
return creation_func(config)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user