Add config for dynamic imports, and tests

This commit is contained in:
mbsantiago 2026-03-16 09:30:23 +00:00
parent 47f418a63c
commit 038d58ed99
5 changed files with 246 additions and 7 deletions

View File

@ -1,8 +1,14 @@
from batdetect2.core.configs import BaseConfig, load_config, merge_configs
from batdetect2.core.registries import Registry
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
__all__ = [
"add_import_config",
"BaseConfig",
"ImportConfig",
"load_config",
"Registry",
"merge_configs",

View File

@ -1,4 +1,5 @@
from typing import (
Any,
Callable,
Concatenate,
Generic,
@ -7,18 +8,20 @@ from typing import (
TypeVar,
)
from pydantic import BaseModel
from hydra.utils import instantiate
from pydantic import BaseModel, Field
__all__ = [
"add_import_config",
"ImportConfig",
"Registry",
"SimpleRegistry",
]
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
T_Type = TypeVar("T_Type", covariant=True)
P_Type = ParamSpec("P_Type")
T = TypeVar("T")
@ -114,3 +117,84 @@ class Registry(Generic[T_Type, P_Type]):
)
return self._registry[name](config, *args, **kwargs)
class ImportConfig(BaseModel):
"""Base config for dynamic instantiation via Hydra.
Subclass this to create a registry-specific import escape hatch.
The subclass must add a discriminator field whose name matches the
registry's own discriminator key, with its value fixed to
``Literal["import"]``.
Attributes
----------
target : str
Fully-qualified dotted path to the callable to instantiate,
e.g. ``"mypackage.module.MyClass"``.
arguments : dict[str, Any]
Base keyword arguments forwarded to the callable. When the
same key also appears in ``kwargs`` passed to ``build()``,
the ``kwargs`` value takes priority.
"""
target: str
arguments: dict[str, Any] = Field(default_factory=dict)
T_Import = TypeVar("T_Import", bound=ImportConfig)
def add_import_config(
registry: Registry[T_Type, P_Type],
) -> Callable[[Type[T_Import]], Type[T_Import]]:
"""Decorator that registers an ImportConfig subclass as an escape hatch.
Wraps the decorated class in a builder that calls
``hydra.utils.instantiate`` using ``config.target`` and
``config.arguments``. The builder is registered on *registry*
under the discriminator value ``"import"``.
Parameters
----------
registry : Registry
The registry instance on which the config should be registered.
Returns
-------
Callable[[type[ImportConfig]], type[ImportConfig]]
A class decorator that registers the class and returns it
unchanged.
Examples
--------
Define a per-registry import escape hatch::
@add_import_config(my_registry)
class MyRegistryImportConfig(ImportConfig):
name: Literal["import"] = "import"
"""
def decorator(config_cls: Type[T_Import]) -> Type[T_Import]:
def builder(
config: T_Import,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type:
if len(args) > 0:
raise ValueError(
"Positional arguments are not supported "
"for import escape hatch."
)
hydra_cfg = {
"_target_": config.target,
**config.arguments,
**kwargs,
}
return instantiate(hydra_cfg)
registry.register(config_cls)(builder)
return config_cls
return decorator

View File

@ -1,7 +1,10 @@
from batdetect2.core import Registry
from typing import Literal
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.data.annotations.types import AnnotationLoader
__all__ = [
"AnnotationFormatImportConfig",
"annotation_format_registry",
]
@ -9,3 +12,24 @@ annotation_format_registry: Registry[AnnotationLoader, []] = Registry(
"annotation_format",
discriminator="format",
)
@add_import_config(annotation_format_registry)
class AnnotationFormatImportConfig(ImportConfig):
"""Import escape hatch for the annotation format registry.
Use this config to dynamically instantiate any callable as an
annotation loader without registering it in
``annotation_format_registry`` ahead of time.
Parameters
----------
format : Literal["import"]
Discriminator value; must always be ``"import"``.
target : str
Fully-qualified dotted path to the callable to instantiate.
arguments : dict[str, Any]
Keyword arguments forwarded to the callable.
"""
format: Literal["import"] = "import"

View File

@ -104,7 +104,7 @@ def create_dvclive_logger(
run_name: str | None = None,
) -> Logger:
try:
from dvclive.lightning import DVCLiveLogger # type: ignore
from dvclive.lightning import DVCLiveLogger
except ImportError as error:
raise ValueError(
"DVCLive is not installed and cannot be used for logging"

View File

@ -4,6 +4,8 @@ Covers:
- SimpleRegistry: registration, retrieval, and membership checks.
- Registry: decorator-based registration, config type tracking,
discriminator-based dispatch, and error handling.
- ImportConfig base class and add_import_config decorator utility.
- AnnotationFormatImportConfig: concrete per-registry escape hatch.
"""
from typing import Literal
@ -11,7 +13,12 @@ from typing import Literal
import pytest
from pydantic import BaseModel
from batdetect2.core.registries import Registry, SimpleRegistry
from batdetect2.core.registries import (
ImportConfig,
Registry,
SimpleRegistry,
add_import_config,
)
class TestSimpleRegistry:
@ -292,3 +299,121 @@ class TestRegistryBuild:
with pytest.raises(NotImplementedError, match="my_registry"):
registry.build(UnknownConfig())
class TestAddImportConfig:
def test_decorator_returns_config_class_unchanged(self):
"""add_import_config returns the decorated class as-is."""
class MyImportConfig(ImportConfig):
name: Literal["import"] = "import"
registry: Registry[object, []] = Registry("test")
result = add_import_config(registry)(MyImportConfig)
assert result is MyImportConfig
def test_registered_config_type_is_discoverable(self):
"""After decoration, get_config_type('import') returns the class."""
registry: Registry[object, []] = Registry("test")
@add_import_config(registry)
class MyImportConfig(ImportConfig):
name: Literal["import"] = "import"
assert registry.get_config_type("import") is MyImportConfig
def test_build_instantiates_target(self):
"""build() with a registered import config instantiates the target."""
import collections
registry: Registry[object, []] = Registry("test")
@add_import_config(registry)
class MyImportConfig(ImportConfig):
name: Literal["import"] = "import"
config = MyImportConfig(target="collections.OrderedDict")
result = registry.build(config)
assert isinstance(result, collections.OrderedDict)
def test_build_forwards_arguments_to_target(self):
"""build() passes config.arguments as kwargs to the target."""
import decimal
registry: Registry[object, []] = Registry("test")
@add_import_config(registry)
class MyImportConfig(ImportConfig):
name: Literal["import"] = "import"
config = MyImportConfig(
target="decimal.Decimal",
arguments={"value": "3.14"},
)
result = registry.build(config)
assert isinstance(result, decimal.Decimal)
assert result == decimal.Decimal("3.14")
def test_build_kwargs_override_config_arguments(self):
"""kwargs passed to build() win over same-key entries in arguments."""
import decimal
registry: Registry[object, []] = Registry("test")
@add_import_config(registry)
class MyImportConfig(ImportConfig):
name: Literal["import"] = "import"
config = MyImportConfig(
target="decimal.Decimal",
arguments={"value": "1.00"},
)
result = registry.build(config, value="9.99")
assert isinstance(result, decimal.Decimal)
assert result == decimal.Decimal("9.99")
def test_build_bad_target_raises(self):
"""build() raises when the dotted target path cannot be resolved."""
from hydra.errors import InstantiationException
registry: Registry[object, []] = Registry("test")
@add_import_config(registry)
class MyImportConfig(ImportConfig):
name: Literal["import"] = "import"
config = MyImportConfig(target="nonexistent.module.DoesNotExist")
with pytest.raises(InstantiationException):
registry.build(config)
def test_works_with_custom_discriminator_field(self):
"""add_import_config works for registries with a non-default discriminator."""
import collections
registry: Registry[object, []] = Registry(
"test",
discriminator="format",
)
@add_import_config(registry)
class FormatImportConfig(ImportConfig):
format: Literal["import"] = "import"
config = FormatImportConfig(target="collections.OrderedDict")
result = registry.build(config)
assert isinstance(result, collections.OrderedDict)
def test_coexists_with_other_registered_entries(self):
"""The import config entry does not interfere with other entries."""
registry: Registry[object, []] = Registry("test")
class DummyConfig(BaseModel):
name: Literal["dummy"] = "dummy"
@add_import_config(registry)
class MyImportConfig(ImportConfig):
name: Literal["import"] = "import"
registry.register(DummyConfig)(lambda c: "dummy_result")
assert registry.build(DummyConfig()) == "dummy_result"