mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: default to bundled checkpoint
Fall back to the bundled uk_same model when no checkpoint is provided in the shared loader and fine-tune CLI. Keep tests aligned with the new default resolution behavior.
This commit is contained in:
parent
31054f64f6
commit
7c05fb8577
@ -656,7 +656,7 @@ class BatDetect2API:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
path: data.PathLike | str,
|
path: data.PathLike | str | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
evaluation_config: EvaluationConfig | None = None,
|
evaluation_config: EvaluationConfig | None = None,
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.train.checkpoints import DEFAULT_BUNDLED_CHECKPOINT
|
||||||
|
|
||||||
__all__ = ["finetune_command"]
|
__all__ = ["finetune_command"]
|
||||||
|
|
||||||
@ -16,9 +17,12 @@ __all__ = ["finetune_command"]
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--model",
|
"--model",
|
||||||
"model_path",
|
"model_path",
|
||||||
required=True,
|
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to a checkpoint or a Hugging Face URI to fine-tune from.",
|
help=(
|
||||||
|
"Path to a checkpoint, bundled checkpoint alias, or a Hugging Face "
|
||||||
|
"URI to fine-tune from. Defaults to "
|
||||||
|
f"'{DEFAULT_BUNDLED_CHECKPOINT}'."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--targets",
|
"--targets",
|
||||||
@ -106,7 +110,7 @@ __all__ = ["finetune_command"]
|
|||||||
)
|
)
|
||||||
def finetune_command(
|
def finetune_command(
|
||||||
train_dataset: Path,
|
train_dataset: Path,
|
||||||
model_path: str,
|
model_path: str | None,
|
||||||
targets_config: Path,
|
targets_config: Path,
|
||||||
val_dataset: Path | None = None,
|
val_dataset: Path | None = None,
|
||||||
ckpt_dir: Path | None = None,
|
ckpt_dir: Path | None = None,
|
||||||
|
|||||||
@ -8,19 +8,21 @@ from batdetect2.core import BaseConfig
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CheckpointConfig",
|
"CheckpointConfig",
|
||||||
|
"DEFAULT_CHECKPOINT",
|
||||||
"build_checkpoint_callback",
|
"build_checkpoint_callback",
|
||||||
"get_bundled_checkpoint_names",
|
"get_bundled_checkpoint_names",
|
||||||
"resolve_checkpoint_path",
|
"resolve_checkpoint_path",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
||||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||||
_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
DEFAULT_CHECKPOINT = "uk_same"
|
||||||
_BUNDLED_CHECKPOINTS = {
|
CHECKPOINT_ALIASES = {
|
||||||
"uk_same": _PACKAGE_ROOT
|
DEFAULT_CHECKPOINT: PACKAGE_ROOT
|
||||||
/ "models"
|
/ "models"
|
||||||
/ "checkpoints"
|
/ "checkpoints"
|
||||||
/ "batdetect2_uk_same.ckpt",
|
/ "batdetect2_uk_same.ckpt",
|
||||||
"batdetect2_uk_same": _PACKAGE_ROOT
|
"batdetect2_uk_same": PACKAGE_ROOT
|
||||||
/ "models"
|
/ "models"
|
||||||
/ "checkpoints"
|
/ "checkpoints"
|
||||||
/ "batdetect2_uk_same.ckpt",
|
/ "batdetect2_uk_same.ckpt",
|
||||||
@ -32,7 +34,6 @@ class CheckpointConfig(BaseConfig):
|
|||||||
monitor: str | None = None
|
monitor: str | None = None
|
||||||
mode: str = "max"
|
mode: str = "max"
|
||||||
save_top_k: int = 1
|
save_top_k: int = 1
|
||||||
# Default to distributable inference checkpoints, not full train resumes.
|
|
||||||
save_weights_only: bool = True
|
save_weights_only: bool = True
|
||||||
filename: str | None = None
|
filename: str | None = None
|
||||||
save_last: bool | Literal["link"] = "link"
|
save_last: bool | Literal["link"] = "link"
|
||||||
@ -74,55 +75,55 @@ def build_checkpoint_callback(
|
|||||||
|
|
||||||
def get_bundled_checkpoint_names() -> tuple[str, ...]:
|
def get_bundled_checkpoint_names() -> tuple[str, ...]:
|
||||||
"""Return the supported bundled checkpoint aliases."""
|
"""Return the supported bundled checkpoint aliases."""
|
||||||
|
return tuple(CHECKPOINT_ALIASES.keys())
|
||||||
return tuple(_BUNDLED_CHECKPOINTS)
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_checkpoint_path(path: PathLike | str) -> Path:
|
def resolve_checkpoint_from_huggingface(path: str) -> Path:
|
||||||
"""Resolve a local path, bundled alias, or Hugging Face checkpoint URI.
|
"""Resolve a Hugging Face checkpoint URI."""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
except ImportError as error:
|
||||||
|
raise ValueError(
|
||||||
|
"Hugging Face checkpoint support is not installed. "
|
||||||
|
"Install it with `pip install batdetect2[huggingface]`."
|
||||||
|
) from error
|
||||||
|
|
||||||
|
repo_id, filename = _parse_huggingface_uri(path)
|
||||||
|
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_checkpoint_path(path: PathLike | str | None = None) -> Path:
|
||||||
|
"""Resolve a local path, alias, or Hugging Face checkpoint URI.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : PathLike | str
|
path : PathLike | str | None
|
||||||
Local checkpoint path, bundled checkpoint alias, or a Hugging Face
|
Local checkpoint path, checkpoint alias, or a Hugging Face
|
||||||
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``.
|
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. If
|
||||||
|
omitted, the default alias checkpoint is used.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Path
|
Path
|
||||||
Resolved local filesystem path to the checkpoint.
|
Resolved local filesystem path to the checkpoint.
|
||||||
"""
|
"""
|
||||||
|
if path is None:
|
||||||
|
path = DEFAULT_CHECKPOINT
|
||||||
|
|
||||||
if isinstance(path, str) and path.startswith("hf://"):
|
if isinstance(path, str) and path.startswith("hf://"):
|
||||||
try:
|
return resolve_checkpoint_from_huggingface(path)
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
except ImportError as error:
|
|
||||||
raise ValueError(
|
|
||||||
"Hugging Face checkpoint support is not installed. "
|
|
||||||
"Install it with `uv sync --group huggingface` or "
|
|
||||||
"`pip install huggingface-hub`."
|
|
||||||
) from error
|
|
||||||
|
|
||||||
repo_id, filename = _parse_huggingface_uri(path)
|
if isinstance(path, str) and path in CHECKPOINT_ALIASES:
|
||||||
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
|
return Path(CHECKPOINT_ALIASES[path])
|
||||||
|
|
||||||
if not isinstance(path, Path):
|
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
if path.exists():
|
if path.exists():
|
||||||
return path
|
return path.resolve()
|
||||||
|
|
||||||
bundled_path = _BUNDLED_CHECKPOINTS.get(str(path))
|
|
||||||
if bundled_path is not None:
|
|
||||||
if not bundled_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Bundled checkpoint is missing: {bundled_path}"
|
|
||||||
)
|
|
||||||
return bundled_path
|
|
||||||
|
|
||||||
bundled_names = ", ".join(get_bundled_checkpoint_names())
|
bundled_names = ", ".join(get_bundled_checkpoint_names())
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
"Checkpoint not found: "
|
f"Checkpoint not found: {path}. "
|
||||||
f"{path}. Expected a local path, a bundled checkpoint alias "
|
"Expected a local path, a checkpoint alias "
|
||||||
f"({bundled_names}), or a Hugging Face URI."
|
f"({bundled_names}), or a Hugging Face URI."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -132,15 +132,15 @@ class StoredConfig:
|
|||||||
|
|
||||||
|
|
||||||
def load_model_from_checkpoint(
|
def load_model_from_checkpoint(
|
||||||
path: PathLike | str,
|
path: PathLike | str | None = None,
|
||||||
) -> tuple[Model, StoredConfig]:
|
) -> tuple[Model, StoredConfig]:
|
||||||
"""Load a model and its configuration from a Lightning checkpoint.
|
"""Load a model and its configuration from a Lightning checkpoint.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : PathLike
|
path : PathLike | str | None
|
||||||
Path to a ``.ckpt`` file produced by the BatDetect2 training
|
Path to a ``.ckpt`` file produced by the BatDetect2 training
|
||||||
pipeline.
|
pipeline. If omitted, the default bundled checkpoint is used.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
|||||||
@ -307,6 +307,12 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_from_checkpoint_defaults_to_bundled_model() -> None:
|
||||||
|
api = BatDetect2API.from_checkpoint()
|
||||||
|
|
||||||
|
assert api.model.class_names
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||||
api_v2: BatDetect2API,
|
api_v2: BatDetect2API,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""CLI tests for finetune command."""
|
"""CLI tests for finetune command."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
@ -25,8 +26,41 @@ def test_cli_finetune_help() -> None:
|
|||||||
assert "--outputs-config" not in result.output
|
assert "--outputs-config" not in result.output
|
||||||
|
|
||||||
|
|
||||||
def test_cli_finetune_requires_model() -> None:
|
def test_cli_finetune_defaults_to_bundled_model(
|
||||||
"""User story: finetune requires a checkpoint argument."""
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""User story: finetune can use the bundled checkpoint by default."""
|
||||||
|
|
||||||
|
called = {}
|
||||||
|
|
||||||
|
class FakeAPI:
|
||||||
|
def finetune(self, **kwargs):
|
||||||
|
called["finetune"] = kwargs
|
||||||
|
return None
|
||||||
|
|
||||||
|
class FakeBatDetect2API:
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(cls, path=None, **kwargs):
|
||||||
|
called["path"] = path
|
||||||
|
called["from_checkpoint_kwargs"] = kwargs
|
||||||
|
return FakeAPI()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"batdetect2.api_v2.BatDetect2API",
|
||||||
|
FakeBatDetect2API,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"batdetect2.data.load_dataset_config",
|
||||||
|
lambda path: SimpleNamespace(path=path),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"batdetect2.data.load_dataset",
|
||||||
|
lambda config, base_dir=None: [],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"batdetect2.targets.TargetConfig.load",
|
||||||
|
lambda path: SimpleNamespace(path=path),
|
||||||
|
)
|
||||||
|
|
||||||
result = CliRunner().invoke(
|
result = CliRunner().invoke(
|
||||||
cli,
|
cli,
|
||||||
@ -38,8 +72,9 @@ def test_cli_finetune_requires_model() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.exit_code != 0
|
assert result.exit_code == 0
|
||||||
assert "--model" in result.output
|
assert called["path"] is None
|
||||||
|
assert "finetune" in called
|
||||||
|
|
||||||
|
|
||||||
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
|
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.train import TrainingConfig, run_train
|
from batdetect2.train import TrainingConfig, run_train
|
||||||
from batdetect2.train.checkpoints import (
|
from batdetect2.train.checkpoints import (
|
||||||
|
DEFAULT_BUNDLED_CHECKPOINT,
|
||||||
get_bundled_checkpoint_names,
|
get_bundled_checkpoint_names,
|
||||||
resolve_checkpoint_path,
|
resolve_checkpoint_path,
|
||||||
)
|
)
|
||||||
@ -144,13 +145,19 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
|||||||
|
|
||||||
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
||||||
assert get_bundled_checkpoint_names() == (
|
assert get_bundled_checkpoint_names() == (
|
||||||
"uk_same",
|
DEFAULT_BUNDLED_CHECKPOINT,
|
||||||
"batdetect2_uk_same",
|
"batdetect2_uk_same",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
|
||||||
|
resolved = resolve_checkpoint_path()
|
||||||
|
|
||||||
|
assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
|
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
|
||||||
resolved = resolve_checkpoint_path("uk_same")
|
resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
|
||||||
|
|
||||||
assert resolved.name == "batdetect2_uk_same.ckpt"
|
assert resolved.name == "batdetect2_uk_same.ckpt"
|
||||||
assert resolved.exists()
|
assert resolved.exists()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user