Compare commits

...

3 Commits

Author SHA1 Message Date
mbsantiago
115084fd2b Updat lightning version 2025-09-09 15:31:40 +01:00
mbsantiago
951dc59718 Add seed option to train 2025-09-09 13:23:56 +01:00
mbsantiago
3376be06a4 Add experiment name 2025-09-09 09:02:25 +01:00
6 changed files with 174 additions and 60 deletions

View File

@ -109,7 +109,7 @@ train:
sigma: 3 sigma: 3
trainer: trainer:
max_epochs: 40 max_epochs: 5
dataloaders: dataloaders:
train: train:
@ -136,9 +136,9 @@ train:
weight: 0.1 weight: 0.1
logger: logger:
logger_type: csv name: mlflow
# save_dir: outputs/log/ tracking_uri: http://10.20.20.211:9000
# name: logs log_model: true
augmentations: augmentations:
enabled: true enabled: true

View File

@ -23,7 +23,7 @@ dependencies = [
"tqdm>=4.66.2", "tqdm>=4.66.2",
"cf-xarray>=0.9.0", "cf-xarray>=0.9.0",
"onnx>=1.16.0", "onnx>=1.16.0",
"lightning[extra]>=2.2.2", "lightning[extra]==2.5.0",
"tensorboard>=2.16.2", "tensorboard>=2.16.2",
"omegaconf>=2.3.0", "omegaconf>=2.3.0",
"pyyaml>=6.0.2", "pyyaml>=6.0.2",

View File

@ -26,6 +26,9 @@ __all__ = ["train_command"]
@click.option("--config-field", type=str) @click.option("--config-field", type=str)
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int)
@click.option("--val-workers", type=int) @click.option("--val-workers", type=int)
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--seed", type=int)
@click.option( @click.option(
"-v", "-v",
"--verbose", "--verbose",
@ -40,8 +43,11 @@ def train_command(
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
config: Optional[Path] = None, config: Optional[Path] = None,
config_field: Optional[str] = None, config_field: Optional[str] = None,
seed: Optional[int] = None,
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0, verbose: int = 0,
): ):
logger.remove() logger.remove()
@ -87,6 +93,9 @@ def train_command(
model_path=model_path, model_path=model_path,
train_workers=train_workers, train_workers=train_workers,
val_workers=val_workers, val_workers=val_workers,
experiment_name=experiment_name,
log_dir=log_dir, log_dir=log_dir,
checkpoint_dir=ckpt_dir, checkpoint_dir=ckpt_dir,
seed=seed,
run_name=run_name,
) )

View File

@ -68,7 +68,7 @@ class ValidationMetrics(Callback):
n_examples=4, n_examples=4,
): ):
plotter( plotter(
f"images/{class_name}_examples", f"examples/{class_name}",
fig, fig,
pl_module.global_step, pl_module.global_step,
) )

View File

@ -1,5 +1,16 @@
import io import io
from typing import Annotated, Any, Literal, Optional, Union from pathlib import Path
from typing import (
Annotated,
Any,
Dict,
Generic,
Literal,
Optional,
Protocol,
TypeVar,
Union,
)
import numpy as np import numpy as np
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
@ -9,39 +20,34 @@ from soundevent import data
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
DEFAULT_LOGS_DIR: str = "outputs" DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
class DVCLiveConfig(BaseConfig): class BaseLoggerConfig(BaseConfig):
logger_type: Literal["dvclive"] = "dvclive" log_dir: Path = DEFAULT_LOGS_DIR
dir: str = DEFAULT_LOGS_DIR experiment_name: Optional[str] = None
run_name: Optional[str] = None run_name: Optional[str] = None
class DVCLiveConfig(BaseLoggerConfig):
name: Literal["dvclive"] = "dvclive"
prefix: str = "" prefix: str = ""
log_model: Union[bool, Literal["all"]] = False log_model: Union[bool, Literal["all"]] = False
monitor_system: bool = False monitor_system: bool = False
class CSVLoggerConfig(BaseConfig): class CSVLoggerConfig(BaseLoggerConfig):
logger_type: Literal["csv"] = "csv" name: Literal["csv"] = "csv"
save_dir: str = DEFAULT_LOGS_DIR
name: Optional[str] = "logs"
version: Optional[str] = None
flush_logs_every_n_steps: int = 100 flush_logs_every_n_steps: int = 100
class TensorBoardLoggerConfig(BaseConfig): class TensorBoardLoggerConfig(BaseLoggerConfig):
logger_type: Literal["tensorboard"] = "tensorboard" name: Literal["tensorboard"] = "tensorboard"
save_dir: str = DEFAULT_LOGS_DIR
name: Optional[str] = "logs"
version: Optional[str] = None
log_graph: bool = False log_graph: bool = False
class MLFlowLoggerConfig(BaseConfig): class MLFlowLoggerConfig(BaseLoggerConfig):
logger_type: Literal["mlflow"] = "mlflow" name: Literal["mlflow"] = "mlflow"
experiment_name: str = "default"
run_name: Optional[str] = None
save_dir: Optional[str] = "./mlruns"
tracking_uri: Optional[str] = None tracking_uri: Optional[str] = None
tags: Optional[dict[str, Any]] = None tags: Optional[dict[str, Any]] = None
log_model: bool = False log_model: bool = False
@ -54,13 +60,28 @@ LoggerConfig = Annotated[
TensorBoardLoggerConfig, TensorBoardLoggerConfig,
MLFlowLoggerConfig, MLFlowLoggerConfig,
], ],
Field(discriminator="logger_type"), Field(discriminator="name"),
] ]
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
class LoggerBuilder(Protocol, Generic[T]):
def __call__(
self,
config: T,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ...
def create_dvclive_logger( def create_dvclive_logger(
config: DVCLiveConfig, config: DVCLiveConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
try: try:
from dvclive.lightning import DVCLiveLogger # type: ignore from dvclive.lightning import DVCLiveLogger # type: ignore
@ -72,8 +93,11 @@ def create_dvclive_logger(
) from error ) from error
return DVCLiveLogger( return DVCLiveLogger(
dir=log_dir if log_dir is not None else config.dir, dir=log_dir if log_dir is not None else config.log_dir,
run_name=config.run_name, run_name=run_name if run_name is not None else config.run_name,
experiment=experiment_name
if experiment_name is not None
else config.experiment_name,
prefix=config.prefix, prefix=config.prefix,
log_model=config.log_model, log_model=config.log_model,
monitor_system=config.monitor_system, monitor_system=config.monitor_system,
@ -82,28 +106,58 @@ def create_dvclive_logger(
def create_csv_logger( def create_csv_logger(
config: CSVLoggerConfig, config: CSVLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers import CSVLogger
if log_dir is None:
log_dir = Path(config.log_dir)
if run_name is None:
run_name = config.run_name
if experiment_name is None:
experiment_name = config.experiment_name
name = run_name
if run_name is not None and experiment_name is not None:
name = str(Path(experiment_name) / run_name)
return CSVLogger( return CSVLogger(
save_dir=str(log_dir) if log_dir is not None else config.save_dir, save_dir=str(log_dir),
name=config.name, name=name,
version=config.version,
flush_logs_every_n_steps=config.flush_logs_every_n_steps, flush_logs_every_n_steps=config.flush_logs_every_n_steps,
) )
def create_tensorboard_logger( def create_tensorboard_logger(
config: TensorBoardLoggerConfig, config: TensorBoardLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
if log_dir is None:
log_dir = Path(config.log_dir)
if run_name is None:
run_name = config.run_name
if experiment_name is None:
experiment_name = config.experiment_name
name = run_name
if run_name is not None and experiment_name is not None:
name = str(Path(experiment_name) / run_name)
return TensorBoardLogger( return TensorBoardLogger(
save_dir=str(log_dir) if log_dir is not None else config.save_dir, save_dir=str(log_dir),
name=config.name, name=name,
version=config.version,
log_graph=config.log_graph, log_graph=config.log_graph,
) )
@ -111,6 +165,8 @@ def create_tensorboard_logger(
def create_mlflow_logger( def create_mlflow_logger(
config: MLFlowLoggerConfig, config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
try: try:
from lightning.pytorch.loggers import MLFlowLogger from lightning.pytorch.loggers import MLFlowLogger
@ -121,17 +177,25 @@ def create_mlflow_logger(
"or `uv add mlflow`" "or `uv add mlflow`"
) from error ) from error
if experiment_name is None:
experiment_name = config.experiment_name or "Default"
if log_dir is None:
log_dir = config.log_dir
return MLFlowLogger( return MLFlowLogger(
experiment_name=config.experiment_name, experiment_name=experiment_name
run_name=config.run_name, if experiment_name is not None
save_dir=str(log_dir) if log_dir is not None else config.save_dir, else config.experiment_name,
run_name=run_name if run_name is not None else config.run_name,
save_dir=str(log_dir),
tracking_uri=config.tracking_uri, tracking_uri=config.tracking_uri,
tags=config.tags, tags=config.tags,
log_model=config.log_model, log_model=config.log_model,
) )
LOGGER_FACTORY = { LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
"dvclive": create_dvclive_logger, "dvclive": create_dvclive_logger,
"csv": create_csv_logger, "csv": create_csv_logger,
"tensorboard": create_tensorboard_logger, "tensorboard": create_tensorboard_logger,
@ -141,7 +205,9 @@ LOGGER_FACTORY = {
def build_logger( def build_logger(
config: LoggerConfig, config: LoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
""" """
Creates a logger instance from a validated Pydantic config object. Creates a logger instance from a validated Pydantic config object.
@ -150,14 +216,19 @@ def build_logger(
"Building logger with config: \n{}", "Building logger with config: \n{}",
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
logger_type = config.logger_type logger_type = config.name
if logger_type not in LOGGER_FACTORY: if logger_type not in LOGGER_FACTORY:
raise ValueError(f"Unknown logger type: {logger_type}") raise ValueError(f"Unknown logger type: {logger_type}")
creation_func = LOGGER_FACTORY[logger_type] creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config, log_dir=log_dir) return creation_func(
config,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
)
def get_image_plotter(logger: Logger): def get_image_plotter(logger: Logger):
@ -173,8 +244,8 @@ def get_image_plotter(logger: Logger):
def plot_figure(name, figure, step): def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure) image = _convert_figure_to_image(figure)
return logger.experiment.log_image( return logger.experiment.log_image(
run_id=logger.run_id, logger.run_id,
image=image, image,
key=name, key=name,
step=step, step=step,
) )

View File

@ -1,8 +1,9 @@
from collections.abc import Sequence from collections.abc import Sequence
from pathlib import Path
from typing import List, Optional from typing import List, Optional
import torch import torch
from lightning import Trainer from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
@ -45,6 +46,8 @@ __all__ = [
"train", "train",
] ]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
def train( def train(
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
@ -53,9 +56,15 @@ def train(
model_path: Optional[data.PathLike] = None, model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None, train_workers: Optional[int] = None,
val_workers: Optional[int] = None, val_workers: Optional[int] = None,
checkpoint_dir: Optional[data.PathLike] = None, checkpoint_dir: Optional[Path] = None,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
): ):
if seed is not None:
seed_everything(seed)
config = config or FullTrainingConfig() config = config or FullTrainingConfig()
targets = build_targets(config.targets) targets = build_targets(config.targets)
@ -107,6 +116,8 @@ def train(
targets=targets, targets=targets,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
) )
logger.info("Starting main training loop...") logger.info("Starting main training loop...")
@ -134,17 +145,32 @@ def build_trainer_callbacks(
targets: TargetProtocol, targets: TargetProtocol,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: EvaluationConfig, config: EvaluationConfig,
checkpoint_dir: Optional[data.PathLike] = None, checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> List[Callback]: ) -> List[Callback]:
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = "outputs/checkpoints" checkpoint_dir = DEFAULT_CHECKPOINT_DIR
filename = "best-{epoch:02d}-{val_loss:.0f}"
if run_name is not None:
filename = f"run_{run_name}_{filename}"
if experiment_name is not None:
filename = f"experiment_{experiment_name}_{filename}"
model_checkpoint = ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
filename=filename,
monitor="total_loss/val",
)
model_checkpoint.CHECKPOINT_EQUALS_CHAR = "_" # type: ignore
return [ return [
ModelCheckpoint( model_checkpoint,
dirpath=str(checkpoint_dir),
save_top_k=1,
monitor="total_loss/val",
),
ValidationMetrics( ValidationMetrics(
metrics=[ metrics=[
DetectionAveragePrecision(), DetectionAveragePrecision(),
@ -162,15 +188,22 @@ def build_trainer_callbacks(
def build_trainer( def build_trainer(
conf: FullTrainingConfig, conf: FullTrainingConfig,
targets: TargetProtocol, targets: TargetProtocol,
checkpoint_dir: Optional[data.PathLike] = None, checkpoint_dir: Optional[Path] = None,
log_dir: Optional[data.PathLike] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Trainer: ) -> Trainer:
trainer_conf = conf.train.trainer trainer_conf = conf.train.trainer
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}", "Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True), config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
) )
train_logger = build_logger(conf.train.logger, log_dir=log_dir) train_logger = build_logger(
conf.train.logger,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
)
train_logger.log_hyperparams( train_logger.log_hyperparams(
conf.model_dump( conf.model_dump(
@ -187,6 +220,7 @@ def build_trainer(
config=conf.evaluation, config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess), preprocessor=build_preprocessor(conf.preprocess),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
experiment_name=train_logger.name,
), ),
) )