mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
By default only save the last checkpoint
This commit is contained in:
parent
875751d340
commit
b8acd86c71
@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
|
||||
@ -14,11 +15,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
|
||||
class CheckpointConfig(BaseConfig):
|
||||
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
|
||||
monitor: str = "classification/mean_average_precision"
|
||||
monitor: str | None = None
|
||||
mode: str = "max"
|
||||
save_top_k: int = 1
|
||||
filename: str | None = None
|
||||
save_last: bool = False
|
||||
save_last: bool | Literal["link"] = "link"
|
||||
every_n_epochs: int | None = 1
|
||||
|
||||
|
||||
def build_checkpoint_callback(
|
||||
@ -38,6 +40,8 @@ def build_checkpoint_callback(
|
||||
if run_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / run_name
|
||||
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=config.save_top_k,
|
||||
@ -45,4 +49,5 @@ def build_checkpoint_callback(
|
||||
mode=config.mode,
|
||||
filename=config.filename,
|
||||
save_last=config.save_last,
|
||||
every_n_epochs=config.every_n_epochs,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user