By default only save the last checkpoint

This commit is contained in:
Santiago Martinez Balvanera 2026-03-19 01:26:11 +00:00
parent 875751d340
commit b8acd86c71

View File

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Literal
from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.callbacks import Callback, ModelCheckpoint
@ -14,11 +15,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
class CheckpointConfig(BaseConfig): class CheckpointConfig(BaseConfig):
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
monitor: str = "classification/mean_average_precision" monitor: str | None = None
mode: str = "max" mode: str = "max"
save_top_k: int = 1 save_top_k: int = 1
filename: str | None = None filename: str | None = None
save_last: bool = False save_last: bool | Literal["link"] = "link"
every_n_epochs: int | None = 1
def build_checkpoint_callback( def build_checkpoint_callback(
@ -38,6 +40,8 @@ def build_checkpoint_callback(
if run_name is not None: if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name checkpoint_dir = checkpoint_dir / run_name
checkpoint_dir.mkdir(parents=True, exist_ok=True)
return ModelCheckpoint( return ModelCheckpoint(
dirpath=str(checkpoint_dir), dirpath=str(checkpoint_dir),
save_top_k=config.save_top_k, save_top_k=config.save_top_k,
@ -45,4 +49,5 @@ def build_checkpoint_callback(
mode=config.mode, mode=config.mode,
filename=config.filename, filename=config.filename,
save_last=config.save_last, save_last=config.save_last,
every_n_epochs=config.every_n_epochs,
) )