mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: load checkpoints from Hugging Face
This commit is contained in:
parent
5a974711b0
commit
f5afa9881c
@ -11,6 +11,7 @@ dependencies = [
|
||||
"click>=8.1.7",
|
||||
"deepmerge>=2.0",
|
||||
"hydra-core>=1.3.2",
|
||||
"huggingface-hub>=0.32.0",
|
||||
"librosa>=0.10.1",
|
||||
"lightning[extra]==2.5.0",
|
||||
"loguru>=0.7.3",
|
||||
|
||||
@ -653,7 +653,7 @@ class BatDetect2API:
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
path: data.PathLike | str,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
|
||||
@ -12,14 +12,8 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
|
||||
|
||||
@cli.command(name="evaluate", short_help="Evaluate a model checkpoint.")
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("model_path", type=str)
|
||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--targets",
|
||||
"targets_config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to targets config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--audio-config",
|
||||
type=click.Path(exists=True),
|
||||
@ -80,10 +74,9 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
default=0,
|
||||
)
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
test_dataset: Path,
|
||||
base_dir: Path,
|
||||
targets_config: Path | None,
|
||||
audio_config: Path | None,
|
||||
evaluation_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
|
||||
@ -17,8 +17,8 @@ __all__ = ["finetune_command"]
|
||||
"--model",
|
||||
"model_path",
|
||||
required=True,
|
||||
type=click.Path(exists=True),
|
||||
help="Path to a checkpoint to fine-tune from.",
|
||||
type=str,
|
||||
help="Path to a checkpoint or a Hugging Face URI to fine-tune from.",
|
||||
)
|
||||
@click.option(
|
||||
"--targets",
|
||||
@ -106,7 +106,7 @@ __all__ = ["finetune_command"]
|
||||
)
|
||||
def finetune_command(
|
||||
train_dataset: Path,
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
targets_config: Path,
|
||||
val_dataset: Path | None = None,
|
||||
ckpt_dir: Path | None = None,
|
||||
|
||||
@ -102,7 +102,7 @@ def common_predict_options(func):
|
||||
|
||||
|
||||
def _build_api(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
audio_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
@ -144,7 +144,7 @@ def _build_api(
|
||||
|
||||
|
||||
def _run_prediction(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
audio_files: list[Path],
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
@ -195,12 +195,12 @@ def _run_prediction(
|
||||
name="directory",
|
||||
short_help="Predict on audio files in a directory.",
|
||||
)
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("model_path", type=str)
|
||||
@click.argument("audio_dir", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@common_predict_options
|
||||
def predict_directory_command(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
audio_dir: Path,
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
@ -239,12 +239,12 @@ def predict_directory_command(
|
||||
name="file_list",
|
||||
short_help="Predict on paths listed in a text file.",
|
||||
)
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("model_path", type=str)
|
||||
@click.argument("file_list", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@common_predict_options
|
||||
def predict_file_list_command(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
file_list: Path,
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
@ -287,12 +287,12 @@ def predict_file_list_command(
|
||||
name="dataset",
|
||||
short_help="Predict on recordings from a dataset config.",
|
||||
)
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("model_path", type=str)
|
||||
@click.argument("dataset_path", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@common_predict_options
|
||||
def predict_dataset_command(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
dataset_path: Path,
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
|
||||
@ -18,10 +18,10 @@ __all__ = ["train_command"]
|
||||
@click.option(
|
||||
"--model",
|
||||
"model_path",
|
||||
type=click.Path(exists=True),
|
||||
type=str,
|
||||
help=(
|
||||
"Path to a checkpoint to continue training from. If omitted, "
|
||||
"training starts from a fresh model config."
|
||||
"Path to a checkpoint or a Hugging Face URI to continue training "
|
||||
"from. If omitted, training starts from a fresh model config."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
@ -118,7 +118,7 @@ __all__ = ["train_command"]
|
||||
def train_command(
|
||||
train_dataset: Path,
|
||||
val_dataset: Path | None = None,
|
||||
model_path: Path | None = None,
|
||||
model_path: str | None = None,
|
||||
ckpt_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
base_dir: Path | None = None,
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
|
||||
from batdetect2.train.checkpoints import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
resolve_checkpoint_path,
|
||||
)
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.lightning import (
|
||||
TrainingModule,
|
||||
@ -26,5 +29,6 @@ __all__ = [
|
||||
"TrainingModule",
|
||||
"build_trainer",
|
||||
"load_model_from_checkpoint",
|
||||
"resolve_checkpoint_path",
|
||||
"run_train",
|
||||
]
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"CheckpointConfig",
|
||||
"build_checkpoint_callback",
|
||||
"resolve_checkpoint_path",
|
||||
]
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
@ -53,3 +56,51 @@ def build_checkpoint_callback(
|
||||
save_last=config.save_last,
|
||||
every_n_epochs=config.every_n_epochs,
|
||||
)
|
||||
|
||||
|
||||
def resolve_checkpoint_path(path: PathLike | str) -> Path:
|
||||
"""Resolve a local path or Hugging Face checkpoint URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike | str
|
||||
Local checkpoint path or a Hugging Face URI of the form
|
||||
``hf://owner/repo/path/to/checkpoint.ckpt``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
Resolved local filesystem path to the checkpoint.
|
||||
"""
|
||||
if isinstance(path, str) and path.startswith("hf://"):
|
||||
repo_id, filename = _parse_huggingface_uri(path)
|
||||
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
|
||||
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint not found: {path}")
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def _parse_huggingface_uri(uri: str) -> tuple[str, str]:
|
||||
prefix = "hf://"
|
||||
if not uri.startswith(prefix):
|
||||
raise ValueError(
|
||||
"Hugging Face checkpoint URIs must start with 'hf://'."
|
||||
)
|
||||
|
||||
without_prefix = uri.removeprefix(prefix).strip("/")
|
||||
parts = without_prefix.split("/")
|
||||
|
||||
if len(parts) < 3:
|
||||
raise ValueError(
|
||||
"Hugging Face checkpoint URIs must be in the form "
|
||||
"'hf://owner/repo/path/to/checkpoint.ckpt'."
|
||||
)
|
||||
|
||||
repo_id = "/".join(parts[:2])
|
||||
filename = "/".join(parts[2:])
|
||||
return repo_id, filename
|
||||
|
||||
@ -6,6 +6,7 @@ from soundevent.data import PathLike
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train.checkpoints import resolve_checkpoint_path
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.train.optimizers import build_optimizer
|
||||
@ -130,7 +131,7 @@ class StoredConfig:
|
||||
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike,
|
||||
path: PathLike | str,
|
||||
) -> tuple[Model, StoredConfig]:
|
||||
"""Load a model and its configuration from a Lightning checkpoint.
|
||||
|
||||
@ -146,7 +147,8 @@ def load_model_from_checkpoint(
|
||||
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||
describes its architecture, preprocessing, and postprocessing.
|
||||
"""
|
||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||
resolved_path = resolve_checkpoint_path(path)
|
||||
module = TrainingModule.load_from_checkpoint(resolved_path)
|
||||
training_config = TrainingConfig.model_validate(module.train_config)
|
||||
model_config = ModelConfig.model_validate(module.model_config)
|
||||
targets_config = TargetConfig.model_validate(module.targets_config)
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.train import TrainingConfig, run_train
|
||||
from batdetect2.train.checkpoints import resolve_checkpoint_path
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
@ -92,3 +93,44 @@ def test_train_controls_which_checkpoints_are_kept(
|
||||
assert last_checkpoints
|
||||
assert len(best_checkpoints) == 1
|
||||
assert "epoch" in best_checkpoints[0].name
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
local_path = tmp_path / "model.ckpt"
|
||||
local_path.write_bytes(b"checkpoint")
|
||||
|
||||
assert resolve_checkpoint_path(local_path) == local_path
|
||||
assert resolve_checkpoint_path(str(local_path)) == local_path
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_downloads_huggingface_checkpoint(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
expected_path = tmp_path / "downloaded.ckpt"
|
||||
|
||||
def fake_hf_hub_download(repo_id: str, filename: str) -> str:
|
||||
assert repo_id == "owner/repo"
|
||||
assert filename == "weights/model.ckpt"
|
||||
return str(expected_path)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.train.checkpoints.hf_hub_download",
|
||||
fake_hf_hub_download,
|
||||
)
|
||||
|
||||
resolved = resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt")
|
||||
|
||||
assert resolved == expected_path
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
|
||||
with pytest.raises(ValueError, match="hf://owner/repo/path/to"):
|
||||
resolve_checkpoint_path("hf://owner/repo")
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
|
||||
with pytest.raises(FileNotFoundError, match="Checkpoint not found"):
|
||||
resolve_checkpoint_path("missing.ckpt")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user