diff --git a/example_data/config.yaml b/example_data/config.yaml index 99d0a10..9a95182 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -32,7 +32,6 @@ model: name: UNetBackbone input_height: 128 in_channels: 1 - out_channels: 32 encoder: layers: - name: FreqCoordConvDown diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py index f0274af..9b51dfa 100644 --- a/src/batdetect2/config.py +++ b/src/batdetect2/config.py @@ -12,7 +12,7 @@ from batdetect2.evaluate.config import ( get_default_eval_config, ) from batdetect2.inference.config import InferenceConfig -from batdetect2.models.backbones import BackboneConfig +from batdetect2.models.backbones import BackboneConfig, UNetBackboneConfig from batdetect2.postprocess.config import PostprocessConfig from batdetect2.preprocess.config import PreprocessingConfig from batdetect2.targets.config import TargetConfig @@ -32,7 +32,7 @@ class BatDetect2Config(BaseConfig): evaluation: EvaluationConfig = Field( default_factory=get_default_eval_config ) - model: BackboneConfig = Field(default_factory=BackboneConfig) + model: BackboneConfig = Field(default_factory=UNetBackboneConfig) preprocess: PreprocessingConfig = Field( default_factory=PreprocessingConfig ) diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index c39533c..60d4266 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -8,11 +8,11 @@ configuration data from files, with optional support for accessing nested configuration sections. """ -from typing import Any, Type, TypeVar +from typing import Any, Type, TypeVar, Union, overload import yaml from deepmerge.merger import Merger -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, TypeAdapter from soundevent.data import PathLike __all__ = [ @@ -67,7 +67,10 @@ class BaseConfig(BaseModel): return cls.model_validate(yaml.safe_load(yaml_str)) -T = TypeVar("T", bound=BaseModel) +T = TypeVar("T") +T_Model = TypeVar("T_Model", bound=BaseModel) + +Schema = Union[Type[T_Model], TypeAdapter[T]] def get_object_field(obj: dict, current_key: str) -> Any: @@ -129,24 +132,41 @@ def get_object_field(obj: dict, current_key: str) -> Any: return get_object_field(subobj, rest) +@overload def load_config( path: PathLike, - schema: Type[T], + schema: Type[T_Model], field: str | None = None, -) -> T: +) -> T_Model: ... + + +@overload +def load_config( + path: PathLike, + schema: TypeAdapter[T], + field: str | None = None, +) -> T: ... + + +def load_config( + path: PathLike, + schema: Type[T_Model] | TypeAdapter[T], + field: str | None = None, +) -> T_Model | T: """Load and validate configuration data from a file against a schema. Reads a YAML file, optionally extracts a specific section using dot notation, and then validates the resulting data against the provided - Pydantic `schema`. + Pydantic schema. Parameters ---------- path : PathLike The path to the configuration file (typically `.yaml`). - schema : Type[T] - The Pydantic `BaseModel` subclass that defines the expected structure - and types for the configuration data. + schema : Type[T_Model] | TypeAdapter[T] + Either a Pydantic `BaseModel` subclass or a `TypeAdapter` instance + that defines the expected structure and types for the configuration + data. field : str, optional A dot-separated string indicating a nested section within the YAML file to extract before validation. If None (default), the entire @@ -156,8 +176,8 @@ def load_config( Returns ------- - T - An instance of the provided `schema`, populated and validated with + T_Model | T + An instance of the schema type, populated and validated with data from the configuration file. Raises @@ -183,6 +203,9 @@ def load_config( if field: config = get_object_field(config, field) + if isinstance(schema, TypeAdapter): + return schema.validate_python(config or {}) + return schema.model_validate(config or {}) @@ -193,7 +216,7 @@ default_merger = Merger( ) -def merge_configs(config1: T, config2: T) -> T: +def merge_configs(config1: T_Model, config2: T_Model) -> T_Model: """Merge two configuration objects.""" model = type(config1) dict1 = config1.model_dump() diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 80e00ba..195e4a2 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -31,9 +31,9 @@ from typing import List import torch from batdetect2.models.backbones import ( + BackboneConfig, UNetBackbone, UNetBackboneConfig, - BackboneConfig, build_backbone, load_backbone_config, ) diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index 07bc4b8..a83555a 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -22,7 +22,7 @@ from typing import Annotated, Literal, Tuple, Union import torch import torch.nn.functional as F -from pydantic import Field +from pydantic import Field, TypeAdapter from soundevent import data from batdetect2.core.configs import BaseConfig, load_config @@ -57,7 +57,6 @@ class UNetBackboneConfig(BaseConfig): encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG decoder: DecoderConfig = DEFAULT_DECODER_CONFIG - out_channels: int = 32 backbone_registry: Registry[BackboneModel, []] = Registry("backbone") @@ -293,4 +292,8 @@ def load_backbone_config( path: data.PathLike, field: str | None = None, ) -> BackboneConfig: - return load_config(path, schema=BackboneConfig, field=field) + return load_config( + path, + schema=TypeAdapter(BackboneConfig), + field=field, + ) diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index fa42bdf..e21ce2f 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -27,7 +27,7 @@ A unified factory function `build_layer` allows creating instances of these blocks based on configuration objects. """ -from typing import Annotated, List, Literal, Protocol, Tuple, Union +from typing import Annotated, List, Literal, Tuple, Union import torch import torch.nn.functional as F diff --git a/src/batdetect2/models/bottleneck.py b/src/batdetect2/models/bottleneck.py index 290fb64..420ec2c 100644 --- a/src/batdetect2/models/bottleneck.py +++ b/src/batdetect2/models/bottleneck.py @@ -27,7 +27,6 @@ from batdetect2.models.blocks import ( VerticalConv, build_layer, ) - from batdetect2.typing.models import BottleneckProtocol __all__ = [ @@ -159,9 +158,7 @@ class BottleneckConfig(BaseConfig): """ channels: int - layers: List[BottleneckLayerConfig] = Field( - default_factory=list, - ) + layers: List[BottleneckLayerConfig] = Field(default_factory=list) DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig( diff --git a/tests/test_models/test_backbones.py b/tests/test_models/test_backbones.py new file mode 100644 index 0000000..f04fcfa --- /dev/null +++ b/tests/test_models/test_backbones.py @@ -0,0 +1,252 @@ +"""Tests for backbone configuration loading and the backbone registry. + +Covers: +- UNetBackboneConfig default construction and field values. +- build_backbone with default and explicit configs. +- load_backbone_config loading from a YAML file. +- load_backbone_config with a nested field path. +- load_backbone_config round-trip: YAML → config → build_backbone. +- Registry registration and dispatch for UNetBackbone. +- BackboneConfig discriminated union validation. +""" + +from pathlib import Path +from typing import Callable + +import pytest + +from batdetect2.models.backbones import ( + BackboneConfig, + UNetBackbone, + UNetBackboneConfig, + backbone_registry, + build_backbone, + load_backbone_config, +) +from batdetect2.typing.models import BackboneModel + +# --------------------------------------------------------------------------- +# UNetBackboneConfig +# --------------------------------------------------------------------------- + + +def test_unet_backbone_config_defaults(): + """Default config has expected field values.""" + config = UNetBackboneConfig() + + assert config.name == "UNetBackbone" + assert config.input_height == 128 + assert config.in_channels == 1 + + +def test_unet_backbone_config_custom_fields(): + """Custom field values are stored correctly.""" + config = UNetBackboneConfig(in_channels=2, input_height=64) + + assert config.in_channels == 2 + assert config.input_height == 64 + + +def test_unet_backbone_config_extra_fields_ignored(): + """Extra/unknown fields are silently ignored (BaseConfig behaviour).""" + config = UNetBackboneConfig.model_validate( + {"name": "UNetBackbone", "unknown_field": 99} + ) + + assert config.name == "UNetBackbone" + assert not hasattr(config, "unknown_field") + + +# --------------------------------------------------------------------------- +# build_backbone +# --------------------------------------------------------------------------- + + +def test_build_backbone_default(): + """Building with no config uses UNetBackbone defaults.""" + backbone = build_backbone() + + assert isinstance(backbone, UNetBackbone) + assert backbone.input_height == 128 + + +def test_build_backbone_custom_config(): + """Building with a custom config propagates input_height and in_channels.""" + config = UNetBackboneConfig(in_channels=2, input_height=64) + backbone = build_backbone(config) + + assert isinstance(backbone, UNetBackbone) + assert backbone.input_height == 64 + assert backbone.encoder.in_channels == 2 + + +def test_build_backbone_returns_backbone_model(): + """build_backbone always returns a BackboneModel instance.""" + backbone = build_backbone() + + assert isinstance(backbone, BackboneModel) + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + + +def test_registry_has_unet_backbone(): + """The backbone registry has UNetBackbone registered.""" + config_types = backbone_registry.get_config_types() + + assert UNetBackboneConfig in config_types + + +def test_registry_config_type_is_unet_backbone_config(): + """The config type stored for UNetBackbone is UNetBackboneConfig.""" + config_type = backbone_registry.get_config_type("UNetBackbone") + + assert config_type is UNetBackboneConfig + + +def test_registry_build_dispatches_correctly(): + """Registry.build dispatches to UNetBackbone.from_config.""" + config = UNetBackboneConfig(input_height=128) + backbone = backbone_registry.build(config) + + assert isinstance(backbone, UNetBackbone) + + +def test_registry_build_unknown_name_raises(): + """Registry.build raises NotImplementedError for an unknown config name.""" + + class FakeConfig: + name = "NonExistentBackbone" + + with pytest.raises(NotImplementedError): + backbone_registry.build(FakeConfig()) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# BackboneConfig discriminated union +# --------------------------------------------------------------------------- + + +def test_backbone_config_validates_unet_from_dict(): + """BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name.""" + from pydantic import TypeAdapter + + adapter = TypeAdapter(BackboneConfig) + config = adapter.validate_python( + {"name": "UNetBackbone", "input_height": 64} + ) + + assert isinstance(config, UNetBackboneConfig) + assert config.input_height == 64 + + +def test_backbone_config_invalid_name_raises(): + """BackboneConfig validation raises for an unknown name discriminator.""" + from pydantic import TypeAdapter, ValidationError + + adapter = TypeAdapter(BackboneConfig) + with pytest.raises(ValidationError): + adapter.validate_python({"name": "NonExistentBackbone"}) + + +# --------------------------------------------------------------------------- +# load_backbone_config +# --------------------------------------------------------------------------- + + +def test_load_backbone_config_from_yaml( + create_temp_yaml: Callable[[str], Path], +): + """load_backbone_config loads a UNetBackboneConfig from a YAML file.""" + yaml_content = """\ +name: UNetBackbone +input_height: 64 +in_channels: 2 +""" + path = create_temp_yaml(yaml_content) + config = load_backbone_config(path) + + assert isinstance(config, UNetBackboneConfig) + assert config.input_height == 64 + assert config.in_channels == 2 + + +def test_load_backbone_config_with_field( + create_temp_yaml: Callable[[str], Path], +): + """load_backbone_config extracts a nested field before validation.""" + yaml_content = """\ +model: + name: UNetBackbone + input_height: 32 +""" + path = create_temp_yaml(yaml_content) + config = load_backbone_config(path, field="model") + + assert isinstance(config, UNetBackboneConfig) + assert config.input_height == 32 + + +def test_load_backbone_config_defaults_on_minimal_yaml( + create_temp_yaml: Callable[[str], Path], +): + """Minimal YAML with only name fills remaining fields with defaults.""" + yaml_content = "name: UNetBackbone\n" + path = create_temp_yaml(yaml_content) + config = load_backbone_config(path) + + assert isinstance(config, UNetBackboneConfig) + assert config.input_height == UNetBackboneConfig().input_height + assert config.in_channels == UNetBackboneConfig().in_channels + + +def test_load_backbone_config_extra_fields_ignored( + create_temp_yaml: Callable[[str], Path], +): + """Extra YAML fields are silently ignored when loading backbone config.""" + yaml_content = """\ +name: UNetBackbone +input_height: 128 +deprecated_field: 99 +""" + path = create_temp_yaml(yaml_content) + config = load_backbone_config(path) + + assert isinstance(config, UNetBackboneConfig) + assert config.input_height == 128 + + +# --------------------------------------------------------------------------- +# Round-trip: YAML → config → build_backbone +# --------------------------------------------------------------------------- + + +def test_round_trip_yaml_to_build_backbone( + create_temp_yaml: Callable[[str], Path], +): + """A backbone config loaded from YAML can be used directly with build_backbone.""" + yaml_content = """\ +name: UNetBackbone +input_height: 128 +in_channels: 1 +""" + path = create_temp_yaml(yaml_content) + config = load_backbone_config(path) + backbone = build_backbone(config) + + assert isinstance(backbone, UNetBackbone) + assert backbone.input_height == 128 + + +def test_load_backbone_config_from_example_data(example_data_dir: Path): + """load_backbone_config loads the real example config correctly.""" + config = load_backbone_config( + example_data_dir / "config.yaml", + field="model", + ) + + assert isinstance(config, UNetBackboneConfig) + assert config.input_height == 128 + assert config.in_channels == 1 diff --git a/tests/test_models/test_detectors.py b/tests/test_models/test_detectors.py index a2b4b69..0edb694 100644 --- a/tests/test_models/test_detectors.py +++ b/tests/test_models/test_detectors.py @@ -2,7 +2,7 @@ import numpy as np import pytest import torch -from batdetect2.models.backbones import BackboneConfig +from batdetect2.models.backbones import UNetBackboneConfig from batdetect2.models.detectors import Detector, build_detector from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.typing.models import ModelOutput @@ -28,7 +28,7 @@ def test_build_detector_default(): def test_build_detector_custom_config(): """Test building a detector with a custom BackboneConfig.""" num_classes = 3 - config = BackboneConfig(in_channels=2, input_height=128) + config = UNetBackboneConfig(in_channels=2, input_height=128) model = build_detector(num_classes=num_classes, config=config) @@ -41,7 +41,7 @@ def test_detector_forward_pass_shapes(dummy_spectrogram): """Test that the forward pass produces correctly shaped outputs.""" num_classes = 4 # Build model matching the dummy input shape - config = BackboneConfig(in_channels=1, input_height=256) + config = UNetBackboneConfig(in_channels=1, input_height=256) model = build_detector(num_classes=num_classes, config=config) # Process the spectrogram through the model @@ -108,7 +108,7 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor): # Build model matching the preprocessor's output shape # The preprocessor output is (B, C, H, W) -> spec.shape[1] is C, spec.shape[2] is H - config = BackboneConfig( + config = UNetBackboneConfig( in_channels=spec.shape[1], input_height=spec.shape[2] ) model = build_detector(num_classes=3, config=config)