feat: streamline bundled checkpoint handling

Support packaged model aliases and save weights-only checkpoints by default so distributed models stay small while remaining easy to load.
This commit is contained in:
mbsantiago 2026-05-05 21:34:54 +01:00
parent d83f801515
commit 84918086c8
3 changed files with 104 additions and 8 deletions

View File

@ -9,10 +9,22 @@ from batdetect2.core import BaseConfig
__all__ = [ __all__ = [
"CheckpointConfig", "CheckpointConfig",
"build_checkpoint_callback", "build_checkpoint_callback",
"get_bundled_checkpoint_names",
"resolve_checkpoint_path", "resolve_checkpoint_path",
] ]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
_BUNDLED_CHECKPOINTS = {
"uk_same": _PACKAGE_ROOT
/ "models"
/ "checkpoints"
/ "batdetect2_uk_same.ckpt",
"batdetect2_uk_same": _PACKAGE_ROOT
/ "models"
/ "checkpoints"
/ "batdetect2_uk_same.ckpt",
}
class CheckpointConfig(BaseConfig): class CheckpointConfig(BaseConfig):
@ -20,6 +32,8 @@ class CheckpointConfig(BaseConfig):
monitor: str | None = None monitor: str | None = None
mode: str = "max" mode: str = "max"
save_top_k: int = 1 save_top_k: int = 1
# Default to distributable inference checkpoints, not full train resumes.
save_weights_only: bool = True
filename: str | None = None filename: str | None = None
save_last: bool | Literal["link"] = "link" save_last: bool | Literal["link"] = "link"
every_n_epochs: int | None = 1 every_n_epochs: int | None = 1
@ -49,6 +63,7 @@ def build_checkpoint_callback(
return ModelCheckpoint( return ModelCheckpoint(
dirpath=str(checkpoint_dir), dirpath=str(checkpoint_dir),
save_top_k=config.save_top_k, save_top_k=config.save_top_k,
save_weights_only=config.save_weights_only,
monitor=config.monitor, monitor=config.monitor,
mode=config.mode, mode=config.mode,
filename=config.filename, filename=config.filename,
@ -57,14 +72,20 @@ def build_checkpoint_callback(
) )
def get_bundled_checkpoint_names() -> tuple[str, ...]:
"""Return the supported bundled checkpoint aliases."""
return tuple(_BUNDLED_CHECKPOINTS)
def resolve_checkpoint_path(path: PathLike | str) -> Path: def resolve_checkpoint_path(path: PathLike | str) -> Path:
"""Resolve a local path or Hugging Face checkpoint URI. """Resolve a local path, bundled alias, or Hugging Face checkpoint URI.
Parameters Parameters
---------- ----------
path : PathLike | str path : PathLike | str
Local checkpoint path or a Hugging Face URI of the form Local checkpoint path, bundled checkpoint alias, or a Hugging Face
``hf://owner/repo/path/to/checkpoint.ckpt``. URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``.
Returns Returns
------- -------
@ -87,11 +108,24 @@ def resolve_checkpoint_path(path: PathLike | str) -> Path:
if not isinstance(path, Path): if not isinstance(path, Path):
path = Path(path) path = Path(path)
if not path.exists(): if path.exists():
raise FileNotFoundError(f"Checkpoint not found: {path}")
return path return path
bundled_path = _BUNDLED_CHECKPOINTS.get(str(path))
if bundled_path is not None:
if not bundled_path.exists():
raise FileNotFoundError(
f"Bundled checkpoint is missing: {bundled_path}"
)
return bundled_path
bundled_names = ", ".join(get_bundled_checkpoint_names())
raise FileNotFoundError(
"Checkpoint not found: "
f"{path}. Expected a local path, a bundled checkpoint alias "
f"({bundled_names}), or a Hugging Face URI."
)
def _parse_huggingface_uri(uri: str) -> tuple[str, str]: def _parse_huggingface_uri(uri: str) -> tuple[str, str]:
prefix = "hf://" prefix = "hf://"

View File

@ -3,10 +3,14 @@ import types
from pathlib import Path from pathlib import Path
import pytest import pytest
import torch
from soundevent import data from soundevent import data
from batdetect2.train import TrainingConfig, run_train from batdetect2.train import TrainingConfig, run_train
from batdetect2.train.checkpoints import resolve_checkpoint_path from batdetect2.train.checkpoints import (
get_bundled_checkpoint_names,
resolve_checkpoint_path,
)
pytestmark = pytest.mark.slow pytestmark = pytest.mark.slow
@ -97,6 +101,37 @@ def test_train_controls_which_checkpoints_are_kept(
assert "epoch" in best_checkpoints[0].name assert "epoch" in best_checkpoints[0].name
def test_train_saves_weights_only_checkpoints_by_default(
tmp_path: Path,
example_annotations: list[data.ClipAnnotation],
) -> None:
config = _build_fast_train_config()
run_train(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
train_config=config,
num_epochs=1,
train_workers=0,
val_workers=0,
checkpoint_dir=tmp_path,
seed=0,
)
checkpoint_path = next(tmp_path.rglob("*.ckpt"))
checkpoint = torch.load(
checkpoint_path,
map_location="cpu",
weights_only=False,
)
assert "state_dict" in checkpoint
assert "hyper_parameters" in checkpoint
assert "pytorch-lightning_version" in checkpoint
assert "optimizer_states" not in checkpoint
assert "lr_schedulers" not in checkpoint
def test_resolve_checkpoint_path_returns_local_path_unchanged( def test_resolve_checkpoint_path_returns_local_path_unchanged(
tmp_path: Path, tmp_path: Path,
) -> None: ) -> None:
@ -107,6 +142,30 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged(
assert resolve_checkpoint_path(str(local_path)) == local_path assert resolve_checkpoint_path(str(local_path)) == local_path
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
assert get_bundled_checkpoint_names() == (
"uk_same",
"batdetect2_uk_same",
)
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
resolved = resolve_checkpoint_path("uk_same")
assert resolved.name == "batdetect2_uk_same.ckpt"
assert resolved.exists()
def test_resolve_checkpoint_path_prefers_existing_local_path_over_alias(
tmp_path: Path,
) -> None:
local_path = tmp_path / "uk_same"
local_path.write_bytes(b"checkpoint")
assert resolve_checkpoint_path(local_path) == local_path
assert resolve_checkpoint_path(str(local_path)) == local_path
def test_resolve_checkpoint_path_downloads_huggingface_checkpoint( def test_resolve_checkpoint_path_downloads_huggingface_checkpoint(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
tmp_path: Path, tmp_path: Path,
@ -159,5 +218,8 @@ def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None: def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
with pytest.raises(FileNotFoundError, match="Checkpoint not found"): with pytest.raises(
FileNotFoundError,
match="bundled checkpoint alias",
):
resolve_checkpoint_path("missing.ckpt") resolve_checkpoint_path("missing.ckpt")