diff --git a/src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt b/src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt new file mode 100644 index 0000000..49b64a4 Binary files /dev/null and b/src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt differ diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py index 0badb9a..f9d851e 100644 --- a/src/batdetect2/train/checkpoints.py +++ b/src/batdetect2/train/checkpoints.py @@ -9,10 +9,22 @@ from batdetect2.core import BaseConfig __all__ = [ "CheckpointConfig", "build_checkpoint_callback", + "get_bundled_checkpoint_names", "resolve_checkpoint_path", ] DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" +_PACKAGE_ROOT = Path(__file__).resolve().parents[1] +_BUNDLED_CHECKPOINTS = { + "uk_same": _PACKAGE_ROOT + / "models" + / "checkpoints" + / "batdetect2_uk_same.ckpt", + "batdetect2_uk_same": _PACKAGE_ROOT + / "models" + / "checkpoints" + / "batdetect2_uk_same.ckpt", +} class CheckpointConfig(BaseConfig): @@ -20,6 +32,8 @@ class CheckpointConfig(BaseConfig): monitor: str | None = None mode: str = "max" save_top_k: int = 1 + # Default to distributable inference checkpoints, not full train resumes. + save_weights_only: bool = True filename: str | None = None save_last: bool | Literal["link"] = "link" every_n_epochs: int | None = 1 @@ -49,6 +63,7 @@ def build_checkpoint_callback( return ModelCheckpoint( dirpath=str(checkpoint_dir), save_top_k=config.save_top_k, + save_weights_only=config.save_weights_only, monitor=config.monitor, mode=config.mode, filename=config.filename, @@ -57,14 +72,20 @@ def build_checkpoint_callback( ) +def get_bundled_checkpoint_names() -> tuple[str, ...]: + """Return the supported bundled checkpoint aliases.""" + + return tuple(_BUNDLED_CHECKPOINTS) + + def resolve_checkpoint_path(path: PathLike | str) -> Path: - """Resolve a local path or Hugging Face checkpoint URI. + """Resolve a local path, bundled alias, or Hugging Face checkpoint URI. Parameters ---------- path : PathLike | str - Local checkpoint path or a Hugging Face URI of the form - ``hf://owner/repo/path/to/checkpoint.ckpt``. + Local checkpoint path, bundled checkpoint alias, or a Hugging Face + URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. Returns ------- @@ -87,10 +108,23 @@ def resolve_checkpoint_path(path: PathLike | str) -> Path: if not isinstance(path, Path): path = Path(path) - if not path.exists(): - raise FileNotFoundError(f"Checkpoint not found: {path}") + if path.exists(): + return path - return path + bundled_path = _BUNDLED_CHECKPOINTS.get(str(path)) + if bundled_path is not None: + if not bundled_path.exists(): + raise FileNotFoundError( + f"Bundled checkpoint is missing: {bundled_path}" + ) + return bundled_path + + bundled_names = ", ".join(get_bundled_checkpoint_names()) + raise FileNotFoundError( + "Checkpoint not found: " + f"{path}. Expected a local path, a bundled checkpoint alias " + f"({bundled_names}), or a Hugging Face URI." + ) def _parse_huggingface_uri(uri: str) -> tuple[str, str]: diff --git a/tests/test_train/test_checkpoints.py b/tests/test_train/test_checkpoints.py index 2e5eaed..7d99f0a 100644 --- a/tests/test_train/test_checkpoints.py +++ b/tests/test_train/test_checkpoints.py @@ -3,10 +3,14 @@ import types from pathlib import Path import pytest +import torch from soundevent import data from batdetect2.train import TrainingConfig, run_train -from batdetect2.train.checkpoints import resolve_checkpoint_path +from batdetect2.train.checkpoints import ( + get_bundled_checkpoint_names, + resolve_checkpoint_path, +) pytestmark = pytest.mark.slow @@ -97,6 +101,37 @@ def test_train_controls_which_checkpoints_are_kept( assert "epoch" in best_checkpoints[0].name +def test_train_saves_weights_only_checkpoints_by_default( + tmp_path: Path, + example_annotations: list[data.ClipAnnotation], +) -> None: + config = _build_fast_train_config() + + run_train( + train_annotations=example_annotations[:1], + val_annotations=example_annotations[:1], + train_config=config, + num_epochs=1, + train_workers=0, + val_workers=0, + checkpoint_dir=tmp_path, + seed=0, + ) + + checkpoint_path = next(tmp_path.rglob("*.ckpt")) + checkpoint = torch.load( + checkpoint_path, + map_location="cpu", + weights_only=False, + ) + + assert "state_dict" in checkpoint + assert "hyper_parameters" in checkpoint + assert "pytorch-lightning_version" in checkpoint + assert "optimizer_states" not in checkpoint + assert "lr_schedulers" not in checkpoint + + def test_resolve_checkpoint_path_returns_local_path_unchanged( tmp_path: Path, ) -> None: @@ -107,6 +142,30 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged( assert resolve_checkpoint_path(str(local_path)) == local_path +def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None: + assert get_bundled_checkpoint_names() == ( + "uk_same", + "batdetect2_uk_same", + ) + + +def test_resolve_checkpoint_path_accepts_bundled_alias() -> None: + resolved = resolve_checkpoint_path("uk_same") + + assert resolved.name == "batdetect2_uk_same.ckpt" + assert resolved.exists() + + +def test_resolve_checkpoint_path_prefers_existing_local_path_over_alias( + tmp_path: Path, +) -> None: + local_path = tmp_path / "uk_same" + local_path.write_bytes(b"checkpoint") + + assert resolve_checkpoint_path(local_path) == local_path + assert resolve_checkpoint_path(str(local_path)) == local_path + + def test_resolve_checkpoint_path_downloads_huggingface_checkpoint( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, @@ -159,5 +218,8 @@ def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None: def test_resolve_checkpoint_path_rejects_missing_local_path() -> None: - with pytest.raises(FileNotFoundError, match="Checkpoint not found"): + with pytest.raises( + FileNotFoundError, + match="bundled checkpoint alias", + ): resolve_checkpoint_path("missing.ckpt")