From f5afa9881c991b3da03c075727f31dc254612c1e Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 5 May 2026 15:46:39 +0100 Subject: [PATCH] feat: load checkpoints from Hugging Face --- pyproject.toml | 1 + src/batdetect2/api_v2.py | 2 +- src/batdetect2/cli/evaluate.py | 11 ++---- src/batdetect2/cli/finetune.py | 6 ++-- src/batdetect2/cli/inference.py | 16 ++++----- src/batdetect2/cli/train.py | 8 ++--- src/batdetect2/train/__init__.py | 6 +++- src/batdetect2/train/checkpoints.py | 51 ++++++++++++++++++++++++++++ src/batdetect2/train/lightning.py | 6 ++-- tests/test_train/test_checkpoints.py | 42 +++++++++++++++++++++++ 10 files changed, 121 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6096d78..d73d86e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "click>=8.1.7", "deepmerge>=2.0", "hydra-core>=1.3.2", + "huggingface-hub>=0.32.0", "librosa>=0.10.1", "lightning[extra]==2.5.0", "loguru>=0.7.3", diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 143dfe9..32d6f32 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -653,7 +653,7 @@ class BatDetect2API: @classmethod def from_checkpoint( cls, - path: data.PathLike, + path: data.PathLike | str, audio_config: AudioConfig | None = None, train_config: TrainingConfig | None = None, evaluation_config: EvaluationConfig | None = None, diff --git a/src/batdetect2/cli/evaluate.py b/src/batdetect2/cli/evaluate.py index feaa9e1..f35fed6 100644 --- a/src/batdetect2/cli/evaluate.py +++ b/src/batdetect2/cli/evaluate.py @@ -12,14 +12,8 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation" @cli.command(name="evaluate", short_help="Evaluate a model checkpoint.") -@click.argument("model_path", type=click.Path(exists=True)) +@click.argument("model_path", type=str) @click.argument("test_dataset", type=click.Path(exists=True)) -@click.option( - "--targets", - "targets_config", - type=click.Path(exists=True), - help="Path to targets config file.", -) @click.option( "--audio-config", type=click.Path(exists=True), @@ -80,10 +74,9 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation" default=0, ) def evaluate_command( - model_path: Path, + model_path: str, test_dataset: Path, base_dir: Path, - targets_config: Path | None, audio_config: Path | None, evaluation_config: Path | None, inference_config: Path | None, diff --git a/src/batdetect2/cli/finetune.py b/src/batdetect2/cli/finetune.py index 467b91d..b2d5619 100644 --- a/src/batdetect2/cli/finetune.py +++ b/src/batdetect2/cli/finetune.py @@ -17,8 +17,8 @@ __all__ = ["finetune_command"] "--model", "model_path", required=True, - type=click.Path(exists=True), - help="Path to a checkpoint to fine-tune from.", + type=str, + help="Path to a checkpoint or a Hugging Face URI to fine-tune from.", ) @click.option( "--targets", @@ -106,7 +106,7 @@ __all__ = ["finetune_command"] ) def finetune_command( train_dataset: Path, - model_path: Path, + model_path: str, targets_config: Path, val_dataset: Path | None = None, ckpt_dir: Path | None = None, diff --git a/src/batdetect2/cli/inference.py b/src/batdetect2/cli/inference.py index 4ea7021..edec826 100644 --- a/src/batdetect2/cli/inference.py +++ b/src/batdetect2/cli/inference.py @@ -102,7 +102,7 @@ def common_predict_options(func): def _build_api( - model_path: Path, + model_path: str, audio_config: Path | None, inference_config: Path | None, outputs_config: Path | None, @@ -144,7 +144,7 @@ def _build_api( def _run_prediction( - model_path: Path, + model_path: str, audio_files: list[Path], output_path: Path, audio_config: Path | None, @@ -195,12 +195,12 @@ def _run_prediction( name="directory", short_help="Predict on audio files in a directory.", ) -@click.argument("model_path", type=click.Path(exists=True)) +@click.argument("model_path", type=str) @click.argument("audio_dir", type=click.Path(exists=True)) @click.argument("output_path", type=click.Path()) @common_predict_options def predict_directory_command( - model_path: Path, + model_path: str, audio_dir: Path, output_path: Path, audio_config: Path | None, @@ -239,12 +239,12 @@ def predict_directory_command( name="file_list", short_help="Predict on paths listed in a text file.", ) -@click.argument("model_path", type=click.Path(exists=True)) +@click.argument("model_path", type=str) @click.argument("file_list", type=click.Path(exists=True)) @click.argument("output_path", type=click.Path()) @common_predict_options def predict_file_list_command( - model_path: Path, + model_path: str, file_list: Path, output_path: Path, audio_config: Path | None, @@ -287,12 +287,12 @@ def predict_file_list_command( name="dataset", short_help="Predict on recordings from a dataset config.", ) -@click.argument("model_path", type=click.Path(exists=True)) +@click.argument("model_path", type=str) @click.argument("dataset_path", type=click.Path(exists=True)) @click.argument("output_path", type=click.Path()) @common_predict_options def predict_dataset_command( - model_path: Path, + model_path: str, dataset_path: Path, output_path: Path, audio_config: Path | None, diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index a23687e..0849382 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -18,10 +18,10 @@ __all__ = ["train_command"] @click.option( "--model", "model_path", - type=click.Path(exists=True), + type=str, help=( - "Path to a checkpoint to continue training from. If omitted, " - "training starts from a fresh model config." + "Path to a checkpoint or a Hugging Face URI to continue training " + "from. If omitted, training starts from a fresh model config." ), ) @click.option( @@ -118,7 +118,7 @@ __all__ = ["train_command"] def train_command( train_dataset: Path, val_dataset: Path | None = None, - model_path: Path | None = None, + model_path: str | None = None, ckpt_dir: Path | None = None, log_dir: Path | None = None, base_dir: Path | None = None, diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index ee0731e..27f5539 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -1,4 +1,7 @@ -from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR +from batdetect2.train.checkpoints import ( + DEFAULT_CHECKPOINT_DIR, + resolve_checkpoint_path, +) from batdetect2.train.config import TrainingConfig from batdetect2.train.lightning import ( TrainingModule, @@ -26,5 +29,6 @@ __all__ = [ "TrainingModule", "build_trainer", "load_model_from_checkpoint", + "resolve_checkpoint_path", "run_train", ] diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py index fa90973..9ae997a 100644 --- a/src/batdetect2/train/checkpoints.py +++ b/src/batdetect2/train/checkpoints.py @@ -1,13 +1,16 @@ from pathlib import Path from typing import Literal +from huggingface_hub import hf_hub_download from lightning.pytorch.callbacks import Callback, ModelCheckpoint +from soundevent.data import PathLike from batdetect2.core import BaseConfig __all__ = [ "CheckpointConfig", "build_checkpoint_callback", + "resolve_checkpoint_path", ] DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" @@ -53,3 +56,51 @@ def build_checkpoint_callback( save_last=config.save_last, every_n_epochs=config.every_n_epochs, ) + + +def resolve_checkpoint_path(path: PathLike | str) -> Path: + """Resolve a local path or Hugging Face checkpoint URI. + + Parameters + ---------- + path : PathLike | str + Local checkpoint path or a Hugging Face URI of the form + ``hf://owner/repo/path/to/checkpoint.ckpt``. + + Returns + ------- + Path + Resolved local filesystem path to the checkpoint. + """ + if isinstance(path, str) and path.startswith("hf://"): + repo_id, filename = _parse_huggingface_uri(path) + return Path(hf_hub_download(repo_id=repo_id, filename=filename)) + + if not isinstance(path, Path): + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {path}") + + return path + + +def _parse_huggingface_uri(uri: str) -> tuple[str, str]: + prefix = "hf://" + if not uri.startswith(prefix): + raise ValueError( + "Hugging Face checkpoint URIs must start with 'hf://'." + ) + + without_prefix = uri.removeprefix(prefix).strip("/") + parts = without_prefix.split("/") + + if len(parts) < 3: + raise ValueError( + "Hugging Face checkpoint URIs must be in the form " + "'hf://owner/repo/path/to/checkpoint.ckpt'." + ) + + repo_id = "/".join(parts[:2]) + filename = "/".join(parts[2:]) + return repo_id, filename diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 80a17a4..9f602f6 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -6,6 +6,7 @@ from soundevent.data import PathLike from batdetect2.models import Model, ModelConfig, build_model from batdetect2.models.types import ModelOutput from batdetect2.targets import TargetConfig +from batdetect2.train.checkpoints import resolve_checkpoint_path from batdetect2.train.config import TrainingConfig from batdetect2.train.losses import build_loss from batdetect2.train.optimizers import build_optimizer @@ -130,7 +131,7 @@ class StoredConfig: def load_model_from_checkpoint( - path: PathLike, + path: PathLike | str, ) -> tuple[Model, StoredConfig]: """Load a model and its configuration from a Lightning checkpoint. @@ -146,7 +147,8 @@ def load_model_from_checkpoint( The restored ``Model`` instance and the ``ModelConfig`` that describes its architecture, preprocessing, and postprocessing. """ - module = TrainingModule.load_from_checkpoint(path) # type: ignore + resolved_path = resolve_checkpoint_path(path) + module = TrainingModule.load_from_checkpoint(resolved_path) training_config = TrainingConfig.model_validate(module.train_config) model_config = ModelConfig.model_validate(module.model_config) targets_config = TargetConfig.model_validate(module.targets_config) diff --git a/tests/test_train/test_checkpoints.py b/tests/test_train/test_checkpoints.py index 3603cac..5b55533 100644 --- a/tests/test_train/test_checkpoints.py +++ b/tests/test_train/test_checkpoints.py @@ -4,6 +4,7 @@ import pytest from soundevent import data from batdetect2.train import TrainingConfig, run_train +from batdetect2.train.checkpoints import resolve_checkpoint_path pytestmark = pytest.mark.slow @@ -92,3 +93,44 @@ def test_train_controls_which_checkpoints_are_kept( assert last_checkpoints assert len(best_checkpoints) == 1 assert "epoch" in best_checkpoints[0].name + + +def test_resolve_checkpoint_path_returns_local_path_unchanged( + tmp_path: Path, +) -> None: + local_path = tmp_path / "model.ckpt" + local_path.write_bytes(b"checkpoint") + + assert resolve_checkpoint_path(local_path) == local_path + assert resolve_checkpoint_path(str(local_path)) == local_path + + +def test_resolve_checkpoint_path_downloads_huggingface_checkpoint( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + expected_path = tmp_path / "downloaded.ckpt" + + def fake_hf_hub_download(repo_id: str, filename: str) -> str: + assert repo_id == "owner/repo" + assert filename == "weights/model.ckpt" + return str(expected_path) + + monkeypatch.setattr( + "batdetect2.train.checkpoints.hf_hub_download", + fake_hf_hub_download, + ) + + resolved = resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt") + + assert resolved == expected_path + + +def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None: + with pytest.raises(ValueError, match="hf://owner/repo/path/to"): + resolve_checkpoint_path("hf://owner/repo") + + +def test_resolve_checkpoint_path_rejects_missing_local_path() -> None: + with pytest.raises(FileNotFoundError, match="Checkpoint not found"): + resolve_checkpoint_path("missing.ckpt")