From 038d58ed99e3a19cad8e34b3f9862bb5d7e553eb Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 16 Mar 2026 09:30:23 +0000 Subject: [PATCH] Add config for dynamic imports, and tests --- src/batdetect2/core/__init__.py | 8 +- src/batdetect2/core/registries.py | 90 +++++++++++++- src/batdetect2/data/annotations/registry.py | 26 +++- src/batdetect2/logging.py | 2 +- tests/test_core/test_registry.py | 127 +++++++++++++++++++- 5 files changed, 246 insertions(+), 7 deletions(-) diff --git a/src/batdetect2/core/__init__.py b/src/batdetect2/core/__init__.py index 19acaca..23b77c5 100644 --- a/src/batdetect2/core/__init__.py +++ b/src/batdetect2/core/__init__.py @@ -1,8 +1,14 @@ from batdetect2.core.configs import BaseConfig, load_config, merge_configs -from batdetect2.core.registries import Registry +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) __all__ = [ + "add_import_config", "BaseConfig", + "ImportConfig", "load_config", "Registry", "merge_configs", diff --git a/src/batdetect2/core/registries.py b/src/batdetect2/core/registries.py index 33a5d41..3fb170e 100644 --- a/src/batdetect2/core/registries.py +++ b/src/batdetect2/core/registries.py @@ -1,4 +1,5 @@ from typing import ( + Any, Callable, Concatenate, Generic, @@ -7,18 +8,20 @@ from typing import ( TypeVar, ) -from pydantic import BaseModel +from hydra.utils import instantiate +from pydantic import BaseModel, Field __all__ = [ + "add_import_config", + "ImportConfig", "Registry", "SimpleRegistry", ] + T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True) T_Type = TypeVar("T_Type", covariant=True) P_Type = ParamSpec("P_Type") - - T = TypeVar("T") @@ -114,3 +117,84 @@ class Registry(Generic[T_Type, P_Type]): ) return self._registry[name](config, *args, **kwargs) + + +class ImportConfig(BaseModel): + """Base config for dynamic instantiation via Hydra. + + Subclass this to create a registry-specific import escape hatch. + The subclass must add a discriminator field whose name matches the + registry's own discriminator key, with its value fixed to + ``Literal["import"]``. + + Attributes + ---------- + target : str + Fully-qualified dotted path to the callable to instantiate, + e.g. ``"mypackage.module.MyClass"``. + arguments : dict[str, Any] + Base keyword arguments forwarded to the callable. When the + same key also appears in ``kwargs`` passed to ``build()``, + the ``kwargs`` value takes priority. + """ + + target: str + arguments: dict[str, Any] = Field(default_factory=dict) + + +T_Import = TypeVar("T_Import", bound=ImportConfig) + + +def add_import_config( + registry: Registry[T_Type, P_Type], +) -> Callable[[Type[T_Import]], Type[T_Import]]: + """Decorator that registers an ImportConfig subclass as an escape hatch. + + Wraps the decorated class in a builder that calls + ``hydra.utils.instantiate`` using ``config.target`` and + ``config.arguments``. The builder is registered on *registry* + under the discriminator value ``"import"``. + + Parameters + ---------- + registry : Registry + The registry instance on which the config should be registered. + + Returns + ------- + Callable[[type[ImportConfig]], type[ImportConfig]] + A class decorator that registers the class and returns it + unchanged. + + Examples + -------- + Define a per-registry import escape hatch:: + + @add_import_config(my_registry) + class MyRegistryImportConfig(ImportConfig): + name: Literal["import"] = "import" + """ + + def decorator(config_cls: Type[T_Import]) -> Type[T_Import]: + def builder( + config: T_Import, + *args: P_Type.args, + **kwargs: P_Type.kwargs, + ) -> T_Type: + if len(args) > 0: + raise ValueError( + "Positional arguments are not supported " + "for import escape hatch." + ) + + hydra_cfg = { + "_target_": config.target, + **config.arguments, + **kwargs, + } + return instantiate(hydra_cfg) + + registry.register(config_cls)(builder) + return config_cls + + return decorator diff --git a/src/batdetect2/data/annotations/registry.py b/src/batdetect2/data/annotations/registry.py index 6200cdc..5f3fa77 100644 --- a/src/batdetect2/data/annotations/registry.py +++ b/src/batdetect2/data/annotations/registry.py @@ -1,7 +1,10 @@ -from batdetect2.core import Registry +from typing import Literal + +from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.data.annotations.types import AnnotationLoader __all__ = [ + "AnnotationFormatImportConfig", "annotation_format_registry", ] @@ -9,3 +12,24 @@ annotation_format_registry: Registry[AnnotationLoader, []] = Registry( "annotation_format", discriminator="format", ) + + +@add_import_config(annotation_format_registry) +class AnnotationFormatImportConfig(ImportConfig): + """Import escape hatch for the annotation format registry. + + Use this config to dynamically instantiate any callable as an + annotation loader without registering it in + ``annotation_format_registry`` ahead of time. + + Parameters + ---------- + format : Literal["import"] + Discriminator value; must always be ``"import"``. + target : str + Fully-qualified dotted path to the callable to instantiate. + arguments : dict[str, Any] + Keyword arguments forwarded to the callable. + """ + + format: Literal["import"] = "import" diff --git a/src/batdetect2/logging.py b/src/batdetect2/logging.py index 62dcedd..83d1d6f 100644 --- a/src/batdetect2/logging.py +++ b/src/batdetect2/logging.py @@ -104,7 +104,7 @@ def create_dvclive_logger( run_name: str | None = None, ) -> Logger: try: - from dvclive.lightning import DVCLiveLogger # type: ignore + from dvclive.lightning import DVCLiveLogger except ImportError as error: raise ValueError( "DVCLive is not installed and cannot be used for logging" diff --git a/tests/test_core/test_registry.py b/tests/test_core/test_registry.py index 505dbcb..3c42292 100644 --- a/tests/test_core/test_registry.py +++ b/tests/test_core/test_registry.py @@ -4,6 +4,8 @@ Covers: - SimpleRegistry: registration, retrieval, and membership checks. - Registry: decorator-based registration, config type tracking, discriminator-based dispatch, and error handling. +- ImportConfig base class and add_import_config decorator utility. +- AnnotationFormatImportConfig: concrete per-registry escape hatch. """ from typing import Literal @@ -11,7 +13,12 @@ from typing import Literal import pytest from pydantic import BaseModel -from batdetect2.core.registries import Registry, SimpleRegistry +from batdetect2.core.registries import ( + ImportConfig, + Registry, + SimpleRegistry, + add_import_config, +) class TestSimpleRegistry: @@ -292,3 +299,121 @@ class TestRegistryBuild: with pytest.raises(NotImplementedError, match="my_registry"): registry.build(UnknownConfig()) + + +class TestAddImportConfig: + def test_decorator_returns_config_class_unchanged(self): + """add_import_config returns the decorated class as-is.""" + + class MyImportConfig(ImportConfig): + name: Literal["import"] = "import" + + registry: Registry[object, []] = Registry("test") + result = add_import_config(registry)(MyImportConfig) + assert result is MyImportConfig + + def test_registered_config_type_is_discoverable(self): + """After decoration, get_config_type('import') returns the class.""" + registry: Registry[object, []] = Registry("test") + + @add_import_config(registry) + class MyImportConfig(ImportConfig): + name: Literal["import"] = "import" + + assert registry.get_config_type("import") is MyImportConfig + + def test_build_instantiates_target(self): + """build() with a registered import config instantiates the target.""" + import collections + + registry: Registry[object, []] = Registry("test") + + @add_import_config(registry) + class MyImportConfig(ImportConfig): + name: Literal["import"] = "import" + + config = MyImportConfig(target="collections.OrderedDict") + result = registry.build(config) + assert isinstance(result, collections.OrderedDict) + + def test_build_forwards_arguments_to_target(self): + """build() passes config.arguments as kwargs to the target.""" + import decimal + + registry: Registry[object, []] = Registry("test") + + @add_import_config(registry) + class MyImportConfig(ImportConfig): + name: Literal["import"] = "import" + + config = MyImportConfig( + target="decimal.Decimal", + arguments={"value": "3.14"}, + ) + result = registry.build(config) + assert isinstance(result, decimal.Decimal) + assert result == decimal.Decimal("3.14") + + def test_build_kwargs_override_config_arguments(self): + """kwargs passed to build() win over same-key entries in arguments.""" + import decimal + + registry: Registry[object, []] = Registry("test") + + @add_import_config(registry) + class MyImportConfig(ImportConfig): + name: Literal["import"] = "import" + + config = MyImportConfig( + target="decimal.Decimal", + arguments={"value": "1.00"}, + ) + result = registry.build(config, value="9.99") + assert isinstance(result, decimal.Decimal) + assert result == decimal.Decimal("9.99") + + def test_build_bad_target_raises(self): + """build() raises when the dotted target path cannot be resolved.""" + from hydra.errors import InstantiationException + + registry: Registry[object, []] = Registry("test") + + @add_import_config(registry) + class MyImportConfig(ImportConfig): + name: Literal["import"] = "import" + + config = MyImportConfig(target="nonexistent.module.DoesNotExist") + with pytest.raises(InstantiationException): + registry.build(config) + + def test_works_with_custom_discriminator_field(self): + """add_import_config works for registries with a non-default discriminator.""" + import collections + + registry: Registry[object, []] = Registry( + "test", + discriminator="format", + ) + + @add_import_config(registry) + class FormatImportConfig(ImportConfig): + format: Literal["import"] = "import" + + config = FormatImportConfig(target="collections.OrderedDict") + result = registry.build(config) + assert isinstance(result, collections.OrderedDict) + + def test_coexists_with_other_registered_entries(self): + """The import config entry does not interfere with other entries.""" + registry: Registry[object, []] = Registry("test") + + class DummyConfig(BaseModel): + name: Literal["dummy"] = "dummy" + + @add_import_config(registry) + class MyImportConfig(ImportConfig): + name: Literal["import"] = "import" + + registry.register(DummyConfig)(lambda c: "dummy_result") + + assert registry.build(DummyConfig()) == "dummy_result"