mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add config for dynamic imports, and tests
This commit is contained in:
parent
47f418a63c
commit
038d58ed99
@ -1,8 +1,14 @@
|
||||
from batdetect2.core.configs import BaseConfig, load_config, merge_configs
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"add_import_config",
|
||||
"BaseConfig",
|
||||
"ImportConfig",
|
||||
"load_config",
|
||||
"Registry",
|
||||
"merge_configs",
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Generic,
|
||||
@ -7,18 +8,20 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
from hydra.utils import instantiate
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = [
|
||||
"add_import_config",
|
||||
"ImportConfig",
|
||||
"Registry",
|
||||
"SimpleRegistry",
|
||||
]
|
||||
|
||||
|
||||
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
||||
T_Type = TypeVar("T_Type", covariant=True)
|
||||
P_Type = ParamSpec("P_Type")
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@ -114,3 +117,84 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
)
|
||||
|
||||
return self._registry[name](config, *args, **kwargs)
|
||||
|
||||
|
||||
class ImportConfig(BaseModel):
|
||||
"""Base config for dynamic instantiation via Hydra.
|
||||
|
||||
Subclass this to create a registry-specific import escape hatch.
|
||||
The subclass must add a discriminator field whose name matches the
|
||||
registry's own discriminator key, with its value fixed to
|
||||
``Literal["import"]``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
target : str
|
||||
Fully-qualified dotted path to the callable to instantiate,
|
||||
e.g. ``"mypackage.module.MyClass"``.
|
||||
arguments : dict[str, Any]
|
||||
Base keyword arguments forwarded to the callable. When the
|
||||
same key also appears in ``kwargs`` passed to ``build()``,
|
||||
the ``kwargs`` value takes priority.
|
||||
"""
|
||||
|
||||
target: str
|
||||
arguments: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
T_Import = TypeVar("T_Import", bound=ImportConfig)
|
||||
|
||||
|
||||
def add_import_config(
|
||||
registry: Registry[T_Type, P_Type],
|
||||
) -> Callable[[Type[T_Import]], Type[T_Import]]:
|
||||
"""Decorator that registers an ImportConfig subclass as an escape hatch.
|
||||
|
||||
Wraps the decorated class in a builder that calls
|
||||
``hydra.utils.instantiate`` using ``config.target`` and
|
||||
``config.arguments``. The builder is registered on *registry*
|
||||
under the discriminator value ``"import"``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
registry : Registry
|
||||
The registry instance on which the config should be registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable[[type[ImportConfig]], type[ImportConfig]]
|
||||
A class decorator that registers the class and returns it
|
||||
unchanged.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Define a per-registry import escape hatch::
|
||||
|
||||
@add_import_config(my_registry)
|
||||
class MyRegistryImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
"""
|
||||
|
||||
def decorator(config_cls: Type[T_Import]) -> Type[T_Import]:
|
||||
def builder(
|
||||
config: T_Import,
|
||||
*args: P_Type.args,
|
||||
**kwargs: P_Type.kwargs,
|
||||
) -> T_Type:
|
||||
if len(args) > 0:
|
||||
raise ValueError(
|
||||
"Positional arguments are not supported "
|
||||
"for import escape hatch."
|
||||
)
|
||||
|
||||
hydra_cfg = {
|
||||
"_target_": config.target,
|
||||
**config.arguments,
|
||||
**kwargs,
|
||||
}
|
||||
return instantiate(hydra_cfg)
|
||||
|
||||
registry.register(config_cls)(builder)
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
from batdetect2.core import Registry
|
||||
from typing import Literal
|
||||
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.data.annotations.types import AnnotationLoader
|
||||
|
||||
__all__ = [
|
||||
"AnnotationFormatImportConfig",
|
||||
"annotation_format_registry",
|
||||
]
|
||||
|
||||
@ -9,3 +12,24 @@ annotation_format_registry: Registry[AnnotationLoader, []] = Registry(
|
||||
"annotation_format",
|
||||
discriminator="format",
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(annotation_format_registry)
|
||||
class AnnotationFormatImportConfig(ImportConfig):
|
||||
"""Import escape hatch for the annotation format registry.
|
||||
|
||||
Use this config to dynamically instantiate any callable as an
|
||||
annotation loader without registering it in
|
||||
``annotation_format_registry`` ahead of time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
format : Literal["import"]
|
||||
Discriminator value; must always be ``"import"``.
|
||||
target : str
|
||||
Fully-qualified dotted path to the callable to instantiate.
|
||||
arguments : dict[str, Any]
|
||||
Keyword arguments forwarded to the callable.
|
||||
"""
|
||||
|
||||
format: Literal["import"] = "import"
|
||||
|
||||
@ -104,7 +104,7 @@ def create_dvclive_logger(
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||
from dvclive.lightning import DVCLiveLogger
|
||||
except ImportError as error:
|
||||
raise ValueError(
|
||||
"DVCLive is not installed and cannot be used for logging"
|
||||
|
||||
@ -4,6 +4,8 @@ Covers:
|
||||
- SimpleRegistry: registration, retrieval, and membership checks.
|
||||
- Registry: decorator-based registration, config type tracking,
|
||||
discriminator-based dispatch, and error handling.
|
||||
- ImportConfig base class and add_import_config decorator utility.
|
||||
- AnnotationFormatImportConfig: concrete per-registry escape hatch.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
@ -11,7 +13,12 @@ from typing import Literal
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from batdetect2.core.registries import Registry, SimpleRegistry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
SimpleRegistry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
|
||||
class TestSimpleRegistry:
|
||||
@ -292,3 +299,121 @@ class TestRegistryBuild:
|
||||
|
||||
with pytest.raises(NotImplementedError, match="my_registry"):
|
||||
registry.build(UnknownConfig())
|
||||
|
||||
|
||||
class TestAddImportConfig:
|
||||
def test_decorator_returns_config_class_unchanged(self):
|
||||
"""add_import_config returns the decorated class as-is."""
|
||||
|
||||
class MyImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
result = add_import_config(registry)(MyImportConfig)
|
||||
assert result is MyImportConfig
|
||||
|
||||
def test_registered_config_type_is_discoverable(self):
|
||||
"""After decoration, get_config_type('import') returns the class."""
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
|
||||
@add_import_config(registry)
|
||||
class MyImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
assert registry.get_config_type("import") is MyImportConfig
|
||||
|
||||
def test_build_instantiates_target(self):
|
||||
"""build() with a registered import config instantiates the target."""
|
||||
import collections
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
|
||||
@add_import_config(registry)
|
||||
class MyImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
config = MyImportConfig(target="collections.OrderedDict")
|
||||
result = registry.build(config)
|
||||
assert isinstance(result, collections.OrderedDict)
|
||||
|
||||
def test_build_forwards_arguments_to_target(self):
|
||||
"""build() passes config.arguments as kwargs to the target."""
|
||||
import decimal
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
|
||||
@add_import_config(registry)
|
||||
class MyImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
config = MyImportConfig(
|
||||
target="decimal.Decimal",
|
||||
arguments={"value": "3.14"},
|
||||
)
|
||||
result = registry.build(config)
|
||||
assert isinstance(result, decimal.Decimal)
|
||||
assert result == decimal.Decimal("3.14")
|
||||
|
||||
def test_build_kwargs_override_config_arguments(self):
|
||||
"""kwargs passed to build() win over same-key entries in arguments."""
|
||||
import decimal
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
|
||||
@add_import_config(registry)
|
||||
class MyImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
config = MyImportConfig(
|
||||
target="decimal.Decimal",
|
||||
arguments={"value": "1.00"},
|
||||
)
|
||||
result = registry.build(config, value="9.99")
|
||||
assert isinstance(result, decimal.Decimal)
|
||||
assert result == decimal.Decimal("9.99")
|
||||
|
||||
def test_build_bad_target_raises(self):
|
||||
"""build() raises when the dotted target path cannot be resolved."""
|
||||
from hydra.errors import InstantiationException
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
|
||||
@add_import_config(registry)
|
||||
class MyImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
config = MyImportConfig(target="nonexistent.module.DoesNotExist")
|
||||
with pytest.raises(InstantiationException):
|
||||
registry.build(config)
|
||||
|
||||
def test_works_with_custom_discriminator_field(self):
|
||||
"""add_import_config works for registries with a non-default discriminator."""
|
||||
import collections
|
||||
|
||||
registry: Registry[object, []] = Registry(
|
||||
"test",
|
||||
discriminator="format",
|
||||
)
|
||||
|
||||
@add_import_config(registry)
|
||||
class FormatImportConfig(ImportConfig):
|
||||
format: Literal["import"] = "import"
|
||||
|
||||
config = FormatImportConfig(target="collections.OrderedDict")
|
||||
result = registry.build(config)
|
||||
assert isinstance(result, collections.OrderedDict)
|
||||
|
||||
def test_coexists_with_other_registered_entries(self):
|
||||
"""The import config entry does not interfere with other entries."""
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
name: Literal["dummy"] = "dummy"
|
||||
|
||||
@add_import_config(registry)
|
||||
class MyImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
registry.register(DummyConfig)(lambda c: "dummy_result")
|
||||
|
||||
assert registry.build(DummyConfig()) == "dummy_result"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user