By default only save the last checkpoint

This commit is contained in:
Santiago Martinez Balvanera 2026-03-19 01:26:11 +00:00
parent 875751d340
commit b8acd86c71

View File

@ -1,4 +1,5 @@
from pathlib import Path
from typing import Literal
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
@ -14,11 +15,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
class CheckpointConfig(BaseConfig):
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
monitor: str = "classification/mean_average_precision"
monitor: str | None = None
mode: str = "max"
save_top_k: int = 1
filename: str | None = None
save_last: bool = False
save_last: bool | Literal["link"] = "link"
every_n_epochs: int | None = 1
def build_checkpoint_callback(
@ -38,6 +40,8 @@ def build_checkpoint_callback(
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
checkpoint_dir.mkdir(parents=True, exist_ok=True)
return ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=config.save_top_k,
@ -45,4 +49,5 @@ def build_checkpoint_callback(
mode=config.mode,
filename=config.filename,
save_last=config.save_last,
every_n_epochs=config.every_n_epochs,
)