batdetect2/tests/test_models/test_blocks.py
2026-03-08 14:17:47 +00:00

191 lines
5.7 KiB
Python

import pytest
import torch
from batdetect2.models.blocks import (
ConvBlock,
ConvConfig,
FreqCoordConvDownBlock,
FreqCoordConvDownConfig,
FreqCoordConvUpBlock,
FreqCoordConvUpConfig,
LayerGroup,
LayerGroupConfig,
SelfAttention,
SelfAttentionConfig,
StandardConvDownBlock,
StandardConvDownConfig,
StandardConvUpBlock,
StandardConvUpConfig,
VerticalConv,
VerticalConvConfig,
build_layer,
)
@pytest.fixture
def dummy_input() -> torch.Tensor:
"""Provides a standard (B, C, H, W) tensor for testing blocks."""
batch_size, in_channels, height, width = 2, 16, 32, 32
return torch.randn(batch_size, in_channels, height, width)
@pytest.fixture
def dummy_bottleneck_input() -> torch.Tensor:
"""Provides an input typical for the Bottleneck/SelfAttention (H=1)."""
return torch.randn(2, 64, 1, 32)
@pytest.mark.parametrize(
"block_class, expected_h_scale",
[
(ConvBlock, 1.0),
(StandardConvDownBlock, 0.5),
(StandardConvUpBlock, 2.0),
],
)
def test_standard_block_protocol_methods(
block_class, expected_h_scale, dummy_input
):
"""Test get_output_channels and get_output_height for standard blocks."""
in_channels = dummy_input.size(1)
input_height = dummy_input.size(2)
out_channels = 32
block = block_class(in_channels=in_channels, out_channels=out_channels)
assert block.get_output_channels() == out_channels
assert block.get_output_height(input_height) == int(
input_height * expected_h_scale
)
@pytest.mark.parametrize(
"block_class, expected_h_scale",
[
(FreqCoordConvDownBlock, 0.5),
(FreqCoordConvUpBlock, 2.0),
],
)
def test_coord_block_protocol_methods(
block_class, expected_h_scale, dummy_input
):
"""Test get_output_channels and get_output_height for coord blocks."""
in_channels = dummy_input.size(1)
input_height = dummy_input.size(2)
out_channels = 32
block = block_class(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
assert block.get_output_channels() == out_channels
assert block.get_output_height(input_height) == int(
input_height * expected_h_scale
)
def test_vertical_conv_forward_shape(dummy_input):
"""Test that VerticalConv correctly collapses the height dimension to 1."""
in_channels = dummy_input.size(1)
input_height = dummy_input.size(2)
out_channels = 32
block = VerticalConv(in_channels, out_channels, input_height)
output = block(dummy_input)
assert output.shape == (2, out_channels, 1, 32)
assert block.get_output_channels() == out_channels
def test_self_attention_forward_shape(dummy_bottleneck_input):
"""Test that SelfAttention maintains the exact shape."""
in_channels = dummy_bottleneck_input.size(1)
attention_channels = 32
block = SelfAttention(
in_channels=in_channels, attention_channels=attention_channels
)
output = block(dummy_bottleneck_input)
assert output.shape == dummy_bottleneck_input.shape
assert block.get_output_channels() == in_channels
def test_self_attention_weights(dummy_bottleneck_input):
"""Test that attention weights sum to 1 over the time sequence."""
in_channels = dummy_bottleneck_input.size(1)
block = SelfAttention(in_channels=in_channels, attention_channels=32)
weights = block.compute_attention_weights(dummy_bottleneck_input)
# Weights shape should be (B, T, T) where T is time (width)
batch_size = dummy_bottleneck_input.size(0)
time_steps = dummy_bottleneck_input.size(3)
assert weights.shape == (batch_size, time_steps, time_steps)
# Summing across the keys (dim=1) for each query should equal 1.0
sum_weights = weights.sum(dim=1)
assert torch.allclose(sum_weights, torch.ones_like(sum_weights), atol=1e-5)
@pytest.mark.parametrize(
"layer_config, expected_type",
[
(ConvConfig(out_channels=32), ConvBlock),
(StandardConvDownConfig(out_channels=32), StandardConvDownBlock),
(StandardConvUpConfig(out_channels=32), StandardConvUpBlock),
(FreqCoordConvDownConfig(out_channels=32), FreqCoordConvDownBlock),
(FreqCoordConvUpConfig(out_channels=32), FreqCoordConvUpBlock),
(SelfAttentionConfig(attention_channels=32), SelfAttention),
(VerticalConvConfig(channels=32), VerticalConv),
],
)
def test_build_layer_factory(layer_config, expected_type):
"""Test that the factory dynamically builds the correct block."""
input_height = 32
in_channels = 16
layer = build_layer(
input_height=input_height,
in_channels=in_channels,
config=layer_config,
)
assert isinstance(layer, expected_type)
def test_layer_group_from_config_and_forward(dummy_input):
"""Test that LayerGroup successfully chains multiple blocks."""
in_channels = dummy_input.size(1)
input_height = dummy_input.size(2)
config = LayerGroupConfig(
layers=[
ConvConfig(out_channels=32),
StandardConvDownConfig(out_channels=64),
]
)
layer_group = build_layer(
input_height=input_height,
in_channels=in_channels,
config=config,
)
assert isinstance(layer_group, LayerGroup)
assert len(layer_group.layers) == 2
# The group should report the output channels of the LAST block
assert layer_group.get_output_channels() == 64
# The group should report the accumulated height changes
assert layer_group.get_output_height(input_height) == input_height // 2
output = layer_group(dummy_input)
# Shape should reflect: Conv (stays 32x32) -> DownConv (halves to 16x16)
assert output.shape == (2, 64, 16, 16)