mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add test for backbones
This commit is contained in:
parent
45e3cf1434
commit
4207661da4
@ -32,7 +32,6 @@ model:
|
||||
name: UNetBackbone
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
out_channels: 32
|
||||
encoder:
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
|
||||
@ -12,7 +12,7 @@ from batdetect2.evaluate.config import (
|
||||
get_default_eval_config,
|
||||
)
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.models.backbones import BackboneConfig
|
||||
from batdetect2.models.backbones import BackboneConfig, UNetBackboneConfig
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
@ -32,7 +32,7 @@ class BatDetect2Config(BaseConfig):
|
||||
evaluation: EvaluationConfig = Field(
|
||||
default_factory=get_default_eval_config
|
||||
)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
model: BackboneConfig = Field(default_factory=UNetBackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
|
||||
@ -8,11 +8,11 @@ configuration data from files, with optional support for accessing nested
|
||||
configuration sections.
|
||||
"""
|
||||
|
||||
from typing import Any, Type, TypeVar
|
||||
from typing import Any, Type, TypeVar, Union, overload
|
||||
|
||||
import yaml
|
||||
from deepmerge.merger import Merger
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from soundevent.data import PathLike
|
||||
|
||||
__all__ = [
|
||||
@ -67,7 +67,10 @@ class BaseConfig(BaseModel):
|
||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
T = TypeVar("T")
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
|
||||
Schema = Union[Type[T_Model], TypeAdapter[T]]
|
||||
|
||||
|
||||
def get_object_field(obj: dict, current_key: str) -> Any:
|
||||
@ -129,24 +132,41 @@ def get_object_field(obj: dict, current_key: str) -> Any:
|
||||
return get_object_field(subobj, rest)
|
||||
|
||||
|
||||
@overload
|
||||
def load_config(
|
||||
path: PathLike,
|
||||
schema: Type[T],
|
||||
schema: Type[T_Model],
|
||||
field: str | None = None,
|
||||
) -> T:
|
||||
) -> T_Model: ...
|
||||
|
||||
|
||||
@overload
|
||||
def load_config(
|
||||
path: PathLike,
|
||||
schema: TypeAdapter[T],
|
||||
field: str | None = None,
|
||||
) -> T: ...
|
||||
|
||||
|
||||
def load_config(
|
||||
path: PathLike,
|
||||
schema: Type[T_Model] | TypeAdapter[T],
|
||||
field: str | None = None,
|
||||
) -> T_Model | T:
|
||||
"""Load and validate configuration data from a file against a schema.
|
||||
|
||||
Reads a YAML file, optionally extracts a specific section using dot
|
||||
notation, and then validates the resulting data against the provided
|
||||
Pydantic `schema`.
|
||||
Pydantic schema.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
The path to the configuration file (typically `.yaml`).
|
||||
schema : Type[T]
|
||||
The Pydantic `BaseModel` subclass that defines the expected structure
|
||||
and types for the configuration data.
|
||||
schema : Type[T_Model] | TypeAdapter[T]
|
||||
Either a Pydantic `BaseModel` subclass or a `TypeAdapter` instance
|
||||
that defines the expected structure and types for the configuration
|
||||
data.
|
||||
field : str, optional
|
||||
A dot-separated string indicating a nested section within the YAML
|
||||
file to extract before validation. If None (default), the entire
|
||||
@ -156,8 +176,8 @@ def load_config(
|
||||
|
||||
Returns
|
||||
-------
|
||||
T
|
||||
An instance of the provided `schema`, populated and validated with
|
||||
T_Model | T
|
||||
An instance of the schema type, populated and validated with
|
||||
data from the configuration file.
|
||||
|
||||
Raises
|
||||
@ -183,6 +203,9 @@ def load_config(
|
||||
if field:
|
||||
config = get_object_field(config, field)
|
||||
|
||||
if isinstance(schema, TypeAdapter):
|
||||
return schema.validate_python(config or {})
|
||||
|
||||
return schema.model_validate(config or {})
|
||||
|
||||
|
||||
@ -193,7 +216,7 @@ default_merger = Merger(
|
||||
)
|
||||
|
||||
|
||||
def merge_configs(config1: T, config2: T) -> T:
|
||||
def merge_configs(config1: T_Model, config2: T_Model) -> T_Model:
|
||||
"""Merge two configuration objects."""
|
||||
model = type(config1)
|
||||
dict1 = config1.model_dump()
|
||||
|
||||
@ -31,9 +31,9 @@ from typing import List
|
||||
import torch
|
||||
|
||||
from batdetect2.models.backbones import (
|
||||
BackboneConfig,
|
||||
UNetBackbone,
|
||||
UNetBackboneConfig,
|
||||
BackboneConfig,
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
|
||||
@ -22,7 +22,7 @@ from typing import Annotated, Literal, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from pydantic import Field, TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
@ -57,7 +57,6 @@ class UNetBackboneConfig(BaseConfig):
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
||||
@ -293,4 +292,8 @@ def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: str | None = None,
|
||||
) -> BackboneConfig:
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
return load_config(
|
||||
path,
|
||||
schema=TypeAdapter(BackboneConfig),
|
||||
field=field,
|
||||
)
|
||||
|
||||
@ -27,7 +27,7 @@ A unified factory function `build_layer` allows creating instances
|
||||
of these blocks based on configuration objects.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Literal, Protocol, Tuple, Union
|
||||
from typing import Annotated, List, Literal, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -27,7 +27,6 @@ from batdetect2.models.blocks import (
|
||||
VerticalConv,
|
||||
build_layer,
|
||||
)
|
||||
|
||||
from batdetect2.typing.models import BottleneckProtocol
|
||||
|
||||
__all__ = [
|
||||
@ -159,9 +158,7 @@ class BottleneckConfig(BaseConfig):
|
||||
"""
|
||||
|
||||
channels: int
|
||||
layers: List[BottleneckLayerConfig] = Field(
|
||||
default_factory=list,
|
||||
)
|
||||
layers: List[BottleneckLayerConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||
|
||||
252
tests/test_models/test_backbones.py
Normal file
252
tests/test_models/test_backbones.py
Normal file
@ -0,0 +1,252 @@
|
||||
"""Tests for backbone configuration loading and the backbone registry.
|
||||
|
||||
Covers:
|
||||
- UNetBackboneConfig default construction and field values.
|
||||
- build_backbone with default and explicit configs.
|
||||
- load_backbone_config loading from a YAML file.
|
||||
- load_backbone_config with a nested field path.
|
||||
- load_backbone_config round-trip: YAML → config → build_backbone.
|
||||
- Registry registration and dispatch for UNetBackbone.
|
||||
- BackboneConfig discriminated union validation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from batdetect2.models.backbones import (
|
||||
BackboneConfig,
|
||||
UNetBackbone,
|
||||
UNetBackboneConfig,
|
||||
backbone_registry,
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.typing.models import BackboneModel
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# UNetBackboneConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_unet_backbone_config_defaults():
|
||||
"""Default config has expected field values."""
|
||||
config = UNetBackboneConfig()
|
||||
|
||||
assert config.name == "UNetBackbone"
|
||||
assert config.input_height == 128
|
||||
assert config.in_channels == 1
|
||||
|
||||
|
||||
def test_unet_backbone_config_custom_fields():
|
||||
"""Custom field values are stored correctly."""
|
||||
config = UNetBackboneConfig(in_channels=2, input_height=64)
|
||||
|
||||
assert config.in_channels == 2
|
||||
assert config.input_height == 64
|
||||
|
||||
|
||||
def test_unet_backbone_config_extra_fields_ignored():
|
||||
"""Extra/unknown fields are silently ignored (BaseConfig behaviour)."""
|
||||
config = UNetBackboneConfig.model_validate(
|
||||
{"name": "UNetBackbone", "unknown_field": 99}
|
||||
)
|
||||
|
||||
assert config.name == "UNetBackbone"
|
||||
assert not hasattr(config, "unknown_field")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_backbone
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_backbone_default():
|
||||
"""Building with no config uses UNetBackbone defaults."""
|
||||
backbone = build_backbone()
|
||||
|
||||
assert isinstance(backbone, UNetBackbone)
|
||||
assert backbone.input_height == 128
|
||||
|
||||
|
||||
def test_build_backbone_custom_config():
|
||||
"""Building with a custom config propagates input_height and in_channels."""
|
||||
config = UNetBackboneConfig(in_channels=2, input_height=64)
|
||||
backbone = build_backbone(config)
|
||||
|
||||
assert isinstance(backbone, UNetBackbone)
|
||||
assert backbone.input_height == 64
|
||||
assert backbone.encoder.in_channels == 2
|
||||
|
||||
|
||||
def test_build_backbone_returns_backbone_model():
|
||||
"""build_backbone always returns a BackboneModel instance."""
|
||||
backbone = build_backbone()
|
||||
|
||||
assert isinstance(backbone, BackboneModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registry_has_unet_backbone():
|
||||
"""The backbone registry has UNetBackbone registered."""
|
||||
config_types = backbone_registry.get_config_types()
|
||||
|
||||
assert UNetBackboneConfig in config_types
|
||||
|
||||
|
||||
def test_registry_config_type_is_unet_backbone_config():
|
||||
"""The config type stored for UNetBackbone is UNetBackboneConfig."""
|
||||
config_type = backbone_registry.get_config_type("UNetBackbone")
|
||||
|
||||
assert config_type is UNetBackboneConfig
|
||||
|
||||
|
||||
def test_registry_build_dispatches_correctly():
|
||||
"""Registry.build dispatches to UNetBackbone.from_config."""
|
||||
config = UNetBackboneConfig(input_height=128)
|
||||
backbone = backbone_registry.build(config)
|
||||
|
||||
assert isinstance(backbone, UNetBackbone)
|
||||
|
||||
|
||||
def test_registry_build_unknown_name_raises():
|
||||
"""Registry.build raises NotImplementedError for an unknown config name."""
|
||||
|
||||
class FakeConfig:
|
||||
name = "NonExistentBackbone"
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BackboneConfig discriminated union
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_backbone_config_validates_unet_from_dict():
|
||||
"""BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name."""
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
adapter = TypeAdapter(BackboneConfig)
|
||||
config = adapter.validate_python(
|
||||
{"name": "UNetBackbone", "input_height": 64}
|
||||
)
|
||||
|
||||
assert isinstance(config, UNetBackboneConfig)
|
||||
assert config.input_height == 64
|
||||
|
||||
|
||||
def test_backbone_config_invalid_name_raises():
|
||||
"""BackboneConfig validation raises for an unknown name discriminator."""
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
adapter = TypeAdapter(BackboneConfig)
|
||||
with pytest.raises(ValidationError):
|
||||
adapter.validate_python({"name": "NonExistentBackbone"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_backbone_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_backbone_config_from_yaml(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
"""load_backbone_config loads a UNetBackboneConfig from a YAML file."""
|
||||
yaml_content = """\
|
||||
name: UNetBackbone
|
||||
input_height: 64
|
||||
in_channels: 2
|
||||
"""
|
||||
path = create_temp_yaml(yaml_content)
|
||||
config = load_backbone_config(path)
|
||||
|
||||
assert isinstance(config, UNetBackboneConfig)
|
||||
assert config.input_height == 64
|
||||
assert config.in_channels == 2
|
||||
|
||||
|
||||
def test_load_backbone_config_with_field(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
"""load_backbone_config extracts a nested field before validation."""
|
||||
yaml_content = """\
|
||||
model:
|
||||
name: UNetBackbone
|
||||
input_height: 32
|
||||
"""
|
||||
path = create_temp_yaml(yaml_content)
|
||||
config = load_backbone_config(path, field="model")
|
||||
|
||||
assert isinstance(config, UNetBackboneConfig)
|
||||
assert config.input_height == 32
|
||||
|
||||
|
||||
def test_load_backbone_config_defaults_on_minimal_yaml(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
"""Minimal YAML with only name fills remaining fields with defaults."""
|
||||
yaml_content = "name: UNetBackbone\n"
|
||||
path = create_temp_yaml(yaml_content)
|
||||
config = load_backbone_config(path)
|
||||
|
||||
assert isinstance(config, UNetBackboneConfig)
|
||||
assert config.input_height == UNetBackboneConfig().input_height
|
||||
assert config.in_channels == UNetBackboneConfig().in_channels
|
||||
|
||||
|
||||
def test_load_backbone_config_extra_fields_ignored(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
"""Extra YAML fields are silently ignored when loading backbone config."""
|
||||
yaml_content = """\
|
||||
name: UNetBackbone
|
||||
input_height: 128
|
||||
deprecated_field: 99
|
||||
"""
|
||||
path = create_temp_yaml(yaml_content)
|
||||
config = load_backbone_config(path)
|
||||
|
||||
assert isinstance(config, UNetBackboneConfig)
|
||||
assert config.input_height == 128
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Round-trip: YAML → config → build_backbone
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_round_trip_yaml_to_build_backbone(
|
||||
create_temp_yaml: Callable[[str], Path],
|
||||
):
|
||||
"""A backbone config loaded from YAML can be used directly with build_backbone."""
|
||||
yaml_content = """\
|
||||
name: UNetBackbone
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
"""
|
||||
path = create_temp_yaml(yaml_content)
|
||||
config = load_backbone_config(path)
|
||||
backbone = build_backbone(config)
|
||||
|
||||
assert isinstance(backbone, UNetBackbone)
|
||||
assert backbone.input_height == 128
|
||||
|
||||
|
||||
def test_load_backbone_config_from_example_data(example_data_dir: Path):
|
||||
"""load_backbone_config loads the real example config correctly."""
|
||||
config = load_backbone_config(
|
||||
example_data_dir / "config.yaml",
|
||||
field="model",
|
||||
)
|
||||
|
||||
assert isinstance(config, UNetBackboneConfig)
|
||||
assert config.input_height == 128
|
||||
assert config.in_channels == 1
|
||||
@ -2,7 +2,7 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from batdetect2.models.backbones import BackboneConfig
|
||||
from batdetect2.models.backbones import UNetBackboneConfig
|
||||
from batdetect2.models.detectors import Detector, build_detector
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
@ -28,7 +28,7 @@ def test_build_detector_default():
|
||||
def test_build_detector_custom_config():
|
||||
"""Test building a detector with a custom BackboneConfig."""
|
||||
num_classes = 3
|
||||
config = BackboneConfig(in_channels=2, input_height=128)
|
||||
config = UNetBackboneConfig(in_channels=2, input_height=128)
|
||||
|
||||
model = build_detector(num_classes=num_classes, config=config)
|
||||
|
||||
@ -41,7 +41,7 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
|
||||
"""Test that the forward pass produces correctly shaped outputs."""
|
||||
num_classes = 4
|
||||
# Build model matching the dummy input shape
|
||||
config = BackboneConfig(in_channels=1, input_height=256)
|
||||
config = UNetBackboneConfig(in_channels=1, input_height=256)
|
||||
model = build_detector(num_classes=num_classes, config=config)
|
||||
|
||||
# Process the spectrogram through the model
|
||||
@ -108,7 +108,7 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
|
||||
|
||||
# Build model matching the preprocessor's output shape
|
||||
# The preprocessor output is (B, C, H, W) -> spec.shape[1] is C, spec.shape[2] is H
|
||||
config = BackboneConfig(
|
||||
config = UNetBackboneConfig(
|
||||
in_channels=spec.shape[1], input_height=spec.shape[2]
|
||||
)
|
||||
model = build_detector(num_classes=3, config=config)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user