feat: default to bundled checkpoint

Fall back to the bundled uk_same model when no checkpoint is provided in the shared loader and fine-tune CLI. Keep tests aligned with the new default resolution behavior.
This commit is contained in:
mbsantiago 2026-05-06 10:33:04 +01:00
parent 31054f64f6
commit 7c05fb8577
7 changed files with 102 additions and 49 deletions

View File

@ -656,7 +656,7 @@ class BatDetect2API:
@classmethod
def from_checkpoint(
cls,
path: data.PathLike | str,
path: data.PathLike | str | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,

View File

@ -5,6 +5,7 @@ import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.train.checkpoints import DEFAULT_BUNDLED_CHECKPOINT
__all__ = ["finetune_command"]
@ -16,9 +17,12 @@ __all__ = ["finetune_command"]
@click.option(
"--model",
"model_path",
required=True,
type=str,
help="Path to a checkpoint or a Hugging Face URI to fine-tune from.",
help=(
"Path to a checkpoint, bundled checkpoint alias, or a Hugging Face "
"URI to fine-tune from. Defaults to "
f"'{DEFAULT_BUNDLED_CHECKPOINT}'."
),
)
@click.option(
"--targets",
@ -106,7 +110,7 @@ __all__ = ["finetune_command"]
)
def finetune_command(
train_dataset: Path,
model_path: str,
model_path: str | None,
targets_config: Path,
val_dataset: Path | None = None,
ckpt_dir: Path | None = None,

View File

@ -8,19 +8,21 @@ from batdetect2.core import BaseConfig
__all__ = [
"CheckpointConfig",
"DEFAULT_CHECKPOINT",
"build_checkpoint_callback",
"get_bundled_checkpoint_names",
"resolve_checkpoint_path",
]
PACKAGE_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
_BUNDLED_CHECKPOINTS = {
"uk_same": _PACKAGE_ROOT
DEFAULT_CHECKPOINT = "uk_same"
CHECKPOINT_ALIASES = {
DEFAULT_CHECKPOINT: PACKAGE_ROOT
/ "models"
/ "checkpoints"
/ "batdetect2_uk_same.ckpt",
"batdetect2_uk_same": _PACKAGE_ROOT
"batdetect2_uk_same": PACKAGE_ROOT
/ "models"
/ "checkpoints"
/ "batdetect2_uk_same.ckpt",
@ -32,7 +34,6 @@ class CheckpointConfig(BaseConfig):
monitor: str | None = None
mode: str = "max"
save_top_k: int = 1
# Default to distributable inference checkpoints, not full train resumes.
save_weights_only: bool = True
filename: str | None = None
save_last: bool | Literal["link"] = "link"
@ -74,55 +75,55 @@ def build_checkpoint_callback(
def get_bundled_checkpoint_names() -> tuple[str, ...]:
"""Return the supported bundled checkpoint aliases."""
return tuple(_BUNDLED_CHECKPOINTS)
return tuple(CHECKPOINT_ALIASES.keys())
def resolve_checkpoint_path(path: PathLike | str) -> Path:
"""Resolve a local path, bundled alias, or Hugging Face checkpoint URI.
def resolve_checkpoint_from_huggingface(path: str) -> Path:
"""Resolve a Hugging Face checkpoint URI."""
try:
from huggingface_hub import hf_hub_download
except ImportError as error:
raise ValueError(
"Hugging Face checkpoint support is not installed. "
"Install it with `pip install batdetect2[huggingface]`."
) from error
repo_id, filename = _parse_huggingface_uri(path)
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
def resolve_checkpoint_path(path: PathLike | str | None = None) -> Path:
"""Resolve a local path, alias, or Hugging Face checkpoint URI.
Parameters
----------
path : PathLike | str
Local checkpoint path, bundled checkpoint alias, or a Hugging Face
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``.
path : PathLike | str | None
Local checkpoint path, checkpoint alias, or a Hugging Face
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. If
omitted, the default alias checkpoint is used.
Returns
-------
Path
Resolved local filesystem path to the checkpoint.
"""
if path is None:
path = DEFAULT_CHECKPOINT
if isinstance(path, str) and path.startswith("hf://"):
try:
from huggingface_hub import hf_hub_download
except ImportError as error:
raise ValueError(
"Hugging Face checkpoint support is not installed. "
"Install it with `uv sync --group huggingface` or "
"`pip install huggingface-hub`."
) from error
return resolve_checkpoint_from_huggingface(path)
repo_id, filename = _parse_huggingface_uri(path)
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
if isinstance(path, str) and path in CHECKPOINT_ALIASES:
return Path(CHECKPOINT_ALIASES[path])
if not isinstance(path, Path):
path = Path(path)
if path.exists():
return path
bundled_path = _BUNDLED_CHECKPOINTS.get(str(path))
if bundled_path is not None:
if not bundled_path.exists():
raise FileNotFoundError(
f"Bundled checkpoint is missing: {bundled_path}"
)
return bundled_path
return path.resolve()
bundled_names = ", ".join(get_bundled_checkpoint_names())
raise FileNotFoundError(
"Checkpoint not found: "
f"{path}. Expected a local path, a bundled checkpoint alias "
f"Checkpoint not found: {path}. "
"Expected a local path, a checkpoint alias "
f"({bundled_names}), or a Hugging Face URI."
)

View File

@ -132,15 +132,15 @@ class StoredConfig:
def load_model_from_checkpoint(
path: PathLike | str,
path: PathLike | str | None = None,
) -> tuple[Model, StoredConfig]:
"""Load a model and its configuration from a Lightning checkpoint.
Parameters
----------
path : PathLike
path : PathLike | str | None
Path to a ``.ckpt`` file produced by the BatDetect2 training
pipeline.
pipeline. If omitted, the default bundled checkpoint is used.
Returns
-------

View File

@ -307,6 +307,12 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
)
def test_api_from_checkpoint_defaults_to_bundled_model() -> None:
api = BatDetect2API.from_checkpoint()
assert api.model.class_names
@pytest.mark.slow
def test_user_can_evaluate_small_dataset_and_get_metrics(
api_v2: BatDetect2API,

View File

@ -1,6 +1,7 @@
"""CLI tests for finetune command."""
from pathlib import Path
from types import SimpleNamespace
import pytest
from click.testing import CliRunner
@ -25,8 +26,41 @@ def test_cli_finetune_help() -> None:
assert "--outputs-config" not in result.output
def test_cli_finetune_requires_model() -> None:
"""User story: finetune requires a checkpoint argument."""
def test_cli_finetune_defaults_to_bundled_model(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""User story: finetune can use the bundled checkpoint by default."""
called = {}
class FakeAPI:
def finetune(self, **kwargs):
called["finetune"] = kwargs
return None
class FakeBatDetect2API:
@classmethod
def from_checkpoint(cls, path=None, **kwargs):
called["path"] = path
called["from_checkpoint_kwargs"] = kwargs
return FakeAPI()
monkeypatch.setattr(
"batdetect2.api_v2.BatDetect2API",
FakeBatDetect2API,
)
monkeypatch.setattr(
"batdetect2.data.load_dataset_config",
lambda path: SimpleNamespace(path=path),
)
monkeypatch.setattr(
"batdetect2.data.load_dataset",
lambda config, base_dir=None: [],
)
monkeypatch.setattr(
"batdetect2.targets.TargetConfig.load",
lambda path: SimpleNamespace(path=path),
)
result = CliRunner().invoke(
cli,
@ -38,8 +72,9 @@ def test_cli_finetune_requires_model() -> None:
],
)
assert result.exit_code != 0
assert "--model" in result.output
assert result.exit_code == 0
assert called["path"] is None
assert "finetune" in called
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:

View File

@ -8,6 +8,7 @@ from soundevent import data
from batdetect2.train import TrainingConfig, run_train
from batdetect2.train.checkpoints import (
DEFAULT_BUNDLED_CHECKPOINT,
get_bundled_checkpoint_names,
resolve_checkpoint_path,
)
@ -144,13 +145,19 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged(
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
assert get_bundled_checkpoint_names() == (
"uk_same",
DEFAULT_BUNDLED_CHECKPOINT,
"batdetect2_uk_same",
)
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
resolved = resolve_checkpoint_path()
assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
resolved = resolve_checkpoint_path("uk_same")
resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
assert resolved.name == "batdetect2_uk_same.ckpt"
assert resolved.exists()