diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py index 92db951..e93c7ca 100644 --- a/src/batdetect2/train/checkpoints.py +++ b/src/batdetect2/train/checkpoints.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Literal from lightning.pytorch.callbacks import Callback, ModelCheckpoint @@ -14,11 +15,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" class CheckpointConfig(BaseConfig): checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR - monitor: str = "classification/mean_average_precision" + monitor: str | None = None mode: str = "max" save_top_k: int = 1 filename: str | None = None - save_last: bool = False + save_last: bool | Literal["link"] = "link" + every_n_epochs: int | None = 1 def build_checkpoint_callback( @@ -38,6 +40,8 @@ def build_checkpoint_callback( if run_name is not None: checkpoint_dir = checkpoint_dir / run_name + checkpoint_dir.mkdir(parents=True, exist_ok=True) + return ModelCheckpoint( dirpath=str(checkpoint_dir), save_top_k=config.save_top_k, @@ -45,4 +49,5 @@ def build_checkpoint_callback( mode=config.mode, filename=config.filename, save_last=config.save_last, + every_n_epochs=config.every_n_epochs, )