From 160cb6ae30e45b465a59e380b735180dcc8da224 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 4 Oct 2025 15:08:34 +0100 Subject: [PATCH] Add checkpoint config --- src/batdetect2/train/checkpoints.py | 47 +++++++++++++++++++++ src/batdetect2/train/config.py | 2 + src/batdetect2/train/train.py | 64 ++++++++--------------------- 3 files changed, 66 insertions(+), 47 deletions(-) create mode 100644 src/batdetect2/train/checkpoints.py diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py new file mode 100644 index 0000000..48b7432 --- /dev/null +++ b/src/batdetect2/train/checkpoints.py @@ -0,0 +1,47 @@ +from pathlib import Path +from typing import Optional + +from lightning.pytorch.callbacks import Callback, ModelCheckpoint + +from batdetect2.core import BaseConfig + +__all__ = [ + "CheckpointConfig", + "build_checkpoint_callback", +] + +DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" + + +class CheckpointConfig(BaseConfig): + checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR + monitor: str = "classification/mean_average_precision" + mode: str = "max" + save_top_k: int = 1 + filename: Optional[str] = None + + +def build_checkpoint_callback( + config: Optional[CheckpointConfig] = None, + checkpoint_dir: Optional[Path] = None, + experiment_name: Optional[str] = None, + run_name: Optional[str] = None, +) -> Callback: + config = config or CheckpointConfig() + + if checkpoint_dir is None: + checkpoint_dir = config.checkpoint_dir + + if experiment_name is not None: + checkpoint_dir = checkpoint_dir / experiment_name + + if run_name is not None: + checkpoint_dir = checkpoint_dir / run_name + + return ModelCheckpoint( + dirpath=str(checkpoint_dir), + save_top_k=config.save_top_k, + monitor=config.monitor, + mode=config.mode, + filename=config.filename, + ) diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 699b791..38db5ef 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -6,6 +6,7 @@ from soundevent import data from batdetect2.core.configs import BaseConfig, load_config from batdetect2.evaluate.config import EvaluationConfig from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig +from batdetect2.train.checkpoints import CheckpointConfig from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig from batdetect2.train.labels import LabelConfig from batdetect2.train.losses import LossConfig @@ -51,6 +52,7 @@ class TrainingConfig(BaseConfig): logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig) labels: LabelConfig = Field(default_factory=LabelConfig) validation: EvaluationConfig = Field(default_factory=EvaluationConfig) + checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig) def load_train_config( diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index b0f1642..a5c5a43 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -1,9 +1,8 @@ from collections.abc import Sequence from pathlib import Path -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from lightning import Trainer, seed_everything -from lightning.pytorch.callbacks import Callback, ModelCheckpoint from loguru import logger from soundevent import data @@ -13,6 +12,7 @@ from batdetect2.logging import build_logger from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets from batdetect2.train.callbacks import ValidationMetrics +from batdetect2.train.checkpoints import build_checkpoint_callback from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import build_training_module @@ -32,8 +32,6 @@ __all__ = [ "train", ] -DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" - def train( train_annotations: Sequence[data.ClipAnnotation], @@ -104,7 +102,6 @@ def train( trainer = trainer or build_trainer( config, - targets=targets, evaluator=build_evaluator( config.train.validation, targets=targets, @@ -124,58 +121,29 @@ def train( logger.info("Training complete.") -def build_trainer_callbacks( - targets: "TargetProtocol", - evaluator: Optional["EvaluatorProtocol"] = None, - checkpoint_dir: Optional[Path] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, -) -> List[Callback]: - if checkpoint_dir is None: - checkpoint_dir = DEFAULT_CHECKPOINT_DIR - - if experiment_name is not None: - checkpoint_dir = checkpoint_dir / experiment_name - - if run_name is not None: - checkpoint_dir = checkpoint_dir / run_name - - evaluator = evaluator or build_evaluator(targets=targets) - - return [ - ModelCheckpoint( - dirpath=str(checkpoint_dir), - save_top_k=1, - monitor="classification/mean_average_precision", - mode="max", - ), - ValidationMetrics(evaluator), - ] - - def build_trainer( - conf: "BatDetect2Config", - targets: "TargetProtocol", - evaluator: Optional["EvaluatorProtocol"] = None, + config: "BatDetect2Config", + evaluator: "EvaluatorProtocol", checkpoint_dir: Optional[Path] = None, log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, run_name: Optional[str] = None, ) -> Trainer: - trainer_conf = conf.train.trainer + trainer_conf = config.train.trainer logger.opt(lazy=True).debug( "Building trainer with config: \n{config}", config=lambda: trainer_conf.to_yaml_string(exclude_none=True), ) + train_logger = build_logger( - conf.train.logger, + config.train.logger, log_dir=log_dir, experiment_name=experiment_name, run_name=run_name, ) train_logger.log_hyperparams( - conf.model_dump( + config.model_dump( mode="json", exclude_none=True, ) @@ -184,11 +152,13 @@ def build_trainer( return Trainer( **trainer_conf.model_dump(exclude_none=True), logger=train_logger, - callbacks=build_trainer_callbacks( - targets, - evaluator=evaluator, - checkpoint_dir=checkpoint_dir, - experiment_name=experiment_name, - run_name=run_name, - ), + callbacks=[ + build_checkpoint_callback( + config=config.train.checkpoints, + checkpoint_dir=checkpoint_dir, + experiment_name=experiment_name, + run_name=run_name, + ), + ValidationMetrics(evaluator), + ], )