feat: load checkpoints from Hugging Face

This commit is contained in:
mbsantiago 2026-05-05 15:46:39 +01:00
parent 5a974711b0
commit f5afa9881c
10 changed files with 121 additions and 28 deletions

View File

@ -11,6 +11,7 @@ dependencies = [
"click>=8.1.7", "click>=8.1.7",
"deepmerge>=2.0", "deepmerge>=2.0",
"hydra-core>=1.3.2", "hydra-core>=1.3.2",
"huggingface-hub>=0.32.0",
"librosa>=0.10.1", "librosa>=0.10.1",
"lightning[extra]==2.5.0", "lightning[extra]==2.5.0",
"loguru>=0.7.3", "loguru>=0.7.3",

View File

@ -653,7 +653,7 @@ class BatDetect2API:
@classmethod @classmethod
def from_checkpoint( def from_checkpoint(
cls, cls,
path: data.PathLike, path: data.PathLike | str,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None, evaluation_config: EvaluationConfig | None = None,

View File

@ -12,14 +12,8 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate", short_help="Evaluate a model checkpoint.") @cli.command(name="evaluate", short_help="Evaluate a model checkpoint.")
@click.argument("model_path", type=click.Path(exists=True)) @click.argument("model_path", type=str)
@click.argument("test_dataset", type=click.Path(exists=True)) @click.argument("test_dataset", type=click.Path(exists=True))
@click.option(
"--targets",
"targets_config",
type=click.Path(exists=True),
help="Path to targets config file.",
)
@click.option( @click.option(
"--audio-config", "--audio-config",
type=click.Path(exists=True), type=click.Path(exists=True),
@ -80,10 +74,9 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
default=0, default=0,
) )
def evaluate_command( def evaluate_command(
model_path: Path, model_path: str,
test_dataset: Path, test_dataset: Path,
base_dir: Path, base_dir: Path,
targets_config: Path | None,
audio_config: Path | None, audio_config: Path | None,
evaluation_config: Path | None, evaluation_config: Path | None,
inference_config: Path | None, inference_config: Path | None,

View File

@ -17,8 +17,8 @@ __all__ = ["finetune_command"]
"--model", "--model",
"model_path", "model_path",
required=True, required=True,
type=click.Path(exists=True), type=str,
help="Path to a checkpoint to fine-tune from.", help="Path to a checkpoint or a Hugging Face URI to fine-tune from.",
) )
@click.option( @click.option(
"--targets", "--targets",
@ -106,7 +106,7 @@ __all__ = ["finetune_command"]
) )
def finetune_command( def finetune_command(
train_dataset: Path, train_dataset: Path,
model_path: Path, model_path: str,
targets_config: Path, targets_config: Path,
val_dataset: Path | None = None, val_dataset: Path | None = None,
ckpt_dir: Path | None = None, ckpt_dir: Path | None = None,

View File

@ -102,7 +102,7 @@ def common_predict_options(func):
def _build_api( def _build_api(
model_path: Path, model_path: str,
audio_config: Path | None, audio_config: Path | None,
inference_config: Path | None, inference_config: Path | None,
outputs_config: Path | None, outputs_config: Path | None,
@ -144,7 +144,7 @@ def _build_api(
def _run_prediction( def _run_prediction(
model_path: Path, model_path: str,
audio_files: list[Path], audio_files: list[Path],
output_path: Path, output_path: Path,
audio_config: Path | None, audio_config: Path | None,
@ -195,12 +195,12 @@ def _run_prediction(
name="directory", name="directory",
short_help="Predict on audio files in a directory.", short_help="Predict on audio files in a directory.",
) )
@click.argument("model_path", type=click.Path(exists=True)) @click.argument("model_path", type=str)
@click.argument("audio_dir", type=click.Path(exists=True)) @click.argument("audio_dir", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path()) @click.argument("output_path", type=click.Path())
@common_predict_options @common_predict_options
def predict_directory_command( def predict_directory_command(
model_path: Path, model_path: str,
audio_dir: Path, audio_dir: Path,
output_path: Path, output_path: Path,
audio_config: Path | None, audio_config: Path | None,
@ -239,12 +239,12 @@ def predict_directory_command(
name="file_list", name="file_list",
short_help="Predict on paths listed in a text file.", short_help="Predict on paths listed in a text file.",
) )
@click.argument("model_path", type=click.Path(exists=True)) @click.argument("model_path", type=str)
@click.argument("file_list", type=click.Path(exists=True)) @click.argument("file_list", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path()) @click.argument("output_path", type=click.Path())
@common_predict_options @common_predict_options
def predict_file_list_command( def predict_file_list_command(
model_path: Path, model_path: str,
file_list: Path, file_list: Path,
output_path: Path, output_path: Path,
audio_config: Path | None, audio_config: Path | None,
@ -287,12 +287,12 @@ def predict_file_list_command(
name="dataset", name="dataset",
short_help="Predict on recordings from a dataset config.", short_help="Predict on recordings from a dataset config.",
) )
@click.argument("model_path", type=click.Path(exists=True)) @click.argument("model_path", type=str)
@click.argument("dataset_path", type=click.Path(exists=True)) @click.argument("dataset_path", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path()) @click.argument("output_path", type=click.Path())
@common_predict_options @common_predict_options
def predict_dataset_command( def predict_dataset_command(
model_path: Path, model_path: str,
dataset_path: Path, dataset_path: Path,
output_path: Path, output_path: Path,
audio_config: Path | None, audio_config: Path | None,

View File

@ -18,10 +18,10 @@ __all__ = ["train_command"]
@click.option( @click.option(
"--model", "--model",
"model_path", "model_path",
type=click.Path(exists=True), type=str,
help=( help=(
"Path to a checkpoint to continue training from. If omitted, " "Path to a checkpoint or a Hugging Face URI to continue training "
"training starts from a fresh model config." "from. If omitted, training starts from a fresh model config."
), ),
) )
@click.option( @click.option(
@ -118,7 +118,7 @@ __all__ = ["train_command"]
def train_command( def train_command(
train_dataset: Path, train_dataset: Path,
val_dataset: Path | None = None, val_dataset: Path | None = None,
model_path: Path | None = None, model_path: str | None = None,
ckpt_dir: Path | None = None, ckpt_dir: Path | None = None,
log_dir: Path | None = None, log_dir: Path | None = None,
base_dir: Path | None = None, base_dir: Path | None = None,

View File

@ -1,4 +1,7 @@
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR from batdetect2.train.checkpoints import (
DEFAULT_CHECKPOINT_DIR,
resolve_checkpoint_path,
)
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
from batdetect2.train.lightning import ( from batdetect2.train.lightning import (
TrainingModule, TrainingModule,
@ -26,5 +29,6 @@ __all__ = [
"TrainingModule", "TrainingModule",
"build_trainer", "build_trainer",
"load_model_from_checkpoint", "load_model_from_checkpoint",
"resolve_checkpoint_path",
"run_train", "run_train",
] ]

View File

@ -1,13 +1,16 @@
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
from huggingface_hub import hf_hub_download
from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from soundevent.data import PathLike
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
__all__ = [ __all__ = [
"CheckpointConfig", "CheckpointConfig",
"build_checkpoint_callback", "build_checkpoint_callback",
"resolve_checkpoint_path",
] ]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
@ -53,3 +56,51 @@ def build_checkpoint_callback(
save_last=config.save_last, save_last=config.save_last,
every_n_epochs=config.every_n_epochs, every_n_epochs=config.every_n_epochs,
) )
def resolve_checkpoint_path(path: PathLike | str) -> Path:
"""Resolve a local path or Hugging Face checkpoint URI.
Parameters
----------
path : PathLike | str
Local checkpoint path or a Hugging Face URI of the form
``hf://owner/repo/path/to/checkpoint.ckpt``.
Returns
-------
Path
Resolved local filesystem path to the checkpoint.
"""
if isinstance(path, str) and path.startswith("hf://"):
repo_id, filename = _parse_huggingface_uri(path)
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
if not isinstance(path, Path):
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Checkpoint not found: {path}")
return path
def _parse_huggingface_uri(uri: str) -> tuple[str, str]:
prefix = "hf://"
if not uri.startswith(prefix):
raise ValueError(
"Hugging Face checkpoint URIs must start with 'hf://'."
)
without_prefix = uri.removeprefix(prefix).strip("/")
parts = without_prefix.split("/")
if len(parts) < 3:
raise ValueError(
"Hugging Face checkpoint URIs must be in the form "
"'hf://owner/repo/path/to/checkpoint.ckpt'."
)
repo_id = "/".join(parts[:2])
filename = "/".join(parts[2:])
return repo_id, filename

View File

@ -6,6 +6,7 @@ from soundevent.data import PathLike
from batdetect2.models import Model, ModelConfig, build_model from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.targets import TargetConfig from batdetect2.targets import TargetConfig
from batdetect2.train.checkpoints import resolve_checkpoint_path
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
from batdetect2.train.losses import build_loss from batdetect2.train.losses import build_loss
from batdetect2.train.optimizers import build_optimizer from batdetect2.train.optimizers import build_optimizer
@ -130,7 +131,7 @@ class StoredConfig:
def load_model_from_checkpoint( def load_model_from_checkpoint(
path: PathLike, path: PathLike | str,
) -> tuple[Model, StoredConfig]: ) -> tuple[Model, StoredConfig]:
"""Load a model and its configuration from a Lightning checkpoint. """Load a model and its configuration from a Lightning checkpoint.
@ -146,7 +147,8 @@ def load_model_from_checkpoint(
The restored ``Model`` instance and the ``ModelConfig`` that The restored ``Model`` instance and the ``ModelConfig`` that
describes its architecture, preprocessing, and postprocessing. describes its architecture, preprocessing, and postprocessing.
""" """
module = TrainingModule.load_from_checkpoint(path) # type: ignore resolved_path = resolve_checkpoint_path(path)
module = TrainingModule.load_from_checkpoint(resolved_path)
training_config = TrainingConfig.model_validate(module.train_config) training_config = TrainingConfig.model_validate(module.train_config)
model_config = ModelConfig.model_validate(module.model_config) model_config = ModelConfig.model_validate(module.model_config)
targets_config = TargetConfig.model_validate(module.targets_config) targets_config = TargetConfig.model_validate(module.targets_config)

View File

@ -4,6 +4,7 @@ import pytest
from soundevent import data from soundevent import data
from batdetect2.train import TrainingConfig, run_train from batdetect2.train import TrainingConfig, run_train
from batdetect2.train.checkpoints import resolve_checkpoint_path
pytestmark = pytest.mark.slow pytestmark = pytest.mark.slow
@ -92,3 +93,44 @@ def test_train_controls_which_checkpoints_are_kept(
assert last_checkpoints assert last_checkpoints
assert len(best_checkpoints) == 1 assert len(best_checkpoints) == 1
assert "epoch" in best_checkpoints[0].name assert "epoch" in best_checkpoints[0].name
def test_resolve_checkpoint_path_returns_local_path_unchanged(
tmp_path: Path,
) -> None:
local_path = tmp_path / "model.ckpt"
local_path.write_bytes(b"checkpoint")
assert resolve_checkpoint_path(local_path) == local_path
assert resolve_checkpoint_path(str(local_path)) == local_path
def test_resolve_checkpoint_path_downloads_huggingface_checkpoint(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
expected_path = tmp_path / "downloaded.ckpt"
def fake_hf_hub_download(repo_id: str, filename: str) -> str:
assert repo_id == "owner/repo"
assert filename == "weights/model.ckpt"
return str(expected_path)
monkeypatch.setattr(
"batdetect2.train.checkpoints.hf_hub_download",
fake_hf_hub_download,
)
resolved = resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt")
assert resolved == expected_path
def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
with pytest.raises(ValueError, match="hf://owner/repo/path/to"):
resolve_checkpoint_path("hf://owner/repo")
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
with pytest.raises(FileNotFoundError, match="Checkpoint not found"):
resolve_checkpoint_path("missing.ckpt")