batdetect2/tests/test_core/test_registry.py
2026-03-15 21:17:25 +00:00

295 lines
11 KiB
Python

"""Tests for the Registry and SimpleRegistry classes.
Covers:
- SimpleRegistry: registration, retrieval, and membership checks.
- Registry: decorator-based registration, config type tracking,
discriminator-based dispatch, and error handling.
"""
from typing import Literal
import pytest
from pydantic import BaseModel
from batdetect2.core.registries import Registry, SimpleRegistry
class TestSimpleRegistry:
def test_register_and_get(self):
"""Registered objects can be retrieved by name."""
registry = SimpleRegistry("test")
@registry.register("my_item")
def item():
return 42
assert registry.get("my_item")() == 42
def test_register_returns_original_object(self):
"""The register decorator returns the decorated object unchanged."""
registry = SimpleRegistry[int]("test")
@registry.register("x")
def fn() -> int:
return 7
assert fn() == 7
def test_has_returns_true_for_registered_name(self):
"""has() returns True for a name that was registered."""
registry = SimpleRegistry("test")
registry.register("present")(lambda: None)
assert registry.has("present") is True
def test_has_returns_false_for_unknown_name(self):
"""has() returns False for a name that was never registered."""
registry = SimpleRegistry("test")
assert registry.has("absent") is False
def test_get_raises_for_unknown_name(self):
"""get() raises KeyError for an unregistered name."""
registry = SimpleRegistry("test")
with pytest.raises(KeyError):
registry.get("nonexistent")
def test_register_overwrites_existing_entry(self):
"""Re-registering the same name replaces the previous entry."""
registry = SimpleRegistry("test")
registry.register("key")(lambda: 1)
registry.register("key")(lambda: 2)
assert registry.get("key")() == 2
def test_multiple_items_registered_independently(self):
"""Multiple items can be registered without interfering."""
registry = SimpleRegistry("test")
registry.register("a")(lambda: "a")
registry.register("b")(lambda: "b")
assert registry.get("a")() == "a"
assert registry.get("b")() == "b"
class TestRegistryRegister:
def test_register_makes_factory_callable_via_build(self):
"""A registered factory is reachable through build()."""
class DummyConfig(BaseModel):
name: Literal["dummy"] = "dummy"
value: int = 0
class DummyOutput:
def __init__(self, config: DummyConfig):
self.config = config
registry: Registry[DummyOutput, []] = Registry("test")
registry.register(DummyConfig)(lambda c: DummyOutput(c))
result = registry.build(DummyConfig(value=3))
assert isinstance(result, DummyOutput)
assert result.config.value == 3
def test_register_makes_config_type_retrievable(self):
"""A registered config type is reachable through get_config_type()."""
class DummyConfig(BaseModel):
name: Literal["dummy"] = "dummy"
value: int = 0
registry: Registry[object, []] = Registry("test")
registry.register(DummyConfig)(lambda c: c)
assert registry.get_config_type("dummy") is DummyConfig
def test_register_raises_when_discriminator_field_missing(self):
"""ValueError is raised if config has no discriminator field."""
class ConfigWithoutDiscriminator(BaseModel):
unrelated_field: str = "hello"
registry: Registry[object, []] = Registry("test")
with pytest.raises(ValueError, match="'name' field"):
registry.register(ConfigWithoutDiscriminator)(lambda c: c)
def test_register_raises_when_discriminator_is_not_string(self):
"""ValueError is raised if the discriminator default is not a str."""
class ConfigWithNonStringDiscriminator(BaseModel):
name: int = 42
registry: Registry[object, []] = Registry("test")
with pytest.raises(ValueError, match="'name' field must be a string"):
registry.register(ConfigWithNonStringDiscriminator)(lambda c: c)
def test_register_uses_custom_discriminator_field(self):
"""Registry respects a non-default discriminator field name."""
class FormatConfig(BaseModel):
format: Literal["fmt"] = "fmt"
registry: Registry[object, []] = Registry(
"test", discriminator="format"
)
registry.register(FormatConfig)(lambda c: c)
assert registry.get_config_type("fmt") is FormatConfig
def test_register_decorator_returns_original_function(self):
"""The register decorator returns the wrapped function unchanged."""
class DummyConfig(BaseModel):
name: Literal["dummy"] = "dummy"
value: int = 0
class DummyOutput:
def __init__(self, config: DummyConfig):
self.config = config
registry: Registry[DummyOutput, []] = Registry("test")
def factory(config: DummyConfig) -> DummyOutput:
return DummyOutput(config)
returned = registry.register(DummyConfig)(factory)
assert returned is factory
class TestRegistryConfigTypes:
def test_get_config_types_returns_all_registered_types(self):
"""get_config_types() returns every registered config class."""
class DummyConfig(BaseModel):
name: Literal["dummy"] = "dummy"
class AnotherConfig(BaseModel):
name: Literal["another"] = "another"
registry: Registry[object, []] = Registry("test")
registry.register(DummyConfig)(lambda c: c)
registry.register(AnotherConfig)(lambda c: c)
config_types = registry.get_config_types()
assert DummyConfig in config_types
assert AnotherConfig in config_types
def test_get_config_types_empty_when_nothing_registered(self):
"""get_config_types() returns empty tuple for a fresh registry."""
registry: Registry[object, []] = Registry("test")
assert registry.get_config_types() == ()
def test_get_config_type_returns_correct_class(self):
"""get_config_type() returns the class registered under a key."""
class DummyConfig(BaseModel):
name: Literal["dummy"] = "dummy"
registry: Registry[object, []] = Registry("test")
registry.register(DummyConfig)(lambda c: c)
assert registry.get_config_type("dummy") is DummyConfig
def test_get_config_type_raises_for_unknown_key(self):
"""get_config_type() raises ValueError for an unregistered key."""
registry: Registry[object, []] = Registry("test")
with pytest.raises(
ValueError, match="No config type with name 'unknown'"
):
registry.get_config_type("unknown")
def test_get_config_type_error_message_lists_existing_keys(self):
"""ValueError message includes the names of registered keys."""
class DummyConfig(BaseModel):
name: Literal["dummy"] = "dummy"
registry: Registry[object, []] = Registry("test")
registry.register(DummyConfig)(lambda c: c)
with pytest.raises(ValueError, match="dummy"):
registry.get_config_type("missing")
class TestRegistryBuild:
def test_build_dispatches_to_correct_factory(self):
"""build() calls the factory registered for the config's discriminator."""
class DummyConfig(BaseModel):
name: Literal["dummy"] = "dummy"
value: int = 0
class DummyOutput:
def __init__(self, config: DummyConfig):
self.config = config
registry: Registry[DummyOutput, []] = Registry("test")
registry.register(DummyConfig)(lambda c: DummyOutput(c))
config = DummyConfig(value=99)
result = registry.build(config)
assert isinstance(result, DummyOutput)
assert result.config.value == 99
def test_build_dispatches_to_correct_factory_among_multiple(self):
"""build() picks the right factory when several are registered."""
class DummyConfig(BaseModel):
name: Literal["dummy"] = "dummy"
class AnotherConfig(BaseModel):
name: Literal["another"] = "another"
class DummyOutput:
def __init__(self, config: DummyConfig):
self.config = config
class AnotherOutput:
def __init__(self, config: AnotherConfig):
self.config = config
registry: Registry[object, []] = Registry("test")
registry.register(DummyConfig)(lambda c: DummyOutput(c))
registry.register(AnotherConfig)(lambda c: AnotherOutput(c))
assert isinstance(registry.build(DummyConfig()), DummyOutput)
assert isinstance(registry.build(AnotherConfig()), AnotherOutput)
def test_build_raises_not_implemented_for_unregistered_format(self):
"""build() raises NotImplementedError for an unregistered discriminator."""
registry: Registry[object, []] = Registry("test")
class UnknownConfig(BaseModel):
name: Literal["unknown"] = "unknown"
with pytest.raises(NotImplementedError, match="'unknown'"):
registry.build(UnknownConfig())
def test_build_passes_config_to_factory(self):
"""build() passes the exact config object through to the factory."""
class DummyConfig(BaseModel):
name: Literal["dummy"] = "dummy"
value: int = 0
registry: Registry[DummyConfig, []] = Registry("test")
received: list[DummyConfig] = []
registry.register(DummyConfig)(lambda c: received.append(c) or c)
config = DummyConfig(value=7)
registry.build(config)
assert received == [config]
def test_build_uses_custom_discriminator_field(self):
"""build() resolves the factory using the configured discriminator."""
class FormatConfig(BaseModel):
format: Literal["fmt"] = "fmt"
registry: Registry[str, []] = Registry("test", discriminator="format")
registry.register(FormatConfig)(lambda c: "fmt_result")
assert registry.build(FormatConfig()) == "fmt_result"
def test_build_error_message_includes_registry_name(self):
"""NotImplementedError message names the registry for easier debugging."""
registry: Registry[object, []] = Registry("my_registry")
class UnknownConfig(BaseModel):
name: Literal["ghost"] = "ghost"
with pytest.raises(NotImplementedError, match="my_registry"):
registry.build(UnknownConfig())