mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
420 lines
15 KiB
Python
420 lines
15 KiB
Python
"""Tests for the Registry and SimpleRegistry classes.
|
|
|
|
Covers:
|
|
- SimpleRegistry: registration, retrieval, and membership checks.
|
|
- Registry: decorator-based registration, config type tracking,
|
|
discriminator-based dispatch, and error handling.
|
|
- ImportConfig base class and add_import_config decorator utility.
|
|
- AnnotationFormatImportConfig: concrete per-registry escape hatch.
|
|
"""
|
|
|
|
from typing import Literal
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from batdetect2.core.registries import (
|
|
ImportConfig,
|
|
Registry,
|
|
SimpleRegistry,
|
|
add_import_config,
|
|
)
|
|
|
|
|
|
class TestSimpleRegistry:
|
|
def test_register_and_get(self):
|
|
"""Registered objects can be retrieved by name."""
|
|
registry = SimpleRegistry("test")
|
|
|
|
@registry.register("my_item")
|
|
def item():
|
|
return 42
|
|
|
|
assert registry.get("my_item")() == 42
|
|
|
|
def test_register_returns_original_object(self):
|
|
"""The register decorator returns the decorated object unchanged."""
|
|
registry = SimpleRegistry[int]("test")
|
|
|
|
@registry.register("x")
|
|
def fn() -> int:
|
|
return 7
|
|
|
|
assert fn() == 7
|
|
|
|
def test_has_returns_true_for_registered_name(self):
|
|
"""has() returns True for a name that was registered."""
|
|
registry = SimpleRegistry("test")
|
|
registry.register("present")(lambda: None)
|
|
assert registry.has("present") is True
|
|
|
|
def test_has_returns_false_for_unknown_name(self):
|
|
"""has() returns False for a name that was never registered."""
|
|
registry = SimpleRegistry("test")
|
|
assert registry.has("absent") is False
|
|
|
|
def test_get_raises_for_unknown_name(self):
|
|
"""get() raises KeyError for an unregistered name."""
|
|
registry = SimpleRegistry("test")
|
|
with pytest.raises(KeyError):
|
|
registry.get("nonexistent")
|
|
|
|
def test_register_overwrites_existing_entry(self):
|
|
"""Re-registering the same name replaces the previous entry."""
|
|
registry = SimpleRegistry("test")
|
|
registry.register("key")(lambda: 1)
|
|
registry.register("key")(lambda: 2)
|
|
assert registry.get("key")() == 2
|
|
|
|
def test_multiple_items_registered_independently(self):
|
|
"""Multiple items can be registered without interfering."""
|
|
registry = SimpleRegistry("test")
|
|
registry.register("a")(lambda: "a")
|
|
registry.register("b")(lambda: "b")
|
|
assert registry.get("a")() == "a"
|
|
assert registry.get("b")() == "b"
|
|
|
|
|
|
class TestRegistryRegister:
|
|
def test_register_makes_factory_callable_via_build(self):
|
|
"""A registered factory is reachable through build()."""
|
|
|
|
class DummyConfig(BaseModel):
|
|
name: Literal["dummy"] = "dummy"
|
|
value: int = 0
|
|
|
|
class DummyOutput:
|
|
def __init__(self, config: DummyConfig):
|
|
self.config = config
|
|
|
|
registry: Registry[DummyOutput, []] = Registry("test")
|
|
registry.register(DummyConfig)(lambda c: DummyOutput(c))
|
|
result = registry.build(DummyConfig(value=3))
|
|
assert isinstance(result, DummyOutput)
|
|
assert result.config.value == 3
|
|
|
|
def test_register_makes_config_type_retrievable(self):
|
|
"""A registered config type is reachable through get_config_type()."""
|
|
|
|
class DummyConfig(BaseModel):
|
|
name: Literal["dummy"] = "dummy"
|
|
value: int = 0
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
registry.register(DummyConfig)(lambda c: c)
|
|
assert registry.get_config_type("dummy") is DummyConfig
|
|
|
|
def test_register_raises_when_discriminator_field_missing(self):
|
|
"""ValueError is raised if config has no discriminator field."""
|
|
|
|
class ConfigWithoutDiscriminator(BaseModel):
|
|
unrelated_field: str = "hello"
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
with pytest.raises(ValueError, match="'name' field"):
|
|
registry.register(ConfigWithoutDiscriminator)(lambda c: c)
|
|
|
|
def test_register_raises_when_discriminator_is_not_string(self):
|
|
"""ValueError is raised if the discriminator default is not a str."""
|
|
|
|
class ConfigWithNonStringDiscriminator(BaseModel):
|
|
name: int = 42
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
with pytest.raises(ValueError, match="'name' field must be a string"):
|
|
registry.register(ConfigWithNonStringDiscriminator)(lambda c: c)
|
|
|
|
def test_register_uses_custom_discriminator_field(self):
|
|
"""Registry respects a non-default discriminator field name."""
|
|
|
|
class FormatConfig(BaseModel):
|
|
format: Literal["fmt"] = "fmt"
|
|
|
|
registry: Registry[object, []] = Registry(
|
|
"test", discriminator="format"
|
|
)
|
|
registry.register(FormatConfig)(lambda c: c)
|
|
assert registry.get_config_type("fmt") is FormatConfig
|
|
|
|
def test_register_decorator_returns_original_function(self):
|
|
"""The register decorator returns the wrapped function unchanged."""
|
|
|
|
class DummyConfig(BaseModel):
|
|
name: Literal["dummy"] = "dummy"
|
|
value: int = 0
|
|
|
|
class DummyOutput:
|
|
def __init__(self, config: DummyConfig):
|
|
self.config = config
|
|
|
|
registry: Registry[DummyOutput, []] = Registry("test")
|
|
|
|
def factory(config: DummyConfig) -> DummyOutput:
|
|
return DummyOutput(config)
|
|
|
|
returned = registry.register(DummyConfig)(factory)
|
|
assert returned is factory
|
|
|
|
|
|
class TestRegistryConfigTypes:
|
|
def test_get_config_types_returns_all_registered_types(self):
|
|
"""get_config_types() returns every registered config class."""
|
|
|
|
class DummyConfig(BaseModel):
|
|
name: Literal["dummy"] = "dummy"
|
|
|
|
class AnotherConfig(BaseModel):
|
|
name: Literal["another"] = "another"
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
registry.register(DummyConfig)(lambda c: c)
|
|
registry.register(AnotherConfig)(lambda c: c)
|
|
config_types = registry.get_config_types()
|
|
assert DummyConfig in config_types
|
|
assert AnotherConfig in config_types
|
|
|
|
def test_get_config_types_empty_when_nothing_registered(self):
|
|
"""get_config_types() returns empty tuple for a fresh registry."""
|
|
registry: Registry[object, []] = Registry("test")
|
|
assert registry.get_config_types() == ()
|
|
|
|
def test_get_config_type_returns_correct_class(self):
|
|
"""get_config_type() returns the class registered under a key."""
|
|
|
|
class DummyConfig(BaseModel):
|
|
name: Literal["dummy"] = "dummy"
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
registry.register(DummyConfig)(lambda c: c)
|
|
assert registry.get_config_type("dummy") is DummyConfig
|
|
|
|
def test_get_config_type_raises_for_unknown_key(self):
|
|
"""get_config_type() raises ValueError for an unregistered key."""
|
|
registry: Registry[object, []] = Registry("test")
|
|
with pytest.raises(
|
|
ValueError, match="No config type with name 'unknown'"
|
|
):
|
|
registry.get_config_type("unknown")
|
|
|
|
def test_get_config_type_error_message_lists_existing_keys(self):
|
|
"""ValueError message includes the names of registered keys."""
|
|
|
|
class DummyConfig(BaseModel):
|
|
name: Literal["dummy"] = "dummy"
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
registry.register(DummyConfig)(lambda c: c)
|
|
with pytest.raises(ValueError, match="dummy"):
|
|
registry.get_config_type("missing")
|
|
|
|
|
|
class TestRegistryBuild:
|
|
def test_build_dispatches_to_correct_factory(self):
|
|
"""build() calls the factory registered for the config's discriminator."""
|
|
|
|
class DummyConfig(BaseModel):
|
|
name: Literal["dummy"] = "dummy"
|
|
value: int = 0
|
|
|
|
class DummyOutput:
|
|
def __init__(self, config: DummyConfig):
|
|
self.config = config
|
|
|
|
registry: Registry[DummyOutput, []] = Registry("test")
|
|
registry.register(DummyConfig)(lambda c: DummyOutput(c))
|
|
|
|
config = DummyConfig(value=99)
|
|
result = registry.build(config)
|
|
|
|
assert isinstance(result, DummyOutput)
|
|
assert result.config.value == 99
|
|
|
|
def test_build_dispatches_to_correct_factory_among_multiple(self):
|
|
"""build() picks the right factory when several are registered."""
|
|
|
|
class DummyConfig(BaseModel):
|
|
name: Literal["dummy"] = "dummy"
|
|
|
|
class AnotherConfig(BaseModel):
|
|
name: Literal["another"] = "another"
|
|
|
|
class DummyOutput:
|
|
def __init__(self, config: DummyConfig):
|
|
self.config = config
|
|
|
|
class AnotherOutput:
|
|
def __init__(self, config: AnotherConfig):
|
|
self.config = config
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
registry.register(DummyConfig)(lambda c: DummyOutput(c))
|
|
registry.register(AnotherConfig)(lambda c: AnotherOutput(c))
|
|
|
|
assert isinstance(registry.build(DummyConfig()), DummyOutput)
|
|
assert isinstance(registry.build(AnotherConfig()), AnotherOutput)
|
|
|
|
def test_build_raises_not_implemented_for_unregistered_format(self):
|
|
"""build() raises NotImplementedError for an unregistered discriminator."""
|
|
registry: Registry[object, []] = Registry("test")
|
|
|
|
class UnknownConfig(BaseModel):
|
|
name: Literal["unknown"] = "unknown"
|
|
|
|
with pytest.raises(NotImplementedError, match="'unknown'"):
|
|
registry.build(UnknownConfig())
|
|
|
|
def test_build_passes_config_to_factory(self):
|
|
"""build() passes the exact config object through to the factory."""
|
|
|
|
class DummyConfig(BaseModel):
|
|
name: Literal["dummy"] = "dummy"
|
|
value: int = 0
|
|
|
|
registry: Registry[DummyConfig, []] = Registry("test")
|
|
received: list[DummyConfig] = []
|
|
registry.register(DummyConfig)(lambda c: received.append(c) or c)
|
|
|
|
config = DummyConfig(value=7)
|
|
registry.build(config)
|
|
|
|
assert received == [config]
|
|
|
|
def test_build_uses_custom_discriminator_field(self):
|
|
"""build() resolves the factory using the configured discriminator."""
|
|
|
|
class FormatConfig(BaseModel):
|
|
format: Literal["fmt"] = "fmt"
|
|
|
|
registry: Registry[str, []] = Registry("test", discriminator="format")
|
|
registry.register(FormatConfig)(lambda c: "fmt_result")
|
|
|
|
assert registry.build(FormatConfig()) == "fmt_result"
|
|
|
|
def test_build_error_message_includes_registry_name(self):
|
|
"""NotImplementedError message names the registry for easier debugging."""
|
|
registry: Registry[object, []] = Registry("my_registry")
|
|
|
|
class UnknownConfig(BaseModel):
|
|
name: Literal["ghost"] = "ghost"
|
|
|
|
with pytest.raises(NotImplementedError, match="my_registry"):
|
|
registry.build(UnknownConfig())
|
|
|
|
|
|
class TestAddImportConfig:
|
|
def test_decorator_returns_config_class_unchanged(self):
|
|
"""add_import_config returns the decorated class as-is."""
|
|
|
|
class MyImportConfig(ImportConfig):
|
|
name: Literal["import"] = "import"
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
result = add_import_config(registry)(MyImportConfig)
|
|
assert result is MyImportConfig
|
|
|
|
def test_registered_config_type_is_discoverable(self):
|
|
"""After decoration, get_config_type('import') returns the class."""
|
|
registry: Registry[object, []] = Registry("test")
|
|
|
|
@add_import_config(registry)
|
|
class MyImportConfig(ImportConfig):
|
|
name: Literal["import"] = "import"
|
|
|
|
assert registry.get_config_type("import") is MyImportConfig
|
|
|
|
def test_build_instantiates_target(self):
|
|
"""build() with a registered import config instantiates the target."""
|
|
import collections
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
|
|
@add_import_config(registry)
|
|
class MyImportConfig(ImportConfig):
|
|
name: Literal["import"] = "import"
|
|
|
|
config = MyImportConfig(target="collections.OrderedDict")
|
|
result = registry.build(config)
|
|
assert isinstance(result, collections.OrderedDict)
|
|
|
|
def test_build_forwards_arguments_to_target(self):
|
|
"""build() passes config.arguments as kwargs to the target."""
|
|
import decimal
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
|
|
@add_import_config(registry)
|
|
class MyImportConfig(ImportConfig):
|
|
name: Literal["import"] = "import"
|
|
|
|
config = MyImportConfig(
|
|
target="decimal.Decimal",
|
|
arguments={"value": "3.14"},
|
|
)
|
|
result = registry.build(config)
|
|
assert isinstance(result, decimal.Decimal)
|
|
assert result == decimal.Decimal("3.14")
|
|
|
|
def test_build_kwargs_override_config_arguments(self):
|
|
"""kwargs passed to build() win over same-key entries in arguments."""
|
|
import decimal
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
|
|
@add_import_config(registry)
|
|
class MyImportConfig(ImportConfig):
|
|
name: Literal["import"] = "import"
|
|
|
|
config = MyImportConfig(
|
|
target="decimal.Decimal",
|
|
arguments={"value": "1.00"},
|
|
)
|
|
result = registry.build(config, value="9.99")
|
|
assert isinstance(result, decimal.Decimal)
|
|
assert result == decimal.Decimal("9.99")
|
|
|
|
def test_build_bad_target_raises(self):
|
|
"""build() raises when the dotted target path cannot be resolved."""
|
|
from hydra.errors import InstantiationException
|
|
|
|
registry: Registry[object, []] = Registry("test")
|
|
|
|
@add_import_config(registry)
|
|
class MyImportConfig(ImportConfig):
|
|
name: Literal["import"] = "import"
|
|
|
|
config = MyImportConfig(target="nonexistent.module.DoesNotExist")
|
|
with pytest.raises(InstantiationException):
|
|
registry.build(config)
|
|
|
|
def test_works_with_custom_discriminator_field(self):
|
|
"""add_import_config works for registries with a non-default discriminator."""
|
|
import collections
|
|
|
|
registry: Registry[object, []] = Registry(
|
|
"test",
|
|
discriminator="format",
|
|
)
|
|
|
|
@add_import_config(registry)
|
|
class FormatImportConfig(ImportConfig):
|
|
format: Literal["import"] = "import"
|
|
|
|
config = FormatImportConfig(target="collections.OrderedDict")
|
|
result = registry.build(config)
|
|
assert isinstance(result, collections.OrderedDict)
|
|
|
|
def test_coexists_with_other_registered_entries(self):
|
|
"""The import config entry does not interfere with other entries."""
|
|
registry: Registry[object, []] = Registry("test")
|
|
|
|
class DummyConfig(BaseModel):
|
|
name: Literal["dummy"] = "dummy"
|
|
|
|
@add_import_config(registry)
|
|
class MyImportConfig(ImportConfig):
|
|
name: Literal["import"] = "import"
|
|
|
|
registry.register(DummyConfig)(lambda c: "dummy_result")
|
|
|
|
assert registry.build(DummyConfig()) == "dummy_result"
|