Add checkpoint config

This commit is contained in:
mbsantiago 2025-10-04 15:08:34 +01:00
parent 2d796394f6
commit 160cb6ae30
3 changed files with 66 additions and 47 deletions

View File

@ -0,0 +1,47 @@
from pathlib import Path
from typing import Optional
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from batdetect2.core import BaseConfig
__all__ = [
"CheckpointConfig",
"build_checkpoint_callback",
]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
class CheckpointConfig(BaseConfig):
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
monitor: str = "classification/mean_average_precision"
mode: str = "max"
save_top_k: int = 1
filename: Optional[str] = None
def build_checkpoint_callback(
config: Optional[CheckpointConfig] = None,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Callback:
config = config or CheckpointConfig()
if checkpoint_dir is None:
checkpoint_dir = config.checkpoint_dir
if experiment_name is not None:
checkpoint_dir = checkpoint_dir / experiment_name
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
return ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=config.save_top_k,
monitor=config.monitor,
mode=config.mode,
filename=config.filename,
)

View File

@ -6,6 +6,7 @@ from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
from batdetect2.train.checkpoints import CheckpointConfig
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import LabelConfig
from batdetect2.train.losses import LossConfig
@ -51,6 +52,7 @@ class TrainingConfig(BaseConfig):
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig)
def load_train_config(

View File

@ -1,9 +1,8 @@
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger
from soundevent import data
@ -13,6 +12,7 @@ from batdetect2.logging import build_logger
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.checkpoints import build_checkpoint_callback
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module
@ -32,8 +32,6 @@ __all__ = [
"train",
]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
def train(
train_annotations: Sequence[data.ClipAnnotation],
@ -104,7 +102,6 @@ def train(
trainer = trainer or build_trainer(
config,
targets=targets,
evaluator=build_evaluator(
config.train.validation,
targets=targets,
@ -124,58 +121,29 @@ def train(
logger.info("Training complete.")
def build_trainer_callbacks(
targets: "TargetProtocol",
evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> List[Callback]:
if checkpoint_dir is None:
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
if experiment_name is not None:
checkpoint_dir = checkpoint_dir / experiment_name
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
evaluator = evaluator or build_evaluator(targets=targets)
return [
ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
monitor="classification/mean_average_precision",
mode="max",
),
ValidationMetrics(evaluator),
]
def build_trainer(
conf: "BatDetect2Config",
targets: "TargetProtocol",
evaluator: Optional["EvaluatorProtocol"] = None,
config: "BatDetect2Config",
evaluator: "EvaluatorProtocol",
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Trainer:
trainer_conf = conf.train.trainer
trainer_conf = config.train.trainer
logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
train_logger = build_logger(
conf.train.logger,
config.train.logger,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
)
train_logger.log_hyperparams(
conf.model_dump(
config.model_dump(
mode="json",
exclude_none=True,
)
@ -184,11 +152,13 @@ def build_trainer(
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
logger=train_logger,
callbacks=build_trainer_callbacks(
targets,
evaluator=evaluator,
checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name,
run_name=run_name,
),
callbacks=[
build_checkpoint_callback(
config=config.train.checkpoints,
checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name,
run_name=run_name,
),
ValidationMetrics(evaluator),
],
)