feat: default to bundled checkpoint

Fall back to the bundled uk_same model when no checkpoint is provided in the shared loader and fine-tune CLI. Keep tests aligned with the new default resolution behavior.
This commit is contained in:
mbsantiago 2026-05-06 10:33:04 +01:00
parent 31054f64f6
commit 7c05fb8577
7 changed files with 102 additions and 49 deletions

View File

@ -656,7 +656,7 @@ class BatDetect2API:
@classmethod @classmethod
def from_checkpoint( def from_checkpoint(
cls, cls,
path: data.PathLike | str, path: data.PathLike | str | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None, evaluation_config: EvaluationConfig | None = None,

View File

@ -5,6 +5,7 @@ import click
from loguru import logger from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.train.checkpoints import DEFAULT_BUNDLED_CHECKPOINT
__all__ = ["finetune_command"] __all__ = ["finetune_command"]
@ -16,9 +17,12 @@ __all__ = ["finetune_command"]
@click.option( @click.option(
"--model", "--model",
"model_path", "model_path",
required=True,
type=str, type=str,
help="Path to a checkpoint or a Hugging Face URI to fine-tune from.", help=(
"Path to a checkpoint, bundled checkpoint alias, or a Hugging Face "
"URI to fine-tune from. Defaults to "
f"'{DEFAULT_BUNDLED_CHECKPOINT}'."
),
) )
@click.option( @click.option(
"--targets", "--targets",
@ -106,7 +110,7 @@ __all__ = ["finetune_command"]
) )
def finetune_command( def finetune_command(
train_dataset: Path, train_dataset: Path,
model_path: str, model_path: str | None,
targets_config: Path, targets_config: Path,
val_dataset: Path | None = None, val_dataset: Path | None = None,
ckpt_dir: Path | None = None, ckpt_dir: Path | None = None,

View File

@ -8,19 +8,21 @@ from batdetect2.core import BaseConfig
__all__ = [ __all__ = [
"CheckpointConfig", "CheckpointConfig",
"DEFAULT_CHECKPOINT",
"build_checkpoint_callback", "build_checkpoint_callback",
"get_bundled_checkpoint_names", "get_bundled_checkpoint_names",
"resolve_checkpoint_path", "resolve_checkpoint_path",
] ]
PACKAGE_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
_PACKAGE_ROOT = Path(__file__).resolve().parents[1] DEFAULT_CHECKPOINT = "uk_same"
_BUNDLED_CHECKPOINTS = { CHECKPOINT_ALIASES = {
"uk_same": _PACKAGE_ROOT DEFAULT_CHECKPOINT: PACKAGE_ROOT
/ "models" / "models"
/ "checkpoints" / "checkpoints"
/ "batdetect2_uk_same.ckpt", / "batdetect2_uk_same.ckpt",
"batdetect2_uk_same": _PACKAGE_ROOT "batdetect2_uk_same": PACKAGE_ROOT
/ "models" / "models"
/ "checkpoints" / "checkpoints"
/ "batdetect2_uk_same.ckpt", / "batdetect2_uk_same.ckpt",
@ -32,7 +34,6 @@ class CheckpointConfig(BaseConfig):
monitor: str | None = None monitor: str | None = None
mode: str = "max" mode: str = "max"
save_top_k: int = 1 save_top_k: int = 1
# Default to distributable inference checkpoints, not full train resumes.
save_weights_only: bool = True save_weights_only: bool = True
filename: str | None = None filename: str | None = None
save_last: bool | Literal["link"] = "link" save_last: bool | Literal["link"] = "link"
@ -74,55 +75,55 @@ def build_checkpoint_callback(
def get_bundled_checkpoint_names() -> tuple[str, ...]: def get_bundled_checkpoint_names() -> tuple[str, ...]:
"""Return the supported bundled checkpoint aliases.""" """Return the supported bundled checkpoint aliases."""
return tuple(CHECKPOINT_ALIASES.keys())
return tuple(_BUNDLED_CHECKPOINTS)
def resolve_checkpoint_path(path: PathLike | str) -> Path: def resolve_checkpoint_from_huggingface(path: str) -> Path:
"""Resolve a local path, bundled alias, or Hugging Face checkpoint URI. """Resolve a Hugging Face checkpoint URI."""
try:
from huggingface_hub import hf_hub_download
except ImportError as error:
raise ValueError(
"Hugging Face checkpoint support is not installed. "
"Install it with `pip install batdetect2[huggingface]`."
) from error
repo_id, filename = _parse_huggingface_uri(path)
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
def resolve_checkpoint_path(path: PathLike | str | None = None) -> Path:
"""Resolve a local path, alias, or Hugging Face checkpoint URI.
Parameters Parameters
---------- ----------
path : PathLike | str path : PathLike | str | None
Local checkpoint path, bundled checkpoint alias, or a Hugging Face Local checkpoint path, checkpoint alias, or a Hugging Face
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. If
omitted, the default alias checkpoint is used.
Returns Returns
------- -------
Path Path
Resolved local filesystem path to the checkpoint. Resolved local filesystem path to the checkpoint.
""" """
if path is None:
path = DEFAULT_CHECKPOINT
if isinstance(path, str) and path.startswith("hf://"): if isinstance(path, str) and path.startswith("hf://"):
try: return resolve_checkpoint_from_huggingface(path)
from huggingface_hub import hf_hub_download
except ImportError as error:
raise ValueError(
"Hugging Face checkpoint support is not installed. "
"Install it with `uv sync --group huggingface` or "
"`pip install huggingface-hub`."
) from error
repo_id, filename = _parse_huggingface_uri(path) if isinstance(path, str) and path in CHECKPOINT_ALIASES:
return Path(hf_hub_download(repo_id=repo_id, filename=filename)) return Path(CHECKPOINT_ALIASES[path])
if not isinstance(path, Path):
path = Path(path) path = Path(path)
if path.exists(): if path.exists():
return path return path.resolve()
bundled_path = _BUNDLED_CHECKPOINTS.get(str(path))
if bundled_path is not None:
if not bundled_path.exists():
raise FileNotFoundError(
f"Bundled checkpoint is missing: {bundled_path}"
)
return bundled_path
bundled_names = ", ".join(get_bundled_checkpoint_names()) bundled_names = ", ".join(get_bundled_checkpoint_names())
raise FileNotFoundError( raise FileNotFoundError(
"Checkpoint not found: " f"Checkpoint not found: {path}. "
f"{path}. Expected a local path, a bundled checkpoint alias " "Expected a local path, a checkpoint alias "
f"({bundled_names}), or a Hugging Face URI." f"({bundled_names}), or a Hugging Face URI."
) )

View File

@ -132,15 +132,15 @@ class StoredConfig:
def load_model_from_checkpoint( def load_model_from_checkpoint(
path: PathLike | str, path: PathLike | str | None = None,
) -> tuple[Model, StoredConfig]: ) -> tuple[Model, StoredConfig]:
"""Load a model and its configuration from a Lightning checkpoint. """Load a model and its configuration from a Lightning checkpoint.
Parameters Parameters
---------- ----------
path : PathLike path : PathLike | str | None
Path to a ``.ckpt`` file produced by the BatDetect2 training Path to a ``.ckpt`` file produced by the BatDetect2 training
pipeline. pipeline. If omitted, the default bundled checkpoint is used.
Returns Returns
------- -------

View File

@ -307,6 +307,12 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
) )
def test_api_from_checkpoint_defaults_to_bundled_model() -> None:
api = BatDetect2API.from_checkpoint()
assert api.model.class_names
@pytest.mark.slow @pytest.mark.slow
def test_user_can_evaluate_small_dataset_and_get_metrics( def test_user_can_evaluate_small_dataset_and_get_metrics(
api_v2: BatDetect2API, api_v2: BatDetect2API,

View File

@ -1,6 +1,7 @@
"""CLI tests for finetune command.""" """CLI tests for finetune command."""
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
import pytest import pytest
from click.testing import CliRunner from click.testing import CliRunner
@ -25,8 +26,41 @@ def test_cli_finetune_help() -> None:
assert "--outputs-config" not in result.output assert "--outputs-config" not in result.output
def test_cli_finetune_requires_model() -> None: def test_cli_finetune_defaults_to_bundled_model(
"""User story: finetune requires a checkpoint argument.""" monkeypatch: pytest.MonkeyPatch,
) -> None:
"""User story: finetune can use the bundled checkpoint by default."""
called = {}
class FakeAPI:
def finetune(self, **kwargs):
called["finetune"] = kwargs
return None
class FakeBatDetect2API:
@classmethod
def from_checkpoint(cls, path=None, **kwargs):
called["path"] = path
called["from_checkpoint_kwargs"] = kwargs
return FakeAPI()
monkeypatch.setattr(
"batdetect2.api_v2.BatDetect2API",
FakeBatDetect2API,
)
monkeypatch.setattr(
"batdetect2.data.load_dataset_config",
lambda path: SimpleNamespace(path=path),
)
monkeypatch.setattr(
"batdetect2.data.load_dataset",
lambda config, base_dir=None: [],
)
monkeypatch.setattr(
"batdetect2.targets.TargetConfig.load",
lambda path: SimpleNamespace(path=path),
)
result = CliRunner().invoke( result = CliRunner().invoke(
cli, cli,
@ -38,8 +72,9 @@ def test_cli_finetune_requires_model() -> None:
], ],
) )
assert result.exit_code != 0 assert result.exit_code == 0
assert "--model" in result.output assert called["path"] is None
assert "finetune" in called
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None: def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:

View File

@ -8,6 +8,7 @@ from soundevent import data
from batdetect2.train import TrainingConfig, run_train from batdetect2.train import TrainingConfig, run_train
from batdetect2.train.checkpoints import ( from batdetect2.train.checkpoints import (
DEFAULT_BUNDLED_CHECKPOINT,
get_bundled_checkpoint_names, get_bundled_checkpoint_names,
resolve_checkpoint_path, resolve_checkpoint_path,
) )
@ -144,13 +145,19 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged(
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None: def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
assert get_bundled_checkpoint_names() == ( assert get_bundled_checkpoint_names() == (
"uk_same", DEFAULT_BUNDLED_CHECKPOINT,
"batdetect2_uk_same", "batdetect2_uk_same",
) )
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
resolved = resolve_checkpoint_path()
assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None: def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
resolved = resolve_checkpoint_path("uk_same") resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
assert resolved.name == "batdetect2_uk_same.ckpt" assert resolved.name == "batdetect2_uk_same.ckpt"
assert resolved.exists() assert resolved.exists()