mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add test for backbones
This commit is contained in:
parent
45e3cf1434
commit
4207661da4
@ -32,7 +32,6 @@ model:
|
|||||||
name: UNetBackbone
|
name: UNetBackbone
|
||||||
input_height: 128
|
input_height: 128
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
out_channels: 32
|
|
||||||
encoder:
|
encoder:
|
||||||
layers:
|
layers:
|
||||||
- name: FreqCoordConvDown
|
- name: FreqCoordConvDown
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from batdetect2.evaluate.config import (
|
|||||||
get_default_eval_config,
|
get_default_eval_config,
|
||||||
)
|
)
|
||||||
from batdetect2.inference.config import InferenceConfig
|
from batdetect2.inference.config import InferenceConfig
|
||||||
from batdetect2.models.backbones import BackboneConfig
|
from batdetect2.models.backbones import BackboneConfig, UNetBackboneConfig
|
||||||
from batdetect2.postprocess.config import PostprocessConfig
|
from batdetect2.postprocess.config import PostprocessConfig
|
||||||
from batdetect2.preprocess.config import PreprocessingConfig
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
from batdetect2.targets.config import TargetConfig
|
from batdetect2.targets.config import TargetConfig
|
||||||
@ -32,7 +32,7 @@ class BatDetect2Config(BaseConfig):
|
|||||||
evaluation: EvaluationConfig = Field(
|
evaluation: EvaluationConfig = Field(
|
||||||
default_factory=get_default_eval_config
|
default_factory=get_default_eval_config
|
||||||
)
|
)
|
||||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
model: BackboneConfig = Field(default_factory=UNetBackboneConfig)
|
||||||
preprocess: PreprocessingConfig = Field(
|
preprocess: PreprocessingConfig = Field(
|
||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,11 +8,11 @@ configuration data from files, with optional support for accessing nested
|
|||||||
configuration sections.
|
configuration sections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Type, TypeVar
|
from typing import Any, Type, TypeVar, Union, overload
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from deepmerge.merger import Merger
|
from deepmerge.merger import Merger
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -67,7 +67,10 @@ class BaseConfig(BaseModel):
|
|||||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T")
|
||||||
|
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||||
|
|
||||||
|
Schema = Union[Type[T_Model], TypeAdapter[T]]
|
||||||
|
|
||||||
|
|
||||||
def get_object_field(obj: dict, current_key: str) -> Any:
|
def get_object_field(obj: dict, current_key: str) -> Any:
|
||||||
@ -129,24 +132,41 @@ def get_object_field(obj: dict, current_key: str) -> Any:
|
|||||||
return get_object_field(subobj, rest)
|
return get_object_field(subobj, rest)
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
def load_config(
|
def load_config(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
schema: Type[T],
|
schema: Type[T_Model],
|
||||||
field: str | None = None,
|
field: str | None = None,
|
||||||
) -> T:
|
) -> T_Model: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def load_config(
|
||||||
|
path: PathLike,
|
||||||
|
schema: TypeAdapter[T],
|
||||||
|
field: str | None = None,
|
||||||
|
) -> T: ...
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(
|
||||||
|
path: PathLike,
|
||||||
|
schema: Type[T_Model] | TypeAdapter[T],
|
||||||
|
field: str | None = None,
|
||||||
|
) -> T_Model | T:
|
||||||
"""Load and validate configuration data from a file against a schema.
|
"""Load and validate configuration data from a file against a schema.
|
||||||
|
|
||||||
Reads a YAML file, optionally extracts a specific section using dot
|
Reads a YAML file, optionally extracts a specific section using dot
|
||||||
notation, and then validates the resulting data against the provided
|
notation, and then validates the resulting data against the provided
|
||||||
Pydantic `schema`.
|
Pydantic schema.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : PathLike
|
path : PathLike
|
||||||
The path to the configuration file (typically `.yaml`).
|
The path to the configuration file (typically `.yaml`).
|
||||||
schema : Type[T]
|
schema : Type[T_Model] | TypeAdapter[T]
|
||||||
The Pydantic `BaseModel` subclass that defines the expected structure
|
Either a Pydantic `BaseModel` subclass or a `TypeAdapter` instance
|
||||||
and types for the configuration data.
|
that defines the expected structure and types for the configuration
|
||||||
|
data.
|
||||||
field : str, optional
|
field : str, optional
|
||||||
A dot-separated string indicating a nested section within the YAML
|
A dot-separated string indicating a nested section within the YAML
|
||||||
file to extract before validation. If None (default), the entire
|
file to extract before validation. If None (default), the entire
|
||||||
@ -156,8 +176,8 @@ def load_config(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
T
|
T_Model | T
|
||||||
An instance of the provided `schema`, populated and validated with
|
An instance of the schema type, populated and validated with
|
||||||
data from the configuration file.
|
data from the configuration file.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
@ -183,6 +203,9 @@ def load_config(
|
|||||||
if field:
|
if field:
|
||||||
config = get_object_field(config, field)
|
config = get_object_field(config, field)
|
||||||
|
|
||||||
|
if isinstance(schema, TypeAdapter):
|
||||||
|
return schema.validate_python(config or {})
|
||||||
|
|
||||||
return schema.model_validate(config or {})
|
return schema.model_validate(config or {})
|
||||||
|
|
||||||
|
|
||||||
@ -193,7 +216,7 @@ default_merger = Merger(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def merge_configs(config1: T, config2: T) -> T:
|
def merge_configs(config1: T_Model, config2: T_Model) -> T_Model:
|
||||||
"""Merge two configuration objects."""
|
"""Merge two configuration objects."""
|
||||||
model = type(config1)
|
model = type(config1)
|
||||||
dict1 = config1.model_dump()
|
dict1 = config1.model_dump()
|
||||||
|
|||||||
@ -31,9 +31,9 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
|
BackboneConfig,
|
||||||
UNetBackbone,
|
UNetBackbone,
|
||||||
UNetBackboneConfig,
|
UNetBackboneConfig,
|
||||||
BackboneConfig,
|
|
||||||
build_backbone,
|
build_backbone,
|
||||||
load_backbone_config,
|
load_backbone_config,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from typing import Annotated, Literal, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pydantic import Field
|
from pydantic import Field, TypeAdapter
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
@ -57,7 +57,6 @@ class UNetBackboneConfig(BaseConfig):
|
|||||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||||
out_channels: int = 32
|
|
||||||
|
|
||||||
|
|
||||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
||||||
@ -293,4 +292,8 @@ def load_backbone_config(
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: str | None = None,
|
field: str | None = None,
|
||||||
) -> BackboneConfig:
|
) -> BackboneConfig:
|
||||||
return load_config(path, schema=BackboneConfig, field=field)
|
return load_config(
|
||||||
|
path,
|
||||||
|
schema=TypeAdapter(BackboneConfig),
|
||||||
|
field=field,
|
||||||
|
)
|
||||||
|
|||||||
@ -27,7 +27,7 @@ A unified factory function `build_layer` allows creating instances
|
|||||||
of these blocks based on configuration objects.
|
of these blocks based on configuration objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List, Literal, Protocol, Tuple, Union
|
from typing import Annotated, List, Literal, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|||||||
@ -27,7 +27,6 @@ from batdetect2.models.blocks import (
|
|||||||
VerticalConv,
|
VerticalConv,
|
||||||
build_layer,
|
build_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from batdetect2.typing.models import BottleneckProtocol
|
from batdetect2.typing.models import BottleneckProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -159,9 +158,7 @@ class BottleneckConfig(BaseConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
channels: int
|
channels: int
|
||||||
layers: List[BottleneckLayerConfig] = Field(
|
layers: List[BottleneckLayerConfig] = Field(default_factory=list)
|
||||||
default_factory=list,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||||
|
|||||||
252
tests/test_models/test_backbones.py
Normal file
252
tests/test_models/test_backbones.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
"""Tests for backbone configuration loading and the backbone registry.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- UNetBackboneConfig default construction and field values.
|
||||||
|
- build_backbone with default and explicit configs.
|
||||||
|
- load_backbone_config loading from a YAML file.
|
||||||
|
- load_backbone_config with a nested field path.
|
||||||
|
- load_backbone_config round-trip: YAML → config → build_backbone.
|
||||||
|
- Registry registration and dispatch for UNetBackbone.
|
||||||
|
- BackboneConfig discriminated union validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from batdetect2.models.backbones import (
|
||||||
|
BackboneConfig,
|
||||||
|
UNetBackbone,
|
||||||
|
UNetBackboneConfig,
|
||||||
|
backbone_registry,
|
||||||
|
build_backbone,
|
||||||
|
load_backbone_config,
|
||||||
|
)
|
||||||
|
from batdetect2.typing.models import BackboneModel
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# UNetBackboneConfig
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_unet_backbone_config_defaults():
|
||||||
|
"""Default config has expected field values."""
|
||||||
|
config = UNetBackboneConfig()
|
||||||
|
|
||||||
|
assert config.name == "UNetBackbone"
|
||||||
|
assert config.input_height == 128
|
||||||
|
assert config.in_channels == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_unet_backbone_config_custom_fields():
|
||||||
|
"""Custom field values are stored correctly."""
|
||||||
|
config = UNetBackboneConfig(in_channels=2, input_height=64)
|
||||||
|
|
||||||
|
assert config.in_channels == 2
|
||||||
|
assert config.input_height == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_unet_backbone_config_extra_fields_ignored():
|
||||||
|
"""Extra/unknown fields are silently ignored (BaseConfig behaviour)."""
|
||||||
|
config = UNetBackboneConfig.model_validate(
|
||||||
|
{"name": "UNetBackbone", "unknown_field": 99}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.name == "UNetBackbone"
|
||||||
|
assert not hasattr(config, "unknown_field")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# build_backbone
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_backbone_default():
|
||||||
|
"""Building with no config uses UNetBackbone defaults."""
|
||||||
|
backbone = build_backbone()
|
||||||
|
|
||||||
|
assert isinstance(backbone, UNetBackbone)
|
||||||
|
assert backbone.input_height == 128
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_backbone_custom_config():
|
||||||
|
"""Building with a custom config propagates input_height and in_channels."""
|
||||||
|
config = UNetBackboneConfig(in_channels=2, input_height=64)
|
||||||
|
backbone = build_backbone(config)
|
||||||
|
|
||||||
|
assert isinstance(backbone, UNetBackbone)
|
||||||
|
assert backbone.input_height == 64
|
||||||
|
assert backbone.encoder.in_channels == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_backbone_returns_backbone_model():
|
||||||
|
"""build_backbone always returns a BackboneModel instance."""
|
||||||
|
backbone = build_backbone()
|
||||||
|
|
||||||
|
assert isinstance(backbone, BackboneModel)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Registry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_has_unet_backbone():
|
||||||
|
"""The backbone registry has UNetBackbone registered."""
|
||||||
|
config_types = backbone_registry.get_config_types()
|
||||||
|
|
||||||
|
assert UNetBackboneConfig in config_types
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_config_type_is_unet_backbone_config():
|
||||||
|
"""The config type stored for UNetBackbone is UNetBackboneConfig."""
|
||||||
|
config_type = backbone_registry.get_config_type("UNetBackbone")
|
||||||
|
|
||||||
|
assert config_type is UNetBackboneConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_build_dispatches_correctly():
|
||||||
|
"""Registry.build dispatches to UNetBackbone.from_config."""
|
||||||
|
config = UNetBackboneConfig(input_height=128)
|
||||||
|
backbone = backbone_registry.build(config)
|
||||||
|
|
||||||
|
assert isinstance(backbone, UNetBackbone)
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_build_unknown_name_raises():
|
||||||
|
"""Registry.build raises NotImplementedError for an unknown config name."""
|
||||||
|
|
||||||
|
class FakeConfig:
|
||||||
|
name = "NonExistentBackbone"
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BackboneConfig discriminated union
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_backbone_config_validates_unet_from_dict():
|
||||||
|
"""BackboneConfig TypeAdapter resolves to UNetBackboneConfig via name."""
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
adapter = TypeAdapter(BackboneConfig)
|
||||||
|
config = adapter.validate_python(
|
||||||
|
{"name": "UNetBackbone", "input_height": 64}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(config, UNetBackboneConfig)
|
||||||
|
assert config.input_height == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_backbone_config_invalid_name_raises():
|
||||||
|
"""BackboneConfig validation raises for an unknown name discriminator."""
|
||||||
|
from pydantic import TypeAdapter, ValidationError
|
||||||
|
|
||||||
|
adapter = TypeAdapter(BackboneConfig)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
adapter.validate_python({"name": "NonExistentBackbone"})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# load_backbone_config
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_backbone_config_from_yaml(
|
||||||
|
create_temp_yaml: Callable[[str], Path],
|
||||||
|
):
|
||||||
|
"""load_backbone_config loads a UNetBackboneConfig from a YAML file."""
|
||||||
|
yaml_content = """\
|
||||||
|
name: UNetBackbone
|
||||||
|
input_height: 64
|
||||||
|
in_channels: 2
|
||||||
|
"""
|
||||||
|
path = create_temp_yaml(yaml_content)
|
||||||
|
config = load_backbone_config(path)
|
||||||
|
|
||||||
|
assert isinstance(config, UNetBackboneConfig)
|
||||||
|
assert config.input_height == 64
|
||||||
|
assert config.in_channels == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_backbone_config_with_field(
|
||||||
|
create_temp_yaml: Callable[[str], Path],
|
||||||
|
):
|
||||||
|
"""load_backbone_config extracts a nested field before validation."""
|
||||||
|
yaml_content = """\
|
||||||
|
model:
|
||||||
|
name: UNetBackbone
|
||||||
|
input_height: 32
|
||||||
|
"""
|
||||||
|
path = create_temp_yaml(yaml_content)
|
||||||
|
config = load_backbone_config(path, field="model")
|
||||||
|
|
||||||
|
assert isinstance(config, UNetBackboneConfig)
|
||||||
|
assert config.input_height == 32
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_backbone_config_defaults_on_minimal_yaml(
|
||||||
|
create_temp_yaml: Callable[[str], Path],
|
||||||
|
):
|
||||||
|
"""Minimal YAML with only name fills remaining fields with defaults."""
|
||||||
|
yaml_content = "name: UNetBackbone\n"
|
||||||
|
path = create_temp_yaml(yaml_content)
|
||||||
|
config = load_backbone_config(path)
|
||||||
|
|
||||||
|
assert isinstance(config, UNetBackboneConfig)
|
||||||
|
assert config.input_height == UNetBackboneConfig().input_height
|
||||||
|
assert config.in_channels == UNetBackboneConfig().in_channels
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_backbone_config_extra_fields_ignored(
|
||||||
|
create_temp_yaml: Callable[[str], Path],
|
||||||
|
):
|
||||||
|
"""Extra YAML fields are silently ignored when loading backbone config."""
|
||||||
|
yaml_content = """\
|
||||||
|
name: UNetBackbone
|
||||||
|
input_height: 128
|
||||||
|
deprecated_field: 99
|
||||||
|
"""
|
||||||
|
path = create_temp_yaml(yaml_content)
|
||||||
|
config = load_backbone_config(path)
|
||||||
|
|
||||||
|
assert isinstance(config, UNetBackboneConfig)
|
||||||
|
assert config.input_height == 128
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Round-trip: YAML → config → build_backbone
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_round_trip_yaml_to_build_backbone(
|
||||||
|
create_temp_yaml: Callable[[str], Path],
|
||||||
|
):
|
||||||
|
"""A backbone config loaded from YAML can be used directly with build_backbone."""
|
||||||
|
yaml_content = """\
|
||||||
|
name: UNetBackbone
|
||||||
|
input_height: 128
|
||||||
|
in_channels: 1
|
||||||
|
"""
|
||||||
|
path = create_temp_yaml(yaml_content)
|
||||||
|
config = load_backbone_config(path)
|
||||||
|
backbone = build_backbone(config)
|
||||||
|
|
||||||
|
assert isinstance(backbone, UNetBackbone)
|
||||||
|
assert backbone.input_height == 128
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_backbone_config_from_example_data(example_data_dir: Path):
|
||||||
|
"""load_backbone_config loads the real example config correctly."""
|
||||||
|
config = load_backbone_config(
|
||||||
|
example_data_dir / "config.yaml",
|
||||||
|
field="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(config, UNetBackboneConfig)
|
||||||
|
assert config.input_height == 128
|
||||||
|
assert config.in_channels == 1
|
||||||
@ -2,7 +2,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.models.backbones import BackboneConfig
|
from batdetect2.models.backbones import UNetBackboneConfig
|
||||||
from batdetect2.models.detectors import Detector, build_detector
|
from batdetect2.models.detectors import Detector, build_detector
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.typing.models import ModelOutput
|
from batdetect2.typing.models import ModelOutput
|
||||||
@ -28,7 +28,7 @@ def test_build_detector_default():
|
|||||||
def test_build_detector_custom_config():
|
def test_build_detector_custom_config():
|
||||||
"""Test building a detector with a custom BackboneConfig."""
|
"""Test building a detector with a custom BackboneConfig."""
|
||||||
num_classes = 3
|
num_classes = 3
|
||||||
config = BackboneConfig(in_channels=2, input_height=128)
|
config = UNetBackboneConfig(in_channels=2, input_height=128)
|
||||||
|
|
||||||
model = build_detector(num_classes=num_classes, config=config)
|
model = build_detector(num_classes=num_classes, config=config)
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
|
|||||||
"""Test that the forward pass produces correctly shaped outputs."""
|
"""Test that the forward pass produces correctly shaped outputs."""
|
||||||
num_classes = 4
|
num_classes = 4
|
||||||
# Build model matching the dummy input shape
|
# Build model matching the dummy input shape
|
||||||
config = BackboneConfig(in_channels=1, input_height=256)
|
config = UNetBackboneConfig(in_channels=1, input_height=256)
|
||||||
model = build_detector(num_classes=num_classes, config=config)
|
model = build_detector(num_classes=num_classes, config=config)
|
||||||
|
|
||||||
# Process the spectrogram through the model
|
# Process the spectrogram through the model
|
||||||
@ -108,7 +108,7 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
|
|||||||
|
|
||||||
# Build model matching the preprocessor's output shape
|
# Build model matching the preprocessor's output shape
|
||||||
# The preprocessor output is (B, C, H, W) -> spec.shape[1] is C, spec.shape[2] is H
|
# The preprocessor output is (B, C, H, W) -> spec.shape[1] is C, spec.shape[2] is H
|
||||||
config = BackboneConfig(
|
config = UNetBackboneConfig(
|
||||||
in_channels=spec.shape[1], input_height=spec.shape[2]
|
in_channels=spec.shape[1], input_height=spec.shape[2]
|
||||||
)
|
)
|
||||||
model = build_detector(num_classes=3, config=config)
|
model = build_detector(num_classes=3, config=config)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user