Add test for backbones

This commit is contained in:
mbsantiago 2026-03-08 16:04:54 +00:00
parent 45e3cf1434
commit 4207661da4
9 changed files with 302 additions and 28 deletions

View File

@ -32,7 +32,6 @@ model:
name: UNetBackbone
input_height: 128
in_channels: 1
out_channels: 32
encoder:
layers:
- name: FreqCoordConvDown

View File

@ -12,7 +12,7 @@ from batdetect2.evaluate.config import (
get_default_eval_config,
)
from batdetect2.inference.config import InferenceConfig
from batdetect2.models.backbones import BackboneConfig
from batdetect2.models.backbones import BackboneConfig, UNetBackboneConfig
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig
@ -32,7 +32,7 @@ class BatDetect2Config(BaseConfig):
evaluation: EvaluationConfig = Field(
default_factory=get_default_eval_config
)
model: BackboneConfig = Field(default_factory=BackboneConfig)
model: BackboneConfig = Field(default_factory=UNetBackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)

View File

@ -8,11 +8,11 @@ configuration data from files, with optional support for accessing nested
configuration sections.
"""
from typing import Any, Type, TypeVar
from typing import Any, Type, TypeVar, Union, overload
import yaml
from deepmerge.merger import Merger
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, TypeAdapter
from soundevent.data import PathLike
__all__ = [
@ -67,7 +67,10 @@ class BaseConfig(BaseModel):
return cls.model_validate(yaml.safe_load(yaml_str))
T = TypeVar("T", bound=BaseModel)
T = TypeVar("T")
T_Model = TypeVar("T_Model", bound=BaseModel)
Schema = Union[Type[T_Model], TypeAdapter[T]]
def get_object_field(obj: dict, current_key: str) -> Any:
@ -129,24 +132,41 @@ def get_object_field(obj: dict, current_key: str) -> Any:
return get_object_field(subobj, rest)
@overload
def load_config(
path: PathLike,
schema: Type[T],
schema: Type[T_Model],
field: str | None = None,
) -> T:
) -> T_Model: ...
@overload
def load_config(
path: PathLike,
schema: TypeAdapter[T],
field: str | None = None,
) -> T: ...
def load_config(
path: PathLike,
schema: Type[T_Model] | TypeAdapter[T],
field: str | None = None,
) -> T_Model | T:
"""Load and validate configuration data from a file against a schema.
Reads a YAML file, optionally extracts a specific section using dot
notation, and then validates the resulting data against the provided
Pydantic `schema`.
Pydantic schema.
Parameters
----------
path : PathLike
The path to the configuration file (typically `.yaml`).
schema : Type[T]
The Pydantic `BaseModel` subclass that defines the expected structure
and types for the configuration data.
schema : Type[T_Model] | TypeAdapter[T]
Either a Pydantic `BaseModel` subclass or a `TypeAdapter` instance
that defines the expected structure and types for the configuration
data.
field : str, optional
A dot-separated string indicating a nested section within the YAML
file to extract before validation. If None (default), the entire
@ -156,8 +176,8 @@ def load_config(
Returns
-------
T
An instance of the provided `schema`, populated and validated with
T_Model | T
An instance of the schema type, populated and validated with
data from the configuration file.
Raises
@ -183,6 +203,9 @@ def load_config(
if field:
config = get_object_field(config, field)
if isinstance(schema, TypeAdapter):
return schema.validate_python(config or {})
return schema.model_validate(config or {})
@ -193,7 +216,7 @@ default_merger = Merger(
)
def merge_configs(config1: T, config2: T) -> T:
def merge_configs(config1: T_Model, config2: T_Model) -> T_Model:
"""Merge two configuration objects."""
model = type(config1)
dict1 = config1.model_dump()

View File

@ -31,9 +31,9 @@ from typing import List
import torch
from batdetect2.models.backbones import (
BackboneConfig,
UNetBackbone,
UNetBackboneConfig,
BackboneConfig,
build_backbone,
load_backbone_config,
)

View File

@ -22,7 +22,7 @@ from typing import Annotated, Literal, Tuple, Union
import torch
import torch.nn.functional as F
from pydantic import Field
from pydantic import Field, TypeAdapter
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
@ -57,7 +57,6 @@ class UNetBackboneConfig(BaseConfig):
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
@ -293,4 +292,8 @@ def load_backbone_config(
path: data.PathLike,
field: str | None = None,
) -> BackboneConfig:
return load_config(path, schema=BackboneConfig, field=field)
return load_config(
path,
schema=TypeAdapter(BackboneConfig),
field=field,
)

View File

@ -27,7 +27,7 @@ A unified factory function `build_layer` allows creating instances
of these blocks based on configuration objects.
"""
from typing import Annotated, List, Literal, Protocol, Tuple, Union
from typing import Annotated, List, Literal, Tuple, Union
import torch
import torch.nn.functional as F

View File

@ -27,7 +27,6 @@ from batdetect2.models.blocks import (
VerticalConv,
build_layer,
)
from batdetect2.typing.models import BottleneckProtocol
__all__ = [
@ -159,9 +158,7 @@ class BottleneckConfig(BaseConfig):
"""
channels: int
layers: List[BottleneckLayerConfig] = Field(
default_factory=list,
)
layers: List[BottleneckLayerConfig] = Field(default_factory=list)
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(

View File

@ -0,0 +1,252 @@
"""Tests for backbone configuration loading and the backbone registry.
Covers:
- UNetBackboneConfig default construction and field values.
- build_backbone with default and explicit configs.
- load_backbone_config loading from a YAML file.
- load_backbone_config with a nested field path.
- load_backbone_config round-trip: YAML config build_backbone.
- Registry registration and dispatch for UNetBackbone.
- BackboneConfig discriminated union validation.
"""
from pathlib import Path
from typing import Callable
import pytest
from batdetect2.models.backbones import (
BackboneConfig,
UNetBackbone,
UNetBackboneConfig,
backbone_registry,
build_backbone,
load_backbone_config,
)
from batdetect2.typing.models import BackboneModel
# ---------------------------------------------------------------------------
# UNetBackboneConfig
# ---------------------------------------------------------------------------
def test_unet_backbone_config_defaults():
"""Default config has expected field values."""
config = UNetBackboneConfig()
assert config.name == "UNetBackbone"
assert config.input_height == 128
assert config.in_channels == 1
def test_unet_backbone_config_custom_fields():
"""Custom field values are stored correctly."""
config = UNetBackboneConfig(in_channels=2, input_height=64)
assert config.in_channels == 2
assert config.input_height == 64
def test_unet_backbone_config_extra_fields_ignored():
"""Extra/unknown fields are silently ignored (BaseConfig behaviour)."""
config = UNetBackboneConfig.model_validate(
{"name": "UNetBackbone", "unknown_field": 99}
)
assert config.name == "UNetBackbone"
assert not hasattr(config, "unknown_field")
# ---------------------------------------------------------------------------
# build_backbone
# ---------------------------------------------------------------------------
def test_build_backbone_default():
"""Building with no config uses UNetBackbone defaults."""
backbone = build_backbone()
assert isinstance(backbone, UNetBackbone)
assert backbone.input_height == 128
def test_build_backbone_custom_config():
"""Building with a custom config propagates input_height and in_channels."""
config = UNetBackboneConfig(in_channels=2, input_height=64)
backbone = build_backbone(config)
assert isinstance(backbone, UNetBackbone)
assert backbone.input_height == 64
assert backbone.encoder.in_channels == 2
def test_build_backbone_returns_backbone_model():
"""build_backbone always returns a BackboneModel instance."""
backbone = build_backbone()
assert isinstance(backbone, BackboneModel)
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
def test_registry_has_unet_backbone():
"""The backbone registry has UNetBackbone registered."""
config_types = backbone_registry.get_config_types()
assert UNetBackboneConfig in config_types
def test_registry_config_type_is_unet_backbone_config():
"""The config type stored for UNetBackbone is UNetBackboneConfig."""
config_type = backbone_registry.get_config_type("UNetBackbone")
assert config_type is UNetBackboneConfig
def test_registry_build_dispatches_correctly():
"""Registry.build dispatches to UNetBackbone.from_config."""
config = UNetBackboneConfig(input_height=128)
backbone = backbone_registry.build(config)
assert isinstance(backbone, UNetBackbone)
def test_registry_build_unknown_name_raises():
"""Registry.build raises NotImplementedError for an unknown config name."""
class FakeConfig:
name = "NonExistentBackbone"
with pytest.raises(NotImplementedError):
backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# BackboneConfig discriminated union
# ---------------------------------------------------------------------------
def test_backbone_config_validates_unet_from_dict():
"""BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name."""
from pydantic import TypeAdapter
adapter = TypeAdapter(BackboneConfig)
config = adapter.validate_python(
{"name": "UNetBackbone", "input_height": 64}
)
assert isinstance(config, UNetBackboneConfig)
assert config.input_height == 64
def test_backbone_config_invalid_name_raises():
"""BackboneConfig validation raises for an unknown name discriminator."""
from pydantic import TypeAdapter, ValidationError
adapter = TypeAdapter(BackboneConfig)
with pytest.raises(ValidationError):
adapter.validate_python({"name": "NonExistentBackbone"})
# ---------------------------------------------------------------------------
# load_backbone_config
# ---------------------------------------------------------------------------
def test_load_backbone_config_from_yaml(
create_temp_yaml: Callable[[str], Path],
):
"""load_backbone_config loads a UNetBackboneConfig from a YAML file."""
yaml_content = """\
name: UNetBackbone
input_height: 64
in_channels: 2
"""
path = create_temp_yaml(yaml_content)
config = load_backbone_config(path)
assert isinstance(config, UNetBackboneConfig)
assert config.input_height == 64
assert config.in_channels == 2
def test_load_backbone_config_with_field(
create_temp_yaml: Callable[[str], Path],
):
"""load_backbone_config extracts a nested field before validation."""
yaml_content = """\
model:
name: UNetBackbone
input_height: 32
"""
path = create_temp_yaml(yaml_content)
config = load_backbone_config(path, field="model")
assert isinstance(config, UNetBackboneConfig)
assert config.input_height == 32
def test_load_backbone_config_defaults_on_minimal_yaml(
create_temp_yaml: Callable[[str], Path],
):
"""Minimal YAML with only name fills remaining fields with defaults."""
yaml_content = "name: UNetBackbone\n"
path = create_temp_yaml(yaml_content)
config = load_backbone_config(path)
assert isinstance(config, UNetBackboneConfig)
assert config.input_height == UNetBackboneConfig().input_height
assert config.in_channels == UNetBackboneConfig().in_channels
def test_load_backbone_config_extra_fields_ignored(
create_temp_yaml: Callable[[str], Path],
):
"""Extra YAML fields are silently ignored when loading backbone config."""
yaml_content = """\
name: UNetBackbone
input_height: 128
deprecated_field: 99
"""
path = create_temp_yaml(yaml_content)
config = load_backbone_config(path)
assert isinstance(config, UNetBackboneConfig)
assert config.input_height == 128
# ---------------------------------------------------------------------------
# Round-trip: YAML → config → build_backbone
# ---------------------------------------------------------------------------
def test_round_trip_yaml_to_build_backbone(
create_temp_yaml: Callable[[str], Path],
):
"""A backbone config loaded from YAML can be used directly with build_backbone."""
yaml_content = """\
name: UNetBackbone
input_height: 128
in_channels: 1
"""
path = create_temp_yaml(yaml_content)
config = load_backbone_config(path)
backbone = build_backbone(config)
assert isinstance(backbone, UNetBackbone)
assert backbone.input_height == 128
def test_load_backbone_config_from_example_data(example_data_dir: Path):
"""load_backbone_config loads the real example config correctly."""
config = load_backbone_config(
example_data_dir / "config.yaml",
field="model",
)
assert isinstance(config, UNetBackboneConfig)
assert config.input_height == 128
assert config.in_channels == 1

View File

@ -2,7 +2,7 @@ import numpy as np
import pytest
import torch
from batdetect2.models.backbones import BackboneConfig
from batdetect2.models.backbones import UNetBackboneConfig
from batdetect2.models.detectors import Detector, build_detector
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.typing.models import ModelOutput
@ -28,7 +28,7 @@ def test_build_detector_default():
def test_build_detector_custom_config():
"""Test building a detector with a custom BackboneConfig."""
num_classes = 3
config = BackboneConfig(in_channels=2, input_height=128)
config = UNetBackboneConfig(in_channels=2, input_height=128)
model = build_detector(num_classes=num_classes, config=config)
@ -41,7 +41,7 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
"""Test that the forward pass produces correctly shaped outputs."""
num_classes = 4
# Build model matching the dummy input shape
config = BackboneConfig(in_channels=1, input_height=256)
config = UNetBackboneConfig(in_channels=1, input_height=256)
model = build_detector(num_classes=num_classes, config=config)
# Process the spectrogram through the model
@ -108,7 +108,7 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
# Build model matching the preprocessor's output shape
# The preprocessor output is (B, C, H, W) -> spec.shape[1] is C, spec.shape[2] is H
config = BackboneConfig(
config = UNetBackboneConfig(
in_channels=spec.shape[1], input_height=spec.shape[2]
)
model = build_detector(num_classes=3, config=config)