Add blocks and detector tests

This commit is contained in:
mbsantiago 2026-03-08 14:17:47 +00:00
parent b8b8a68f49
commit 54605ef269
6 changed files with 326 additions and 14 deletions

View File

@ -231,7 +231,7 @@ def _pad_adjust(
- The amount of padding added to height (`h_pad`). - The amount of padding added to height (`h_pad`).
- The amount of padding added to width (`w_pad`). - The amount of padding added to width (`w_pad`).
""" """
h, w = spec.shape[2:] h, w = spec.shape[-2:]
h_pad = -h % factor h_pad = -h % factor
w_pad = -w % factor w_pad = -w % factor
@ -263,9 +263,9 @@ def _restore_pad(
Tensor with padding removed, shape `(B, C, H_original, W_original)`. Tensor with padding removed, shape `(B, C, H_original, W_original)`.
""" """
if h_pad > 0: if h_pad > 0:
x = x[:, :, :-h_pad, :] x = x[..., :-h_pad, :]
if w_pad > 0: if w_pad > 0:
x = x[:, :, :, :-w_pad] x = x[..., :-w_pad]
return x return x

View File

@ -23,7 +23,7 @@ research:
These blocks can be used directly in custom PyTorch model definitions or These blocks can be used directly in custom PyTorch model definitions or
assembled into larger architectures. assembled into larger architectures.
A unified factory function `build_layer_from_config` allows creating instances A unified factory function `build_layer` allows creating instances
of these blocks based on configuration objects. of these blocks based on configuration objects.
""" """
@ -57,7 +57,8 @@ __all__ = [
class BlockProtocol(Protocol): class BlockProtocol(Protocol):
def get_output_channels(self) -> int: ... def get_output_channels(self) -> int:
raise NotImplementedError
def get_output_height(self, input_height: int) -> int: def get_output_height(self, input_height: int) -> int:
return input_height return input_height
@ -867,26 +868,25 @@ class LayerGroup(nn.Module):
input_channels: int, input_channels: int,
): ):
super().__init__() super().__init__()
self.blocks = layers
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x) return self.layers(x)
def get_output_channels(self) -> int: def get_output_channels(self) -> int:
return self.blocks[-1].get_output_channels() return self.layers[-1].get_output_channels() # type: ignore
def get_output_height(self, input_height: int) -> int: def get_output_height(self, input_height: int) -> int:
for block in self.blocks: for block in self.layers:
input_height = block.get_output_height(input_height) input_height = block.get_output_height(input_height) # type: ignore
return input_height return input_height
@block_registry.register(LayerGroupConfig) @block_registry.register(LayerGroupConfig)
@staticmethod @staticmethod
def from_config( def from_config(
config: LayerGroupConfig, config: LayerGroupConfig,
input_height: int,
input_channels: int, input_channels: int,
input_height: int,
): ):
layers = [] layers = []

View File

@ -127,6 +127,9 @@ class Bottleneck(Block):
return x.repeat([1, 1, self.input_height, 1]) return x.repeat([1, 1, self.input_height, 1])
def get_output_channels(self) -> int:
return self.layers[-1].get_output_channels() # type: ignore
BottleneckLayerConfig = Annotated[ BottleneckLayerConfig = Annotated[
SelfAttentionConfig, SelfAttentionConfig,

View File

@ -0,0 +1,190 @@
import pytest
import torch
from batdetect2.models.blocks import (
ConvBlock,
ConvConfig,
FreqCoordConvDownBlock,
FreqCoordConvDownConfig,
FreqCoordConvUpBlock,
FreqCoordConvUpConfig,
LayerGroup,
LayerGroupConfig,
SelfAttention,
SelfAttentionConfig,
StandardConvDownBlock,
StandardConvDownConfig,
StandardConvUpBlock,
StandardConvUpConfig,
VerticalConv,
VerticalConvConfig,
build_layer,
)
@pytest.fixture
def dummy_input() -> torch.Tensor:
"""Provides a standard (B, C, H, W) tensor for testing blocks."""
batch_size, in_channels, height, width = 2, 16, 32, 32
return torch.randn(batch_size, in_channels, height, width)
@pytest.fixture
def dummy_bottleneck_input() -> torch.Tensor:
"""Provides an input typical for the Bottleneck/SelfAttention (H=1)."""
return torch.randn(2, 64, 1, 32)
@pytest.mark.parametrize(
"block_class, expected_h_scale",
[
(ConvBlock, 1.0),
(StandardConvDownBlock, 0.5),
(StandardConvUpBlock, 2.0),
],
)
def test_standard_block_protocol_methods(
block_class, expected_h_scale, dummy_input
):
"""Test get_output_channels and get_output_height for standard blocks."""
in_channels = dummy_input.size(1)
input_height = dummy_input.size(2)
out_channels = 32
block = block_class(in_channels=in_channels, out_channels=out_channels)
assert block.get_output_channels() == out_channels
assert block.get_output_height(input_height) == int(
input_height * expected_h_scale
)
@pytest.mark.parametrize(
"block_class, expected_h_scale",
[
(FreqCoordConvDownBlock, 0.5),
(FreqCoordConvUpBlock, 2.0),
],
)
def test_coord_block_protocol_methods(
block_class, expected_h_scale, dummy_input
):
"""Test get_output_channels and get_output_height for coord blocks."""
in_channels = dummy_input.size(1)
input_height = dummy_input.size(2)
out_channels = 32
block = block_class(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
assert block.get_output_channels() == out_channels
assert block.get_output_height(input_height) == int(
input_height * expected_h_scale
)
def test_vertical_conv_forward_shape(dummy_input):
"""Test that VerticalConv correctly collapses the height dimension to 1."""
in_channels = dummy_input.size(1)
input_height = dummy_input.size(2)
out_channels = 32
block = VerticalConv(in_channels, out_channels, input_height)
output = block(dummy_input)
assert output.shape == (2, out_channels, 1, 32)
assert block.get_output_channels() == out_channels
def test_self_attention_forward_shape(dummy_bottleneck_input):
"""Test that SelfAttention maintains the exact shape."""
in_channels = dummy_bottleneck_input.size(1)
attention_channels = 32
block = SelfAttention(
in_channels=in_channels, attention_channels=attention_channels
)
output = block(dummy_bottleneck_input)
assert output.shape == dummy_bottleneck_input.shape
assert block.get_output_channels() == in_channels
def test_self_attention_weights(dummy_bottleneck_input):
"""Test that attention weights sum to 1 over the time sequence."""
in_channels = dummy_bottleneck_input.size(1)
block = SelfAttention(in_channels=in_channels, attention_channels=32)
weights = block.compute_attention_weights(dummy_bottleneck_input)
# Weights shape should be (B, T, T) where T is time (width)
batch_size = dummy_bottleneck_input.size(0)
time_steps = dummy_bottleneck_input.size(3)
assert weights.shape == (batch_size, time_steps, time_steps)
# Summing across the keys (dim=1) for each query should equal 1.0
sum_weights = weights.sum(dim=1)
assert torch.allclose(sum_weights, torch.ones_like(sum_weights), atol=1e-5)
@pytest.mark.parametrize(
"layer_config, expected_type",
[
(ConvConfig(out_channels=32), ConvBlock),
(StandardConvDownConfig(out_channels=32), StandardConvDownBlock),
(StandardConvUpConfig(out_channels=32), StandardConvUpBlock),
(FreqCoordConvDownConfig(out_channels=32), FreqCoordConvDownBlock),
(FreqCoordConvUpConfig(out_channels=32), FreqCoordConvUpBlock),
(SelfAttentionConfig(attention_channels=32), SelfAttention),
(VerticalConvConfig(channels=32), VerticalConv),
],
)
def test_build_layer_factory(layer_config, expected_type):
"""Test that the factory dynamically builds the correct block."""
input_height = 32
in_channels = 16
layer = build_layer(
input_height=input_height,
in_channels=in_channels,
config=layer_config,
)
assert isinstance(layer, expected_type)
def test_layer_group_from_config_and_forward(dummy_input):
"""Test that LayerGroup successfully chains multiple blocks."""
in_channels = dummy_input.size(1)
input_height = dummy_input.size(2)
config = LayerGroupConfig(
layers=[
ConvConfig(out_channels=32),
StandardConvDownConfig(out_channels=64),
]
)
layer_group = build_layer(
input_height=input_height,
in_channels=in_channels,
config=config,
)
assert isinstance(layer_group, LayerGroup)
assert len(layer_group.layers) == 2
# The group should report the output channels of the LAST block
assert layer_group.get_output_channels() == 64
# The group should report the accumulated height changes
assert layer_group.get_output_height(input_height) == input_height // 2
output = layer_group(dummy_input)
# Shape should reflect: Conv (stays 32x32) -> DownConv (halves to 16x16)
assert output.shape == (2, 64, 16, 16)

View File

@ -0,0 +1,122 @@
import numpy as np
import pytest
import torch
from batdetect2.models.config import BackboneConfig
from batdetect2.models.detectors import Detector, build_detector
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.typing.models import ModelOutput
@pytest.fixture
def dummy_spectrogram() -> torch.Tensor:
"""Provides a dummy spectrogram tensor (B, C, H, W)."""
return torch.randn(2, 1, 256, 128)
def test_build_detector_default():
"""Test building the default detector without a config."""
num_classes = 5
model = build_detector(num_classes=num_classes)
assert isinstance(model, Detector)
assert model.num_classes == num_classes
assert isinstance(model.classifier_head, ClassifierHead)
assert isinstance(model.bbox_head, BBoxHead)
def test_build_detector_custom_config():
"""Test building a detector with a custom BackboneConfig."""
num_classes = 3
config = BackboneConfig(in_channels=2, input_height=128)
model = build_detector(num_classes=num_classes, config=config)
assert isinstance(model, Detector)
assert model.backbone.input_height == 128
assert model.backbone.encoder.in_channels == 2
def test_detector_forward_pass_shapes(dummy_spectrogram):
"""Test that the forward pass produces correctly shaped outputs."""
num_classes = 4
# Build model matching the dummy input shape
config = BackboneConfig(in_channels=1, input_height=256)
model = build_detector(num_classes=num_classes, config=config)
# Process the spectrogram through the model
# PyTorch expects shape (Batch, Channels, Height, Width)
output = model(dummy_spectrogram)
# Verify the output is a NamedTuple ModelOutput
assert isinstance(output, ModelOutput)
batch_size = dummy_spectrogram.size(0)
input_height = dummy_spectrogram.size(2)
input_width = dummy_spectrogram.size(3)
# Check detection probabilities shape: (B, 1, H, W)
assert output.detection_probs.shape == (
batch_size,
1,
input_height,
input_width,
)
# Check size predictions shape: (B, 2, H, W)
assert output.size_preds.shape == (
batch_size,
2,
input_height,
input_width,
)
# Check class probabilities shape: (B, num_classes, H, W)
assert output.class_probs.shape == (
batch_size,
num_classes,
input_height,
input_width,
)
# Check features shape: (B, out_channels, H, W)
out_channels = model.backbone.out_channels
assert output.features.shape == (
batch_size,
out_channels,
input_height,
input_width,
)
def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
"""Test the full pipeline from audio to model output."""
# Generate random audio: 1 second at 256kHz
samplerate = 256000
duration = 1.0
audio = np.random.randn(int(samplerate * duration)).astype(np.float32)
# Create tensor: (Batch=1, Channels=1, Samples) - Preprocessor expects batched 1D waveforms
audio_tensor = torch.from_numpy(audio).unsqueeze(0).unsqueeze(0)
# Preprocess -> Output shape: (Batch=1, Channels=1, Height, Width)
spec = sample_preprocessor(audio_tensor)
# Just to be safe, make sure it has 4 dimensions if the preprocessor didn't add batch
if spec.ndim == 3:
spec = spec.unsqueeze(0)
# Build model matching the preprocessor's output shape
# The preprocessor output is (B, C, H, W) -> spec.shape[1] is C, spec.shape[2] is H
config = BackboneConfig(
in_channels=spec.shape[1], input_height=spec.shape[2]
)
model = build_detector(num_classes=3, config=config)
# Process
output = model(spec)
# Assert
assert isinstance(output, ModelOutput)
assert output.detection_probs.shape[0] == 1 # Batch size 1
assert output.class_probs.shape[1] == 3 # 3 classes

View File

@ -628,8 +628,5 @@ def test_build_roi_mapper_raises_error_for_unknown_name():
name = "non_existent_mapper" name = "non_existent_mapper"
# Then # Then
with pytest.raises(NotImplementedError) as excinfo: with pytest.raises(NotImplementedError):
build_roi_mapper(DummyConfig()) # type: ignore build_roi_mapper(DummyConfig()) # type: ignore
# Check that the error message is informative.
assert "No ROI mapper of name 'non_existent_mapper'" in str(excinfo.value)