feat: load checkpoints from Hugging Face

This commit is contained in:
mbsantiago 2026-05-05 15:46:39 +01:00
parent 5a974711b0
commit f5afa9881c
10 changed files with 121 additions and 28 deletions

View File

@ -11,6 +11,7 @@ dependencies = [
"click>=8.1.7",
"deepmerge>=2.0",
"hydra-core>=1.3.2",
"huggingface-hub>=0.32.0",
"librosa>=0.10.1",
"lightning[extra]==2.5.0",
"loguru>=0.7.3",

View File

@ -653,7 +653,7 @@ class BatDetect2API:
@classmethod
def from_checkpoint(
cls,
path: data.PathLike,
path: data.PathLike | str,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,

View File

@ -12,14 +12,8 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate", short_help="Evaluate a model checkpoint.")
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("model_path", type=str)
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option(
"--targets",
"targets_config",
type=click.Path(exists=True),
help="Path to targets config file.",
)
@click.option(
"--audio-config",
type=click.Path(exists=True),
@ -80,10 +74,9 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
default=0,
)
def evaluate_command(
model_path: Path,
model_path: str,
test_dataset: Path,
base_dir: Path,
targets_config: Path | None,
audio_config: Path | None,
evaluation_config: Path | None,
inference_config: Path | None,

View File

@ -17,8 +17,8 @@ __all__ = ["finetune_command"]
"--model",
"model_path",
required=True,
type=click.Path(exists=True),
help="Path to a checkpoint to fine-tune from.",
type=str,
help="Path to a checkpoint or a Hugging Face URI to fine-tune from.",
)
@click.option(
"--targets",
@ -106,7 +106,7 @@ __all__ = ["finetune_command"]
)
def finetune_command(
train_dataset: Path,
model_path: Path,
model_path: str,
targets_config: Path,
val_dataset: Path | None = None,
ckpt_dir: Path | None = None,

View File

@ -102,7 +102,7 @@ def common_predict_options(func):
def _build_api(
model_path: Path,
model_path: str,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
@ -144,7 +144,7 @@ def _build_api(
def _run_prediction(
model_path: Path,
model_path: str,
audio_files: list[Path],
output_path: Path,
audio_config: Path | None,
@ -195,12 +195,12 @@ def _run_prediction(
name="directory",
short_help="Predict on audio files in a directory.",
)
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("model_path", type=str)
@click.argument("audio_dir", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@common_predict_options
def predict_directory_command(
model_path: Path,
model_path: str,
audio_dir: Path,
output_path: Path,
audio_config: Path | None,
@ -239,12 +239,12 @@ def predict_directory_command(
name="file_list",
short_help="Predict on paths listed in a text file.",
)
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("model_path", type=str)
@click.argument("file_list", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@common_predict_options
def predict_file_list_command(
model_path: Path,
model_path: str,
file_list: Path,
output_path: Path,
audio_config: Path | None,
@ -287,12 +287,12 @@ def predict_file_list_command(
name="dataset",
short_help="Predict on recordings from a dataset config.",
)
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("model_path", type=str)
@click.argument("dataset_path", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@common_predict_options
def predict_dataset_command(
model_path: Path,
model_path: str,
dataset_path: Path,
output_path: Path,
audio_config: Path | None,

View File

@ -18,10 +18,10 @@ __all__ = ["train_command"]
@click.option(
"--model",
"model_path",
type=click.Path(exists=True),
type=str,
help=(
"Path to a checkpoint to continue training from. If omitted, "
"training starts from a fresh model config."
"Path to a checkpoint or a Hugging Face URI to continue training "
"from. If omitted, training starts from a fresh model config."
),
)
@click.option(
@ -118,7 +118,7 @@ __all__ = ["train_command"]
def train_command(
train_dataset: Path,
val_dataset: Path | None = None,
model_path: Path | None = None,
model_path: str | None = None,
ckpt_dir: Path | None = None,
log_dir: Path | None = None,
base_dir: Path | None = None,

View File

@ -1,4 +1,7 @@
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
from batdetect2.train.checkpoints import (
DEFAULT_CHECKPOINT_DIR,
resolve_checkpoint_path,
)
from batdetect2.train.config import TrainingConfig
from batdetect2.train.lightning import (
TrainingModule,
@ -26,5 +29,6 @@ __all__ = [
"TrainingModule",
"build_trainer",
"load_model_from_checkpoint",
"resolve_checkpoint_path",
"run_train",
]

View File

@ -1,13 +1,16 @@
from pathlib import Path
from typing import Literal
from huggingface_hub import hf_hub_download
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from soundevent.data import PathLike
from batdetect2.core import BaseConfig
__all__ = [
"CheckpointConfig",
"build_checkpoint_callback",
"resolve_checkpoint_path",
]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
@ -53,3 +56,51 @@ def build_checkpoint_callback(
save_last=config.save_last,
every_n_epochs=config.every_n_epochs,
)
def resolve_checkpoint_path(path: PathLike | str) -> Path:
"""Resolve a local path or Hugging Face checkpoint URI.
Parameters
----------
path : PathLike | str
Local checkpoint path or a Hugging Face URI of the form
``hf://owner/repo/path/to/checkpoint.ckpt``.
Returns
-------
Path
Resolved local filesystem path to the checkpoint.
"""
if isinstance(path, str) and path.startswith("hf://"):
repo_id, filename = _parse_huggingface_uri(path)
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
if not isinstance(path, Path):
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Checkpoint not found: {path}")
return path
def _parse_huggingface_uri(uri: str) -> tuple[str, str]:
prefix = "hf://"
if not uri.startswith(prefix):
raise ValueError(
"Hugging Face checkpoint URIs must start with 'hf://'."
)
without_prefix = uri.removeprefix(prefix).strip("/")
parts = without_prefix.split("/")
if len(parts) < 3:
raise ValueError(
"Hugging Face checkpoint URIs must be in the form "
"'hf://owner/repo/path/to/checkpoint.ckpt'."
)
repo_id = "/".join(parts[:2])
filename = "/".join(parts[2:])
return repo_id, filename

View File

@ -6,6 +6,7 @@ from soundevent.data import PathLike
from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.models.types import ModelOutput
from batdetect2.targets import TargetConfig
from batdetect2.train.checkpoints import resolve_checkpoint_path
from batdetect2.train.config import TrainingConfig
from batdetect2.train.losses import build_loss
from batdetect2.train.optimizers import build_optimizer
@ -130,7 +131,7 @@ class StoredConfig:
def load_model_from_checkpoint(
path: PathLike,
path: PathLike | str,
) -> tuple[Model, StoredConfig]:
"""Load a model and its configuration from a Lightning checkpoint.
@ -146,7 +147,8 @@ def load_model_from_checkpoint(
The restored ``Model`` instance and the ``ModelConfig`` that
describes its architecture, preprocessing, and postprocessing.
"""
module = TrainingModule.load_from_checkpoint(path) # type: ignore
resolved_path = resolve_checkpoint_path(path)
module = TrainingModule.load_from_checkpoint(resolved_path)
training_config = TrainingConfig.model_validate(module.train_config)
model_config = ModelConfig.model_validate(module.model_config)
targets_config = TargetConfig.model_validate(module.targets_config)

View File

@ -4,6 +4,7 @@ import pytest
from soundevent import data
from batdetect2.train import TrainingConfig, run_train
from batdetect2.train.checkpoints import resolve_checkpoint_path
pytestmark = pytest.mark.slow
@ -92,3 +93,44 @@ def test_train_controls_which_checkpoints_are_kept(
assert last_checkpoints
assert len(best_checkpoints) == 1
assert "epoch" in best_checkpoints[0].name
def test_resolve_checkpoint_path_returns_local_path_unchanged(
tmp_path: Path,
) -> None:
local_path = tmp_path / "model.ckpt"
local_path.write_bytes(b"checkpoint")
assert resolve_checkpoint_path(local_path) == local_path
assert resolve_checkpoint_path(str(local_path)) == local_path
def test_resolve_checkpoint_path_downloads_huggingface_checkpoint(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
expected_path = tmp_path / "downloaded.ckpt"
def fake_hf_hub_download(repo_id: str, filename: str) -> str:
assert repo_id == "owner/repo"
assert filename == "weights/model.ckpt"
return str(expected_path)
monkeypatch.setattr(
"batdetect2.train.checkpoints.hf_hub_download",
fake_hf_hub_download,
)
resolved = resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt")
assert resolved == expected_path
def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
with pytest.raises(ValueError, match="hf://owner/repo/path/to"):
resolve_checkpoint_path("hf://owner/repo")
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
with pytest.raises(FileNotFoundError, match="Checkpoint not found"):
resolve_checkpoint_path("missing.ckpt")