From 7c05fb8577b01c6bd163871fb489cdf732b12c9d Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 6 May 2026 10:33:04 +0100 Subject: [PATCH] feat: default to bundled checkpoint Fall back to the bundled uk_same model when no checkpoint is provided in the shared loader and fine-tune CLI. Keep tests aligned with the new default resolution behavior. --- src/batdetect2/api_v2.py | 2 +- src/batdetect2/cli/finetune.py | 10 ++-- src/batdetect2/train/checkpoints.py | 73 ++++++++++++++-------------- src/batdetect2/train/lightning.py | 6 +-- tests/test_api_v2/test_api_v2.py | 6 +++ tests/test_cli/test_finetune.py | 43 ++++++++++++++-- tests/test_train/test_checkpoints.py | 11 ++++- 7 files changed, 102 insertions(+), 49 deletions(-) diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 442e8b9..037bf15 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -656,7 +656,7 @@ class BatDetect2API: @classmethod def from_checkpoint( cls, - path: data.PathLike | str, + path: data.PathLike | str | None = None, audio_config: AudioConfig | None = None, train_config: TrainingConfig | None = None, evaluation_config: EvaluationConfig | None = None, diff --git a/src/batdetect2/cli/finetune.py b/src/batdetect2/cli/finetune.py index b2d5619..7521274 100644 --- a/src/batdetect2/cli/finetune.py +++ b/src/batdetect2/cli/finetune.py @@ -5,6 +5,7 @@ import click from loguru import logger from batdetect2.cli.base import cli +from batdetect2.train.checkpoints import DEFAULT_BUNDLED_CHECKPOINT __all__ = ["finetune_command"] @@ -16,9 +17,12 @@ __all__ = ["finetune_command"] @click.option( "--model", "model_path", - required=True, type=str, - help="Path to a checkpoint or a Hugging Face URI to fine-tune from.", + help=( + "Path to a checkpoint, bundled checkpoint alias, or a Hugging Face " + "URI to fine-tune from. Defaults to " + f"'{DEFAULT_BUNDLED_CHECKPOINT}'." + ), ) @click.option( "--targets", @@ -106,7 +110,7 @@ __all__ = ["finetune_command"] ) def finetune_command( train_dataset: Path, - model_path: str, + model_path: str | None, targets_config: Path, val_dataset: Path | None = None, ckpt_dir: Path | None = None, diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py index f9d851e..a443743 100644 --- a/src/batdetect2/train/checkpoints.py +++ b/src/batdetect2/train/checkpoints.py @@ -8,19 +8,21 @@ from batdetect2.core import BaseConfig __all__ = [ "CheckpointConfig", + "DEFAULT_CHECKPOINT", "build_checkpoint_callback", "get_bundled_checkpoint_names", "resolve_checkpoint_path", ] +PACKAGE_ROOT = Path(__file__).resolve().parents[1] DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" -_PACKAGE_ROOT = Path(__file__).resolve().parents[1] -_BUNDLED_CHECKPOINTS = { - "uk_same": _PACKAGE_ROOT +DEFAULT_CHECKPOINT = "uk_same" +CHECKPOINT_ALIASES = { + DEFAULT_CHECKPOINT: PACKAGE_ROOT / "models" / "checkpoints" / "batdetect2_uk_same.ckpt", - "batdetect2_uk_same": _PACKAGE_ROOT + "batdetect2_uk_same": PACKAGE_ROOT / "models" / "checkpoints" / "batdetect2_uk_same.ckpt", @@ -32,7 +34,6 @@ class CheckpointConfig(BaseConfig): monitor: str | None = None mode: str = "max" save_top_k: int = 1 - # Default to distributable inference checkpoints, not full train resumes. save_weights_only: bool = True filename: str | None = None save_last: bool | Literal["link"] = "link" @@ -74,55 +75,55 @@ def build_checkpoint_callback( def get_bundled_checkpoint_names() -> tuple[str, ...]: """Return the supported bundled checkpoint aliases.""" - - return tuple(_BUNDLED_CHECKPOINTS) + return tuple(CHECKPOINT_ALIASES.keys()) -def resolve_checkpoint_path(path: PathLike | str) -> Path: - """Resolve a local path, bundled alias, or Hugging Face checkpoint URI. +def resolve_checkpoint_from_huggingface(path: str) -> Path: + """Resolve a Hugging Face checkpoint URI.""" + try: + from huggingface_hub import hf_hub_download + except ImportError as error: + raise ValueError( + "Hugging Face checkpoint support is not installed. " + "Install it with `pip install batdetect2[huggingface]`." + ) from error + + repo_id, filename = _parse_huggingface_uri(path) + return Path(hf_hub_download(repo_id=repo_id, filename=filename)) + + +def resolve_checkpoint_path(path: PathLike | str | None = None) -> Path: + """Resolve a local path, alias, or Hugging Face checkpoint URI. Parameters ---------- - path : PathLike | str - Local checkpoint path, bundled checkpoint alias, or a Hugging Face - URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. + path : PathLike | str | None + Local checkpoint path, checkpoint alias, or a Hugging Face + URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. If + omitted, the default alias checkpoint is used. Returns ------- Path Resolved local filesystem path to the checkpoint. """ + if path is None: + path = DEFAULT_CHECKPOINT + if isinstance(path, str) and path.startswith("hf://"): - try: - from huggingface_hub import hf_hub_download - except ImportError as error: - raise ValueError( - "Hugging Face checkpoint support is not installed. " - "Install it with `uv sync --group huggingface` or " - "`pip install huggingface-hub`." - ) from error + return resolve_checkpoint_from_huggingface(path) - repo_id, filename = _parse_huggingface_uri(path) - return Path(hf_hub_download(repo_id=repo_id, filename=filename)) - - if not isinstance(path, Path): - path = Path(path) + if isinstance(path, str) and path in CHECKPOINT_ALIASES: + return Path(CHECKPOINT_ALIASES[path]) + path = Path(path) if path.exists(): - return path - - bundled_path = _BUNDLED_CHECKPOINTS.get(str(path)) - if bundled_path is not None: - if not bundled_path.exists(): - raise FileNotFoundError( - f"Bundled checkpoint is missing: {bundled_path}" - ) - return bundled_path + return path.resolve() bundled_names = ", ".join(get_bundled_checkpoint_names()) raise FileNotFoundError( - "Checkpoint not found: " - f"{path}. Expected a local path, a bundled checkpoint alias " + f"Checkpoint not found: {path}. " + "Expected a local path, a checkpoint alias " f"({bundled_names}), or a Hugging Face URI." ) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 0ccf5b0..30fe1af 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -132,15 +132,15 @@ class StoredConfig: def load_model_from_checkpoint( - path: PathLike | str, + path: PathLike | str | None = None, ) -> tuple[Model, StoredConfig]: """Load a model and its configuration from a Lightning checkpoint. Parameters ---------- - path : PathLike + path : PathLike | str | None Path to a ``.ckpt`` file produced by the BatDetect2 training - pipeline. + pipeline. If omitted, the default bundled checkpoint is used. Returns ------- diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index 19e0a40..a4b2758 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -307,6 +307,12 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged( ) +def test_api_from_checkpoint_defaults_to_bundled_model() -> None: + api = BatDetect2API.from_checkpoint() + + assert api.model.class_names + + @pytest.mark.slow def test_user_can_evaluate_small_dataset_and_get_metrics( api_v2: BatDetect2API, diff --git a/tests/test_cli/test_finetune.py b/tests/test_cli/test_finetune.py index 270ca58..cdc02aa 100644 --- a/tests/test_cli/test_finetune.py +++ b/tests/test_cli/test_finetune.py @@ -1,6 +1,7 @@ """CLI tests for finetune command.""" from pathlib import Path +from types import SimpleNamespace import pytest from click.testing import CliRunner @@ -25,8 +26,41 @@ def test_cli_finetune_help() -> None: assert "--outputs-config" not in result.output -def test_cli_finetune_requires_model() -> None: - """User story: finetune requires a checkpoint argument.""" +def test_cli_finetune_defaults_to_bundled_model( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """User story: finetune can use the bundled checkpoint by default.""" + + called = {} + + class FakeAPI: + def finetune(self, **kwargs): + called["finetune"] = kwargs + return None + + class FakeBatDetect2API: + @classmethod + def from_checkpoint(cls, path=None, **kwargs): + called["path"] = path + called["from_checkpoint_kwargs"] = kwargs + return FakeAPI() + + monkeypatch.setattr( + "batdetect2.api_v2.BatDetect2API", + FakeBatDetect2API, + ) + monkeypatch.setattr( + "batdetect2.data.load_dataset_config", + lambda path: SimpleNamespace(path=path), + ) + monkeypatch.setattr( + "batdetect2.data.load_dataset", + lambda config, base_dir=None: [], + ) + monkeypatch.setattr( + "batdetect2.targets.TargetConfig.load", + lambda path: SimpleNamespace(path=path), + ) result = CliRunner().invoke( cli, @@ -38,8 +72,9 @@ def test_cli_finetune_requires_model() -> None: ], ) - assert result.exit_code != 0 - assert "--model" in result.output + assert result.exit_code == 0 + assert called["path"] is None + assert "finetune" in called def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None: diff --git a/tests/test_train/test_checkpoints.py b/tests/test_train/test_checkpoints.py index 7d99f0a..77a2856 100644 --- a/tests/test_train/test_checkpoints.py +++ b/tests/test_train/test_checkpoints.py @@ -8,6 +8,7 @@ from soundevent import data from batdetect2.train import TrainingConfig, run_train from batdetect2.train.checkpoints import ( + DEFAULT_BUNDLED_CHECKPOINT, get_bundled_checkpoint_names, resolve_checkpoint_path, ) @@ -144,13 +145,19 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged( def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None: assert get_bundled_checkpoint_names() == ( - "uk_same", + DEFAULT_BUNDLED_CHECKPOINT, "batdetect2_uk_same", ) +def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None: + resolved = resolve_checkpoint_path() + + assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT) + + def test_resolve_checkpoint_path_accepts_bundled_alias() -> None: - resolved = resolve_checkpoint_path("uk_same") + resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT) assert resolved.name == "batdetect2_uk_same.ckpt" assert resolved.exists()