Updat lightning version

This commit is contained in:
mbsantiago 2025-09-09 15:31:40 +01:00
parent 951dc59718
commit 115084fd2b
6 changed files with 136 additions and 57 deletions

View File

@ -136,9 +136,9 @@ train:
weight: 0.1 weight: 0.1
logger: logger:
logger_type: csv name: mlflow
# save_dir: outputs/log/ tracking_uri: http://10.20.20.211:9000
# name: logs log_model: true
augmentations: augmentations:
enabled: true enabled: true

View File

@ -23,7 +23,7 @@ dependencies = [
"tqdm>=4.66.2", "tqdm>=4.66.2",
"cf-xarray>=0.9.0", "cf-xarray>=0.9.0",
"onnx>=1.16.0", "onnx>=1.16.0",
"lightning[extra]>=2.2.2", "lightning[extra]==2.5.0",
"tensorboard>=2.16.2", "tensorboard>=2.16.2",
"omegaconf>=2.3.0", "omegaconf>=2.3.0",
"pyyaml>=6.0.2", "pyyaml>=6.0.2",

View File

@ -27,6 +27,7 @@ __all__ = ["train_command"]
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int)
@click.option("--val-workers", type=int) @click.option("--val-workers", type=int)
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--seed", type=int) @click.option("--seed", type=int)
@click.option( @click.option(
"-v", "-v",
@ -46,6 +47,7 @@ def train_command(
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0, verbose: int = 0,
): ):
logger.remove() logger.remove()
@ -95,4 +97,5 @@ def train_command(
log_dir=log_dir, log_dir=log_dir,
checkpoint_dir=ckpt_dir, checkpoint_dir=ckpt_dir,
seed=seed, seed=seed,
run_name=run_name,
) )

View File

@ -68,7 +68,7 @@ class ValidationMetrics(Callback):
n_examples=4, n_examples=4,
): ):
plotter( plotter(
f"images/{class_name}_examples", f"examples/{class_name}",
fig, fig,
pl_module.global_step, pl_module.global_step,
) )

View File

@ -1,5 +1,16 @@
import io import io
from typing import Annotated, Any, Literal, Optional, Union from pathlib import Path
from typing import (
Annotated,
Any,
Dict,
Generic,
Literal,
Optional,
Protocol,
TypeVar,
Union,
)
import numpy as np import numpy as np
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
@ -9,39 +20,34 @@ from soundevent import data
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
DEFAULT_LOGS_DIR: str = "outputs/logs" DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
class DVCLiveConfig(BaseConfig): class BaseLoggerConfig(BaseConfig):
logger_type: Literal["dvclive"] = "dvclive" log_dir: Path = DEFAULT_LOGS_DIR
dir: str = DEFAULT_LOGS_DIR experiment_name: Optional[str] = None
run_name: Optional[str] = None run_name: Optional[str] = None
class DVCLiveConfig(BaseLoggerConfig):
name: Literal["dvclive"] = "dvclive"
prefix: str = "" prefix: str = ""
log_model: Union[bool, Literal["all"]] = False log_model: Union[bool, Literal["all"]] = False
monitor_system: bool = False monitor_system: bool = False
class CSVLoggerConfig(BaseConfig): class CSVLoggerConfig(BaseLoggerConfig):
logger_type: Literal["csv"] = "csv" name: Literal["csv"] = "csv"
save_dir: str = DEFAULT_LOGS_DIR
name: Optional[str] = "logs"
version: Optional[str] = None
flush_logs_every_n_steps: int = 100 flush_logs_every_n_steps: int = 100
class TensorBoardLoggerConfig(BaseConfig): class TensorBoardLoggerConfig(BaseLoggerConfig):
logger_type: Literal["tensorboard"] = "tensorboard" name: Literal["tensorboard"] = "tensorboard"
save_dir: str = DEFAULT_LOGS_DIR
name: Optional[str] = "logs"
version: Optional[str] = None
log_graph: bool = False log_graph: bool = False
class MLFlowLoggerConfig(BaseConfig): class MLFlowLoggerConfig(BaseLoggerConfig):
logger_type: Literal["mlflow"] = "mlflow" name: Literal["mlflow"] = "mlflow"
experiment_name: str = "default"
run_name: Optional[str] = None
save_dir: Optional[str] = "./mlruns"
tracking_uri: Optional[str] = None tracking_uri: Optional[str] = None
tags: Optional[dict[str, Any]] = None tags: Optional[dict[str, Any]] = None
log_model: bool = False log_model: bool = False
@ -54,14 +60,28 @@ LoggerConfig = Annotated[
TensorBoardLoggerConfig, TensorBoardLoggerConfig,
MLFlowLoggerConfig, MLFlowLoggerConfig,
], ],
Field(discriminator="logger_type"), Field(discriminator="name"),
] ]
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
class LoggerBuilder(Protocol, Generic[T]):
def __call__(
self,
config: T,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ...
def create_dvclive_logger( def create_dvclive_logger(
config: DVCLiveConfig, config: DVCLiveConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
try: try:
from dvclive.lightning import DVCLiveLogger # type: ignore from dvclive.lightning import DVCLiveLogger # type: ignore
@ -73,10 +93,11 @@ def create_dvclive_logger(
) from error ) from error
return DVCLiveLogger( return DVCLiveLogger(
dir=log_dir if log_dir is not None else config.dir, dir=log_dir if log_dir is not None else config.log_dir,
run_name=experiment_name run_name=run_name if run_name is not None else config.run_name,
experiment=experiment_name
if experiment_name is not None if experiment_name is not None
else config.run_name, else config.experiment_name,
prefix=config.prefix, prefix=config.prefix,
log_model=config.log_model, log_model=config.log_model,
monitor_system=config.monitor_system, monitor_system=config.monitor_system,
@ -85,30 +106,58 @@ def create_dvclive_logger(
def create_csv_logger( def create_csv_logger(
config: CSVLoggerConfig, config: CSVLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers import CSVLogger
if log_dir is None:
log_dir = Path(config.log_dir)
if run_name is None:
run_name = config.run_name
if experiment_name is None:
experiment_name = config.experiment_name
name = run_name
if run_name is not None and experiment_name is not None:
name = str(Path(experiment_name) / run_name)
return CSVLogger( return CSVLogger(
save_dir=str(log_dir) if log_dir is not None else config.save_dir, save_dir=str(log_dir),
name=experiment_name if experiment_name is not None else config.name, name=name,
version=config.version,
flush_logs_every_n_steps=config.flush_logs_every_n_steps, flush_logs_every_n_steps=config.flush_logs_every_n_steps,
) )
def create_tensorboard_logger( def create_tensorboard_logger(
config: TensorBoardLoggerConfig, config: TensorBoardLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
if log_dir is None:
log_dir = Path(config.log_dir)
if run_name is None:
run_name = config.run_name
if experiment_name is None:
experiment_name = config.experiment_name
name = run_name
if run_name is not None and experiment_name is not None:
name = str(Path(experiment_name) / run_name)
return TensorBoardLogger( return TensorBoardLogger(
save_dir=str(log_dir) if log_dir is not None else config.save_dir, save_dir=str(log_dir),
name=experiment_name if experiment_name is not None else config.name, name=name,
version=config.version,
log_graph=config.log_graph, log_graph=config.log_graph,
) )
@ -117,6 +166,7 @@ def create_mlflow_logger(
config: MLFlowLoggerConfig, config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
try: try:
from lightning.pytorch.loggers import MLFlowLogger from lightning.pytorch.loggers import MLFlowLogger
@ -127,19 +177,25 @@ def create_mlflow_logger(
"or `uv add mlflow`" "or `uv add mlflow`"
) from error ) from error
if experiment_name is None:
experiment_name = config.experiment_name or "Default"
if log_dir is None:
log_dir = config.log_dir
return MLFlowLogger( return MLFlowLogger(
experiment_name=experiment_name experiment_name=experiment_name
if experiment_name is not None if experiment_name is not None
else config.experiment_name, else config.experiment_name,
run_name=config.run_name, run_name=run_name if run_name is not None else config.run_name,
save_dir=str(log_dir) if log_dir is not None else config.save_dir, save_dir=str(log_dir),
tracking_uri=config.tracking_uri, tracking_uri=config.tracking_uri,
tags=config.tags, tags=config.tags,
log_model=config.log_model, log_model=config.log_model,
) )
LOGGER_FACTORY = { LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
"dvclive": create_dvclive_logger, "dvclive": create_dvclive_logger,
"csv": create_csv_logger, "csv": create_csv_logger,
"tensorboard": create_tensorboard_logger, "tensorboard": create_tensorboard_logger,
@ -149,8 +205,9 @@ LOGGER_FACTORY = {
def build_logger( def build_logger(
config: LoggerConfig, config: LoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
""" """
Creates a logger instance from a validated Pydantic config object. Creates a logger instance from a validated Pydantic config object.
@ -159,7 +216,7 @@ def build_logger(
"Building logger with config: \n{}", "Building logger with config: \n{}",
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
logger_type = config.logger_type logger_type = config.name
if logger_type not in LOGGER_FACTORY: if logger_type not in LOGGER_FACTORY:
raise ValueError(f"Unknown logger type: {logger_type}") raise ValueError(f"Unknown logger type: {logger_type}")
@ -170,6 +227,7 @@ def build_logger(
config, config,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name,
) )
@ -186,8 +244,8 @@ def get_image_plotter(logger: Logger):
def plot_figure(name, figure, step): def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure) image = _convert_figure_to_image(figure)
return logger.experiment.log_image( return logger.experiment.log_image(
run_id=logger.run_id, logger.run_id,
image=image, image,
key=name, key=name,
step=step, step=step,
) )

View File

@ -1,4 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from pathlib import Path
from typing import List, Optional from typing import List, Optional
import torch import torch
@ -45,6 +46,8 @@ __all__ = [
"train", "train",
] ]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
def train( def train(
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
@ -53,9 +56,10 @@ def train(
model_path: Optional[data.PathLike] = None, model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None, train_workers: Optional[int] = None,
val_workers: Optional[int] = None, val_workers: Optional[int] = None,
checkpoint_dir: Optional[data.PathLike] = None, checkpoint_dir: Optional[Path] = None,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
): ):
if seed is not None: if seed is not None:
@ -113,6 +117,7 @@ def train(
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name,
) )
logger.info("Starting main training loop...") logger.info("Starting main training loop...")
@ -140,21 +145,32 @@ def build_trainer_callbacks(
targets: TargetProtocol, targets: TargetProtocol,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: EvaluationConfig, config: EvaluationConfig,
checkpoint_dir: Optional[data.PathLike] = None, checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> List[Callback]: ) -> List[Callback]:
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = "outputs/checkpoints" checkpoint_dir = DEFAULT_CHECKPOINT_DIR
filename = "best-{epoch:02d}-{val_loss:.0f}"
if run_name is not None:
filename = f"run_{run_name}_{filename}"
if experiment_name is not None: if experiment_name is not None:
checkpoint_dir = f"{checkpoint_dir}/{experiment_name}" filename = f"experiment_{experiment_name}_{filename}"
model_checkpoint = ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
filename=filename,
monitor="total_loss/val",
)
model_checkpoint.CHECKPOINT_EQUALS_CHAR = "_" # type: ignore
return [ return [
ModelCheckpoint( model_checkpoint,
dirpath=str(checkpoint_dir),
save_top_k=1,
monitor="total_loss/val",
),
ValidationMetrics( ValidationMetrics(
metrics=[ metrics=[
DetectionAveragePrecision(), DetectionAveragePrecision(),
@ -172,9 +188,10 @@ def build_trainer_callbacks(
def build_trainer( def build_trainer(
conf: FullTrainingConfig, conf: FullTrainingConfig,
targets: TargetProtocol, targets: TargetProtocol,
checkpoint_dir: Optional[data.PathLike] = None, checkpoint_dir: Optional[Path] = None,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Trainer: ) -> Trainer:
trainer_conf = conf.train.trainer trainer_conf = conf.train.trainer
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
@ -185,6 +202,7 @@ def build_trainer(
conf.train.logger, conf.train.logger,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name,
) )
train_logger.log_hyperparams( train_logger.log_hyperparams(