From 5526ac99fce01aade9c4c8e7fdbfc3df7351a44c Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 5 May 2026 16:20:37 +0100 Subject: [PATCH] Remove stale dependencies --- pyproject.toml | 11 ++++------ src/batdetect2/train/checkpoints.py | 10 ++++++++- tests/test_train/test_checkpoints.py | 33 +++++++++++++++++++++++++--- 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d73d86e..1e39184 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,31 +7,27 @@ authors = [ { "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }, ] dependencies = [ - "cf-xarray>=0.9.0", "click>=8.1.7", "deepmerge>=2.0", "hydra-core>=1.3.2", - "huggingface-hub>=0.32.0", "librosa>=0.10.1", "lightning[extra]==2.5.0", "loguru>=0.7.3", "matplotlib>=3.7.1", "netcdf4>=1.6.5", - "numba>=0.60", "numpy>=1.23.5", - "omegaconf>=2.3.0", - "onnx>=1.16.0", "pandas>=1.5.3", + "pydantic>=2.0.0", "pyyaml>=6.0.2", "scikit-learn>=1.2.2", "scipy>=1.10.1", "seaborn>=0.13.2", "soundevent[audio,geometry,plot]>=2.10.0", + "soundfile>=0.12.1", "tensorboard>=2.16.2", "torch>=1.13.1", "torchaudio>=1.13.1", - "torchvision>=0.14.0", - "tqdm>=4.66.2", + "xarray>=2024.0.0", ] requires-python = ">=3.10,<3.14" readme = "README.md" @@ -67,6 +63,7 @@ build-backend = "hatchling.build" batdetect2 = "batdetect2.cli:cli" [dependency-groups] +huggingface = ["huggingface-hub>=0.32.0"] jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"] marimo = ["marimo>=0.12.2", "pyarrow>=20.0.0"] dev = [ diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py index 9ae997a..0badb9a 100644 --- a/src/batdetect2/train/checkpoints.py +++ b/src/batdetect2/train/checkpoints.py @@ -1,7 +1,6 @@ from pathlib import Path from typing import Literal -from huggingface_hub import hf_hub_download from lightning.pytorch.callbacks import Callback, ModelCheckpoint from soundevent.data import PathLike @@ -73,6 +72,15 @@ def resolve_checkpoint_path(path: PathLike | str) -> Path: Resolved local filesystem path to the checkpoint. """ if isinstance(path, str) and path.startswith("hf://"): + try: + from huggingface_hub import hf_hub_download + except ImportError as error: + raise ValueError( + "Hugging Face checkpoint support is not installed. " + "Install it with `uv sync --group huggingface` or " + "`pip install huggingface-hub`." + ) from error + repo_id, filename = _parse_huggingface_uri(path) return Path(hf_hub_download(repo_id=repo_id, filename=filename)) diff --git a/tests/test_train/test_checkpoints.py b/tests/test_train/test_checkpoints.py index 5b55533..2e5eaed 100644 --- a/tests/test_train/test_checkpoints.py +++ b/tests/test_train/test_checkpoints.py @@ -1,3 +1,5 @@ +import sys +import types from pathlib import Path import pytest @@ -116,9 +118,14 @@ def test_resolve_checkpoint_path_downloads_huggingface_checkpoint( assert filename == "weights/model.ckpt" return str(expected_path) - monkeypatch.setattr( - "batdetect2.train.checkpoints.hf_hub_download", - fake_hf_hub_download, + class FakeHuggingFaceHub(types.ModuleType): + hf_hub_download = staticmethod(fake_hf_hub_download) + + fake_module = FakeHuggingFaceHub("huggingface_hub") + monkeypatch.setitem( + sys.modules, + "huggingface_hub", + fake_module, ) resolved = resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt") @@ -126,6 +133,26 @@ def test_resolve_checkpoint_path_downloads_huggingface_checkpoint( assert resolved == expected_path +def test_resolve_checkpoint_path_requires_huggingface_dependency( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delitem(sys.modules, "huggingface_hub", raising=False) + + import builtins + + original_import = builtins.__import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "huggingface_hub": + raise ImportError("missing") + return original_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + with pytest.raises(ValueError, match="Hugging Face checkpoint support"): + resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt") + + def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None: with pytest.raises(ValueError, match="hf://owner/repo/path/to"): resolve_checkpoint_path("hf://owner/repo")