diff --git a/tests/test_core/__init__.py b/tests/test_core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_core/test_registry.py b/tests/test_core/test_registry.py new file mode 100644 index 0000000..505dbcb --- /dev/null +++ b/tests/test_core/test_registry.py @@ -0,0 +1,294 @@ +"""Tests for the Registry and SimpleRegistry classes. + +Covers: +- SimpleRegistry: registration, retrieval, and membership checks. +- Registry: decorator-based registration, config type tracking, + discriminator-based dispatch, and error handling. +""" + +from typing import Literal + +import pytest +from pydantic import BaseModel + +from batdetect2.core.registries import Registry, SimpleRegistry + + +class TestSimpleRegistry: + def test_register_and_get(self): + """Registered objects can be retrieved by name.""" + registry = SimpleRegistry("test") + + @registry.register("my_item") + def item(): + return 42 + + assert registry.get("my_item")() == 42 + + def test_register_returns_original_object(self): + """The register decorator returns the decorated object unchanged.""" + registry = SimpleRegistry[int]("test") + + @registry.register("x") + def fn() -> int: + return 7 + + assert fn() == 7 + + def test_has_returns_true_for_registered_name(self): + """has() returns True for a name that was registered.""" + registry = SimpleRegistry("test") + registry.register("present")(lambda: None) + assert registry.has("present") is True + + def test_has_returns_false_for_unknown_name(self): + """has() returns False for a name that was never registered.""" + registry = SimpleRegistry("test") + assert registry.has("absent") is False + + def test_get_raises_for_unknown_name(self): + """get() raises KeyError for an unregistered name.""" + registry = SimpleRegistry("test") + with pytest.raises(KeyError): + registry.get("nonexistent") + + def test_register_overwrites_existing_entry(self): + """Re-registering the same name replaces the previous entry.""" + registry = SimpleRegistry("test") + registry.register("key")(lambda: 1) + registry.register("key")(lambda: 2) + assert registry.get("key")() == 2 + + def test_multiple_items_registered_independently(self): + """Multiple items can be registered without interfering.""" + registry = SimpleRegistry("test") + registry.register("a")(lambda: "a") + registry.register("b")(lambda: "b") + assert registry.get("a")() == "a" + assert registry.get("b")() == "b" + + +class TestRegistryRegister: + def test_register_makes_factory_callable_via_build(self): + """A registered factory is reachable through build().""" + + class DummyConfig(BaseModel): + name: Literal["dummy"] = "dummy" + value: int = 0 + + class DummyOutput: + def __init__(self, config: DummyConfig): + self.config = config + + registry: Registry[DummyOutput, []] = Registry("test") + registry.register(DummyConfig)(lambda c: DummyOutput(c)) + result = registry.build(DummyConfig(value=3)) + assert isinstance(result, DummyOutput) + assert result.config.value == 3 + + def test_register_makes_config_type_retrievable(self): + """A registered config type is reachable through get_config_type().""" + + class DummyConfig(BaseModel): + name: Literal["dummy"] = "dummy" + value: int = 0 + + registry: Registry[object, []] = Registry("test") + registry.register(DummyConfig)(lambda c: c) + assert registry.get_config_type("dummy") is DummyConfig + + def test_register_raises_when_discriminator_field_missing(self): + """ValueError is raised if config has no discriminator field.""" + + class ConfigWithoutDiscriminator(BaseModel): + unrelated_field: str = "hello" + + registry: Registry[object, []] = Registry("test") + with pytest.raises(ValueError, match="'name' field"): + registry.register(ConfigWithoutDiscriminator)(lambda c: c) + + def test_register_raises_when_discriminator_is_not_string(self): + """ValueError is raised if the discriminator default is not a str.""" + + class ConfigWithNonStringDiscriminator(BaseModel): + name: int = 42 + + registry: Registry[object, []] = Registry("test") + with pytest.raises(ValueError, match="'name' field must be a string"): + registry.register(ConfigWithNonStringDiscriminator)(lambda c: c) + + def test_register_uses_custom_discriminator_field(self): + """Registry respects a non-default discriminator field name.""" + + class FormatConfig(BaseModel): + format: Literal["fmt"] = "fmt" + + registry: Registry[object, []] = Registry( + "test", discriminator="format" + ) + registry.register(FormatConfig)(lambda c: c) + assert registry.get_config_type("fmt") is FormatConfig + + def test_register_decorator_returns_original_function(self): + """The register decorator returns the wrapped function unchanged.""" + + class DummyConfig(BaseModel): + name: Literal["dummy"] = "dummy" + value: int = 0 + + class DummyOutput: + def __init__(self, config: DummyConfig): + self.config = config + + registry: Registry[DummyOutput, []] = Registry("test") + + def factory(config: DummyConfig) -> DummyOutput: + return DummyOutput(config) + + returned = registry.register(DummyConfig)(factory) + assert returned is factory + + +class TestRegistryConfigTypes: + def test_get_config_types_returns_all_registered_types(self): + """get_config_types() returns every registered config class.""" + + class DummyConfig(BaseModel): + name: Literal["dummy"] = "dummy" + + class AnotherConfig(BaseModel): + name: Literal["another"] = "another" + + registry: Registry[object, []] = Registry("test") + registry.register(DummyConfig)(lambda c: c) + registry.register(AnotherConfig)(lambda c: c) + config_types = registry.get_config_types() + assert DummyConfig in config_types + assert AnotherConfig in config_types + + def test_get_config_types_empty_when_nothing_registered(self): + """get_config_types() returns empty tuple for a fresh registry.""" + registry: Registry[object, []] = Registry("test") + assert registry.get_config_types() == () + + def test_get_config_type_returns_correct_class(self): + """get_config_type() returns the class registered under a key.""" + + class DummyConfig(BaseModel): + name: Literal["dummy"] = "dummy" + + registry: Registry[object, []] = Registry("test") + registry.register(DummyConfig)(lambda c: c) + assert registry.get_config_type("dummy") is DummyConfig + + def test_get_config_type_raises_for_unknown_key(self): + """get_config_type() raises ValueError for an unregistered key.""" + registry: Registry[object, []] = Registry("test") + with pytest.raises( + ValueError, match="No config type with name 'unknown'" + ): + registry.get_config_type("unknown") + + def test_get_config_type_error_message_lists_existing_keys(self): + """ValueError message includes the names of registered keys.""" + + class DummyConfig(BaseModel): + name: Literal["dummy"] = "dummy" + + registry: Registry[object, []] = Registry("test") + registry.register(DummyConfig)(lambda c: c) + with pytest.raises(ValueError, match="dummy"): + registry.get_config_type("missing") + + +class TestRegistryBuild: + def test_build_dispatches_to_correct_factory(self): + """build() calls the factory registered for the config's discriminator.""" + + class DummyConfig(BaseModel): + name: Literal["dummy"] = "dummy" + value: int = 0 + + class DummyOutput: + def __init__(self, config: DummyConfig): + self.config = config + + registry: Registry[DummyOutput, []] = Registry("test") + registry.register(DummyConfig)(lambda c: DummyOutput(c)) + + config = DummyConfig(value=99) + result = registry.build(config) + + assert isinstance(result, DummyOutput) + assert result.config.value == 99 + + def test_build_dispatches_to_correct_factory_among_multiple(self): + """build() picks the right factory when several are registered.""" + + class DummyConfig(BaseModel): + name: Literal["dummy"] = "dummy" + + class AnotherConfig(BaseModel): + name: Literal["another"] = "another" + + class DummyOutput: + def __init__(self, config: DummyConfig): + self.config = config + + class AnotherOutput: + def __init__(self, config: AnotherConfig): + self.config = config + + registry: Registry[object, []] = Registry("test") + registry.register(DummyConfig)(lambda c: DummyOutput(c)) + registry.register(AnotherConfig)(lambda c: AnotherOutput(c)) + + assert isinstance(registry.build(DummyConfig()), DummyOutput) + assert isinstance(registry.build(AnotherConfig()), AnotherOutput) + + def test_build_raises_not_implemented_for_unregistered_format(self): + """build() raises NotImplementedError for an unregistered discriminator.""" + registry: Registry[object, []] = Registry("test") + + class UnknownConfig(BaseModel): + name: Literal["unknown"] = "unknown" + + with pytest.raises(NotImplementedError, match="'unknown'"): + registry.build(UnknownConfig()) + + def test_build_passes_config_to_factory(self): + """build() passes the exact config object through to the factory.""" + + class DummyConfig(BaseModel): + name: Literal["dummy"] = "dummy" + value: int = 0 + + registry: Registry[DummyConfig, []] = Registry("test") + received: list[DummyConfig] = [] + registry.register(DummyConfig)(lambda c: received.append(c) or c) + + config = DummyConfig(value=7) + registry.build(config) + + assert received == [config] + + def test_build_uses_custom_discriminator_field(self): + """build() resolves the factory using the configured discriminator.""" + + class FormatConfig(BaseModel): + format: Literal["fmt"] = "fmt" + + registry: Registry[str, []] = Registry("test", discriminator="format") + registry.register(FormatConfig)(lambda c: "fmt_result") + + assert registry.build(FormatConfig()) == "fmt_result" + + def test_build_error_message_includes_registry_name(self): + """NotImplementedError message names the registry for easier debugging.""" + registry: Registry[object, []] = Registry("my_registry") + + class UnknownConfig(BaseModel): + name: Literal["ghost"] = "ghost" + + with pytest.raises(NotImplementedError, match="my_registry"): + registry.build(UnknownConfig())