mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: default to bundled checkpoint
Fall back to the bundled uk_same model when no checkpoint is provided in the shared loader and fine-tune CLI. Keep tests aligned with the new default resolution behavior.
This commit is contained in:
parent
31054f64f6
commit
7c05fb8577
@ -656,7 +656,7 @@ class BatDetect2API:
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike | str,
|
||||
path: data.PathLike | str | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
|
||||
@ -5,6 +5,7 @@ import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.train.checkpoints import DEFAULT_BUNDLED_CHECKPOINT
|
||||
|
||||
__all__ = ["finetune_command"]
|
||||
|
||||
@ -16,9 +17,12 @@ __all__ = ["finetune_command"]
|
||||
@click.option(
|
||||
"--model",
|
||||
"model_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to a checkpoint or a Hugging Face URI to fine-tune from.",
|
||||
help=(
|
||||
"Path to a checkpoint, bundled checkpoint alias, or a Hugging Face "
|
||||
"URI to fine-tune from. Defaults to "
|
||||
f"'{DEFAULT_BUNDLED_CHECKPOINT}'."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--targets",
|
||||
@ -106,7 +110,7 @@ __all__ = ["finetune_command"]
|
||||
)
|
||||
def finetune_command(
|
||||
train_dataset: Path,
|
||||
model_path: str,
|
||||
model_path: str | None,
|
||||
targets_config: Path,
|
||||
val_dataset: Path | None = None,
|
||||
ckpt_dir: Path | None = None,
|
||||
|
||||
@ -8,19 +8,21 @@ from batdetect2.core import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"CheckpointConfig",
|
||||
"DEFAULT_CHECKPOINT",
|
||||
"build_checkpoint_callback",
|
||||
"get_bundled_checkpoint_names",
|
||||
"resolve_checkpoint_path",
|
||||
]
|
||||
|
||||
PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
||||
_BUNDLED_CHECKPOINTS = {
|
||||
"uk_same": _PACKAGE_ROOT
|
||||
DEFAULT_CHECKPOINT = "uk_same"
|
||||
CHECKPOINT_ALIASES = {
|
||||
DEFAULT_CHECKPOINT: PACKAGE_ROOT
|
||||
/ "models"
|
||||
/ "checkpoints"
|
||||
/ "batdetect2_uk_same.ckpt",
|
||||
"batdetect2_uk_same": _PACKAGE_ROOT
|
||||
"batdetect2_uk_same": PACKAGE_ROOT
|
||||
/ "models"
|
||||
/ "checkpoints"
|
||||
/ "batdetect2_uk_same.ckpt",
|
||||
@ -32,7 +34,6 @@ class CheckpointConfig(BaseConfig):
|
||||
monitor: str | None = None
|
||||
mode: str = "max"
|
||||
save_top_k: int = 1
|
||||
# Default to distributable inference checkpoints, not full train resumes.
|
||||
save_weights_only: bool = True
|
||||
filename: str | None = None
|
||||
save_last: bool | Literal["link"] = "link"
|
||||
@ -74,55 +75,55 @@ def build_checkpoint_callback(
|
||||
|
||||
def get_bundled_checkpoint_names() -> tuple[str, ...]:
|
||||
"""Return the supported bundled checkpoint aliases."""
|
||||
|
||||
return tuple(_BUNDLED_CHECKPOINTS)
|
||||
return tuple(CHECKPOINT_ALIASES.keys())
|
||||
|
||||
|
||||
def resolve_checkpoint_path(path: PathLike | str) -> Path:
|
||||
"""Resolve a local path, bundled alias, or Hugging Face checkpoint URI.
|
||||
def resolve_checkpoint_from_huggingface(path: str) -> Path:
|
||||
"""Resolve a Hugging Face checkpoint URI."""
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
except ImportError as error:
|
||||
raise ValueError(
|
||||
"Hugging Face checkpoint support is not installed. "
|
||||
"Install it with `pip install batdetect2[huggingface]`."
|
||||
) from error
|
||||
|
||||
repo_id, filename = _parse_huggingface_uri(path)
|
||||
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
|
||||
|
||||
|
||||
def resolve_checkpoint_path(path: PathLike | str | None = None) -> Path:
|
||||
"""Resolve a local path, alias, or Hugging Face checkpoint URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike | str
|
||||
Local checkpoint path, bundled checkpoint alias, or a Hugging Face
|
||||
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``.
|
||||
path : PathLike | str | None
|
||||
Local checkpoint path, checkpoint alias, or a Hugging Face
|
||||
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. If
|
||||
omitted, the default alias checkpoint is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
Resolved local filesystem path to the checkpoint.
|
||||
"""
|
||||
if path is None:
|
||||
path = DEFAULT_CHECKPOINT
|
||||
|
||||
if isinstance(path, str) and path.startswith("hf://"):
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
except ImportError as error:
|
||||
raise ValueError(
|
||||
"Hugging Face checkpoint support is not installed. "
|
||||
"Install it with `uv sync --group huggingface` or "
|
||||
"`pip install huggingface-hub`."
|
||||
) from error
|
||||
return resolve_checkpoint_from_huggingface(path)
|
||||
|
||||
repo_id, filename = _parse_huggingface_uri(path)
|
||||
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
|
||||
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
if isinstance(path, str) and path in CHECKPOINT_ALIASES:
|
||||
return Path(CHECKPOINT_ALIASES[path])
|
||||
|
||||
path = Path(path)
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
bundled_path = _BUNDLED_CHECKPOINTS.get(str(path))
|
||||
if bundled_path is not None:
|
||||
if not bundled_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Bundled checkpoint is missing: {bundled_path}"
|
||||
)
|
||||
return bundled_path
|
||||
return path.resolve()
|
||||
|
||||
bundled_names = ", ".join(get_bundled_checkpoint_names())
|
||||
raise FileNotFoundError(
|
||||
"Checkpoint not found: "
|
||||
f"{path}. Expected a local path, a bundled checkpoint alias "
|
||||
f"Checkpoint not found: {path}. "
|
||||
"Expected a local path, a checkpoint alias "
|
||||
f"({bundled_names}), or a Hugging Face URI."
|
||||
)
|
||||
|
||||
|
||||
@ -132,15 +132,15 @@ class StoredConfig:
|
||||
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike | str,
|
||||
path: PathLike | str | None = None,
|
||||
) -> tuple[Model, StoredConfig]:
|
||||
"""Load a model and its configuration from a Lightning checkpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
path : PathLike | str | None
|
||||
Path to a ``.ckpt`` file produced by the BatDetect2 training
|
||||
pipeline.
|
||||
pipeline. If omitted, the default bundled checkpoint is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
@ -307,6 +307,12 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||
)
|
||||
|
||||
|
||||
def test_api_from_checkpoint_defaults_to_bundled_model() -> None:
|
||||
api = BatDetect2API.from_checkpoint()
|
||||
|
||||
assert api.model.class_names
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||
api_v2: BatDetect2API,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""CLI tests for finetune command."""
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
@ -25,8 +26,41 @@ def test_cli_finetune_help() -> None:
|
||||
assert "--outputs-config" not in result.output
|
||||
|
||||
|
||||
def test_cli_finetune_requires_model() -> None:
|
||||
"""User story: finetune requires a checkpoint argument."""
|
||||
def test_cli_finetune_defaults_to_bundled_model(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""User story: finetune can use the bundled checkpoint by default."""
|
||||
|
||||
called = {}
|
||||
|
||||
class FakeAPI:
|
||||
def finetune(self, **kwargs):
|
||||
called["finetune"] = kwargs
|
||||
return None
|
||||
|
||||
class FakeBatDetect2API:
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path=None, **kwargs):
|
||||
called["path"] = path
|
||||
called["from_checkpoint_kwargs"] = kwargs
|
||||
return FakeAPI()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.api_v2.BatDetect2API",
|
||||
FakeBatDetect2API,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.data.load_dataset_config",
|
||||
lambda path: SimpleNamespace(path=path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.data.load_dataset",
|
||||
lambda config, base_dir=None: [],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.targets.TargetConfig.load",
|
||||
lambda path: SimpleNamespace(path=path),
|
||||
)
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
@ -38,8 +72,9 @@ def test_cli_finetune_requires_model() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "--model" in result.output
|
||||
assert result.exit_code == 0
|
||||
assert called["path"] is None
|
||||
assert "finetune" in called
|
||||
|
||||
|
||||
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
|
||||
|
||||
@ -8,6 +8,7 @@ from soundevent import data
|
||||
|
||||
from batdetect2.train import TrainingConfig, run_train
|
||||
from batdetect2.train.checkpoints import (
|
||||
DEFAULT_BUNDLED_CHECKPOINT,
|
||||
get_bundled_checkpoint_names,
|
||||
resolve_checkpoint_path,
|
||||
)
|
||||
@ -144,13 +145,19 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
||||
|
||||
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
||||
assert get_bundled_checkpoint_names() == (
|
||||
"uk_same",
|
||||
DEFAULT_BUNDLED_CHECKPOINT,
|
||||
"batdetect2_uk_same",
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
|
||||
resolved = resolve_checkpoint_path()
|
||||
|
||||
assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
|
||||
resolved = resolve_checkpoint_path("uk_same")
|
||||
resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
|
||||
|
||||
assert resolved.name == "batdetect2_uk_same.ckpt"
|
||||
assert resolved.exists()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user