mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: load checkpoints from Hugging Face
This commit is contained in:
parent
5a974711b0
commit
f5afa9881c
@ -11,6 +11,7 @@ dependencies = [
|
|||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"deepmerge>=2.0",
|
"deepmerge>=2.0",
|
||||||
"hydra-core>=1.3.2",
|
"hydra-core>=1.3.2",
|
||||||
|
"huggingface-hub>=0.32.0",
|
||||||
"librosa>=0.10.1",
|
"librosa>=0.10.1",
|
||||||
"lightning[extra]==2.5.0",
|
"lightning[extra]==2.5.0",
|
||||||
"loguru>=0.7.3",
|
"loguru>=0.7.3",
|
||||||
|
|||||||
@ -653,7 +653,7 @@ class BatDetect2API:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
path: data.PathLike,
|
path: data.PathLike | str,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
evaluation_config: EvaluationConfig | None = None,
|
evaluation_config: EvaluationConfig | None = None,
|
||||||
|
|||||||
@ -12,14 +12,8 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
|||||||
|
|
||||||
|
|
||||||
@cli.command(name="evaluate", short_help="Evaluate a model checkpoint.")
|
@cli.command(name="evaluate", short_help="Evaluate a model checkpoint.")
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
@click.argument("model_path", type=str)
|
||||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||||
@click.option(
|
|
||||||
"--targets",
|
|
||||||
"targets_config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help="Path to targets config file.",
|
|
||||||
)
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"--audio-config",
|
"--audio-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
@ -80,10 +74,9 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
|||||||
default=0,
|
default=0,
|
||||||
)
|
)
|
||||||
def evaluate_command(
|
def evaluate_command(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
test_dataset: Path,
|
test_dataset: Path,
|
||||||
base_dir: Path,
|
base_dir: Path,
|
||||||
targets_config: Path | None,
|
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
evaluation_config: Path | None,
|
evaluation_config: Path | None,
|
||||||
inference_config: Path | None,
|
inference_config: Path | None,
|
||||||
|
|||||||
@ -17,8 +17,8 @@ __all__ = ["finetune_command"]
|
|||||||
"--model",
|
"--model",
|
||||||
"model_path",
|
"model_path",
|
||||||
required=True,
|
required=True,
|
||||||
type=click.Path(exists=True),
|
type=str,
|
||||||
help="Path to a checkpoint to fine-tune from.",
|
help="Path to a checkpoint or a Hugging Face URI to fine-tune from.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--targets",
|
"--targets",
|
||||||
@ -106,7 +106,7 @@ __all__ = ["finetune_command"]
|
|||||||
)
|
)
|
||||||
def finetune_command(
|
def finetune_command(
|
||||||
train_dataset: Path,
|
train_dataset: Path,
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
targets_config: Path,
|
targets_config: Path,
|
||||||
val_dataset: Path | None = None,
|
val_dataset: Path | None = None,
|
||||||
ckpt_dir: Path | None = None,
|
ckpt_dir: Path | None = None,
|
||||||
|
|||||||
@ -102,7 +102,7 @@ def common_predict_options(func):
|
|||||||
|
|
||||||
|
|
||||||
def _build_api(
|
def _build_api(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
inference_config: Path | None,
|
inference_config: Path | None,
|
||||||
outputs_config: Path | None,
|
outputs_config: Path | None,
|
||||||
@ -144,7 +144,7 @@ def _build_api(
|
|||||||
|
|
||||||
|
|
||||||
def _run_prediction(
|
def _run_prediction(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
audio_files: list[Path],
|
audio_files: list[Path],
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
@ -195,12 +195,12 @@ def _run_prediction(
|
|||||||
name="directory",
|
name="directory",
|
||||||
short_help="Predict on audio files in a directory.",
|
short_help="Predict on audio files in a directory.",
|
||||||
)
|
)
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
@click.argument("model_path", type=str)
|
||||||
@click.argument("audio_dir", type=click.Path(exists=True))
|
@click.argument("audio_dir", type=click.Path(exists=True))
|
||||||
@click.argument("output_path", type=click.Path())
|
@click.argument("output_path", type=click.Path())
|
||||||
@common_predict_options
|
@common_predict_options
|
||||||
def predict_directory_command(
|
def predict_directory_command(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
audio_dir: Path,
|
audio_dir: Path,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
@ -239,12 +239,12 @@ def predict_directory_command(
|
|||||||
name="file_list",
|
name="file_list",
|
||||||
short_help="Predict on paths listed in a text file.",
|
short_help="Predict on paths listed in a text file.",
|
||||||
)
|
)
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
@click.argument("model_path", type=str)
|
||||||
@click.argument("file_list", type=click.Path(exists=True))
|
@click.argument("file_list", type=click.Path(exists=True))
|
||||||
@click.argument("output_path", type=click.Path())
|
@click.argument("output_path", type=click.Path())
|
||||||
@common_predict_options
|
@common_predict_options
|
||||||
def predict_file_list_command(
|
def predict_file_list_command(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
file_list: Path,
|
file_list: Path,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
@ -287,12 +287,12 @@ def predict_file_list_command(
|
|||||||
name="dataset",
|
name="dataset",
|
||||||
short_help="Predict on recordings from a dataset config.",
|
short_help="Predict on recordings from a dataset config.",
|
||||||
)
|
)
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
@click.argument("model_path", type=str)
|
||||||
@click.argument("dataset_path", type=click.Path(exists=True))
|
@click.argument("dataset_path", type=click.Path(exists=True))
|
||||||
@click.argument("output_path", type=click.Path())
|
@click.argument("output_path", type=click.Path())
|
||||||
@common_predict_options
|
@common_predict_options
|
||||||
def predict_dataset_command(
|
def predict_dataset_command(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
dataset_path: Path,
|
dataset_path: Path,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
|
|||||||
@ -18,10 +18,10 @@ __all__ = ["train_command"]
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--model",
|
"--model",
|
||||||
"model_path",
|
"model_path",
|
||||||
type=click.Path(exists=True),
|
type=str,
|
||||||
help=(
|
help=(
|
||||||
"Path to a checkpoint to continue training from. If omitted, "
|
"Path to a checkpoint or a Hugging Face URI to continue training "
|
||||||
"training starts from a fresh model config."
|
"from. If omitted, training starts from a fresh model config."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -118,7 +118,7 @@ __all__ = ["train_command"]
|
|||||||
def train_command(
|
def train_command(
|
||||||
train_dataset: Path,
|
train_dataset: Path,
|
||||||
val_dataset: Path | None = None,
|
val_dataset: Path | None = None,
|
||||||
model_path: Path | None = None,
|
model_path: str | None = None,
|
||||||
ckpt_dir: Path | None = None,
|
ckpt_dir: Path | None = None,
|
||||||
log_dir: Path | None = None,
|
log_dir: Path | None = None,
|
||||||
base_dir: Path | None = None,
|
base_dir: Path | None = None,
|
||||||
|
|||||||
@ -1,4 +1,7 @@
|
|||||||
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
|
from batdetect2.train.checkpoints import (
|
||||||
|
DEFAULT_CHECKPOINT_DIR,
|
||||||
|
resolve_checkpoint_path,
|
||||||
|
)
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.lightning import (
|
from batdetect2.train.lightning import (
|
||||||
TrainingModule,
|
TrainingModule,
|
||||||
@ -26,5 +29,6 @@ __all__ = [
|
|||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
"load_model_from_checkpoint",
|
"load_model_from_checkpoint",
|
||||||
|
"resolve_checkpoint_path",
|
||||||
"run_train",
|
"run_train",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,13 +1,16 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CheckpointConfig",
|
"CheckpointConfig",
|
||||||
"build_checkpoint_callback",
|
"build_checkpoint_callback",
|
||||||
|
"resolve_checkpoint_path",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||||
@ -53,3 +56,51 @@ def build_checkpoint_callback(
|
|||||||
save_last=config.save_last,
|
save_last=config.save_last,
|
||||||
every_n_epochs=config.every_n_epochs,
|
every_n_epochs=config.every_n_epochs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_checkpoint_path(path: PathLike | str) -> Path:
|
||||||
|
"""Resolve a local path or Hugging Face checkpoint URI.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike | str
|
||||||
|
Local checkpoint path or a Hugging Face URI of the form
|
||||||
|
``hf://owner/repo/path/to/checkpoint.ckpt``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Path
|
||||||
|
Resolved local filesystem path to the checkpoint.
|
||||||
|
"""
|
||||||
|
if isinstance(path, str) and path.startswith("hf://"):
|
||||||
|
repo_id, filename = _parse_huggingface_uri(path)
|
||||||
|
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
|
||||||
|
|
||||||
|
if not isinstance(path, Path):
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Checkpoint not found: {path}")
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_huggingface_uri(uri: str) -> tuple[str, str]:
|
||||||
|
prefix = "hf://"
|
||||||
|
if not uri.startswith(prefix):
|
||||||
|
raise ValueError(
|
||||||
|
"Hugging Face checkpoint URIs must start with 'hf://'."
|
||||||
|
)
|
||||||
|
|
||||||
|
without_prefix = uri.removeprefix(prefix).strip("/")
|
||||||
|
parts = without_prefix.split("/")
|
||||||
|
|
||||||
|
if len(parts) < 3:
|
||||||
|
raise ValueError(
|
||||||
|
"Hugging Face checkpoint URIs must be in the form "
|
||||||
|
"'hf://owner/repo/path/to/checkpoint.ckpt'."
|
||||||
|
)
|
||||||
|
|
||||||
|
repo_id = "/".join(parts[:2])
|
||||||
|
filename = "/".join(parts[2:])
|
||||||
|
return repo_id, filename
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from soundevent.data import PathLike
|
|||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import Model, ModelConfig, build_model
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.targets import TargetConfig
|
from batdetect2.targets import TargetConfig
|
||||||
|
from batdetect2.train.checkpoints import resolve_checkpoint_path
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.train.optimizers import build_optimizer
|
from batdetect2.train.optimizers import build_optimizer
|
||||||
@ -130,7 +131,7 @@ class StoredConfig:
|
|||||||
|
|
||||||
|
|
||||||
def load_model_from_checkpoint(
|
def load_model_from_checkpoint(
|
||||||
path: PathLike,
|
path: PathLike | str,
|
||||||
) -> tuple[Model, StoredConfig]:
|
) -> tuple[Model, StoredConfig]:
|
||||||
"""Load a model and its configuration from a Lightning checkpoint.
|
"""Load a model and its configuration from a Lightning checkpoint.
|
||||||
|
|
||||||
@ -146,7 +147,8 @@ def load_model_from_checkpoint(
|
|||||||
The restored ``Model`` instance and the ``ModelConfig`` that
|
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||||
describes its architecture, preprocessing, and postprocessing.
|
describes its architecture, preprocessing, and postprocessing.
|
||||||
"""
|
"""
|
||||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
resolved_path = resolve_checkpoint_path(path)
|
||||||
|
module = TrainingModule.load_from_checkpoint(resolved_path)
|
||||||
training_config = TrainingConfig.model_validate(module.train_config)
|
training_config = TrainingConfig.model_validate(module.train_config)
|
||||||
model_config = ModelConfig.model_validate(module.model_config)
|
model_config = ModelConfig.model_validate(module.model_config)
|
||||||
targets_config = TargetConfig.model_validate(module.targets_config)
|
targets_config = TargetConfig.model_validate(module.targets_config)
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import pytest
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.train import TrainingConfig, run_train
|
from batdetect2.train import TrainingConfig, run_train
|
||||||
|
from batdetect2.train.checkpoints import resolve_checkpoint_path
|
||||||
|
|
||||||
pytestmark = pytest.mark.slow
|
pytestmark = pytest.mark.slow
|
||||||
|
|
||||||
@ -92,3 +93,44 @@ def test_train_controls_which_checkpoints_are_kept(
|
|||||||
assert last_checkpoints
|
assert last_checkpoints
|
||||||
assert len(best_checkpoints) == 1
|
assert len(best_checkpoints) == 1
|
||||||
assert "epoch" in best_checkpoints[0].name
|
assert "epoch" in best_checkpoints[0].name
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
local_path = tmp_path / "model.ckpt"
|
||||||
|
local_path.write_bytes(b"checkpoint")
|
||||||
|
|
||||||
|
assert resolve_checkpoint_path(local_path) == local_path
|
||||||
|
assert resolve_checkpoint_path(str(local_path)) == local_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_downloads_huggingface_checkpoint(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
expected_path = tmp_path / "downloaded.ckpt"
|
||||||
|
|
||||||
|
def fake_hf_hub_download(repo_id: str, filename: str) -> str:
|
||||||
|
assert repo_id == "owner/repo"
|
||||||
|
assert filename == "weights/model.ckpt"
|
||||||
|
return str(expected_path)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"batdetect2.train.checkpoints.hf_hub_download",
|
||||||
|
fake_hf_hub_download,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved = resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt")
|
||||||
|
|
||||||
|
assert resolved == expected_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
|
||||||
|
with pytest.raises(ValueError, match="hf://owner/repo/path/to"):
|
||||||
|
resolve_checkpoint_path("hf://owner/repo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
|
||||||
|
with pytest.raises(FileNotFoundError, match="Checkpoint not found"):
|
||||||
|
resolve_checkpoint_path("missing.ckpt")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user