From b8acd86c718d1965782a39a5cc85a3392c036ac8 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Balvanera Date: Thu, 19 Mar 2026 01:26:11 +0000 Subject: [PATCH] By default only save the last checkpoint --- src/batdetect2/train/checkpoints.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py index 92db951..e93c7ca 100644 --- a/src/batdetect2/train/checkpoints.py +++ b/src/batdetect2/train/checkpoints.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Literal from lightning.pytorch.callbacks import Callback, ModelCheckpoint @@ -14,11 +15,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" class CheckpointConfig(BaseConfig): checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR - monitor: str = "classification/mean_average_precision" + monitor: str | None = None mode: str = "max" save_top_k: int = 1 filename: str | None = None - save_last: bool = False + save_last: bool | Literal["link"] = "link" + every_n_epochs: int | None = 1 def build_checkpoint_callback( @@ -38,6 +40,8 @@ def build_checkpoint_callback( if run_name is not None: checkpoint_dir = checkpoint_dir / run_name + checkpoint_dir.mkdir(parents=True, exist_ok=True) + return ModelCheckpoint( dirpath=str(checkpoint_dir), save_top_k=config.save_top_k, @@ -45,4 +49,5 @@ def build_checkpoint_callback( mode=config.mode, filename=config.filename, save_last=config.save_last, + every_n_epochs=config.every_n_epochs, )