mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: streamline bundled checkpoint handling
Support packaged model aliases and save weights-only checkpoints by default so distributed models stay small while remaining easy to load.
This commit is contained in:
parent
d83f801515
commit
84918086c8
BIN
src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt
Normal file
BIN
src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt
Normal file
Binary file not shown.
@ -9,10 +9,22 @@ from batdetect2.core import BaseConfig
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"CheckpointConfig",
|
"CheckpointConfig",
|
||||||
"build_checkpoint_callback",
|
"build_checkpoint_callback",
|
||||||
|
"get_bundled_checkpoint_names",
|
||||||
"resolve_checkpoint_path",
|
"resolve_checkpoint_path",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||||
|
_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
_BUNDLED_CHECKPOINTS = {
|
||||||
|
"uk_same": _PACKAGE_ROOT
|
||||||
|
/ "models"
|
||||||
|
/ "checkpoints"
|
||||||
|
/ "batdetect2_uk_same.ckpt",
|
||||||
|
"batdetect2_uk_same": _PACKAGE_ROOT
|
||||||
|
/ "models"
|
||||||
|
/ "checkpoints"
|
||||||
|
/ "batdetect2_uk_same.ckpt",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class CheckpointConfig(BaseConfig):
|
class CheckpointConfig(BaseConfig):
|
||||||
@ -20,6 +32,8 @@ class CheckpointConfig(BaseConfig):
|
|||||||
monitor: str | None = None
|
monitor: str | None = None
|
||||||
mode: str = "max"
|
mode: str = "max"
|
||||||
save_top_k: int = 1
|
save_top_k: int = 1
|
||||||
|
# Default to distributable inference checkpoints, not full train resumes.
|
||||||
|
save_weights_only: bool = True
|
||||||
filename: str | None = None
|
filename: str | None = None
|
||||||
save_last: bool | Literal["link"] = "link"
|
save_last: bool | Literal["link"] = "link"
|
||||||
every_n_epochs: int | None = 1
|
every_n_epochs: int | None = 1
|
||||||
@ -49,6 +63,7 @@ def build_checkpoint_callback(
|
|||||||
return ModelCheckpoint(
|
return ModelCheckpoint(
|
||||||
dirpath=str(checkpoint_dir),
|
dirpath=str(checkpoint_dir),
|
||||||
save_top_k=config.save_top_k,
|
save_top_k=config.save_top_k,
|
||||||
|
save_weights_only=config.save_weights_only,
|
||||||
monitor=config.monitor,
|
monitor=config.monitor,
|
||||||
mode=config.mode,
|
mode=config.mode,
|
||||||
filename=config.filename,
|
filename=config.filename,
|
||||||
@ -57,14 +72,20 @@ def build_checkpoint_callback(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bundled_checkpoint_names() -> tuple[str, ...]:
|
||||||
|
"""Return the supported bundled checkpoint aliases."""
|
||||||
|
|
||||||
|
return tuple(_BUNDLED_CHECKPOINTS)
|
||||||
|
|
||||||
|
|
||||||
def resolve_checkpoint_path(path: PathLike | str) -> Path:
|
def resolve_checkpoint_path(path: PathLike | str) -> Path:
|
||||||
"""Resolve a local path or Hugging Face checkpoint URI.
|
"""Resolve a local path, bundled alias, or Hugging Face checkpoint URI.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : PathLike | str
|
path : PathLike | str
|
||||||
Local checkpoint path or a Hugging Face URI of the form
|
Local checkpoint path, bundled checkpoint alias, or a Hugging Face
|
||||||
``hf://owner/repo/path/to/checkpoint.ckpt``.
|
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -87,11 +108,24 @@ def resolve_checkpoint_path(path: PathLike | str) -> Path:
|
|||||||
if not isinstance(path, Path):
|
if not isinstance(path, Path):
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
if not path.exists():
|
if path.exists():
|
||||||
raise FileNotFoundError(f"Checkpoint not found: {path}")
|
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
bundled_path = _BUNDLED_CHECKPOINTS.get(str(path))
|
||||||
|
if bundled_path is not None:
|
||||||
|
if not bundled_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Bundled checkpoint is missing: {bundled_path}"
|
||||||
|
)
|
||||||
|
return bundled_path
|
||||||
|
|
||||||
|
bundled_names = ", ".join(get_bundled_checkpoint_names())
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Checkpoint not found: "
|
||||||
|
f"{path}. Expected a local path, a bundled checkpoint alias "
|
||||||
|
f"({bundled_names}), or a Hugging Face URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_huggingface_uri(uri: str) -> tuple[str, str]:
|
def _parse_huggingface_uri(uri: str) -> tuple[str, str]:
|
||||||
prefix = "hf://"
|
prefix = "hf://"
|
||||||
|
|||||||
@ -3,10 +3,14 @@ import types
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.train import TrainingConfig, run_train
|
from batdetect2.train import TrainingConfig, run_train
|
||||||
from batdetect2.train.checkpoints import resolve_checkpoint_path
|
from batdetect2.train.checkpoints import (
|
||||||
|
get_bundled_checkpoint_names,
|
||||||
|
resolve_checkpoint_path,
|
||||||
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.slow
|
pytestmark = pytest.mark.slow
|
||||||
|
|
||||||
@ -97,6 +101,37 @@ def test_train_controls_which_checkpoints_are_kept(
|
|||||||
assert "epoch" in best_checkpoints[0].name
|
assert "epoch" in best_checkpoints[0].name
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_saves_weights_only_checkpoints_by_default(
|
||||||
|
tmp_path: Path,
|
||||||
|
example_annotations: list[data.ClipAnnotation],
|
||||||
|
) -> None:
|
||||||
|
config = _build_fast_train_config()
|
||||||
|
|
||||||
|
run_train(
|
||||||
|
train_annotations=example_annotations[:1],
|
||||||
|
val_annotations=example_annotations[:1],
|
||||||
|
train_config=config,
|
||||||
|
num_epochs=1,
|
||||||
|
train_workers=0,
|
||||||
|
val_workers=0,
|
||||||
|
checkpoint_dir=tmp_path,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint_path = next(tmp_path.rglob("*.ckpt"))
|
||||||
|
checkpoint = torch.load(
|
||||||
|
checkpoint_path,
|
||||||
|
map_location="cpu",
|
||||||
|
weights_only=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "state_dict" in checkpoint
|
||||||
|
assert "hyper_parameters" in checkpoint
|
||||||
|
assert "pytorch-lightning_version" in checkpoint
|
||||||
|
assert "optimizer_states" not in checkpoint
|
||||||
|
assert "lr_schedulers" not in checkpoint
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -107,6 +142,30 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
|||||||
assert resolve_checkpoint_path(str(local_path)) == local_path
|
assert resolve_checkpoint_path(str(local_path)) == local_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
||||||
|
assert get_bundled_checkpoint_names() == (
|
||||||
|
"uk_same",
|
||||||
|
"batdetect2_uk_same",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
|
||||||
|
resolved = resolve_checkpoint_path("uk_same")
|
||||||
|
|
||||||
|
assert resolved.name == "batdetect2_uk_same.ckpt"
|
||||||
|
assert resolved.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_prefers_existing_local_path_over_alias(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
local_path = tmp_path / "uk_same"
|
||||||
|
local_path.write_bytes(b"checkpoint")
|
||||||
|
|
||||||
|
assert resolve_checkpoint_path(local_path) == local_path
|
||||||
|
assert resolve_checkpoint_path(str(local_path)) == local_path
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_checkpoint_path_downloads_huggingface_checkpoint(
|
def test_resolve_checkpoint_path_downloads_huggingface_checkpoint(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
@ -159,5 +218,8 @@ def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
|
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
|
||||||
with pytest.raises(FileNotFoundError, match="Checkpoint not found"):
|
with pytest.raises(
|
||||||
|
FileNotFoundError,
|
||||||
|
match="bundled checkpoint alias",
|
||||||
|
):
|
||||||
resolve_checkpoint_path("missing.ckpt")
|
resolve_checkpoint_path("missing.ckpt")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user