mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add tests for registry
This commit is contained in:
parent
e0503487ec
commit
197cc38e3e
0
tests/test_core/__init__.py
Normal file
0
tests/test_core/__init__.py
Normal file
294
tests/test_core/test_registry.py
Normal file
294
tests/test_core/test_registry.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""Tests for the Registry and SimpleRegistry classes.
|
||||
|
||||
Covers:
|
||||
- SimpleRegistry: registration, retrieval, and membership checks.
|
||||
- Registry: decorator-based registration, config type tracking,
|
||||
discriminator-based dispatch, and error handling.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from batdetect2.core.registries import Registry, SimpleRegistry
|
||||
|
||||
|
||||
class TestSimpleRegistry:
|
||||
def test_register_and_get(self):
|
||||
"""Registered objects can be retrieved by name."""
|
||||
registry = SimpleRegistry("test")
|
||||
|
||||
@registry.register("my_item")
|
||||
def item():
|
||||
return 42
|
||||
|
||||
assert registry.get("my_item")() == 42
|
||||
|
||||
def test_register_returns_original_object(self):
|
||||
"""The register decorator returns the decorated object unchanged."""
|
||||
registry = SimpleRegistry[int]("test")
|
||||
|
||||
@registry.register("x")
|
||||
def fn() -> int:
|
||||
return 7
|
||||
|
||||
assert fn() == 7
|
||||
|
||||
def test_has_returns_true_for_registered_name(self):
|
||||
"""has() returns True for a name that was registered."""
|
||||
registry = SimpleRegistry("test")
|
||||
registry.register("present")(lambda: None)
|
||||
assert registry.has("present") is True
|
||||
|
||||
def test_has_returns_false_for_unknown_name(self):
|
||||
"""has() returns False for a name that was never registered."""
|
||||
registry = SimpleRegistry("test")
|
||||
assert registry.has("absent") is False
|
||||
|
||||
def test_get_raises_for_unknown_name(self):
|
||||
"""get() raises KeyError for an unregistered name."""
|
||||
registry = SimpleRegistry("test")
|
||||
with pytest.raises(KeyError):
|
||||
registry.get("nonexistent")
|
||||
|
||||
def test_register_overwrites_existing_entry(self):
|
||||
"""Re-registering the same name replaces the previous entry."""
|
||||
registry = SimpleRegistry("test")
|
||||
registry.register("key")(lambda: 1)
|
||||
registry.register("key")(lambda: 2)
|
||||
assert registry.get("key")() == 2
|
||||
|
||||
def test_multiple_items_registered_independently(self):
|
||||
"""Multiple items can be registered without interfering."""
|
||||
registry = SimpleRegistry("test")
|
||||
registry.register("a")(lambda: "a")
|
||||
registry.register("b")(lambda: "b")
|
||||
assert registry.get("a")() == "a"
|
||||
assert registry.get("b")() == "b"
|
||||
|
||||
|
||||
class TestRegistryRegister:
|
||||
def test_register_makes_factory_callable_via_build(self):
|
||||
"""A registered factory is reachable through build()."""
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
name: Literal["dummy"] = "dummy"
|
||||
value: int = 0
|
||||
|
||||
class DummyOutput:
|
||||
def __init__(self, config: DummyConfig):
|
||||
self.config = config
|
||||
|
||||
registry: Registry[DummyOutput, []] = Registry("test")
|
||||
registry.register(DummyConfig)(lambda c: DummyOutput(c))
|
||||
result = registry.build(DummyConfig(value=3))
|
||||
assert isinstance(result, DummyOutput)
|
||||
assert result.config.value == 3
|
||||
|
||||
def test_register_makes_config_type_retrievable(self):
|
||||
"""A registered config type is reachable through get_config_type()."""
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
name: Literal["dummy"] = "dummy"
|
||||
value: int = 0
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
registry.register(DummyConfig)(lambda c: c)
|
||||
assert registry.get_config_type("dummy") is DummyConfig
|
||||
|
||||
def test_register_raises_when_discriminator_field_missing(self):
|
||||
"""ValueError is raised if config has no discriminator field."""
|
||||
|
||||
class ConfigWithoutDiscriminator(BaseModel):
|
||||
unrelated_field: str = "hello"
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
with pytest.raises(ValueError, match="'name' field"):
|
||||
registry.register(ConfigWithoutDiscriminator)(lambda c: c)
|
||||
|
||||
def test_register_raises_when_discriminator_is_not_string(self):
|
||||
"""ValueError is raised if the discriminator default is not a str."""
|
||||
|
||||
class ConfigWithNonStringDiscriminator(BaseModel):
|
||||
name: int = 42
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
with pytest.raises(ValueError, match="'name' field must be a string"):
|
||||
registry.register(ConfigWithNonStringDiscriminator)(lambda c: c)
|
||||
|
||||
def test_register_uses_custom_discriminator_field(self):
|
||||
"""Registry respects a non-default discriminator field name."""
|
||||
|
||||
class FormatConfig(BaseModel):
|
||||
format: Literal["fmt"] = "fmt"
|
||||
|
||||
registry: Registry[object, []] = Registry(
|
||||
"test", discriminator="format"
|
||||
)
|
||||
registry.register(FormatConfig)(lambda c: c)
|
||||
assert registry.get_config_type("fmt") is FormatConfig
|
||||
|
||||
def test_register_decorator_returns_original_function(self):
|
||||
"""The register decorator returns the wrapped function unchanged."""
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
name: Literal["dummy"] = "dummy"
|
||||
value: int = 0
|
||||
|
||||
class DummyOutput:
|
||||
def __init__(self, config: DummyConfig):
|
||||
self.config = config
|
||||
|
||||
registry: Registry[DummyOutput, []] = Registry("test")
|
||||
|
||||
def factory(config: DummyConfig) -> DummyOutput:
|
||||
return DummyOutput(config)
|
||||
|
||||
returned = registry.register(DummyConfig)(factory)
|
||||
assert returned is factory
|
||||
|
||||
|
||||
class TestRegistryConfigTypes:
|
||||
def test_get_config_types_returns_all_registered_types(self):
|
||||
"""get_config_types() returns every registered config class."""
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
name: Literal["dummy"] = "dummy"
|
||||
|
||||
class AnotherConfig(BaseModel):
|
||||
name: Literal["another"] = "another"
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
registry.register(DummyConfig)(lambda c: c)
|
||||
registry.register(AnotherConfig)(lambda c: c)
|
||||
config_types = registry.get_config_types()
|
||||
assert DummyConfig in config_types
|
||||
assert AnotherConfig in config_types
|
||||
|
||||
def test_get_config_types_empty_when_nothing_registered(self):
|
||||
"""get_config_types() returns empty tuple for a fresh registry."""
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
assert registry.get_config_types() == ()
|
||||
|
||||
def test_get_config_type_returns_correct_class(self):
|
||||
"""get_config_type() returns the class registered under a key."""
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
name: Literal["dummy"] = "dummy"
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
registry.register(DummyConfig)(lambda c: c)
|
||||
assert registry.get_config_type("dummy") is DummyConfig
|
||||
|
||||
def test_get_config_type_raises_for_unknown_key(self):
|
||||
"""get_config_type() raises ValueError for an unregistered key."""
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
with pytest.raises(
|
||||
ValueError, match="No config type with name 'unknown'"
|
||||
):
|
||||
registry.get_config_type("unknown")
|
||||
|
||||
def test_get_config_type_error_message_lists_existing_keys(self):
|
||||
"""ValueError message includes the names of registered keys."""
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
name: Literal["dummy"] = "dummy"
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
registry.register(DummyConfig)(lambda c: c)
|
||||
with pytest.raises(ValueError, match="dummy"):
|
||||
registry.get_config_type("missing")
|
||||
|
||||
|
||||
class TestRegistryBuild:
|
||||
def test_build_dispatches_to_correct_factory(self):
|
||||
"""build() calls the factory registered for the config's discriminator."""
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
name: Literal["dummy"] = "dummy"
|
||||
value: int = 0
|
||||
|
||||
class DummyOutput:
|
||||
def __init__(self, config: DummyConfig):
|
||||
self.config = config
|
||||
|
||||
registry: Registry[DummyOutput, []] = Registry("test")
|
||||
registry.register(DummyConfig)(lambda c: DummyOutput(c))
|
||||
|
||||
config = DummyConfig(value=99)
|
||||
result = registry.build(config)
|
||||
|
||||
assert isinstance(result, DummyOutput)
|
||||
assert result.config.value == 99
|
||||
|
||||
def test_build_dispatches_to_correct_factory_among_multiple(self):
|
||||
"""build() picks the right factory when several are registered."""
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
name: Literal["dummy"] = "dummy"
|
||||
|
||||
class AnotherConfig(BaseModel):
|
||||
name: Literal["another"] = "another"
|
||||
|
||||
class DummyOutput:
|
||||
def __init__(self, config: DummyConfig):
|
||||
self.config = config
|
||||
|
||||
class AnotherOutput:
|
||||
def __init__(self, config: AnotherConfig):
|
||||
self.config = config
|
||||
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
registry.register(DummyConfig)(lambda c: DummyOutput(c))
|
||||
registry.register(AnotherConfig)(lambda c: AnotherOutput(c))
|
||||
|
||||
assert isinstance(registry.build(DummyConfig()), DummyOutput)
|
||||
assert isinstance(registry.build(AnotherConfig()), AnotherOutput)
|
||||
|
||||
def test_build_raises_not_implemented_for_unregistered_format(self):
|
||||
"""build() raises NotImplementedError for an unregistered discriminator."""
|
||||
registry: Registry[object, []] = Registry("test")
|
||||
|
||||
class UnknownConfig(BaseModel):
|
||||
name: Literal["unknown"] = "unknown"
|
||||
|
||||
with pytest.raises(NotImplementedError, match="'unknown'"):
|
||||
registry.build(UnknownConfig())
|
||||
|
||||
def test_build_passes_config_to_factory(self):
|
||||
"""build() passes the exact config object through to the factory."""
|
||||
|
||||
class DummyConfig(BaseModel):
|
||||
name: Literal["dummy"] = "dummy"
|
||||
value: int = 0
|
||||
|
||||
registry: Registry[DummyConfig, []] = Registry("test")
|
||||
received: list[DummyConfig] = []
|
||||
registry.register(DummyConfig)(lambda c: received.append(c) or c)
|
||||
|
||||
config = DummyConfig(value=7)
|
||||
registry.build(config)
|
||||
|
||||
assert received == [config]
|
||||
|
||||
def test_build_uses_custom_discriminator_field(self):
|
||||
"""build() resolves the factory using the configured discriminator."""
|
||||
|
||||
class FormatConfig(BaseModel):
|
||||
format: Literal["fmt"] = "fmt"
|
||||
|
||||
registry: Registry[str, []] = Registry("test", discriminator="format")
|
||||
registry.register(FormatConfig)(lambda c: "fmt_result")
|
||||
|
||||
assert registry.build(FormatConfig()) == "fmt_result"
|
||||
|
||||
def test_build_error_message_includes_registry_name(self):
|
||||
"""NotImplementedError message names the registry for easier debugging."""
|
||||
registry: Registry[object, []] = Registry("my_registry")
|
||||
|
||||
class UnknownConfig(BaseModel):
|
||||
name: Literal["ghost"] = "ghost"
|
||||
|
||||
with pytest.raises(NotImplementedError, match="my_registry"):
|
||||
registry.build(UnknownConfig())
|
||||
Loading…
Reference in New Issue
Block a user