diff --git a/example_data/config.yaml b/example_data/config.yaml index e3cd7ff..e657b7c 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -136,9 +136,9 @@ train: weight: 0.1 logger: - logger_type: csv - # save_dir: outputs/log/ - # name: logs + name: mlflow + tracking_uri: http://10.20.20.211:9000 + log_model: true augmentations: enabled: true diff --git a/pyproject.toml b/pyproject.toml index 771930c..02f3043 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "tqdm>=4.66.2", "cf-xarray>=0.9.0", "onnx>=1.16.0", - "lightning[extra]>=2.2.2", + "lightning[extra]==2.5.0", "tensorboard>=2.16.2", "omegaconf>=2.3.0", "pyyaml>=6.0.2", diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 2911db3..8d5836a 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -27,6 +27,7 @@ __all__ = ["train_command"] @click.option("--train-workers", type=int) @click.option("--val-workers", type=int) @click.option("--experiment-name", type=str) +@click.option("--run-name", type=str) @click.option("--seed", type=int) @click.option( "-v", @@ -46,6 +47,7 @@ def train_command( train_workers: int = 0, val_workers: int = 0, experiment_name: Optional[str] = None, + run_name: Optional[str] = None, verbose: int = 0, ): logger.remove() @@ -95,4 +97,5 @@ def train_command( log_dir=log_dir, checkpoint_dir=ckpt_dir, seed=seed, + run_name=run_name, ) diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index d5bdf31..5108469 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -68,7 +68,7 @@ class ValidationMetrics(Callback): n_examples=4, ): plotter( - f"images/{class_name}_examples", + f"examples/{class_name}", fig, pl_module.global_step, ) diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index 517acb7..fb6a36f 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -1,5 +1,16 @@ import io -from typing import Annotated, Any, Literal, Optional, Union +from pathlib import Path +from typing import ( + Annotated, + Any, + Dict, + Generic, + Literal, + Optional, + Protocol, + TypeVar, + Union, +) import numpy as np from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger @@ -9,39 +20,34 @@ from soundevent import data from batdetect2.configs import BaseConfig -DEFAULT_LOGS_DIR: str = "outputs/logs" +DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs" -class DVCLiveConfig(BaseConfig): - logger_type: Literal["dvclive"] = "dvclive" - dir: str = DEFAULT_LOGS_DIR +class BaseLoggerConfig(BaseConfig): + log_dir: Path = DEFAULT_LOGS_DIR + experiment_name: Optional[str] = None run_name: Optional[str] = None + + +class DVCLiveConfig(BaseLoggerConfig): + name: Literal["dvclive"] = "dvclive" prefix: str = "" log_model: Union[bool, Literal["all"]] = False monitor_system: bool = False -class CSVLoggerConfig(BaseConfig): - logger_type: Literal["csv"] = "csv" - save_dir: str = DEFAULT_LOGS_DIR - name: Optional[str] = "logs" - version: Optional[str] = None +class CSVLoggerConfig(BaseLoggerConfig): + name: Literal["csv"] = "csv" flush_logs_every_n_steps: int = 100 -class TensorBoardLoggerConfig(BaseConfig): - logger_type: Literal["tensorboard"] = "tensorboard" - save_dir: str = DEFAULT_LOGS_DIR - name: Optional[str] = "logs" - version: Optional[str] = None +class TensorBoardLoggerConfig(BaseLoggerConfig): + name: Literal["tensorboard"] = "tensorboard" log_graph: bool = False -class MLFlowLoggerConfig(BaseConfig): - logger_type: Literal["mlflow"] = "mlflow" - experiment_name: str = "default" - run_name: Optional[str] = None - save_dir: Optional[str] = "./mlruns" +class MLFlowLoggerConfig(BaseLoggerConfig): + name: Literal["mlflow"] = "mlflow" tracking_uri: Optional[str] = None tags: Optional[dict[str, Any]] = None log_model: bool = False @@ -54,14 +60,28 @@ LoggerConfig = Annotated[ TensorBoardLoggerConfig, MLFlowLoggerConfig, ], - Field(discriminator="logger_type"), + Field(discriminator="name"), ] +T = TypeVar("T", bound=LoggerConfig, contravariant=True) + + +class LoggerBuilder(Protocol, Generic[T]): + def __call__( + self, + config: T, + log_dir: Optional[Path] = None, + experiment_name: Optional[str] = None, + run_name: Optional[str] = None, + ) -> Logger: ... + + def create_dvclive_logger( config: DVCLiveConfig, - log_dir: Optional[data.PathLike] = None, + log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, + run_name: Optional[str] = None, ) -> Logger: try: from dvclive.lightning import DVCLiveLogger # type: ignore @@ -73,10 +93,11 @@ def create_dvclive_logger( ) from error return DVCLiveLogger( - dir=log_dir if log_dir is not None else config.dir, - run_name=experiment_name + dir=log_dir if log_dir is not None else config.log_dir, + run_name=run_name if run_name is not None else config.run_name, + experiment=experiment_name if experiment_name is not None - else config.run_name, + else config.experiment_name, prefix=config.prefix, log_model=config.log_model, monitor_system=config.monitor_system, @@ -85,30 +106,58 @@ def create_dvclive_logger( def create_csv_logger( config: CSVLoggerConfig, - log_dir: Optional[data.PathLike] = None, + log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, + run_name: Optional[str] = None, ) -> Logger: from lightning.pytorch.loggers import CSVLogger + if log_dir is None: + log_dir = Path(config.log_dir) + + if run_name is None: + run_name = config.run_name + + if experiment_name is None: + experiment_name = config.experiment_name + + name = run_name + + if run_name is not None and experiment_name is not None: + name = str(Path(experiment_name) / run_name) + return CSVLogger( - save_dir=str(log_dir) if log_dir is not None else config.save_dir, - name=experiment_name if experiment_name is not None else config.name, - version=config.version, + save_dir=str(log_dir), + name=name, flush_logs_every_n_steps=config.flush_logs_every_n_steps, ) def create_tensorboard_logger( config: TensorBoardLoggerConfig, - log_dir: Optional[data.PathLike] = None, + log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, + run_name: Optional[str] = None, ) -> Logger: from lightning.pytorch.loggers import TensorBoardLogger + if log_dir is None: + log_dir = Path(config.log_dir) + + if run_name is None: + run_name = config.run_name + + if experiment_name is None: + experiment_name = config.experiment_name + + name = run_name + + if run_name is not None and experiment_name is not None: + name = str(Path(experiment_name) / run_name) + return TensorBoardLogger( - save_dir=str(log_dir) if log_dir is not None else config.save_dir, - name=experiment_name if experiment_name is not None else config.name, - version=config.version, + save_dir=str(log_dir), + name=name, log_graph=config.log_graph, ) @@ -117,6 +166,7 @@ def create_mlflow_logger( config: MLFlowLoggerConfig, log_dir: Optional[data.PathLike] = None, experiment_name: Optional[str] = None, + run_name: Optional[str] = None, ) -> Logger: try: from lightning.pytorch.loggers import MLFlowLogger @@ -127,19 +177,25 @@ def create_mlflow_logger( "or `uv add mlflow`" ) from error + if experiment_name is None: + experiment_name = config.experiment_name or "Default" + + if log_dir is None: + log_dir = config.log_dir + return MLFlowLogger( experiment_name=experiment_name if experiment_name is not None else config.experiment_name, - run_name=config.run_name, - save_dir=str(log_dir) if log_dir is not None else config.save_dir, + run_name=run_name if run_name is not None else config.run_name, + save_dir=str(log_dir), tracking_uri=config.tracking_uri, tags=config.tags, log_model=config.log_model, ) -LOGGER_FACTORY = { +LOGGER_FACTORY: Dict[str, LoggerBuilder] = { "dvclive": create_dvclive_logger, "csv": create_csv_logger, "tensorboard": create_tensorboard_logger, @@ -149,8 +205,9 @@ LOGGER_FACTORY = { def build_logger( config: LoggerConfig, - log_dir: Optional[data.PathLike] = None, + log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, + run_name: Optional[str] = None, ) -> Logger: """ Creates a logger instance from a validated Pydantic config object. @@ -159,7 +216,7 @@ def build_logger( "Building logger with config: \n{}", lambda: config.to_yaml_string(), ) - logger_type = config.logger_type + logger_type = config.name if logger_type not in LOGGER_FACTORY: raise ValueError(f"Unknown logger type: {logger_type}") @@ -170,6 +227,7 @@ def build_logger( config, log_dir=log_dir, experiment_name=experiment_name, + run_name=run_name, ) @@ -186,8 +244,8 @@ def get_image_plotter(logger: Logger): def plot_figure(name, figure, step): image = _convert_figure_to_image(figure) return logger.experiment.log_image( - run_id=logger.run_id, - image=image, + logger.run_id, + image, key=name, step=step, ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 24c54de..b35b2a4 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from pathlib import Path from typing import List, Optional import torch @@ -45,6 +46,8 @@ __all__ = [ "train", ] +DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" + def train( train_annotations: Sequence[data.ClipAnnotation], @@ -53,9 +56,10 @@ def train( model_path: Optional[data.PathLike] = None, train_workers: Optional[int] = None, val_workers: Optional[int] = None, - checkpoint_dir: Optional[data.PathLike] = None, - log_dir: Optional[data.PathLike] = None, + checkpoint_dir: Optional[Path] = None, + log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, + run_name: Optional[str] = None, seed: Optional[int] = None, ): if seed is not None: @@ -113,6 +117,7 @@ def train( checkpoint_dir=checkpoint_dir, log_dir=log_dir, experiment_name=experiment_name, + run_name=run_name, ) logger.info("Starting main training loop...") @@ -140,21 +145,32 @@ def build_trainer_callbacks( targets: TargetProtocol, preprocessor: PreprocessorProtocol, config: EvaluationConfig, - checkpoint_dir: Optional[data.PathLike] = None, + checkpoint_dir: Optional[Path] = None, experiment_name: Optional[str] = None, + run_name: Optional[str] = None, ) -> List[Callback]: if checkpoint_dir is None: - checkpoint_dir = "outputs/checkpoints" + checkpoint_dir = DEFAULT_CHECKPOINT_DIR + + filename = "best-{epoch:02d}-{val_loss:.0f}" + + if run_name is not None: + filename = f"run_{run_name}_{filename}" if experiment_name is not None: - checkpoint_dir = f"{checkpoint_dir}/{experiment_name}" + filename = f"experiment_{experiment_name}_{filename}" + + model_checkpoint = ModelCheckpoint( + dirpath=str(checkpoint_dir), + save_top_k=1, + filename=filename, + monitor="total_loss/val", + ) + + model_checkpoint.CHECKPOINT_EQUALS_CHAR = "_" # type: ignore return [ - ModelCheckpoint( - dirpath=str(checkpoint_dir), - save_top_k=1, - monitor="total_loss/val", - ), + model_checkpoint, ValidationMetrics( metrics=[ DetectionAveragePrecision(), @@ -172,9 +188,10 @@ def build_trainer_callbacks( def build_trainer( conf: FullTrainingConfig, targets: TargetProtocol, - checkpoint_dir: Optional[data.PathLike] = None, - log_dir: Optional[data.PathLike] = None, + checkpoint_dir: Optional[Path] = None, + log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, + run_name: Optional[str] = None, ) -> Trainer: trainer_conf = conf.train.trainer logger.opt(lazy=True).debug( @@ -185,6 +202,7 @@ def build_trainer( conf.train.logger, log_dir=log_dir, experiment_name=experiment_name, + run_name=run_name, ) train_logger.log_hyperparams(