mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
By default only save the last checkpoint
This commit is contained in:
parent
875751d340
commit
b8acd86c71
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
|
|
||||||
@ -14,11 +15,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
|||||||
|
|
||||||
class CheckpointConfig(BaseConfig):
|
class CheckpointConfig(BaseConfig):
|
||||||
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
|
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
|
||||||
monitor: str = "classification/mean_average_precision"
|
monitor: str | None = None
|
||||||
mode: str = "max"
|
mode: str = "max"
|
||||||
save_top_k: int = 1
|
save_top_k: int = 1
|
||||||
filename: str | None = None
|
filename: str | None = None
|
||||||
save_last: bool = False
|
save_last: bool | Literal["link"] = "link"
|
||||||
|
every_n_epochs: int | None = 1
|
||||||
|
|
||||||
|
|
||||||
def build_checkpoint_callback(
|
def build_checkpoint_callback(
|
||||||
@ -38,6 +40,8 @@ def build_checkpoint_callback(
|
|||||||
if run_name is not None:
|
if run_name is not None:
|
||||||
checkpoint_dir = checkpoint_dir / run_name
|
checkpoint_dir = checkpoint_dir / run_name
|
||||||
|
|
||||||
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
return ModelCheckpoint(
|
return ModelCheckpoint(
|
||||||
dirpath=str(checkpoint_dir),
|
dirpath=str(checkpoint_dir),
|
||||||
save_top_k=config.save_top_k,
|
save_top_k=config.save_top_k,
|
||||||
@ -45,4 +49,5 @@ def build_checkpoint_callback(
|
|||||||
mode=config.mode,
|
mode=config.mode,
|
||||||
filename=config.filename,
|
filename=config.filename,
|
||||||
save_last=config.save_last,
|
save_last=config.save_last,
|
||||||
|
every_n_epochs=config.every_n_epochs,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user