mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add blocks and detector tests
This commit is contained in:
parent
b8b8a68f49
commit
54605ef269
@ -231,7 +231,7 @@ def _pad_adjust(
|
||||
- The amount of padding added to height (`h_pad`).
|
||||
- The amount of padding added to width (`w_pad`).
|
||||
"""
|
||||
h, w = spec.shape[2:]
|
||||
h, w = spec.shape[-2:]
|
||||
h_pad = -h % factor
|
||||
w_pad = -w % factor
|
||||
|
||||
@ -263,9 +263,9 @@ def _restore_pad(
|
||||
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
|
||||
"""
|
||||
if h_pad > 0:
|
||||
x = x[:, :, :-h_pad, :]
|
||||
x = x[..., :-h_pad, :]
|
||||
|
||||
if w_pad > 0:
|
||||
x = x[:, :, :, :-w_pad]
|
||||
x = x[..., :-w_pad]
|
||||
|
||||
return x
|
||||
|
||||
@ -23,7 +23,7 @@ research:
|
||||
These blocks can be used directly in custom PyTorch model definitions or
|
||||
assembled into larger architectures.
|
||||
|
||||
A unified factory function `build_layer_from_config` allows creating instances
|
||||
A unified factory function `build_layer` allows creating instances
|
||||
of these blocks based on configuration objects.
|
||||
"""
|
||||
|
||||
@ -57,7 +57,8 @@ __all__ = [
|
||||
|
||||
|
||||
class BlockProtocol(Protocol):
|
||||
def get_output_channels(self) -> int: ...
|
||||
def get_output_channels(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height
|
||||
@ -867,26 +868,25 @@ class LayerGroup(nn.Module):
|
||||
input_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.blocks = layers
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layers(x)
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.blocks[-1].get_output_channels()
|
||||
return self.layers[-1].get_output_channels() # type: ignore
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
for block in self.blocks:
|
||||
input_height = block.get_output_height(input_height)
|
||||
for block in self.layers:
|
||||
input_height = block.get_output_height(input_height) # type: ignore
|
||||
return input_height
|
||||
|
||||
@block_registry.register(LayerGroupConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: LayerGroupConfig,
|
||||
input_height: int,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
layers = []
|
||||
|
||||
|
||||
@ -127,6 +127,9 @@ class Bottleneck(Block):
|
||||
|
||||
return x.repeat([1, 1, self.input_height, 1])
|
||||
|
||||
def get_output_channels(self) -> int:
|
||||
return self.layers[-1].get_output_channels() # type: ignore
|
||||
|
||||
|
||||
BottleneckLayerConfig = Annotated[
|
||||
SelfAttentionConfig,
|
||||
|
||||
190
tests/test_models/test_blocks.py
Normal file
190
tests/test_models/test_blocks.py
Normal file
@ -0,0 +1,190 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from batdetect2.models.blocks import (
|
||||
ConvBlock,
|
||||
ConvConfig,
|
||||
FreqCoordConvDownBlock,
|
||||
FreqCoordConvDownConfig,
|
||||
FreqCoordConvUpBlock,
|
||||
FreqCoordConvUpConfig,
|
||||
LayerGroup,
|
||||
LayerGroupConfig,
|
||||
SelfAttention,
|
||||
SelfAttentionConfig,
|
||||
StandardConvDownBlock,
|
||||
StandardConvDownConfig,
|
||||
StandardConvUpBlock,
|
||||
StandardConvUpConfig,
|
||||
VerticalConv,
|
||||
VerticalConvConfig,
|
||||
build_layer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_input() -> torch.Tensor:
|
||||
"""Provides a standard (B, C, H, W) tensor for testing blocks."""
|
||||
batch_size, in_channels, height, width = 2, 16, 32, 32
|
||||
return torch.randn(batch_size, in_channels, height, width)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_bottleneck_input() -> torch.Tensor:
|
||||
"""Provides an input typical for the Bottleneck/SelfAttention (H=1)."""
|
||||
return torch.randn(2, 64, 1, 32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_class, expected_h_scale",
|
||||
[
|
||||
(ConvBlock, 1.0),
|
||||
(StandardConvDownBlock, 0.5),
|
||||
(StandardConvUpBlock, 2.0),
|
||||
],
|
||||
)
|
||||
def test_standard_block_protocol_methods(
|
||||
block_class, expected_h_scale, dummy_input
|
||||
):
|
||||
"""Test get_output_channels and get_output_height for standard blocks."""
|
||||
in_channels = dummy_input.size(1)
|
||||
input_height = dummy_input.size(2)
|
||||
out_channels = 32
|
||||
|
||||
block = block_class(in_channels=in_channels, out_channels=out_channels)
|
||||
|
||||
assert block.get_output_channels() == out_channels
|
||||
assert block.get_output_height(input_height) == int(
|
||||
input_height * expected_h_scale
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_class, expected_h_scale",
|
||||
[
|
||||
(FreqCoordConvDownBlock, 0.5),
|
||||
(FreqCoordConvUpBlock, 2.0),
|
||||
],
|
||||
)
|
||||
def test_coord_block_protocol_methods(
|
||||
block_class, expected_h_scale, dummy_input
|
||||
):
|
||||
"""Test get_output_channels and get_output_height for coord blocks."""
|
||||
in_channels = dummy_input.size(1)
|
||||
input_height = dummy_input.size(2)
|
||||
out_channels = 32
|
||||
|
||||
block = block_class(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
assert block.get_output_channels() == out_channels
|
||||
assert block.get_output_height(input_height) == int(
|
||||
input_height * expected_h_scale
|
||||
)
|
||||
|
||||
|
||||
def test_vertical_conv_forward_shape(dummy_input):
|
||||
"""Test that VerticalConv correctly collapses the height dimension to 1."""
|
||||
in_channels = dummy_input.size(1)
|
||||
input_height = dummy_input.size(2)
|
||||
out_channels = 32
|
||||
|
||||
block = VerticalConv(in_channels, out_channels, input_height)
|
||||
output = block(dummy_input)
|
||||
|
||||
assert output.shape == (2, out_channels, 1, 32)
|
||||
assert block.get_output_channels() == out_channels
|
||||
|
||||
|
||||
def test_self_attention_forward_shape(dummy_bottleneck_input):
|
||||
"""Test that SelfAttention maintains the exact shape."""
|
||||
in_channels = dummy_bottleneck_input.size(1)
|
||||
attention_channels = 32
|
||||
|
||||
block = SelfAttention(
|
||||
in_channels=in_channels, attention_channels=attention_channels
|
||||
)
|
||||
output = block(dummy_bottleneck_input)
|
||||
|
||||
assert output.shape == dummy_bottleneck_input.shape
|
||||
assert block.get_output_channels() == in_channels
|
||||
|
||||
|
||||
def test_self_attention_weights(dummy_bottleneck_input):
|
||||
"""Test that attention weights sum to 1 over the time sequence."""
|
||||
in_channels = dummy_bottleneck_input.size(1)
|
||||
block = SelfAttention(in_channels=in_channels, attention_channels=32)
|
||||
|
||||
weights = block.compute_attention_weights(dummy_bottleneck_input)
|
||||
|
||||
# Weights shape should be (B, T, T) where T is time (width)
|
||||
batch_size = dummy_bottleneck_input.size(0)
|
||||
time_steps = dummy_bottleneck_input.size(3)
|
||||
|
||||
assert weights.shape == (batch_size, time_steps, time_steps)
|
||||
|
||||
# Summing across the keys (dim=1) for each query should equal 1.0
|
||||
sum_weights = weights.sum(dim=1)
|
||||
assert torch.allclose(sum_weights, torch.ones_like(sum_weights), atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_config, expected_type",
|
||||
[
|
||||
(ConvConfig(out_channels=32), ConvBlock),
|
||||
(StandardConvDownConfig(out_channels=32), StandardConvDownBlock),
|
||||
(StandardConvUpConfig(out_channels=32), StandardConvUpBlock),
|
||||
(FreqCoordConvDownConfig(out_channels=32), FreqCoordConvDownBlock),
|
||||
(FreqCoordConvUpConfig(out_channels=32), FreqCoordConvUpBlock),
|
||||
(SelfAttentionConfig(attention_channels=32), SelfAttention),
|
||||
(VerticalConvConfig(channels=32), VerticalConv),
|
||||
],
|
||||
)
|
||||
def test_build_layer_factory(layer_config, expected_type):
|
||||
"""Test that the factory dynamically builds the correct block."""
|
||||
input_height = 32
|
||||
in_channels = 16
|
||||
|
||||
layer = build_layer(
|
||||
input_height=input_height,
|
||||
in_channels=in_channels,
|
||||
config=layer_config,
|
||||
)
|
||||
|
||||
assert isinstance(layer, expected_type)
|
||||
|
||||
|
||||
def test_layer_group_from_config_and_forward(dummy_input):
|
||||
"""Test that LayerGroup successfully chains multiple blocks."""
|
||||
in_channels = dummy_input.size(1)
|
||||
input_height = dummy_input.size(2)
|
||||
|
||||
config = LayerGroupConfig(
|
||||
layers=[
|
||||
ConvConfig(out_channels=32),
|
||||
StandardConvDownConfig(out_channels=64),
|
||||
]
|
||||
)
|
||||
|
||||
layer_group = build_layer(
|
||||
input_height=input_height,
|
||||
in_channels=in_channels,
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert isinstance(layer_group, LayerGroup)
|
||||
assert len(layer_group.layers) == 2
|
||||
|
||||
# The group should report the output channels of the LAST block
|
||||
assert layer_group.get_output_channels() == 64
|
||||
|
||||
# The group should report the accumulated height changes
|
||||
assert layer_group.get_output_height(input_height) == input_height // 2
|
||||
|
||||
output = layer_group(dummy_input)
|
||||
|
||||
# Shape should reflect: Conv (stays 32x32) -> DownConv (halves to 16x16)
|
||||
assert output.shape == (2, 64, 16, 16)
|
||||
122
tests/test_models/test_detectors.py
Normal file
122
tests/test_models/test_detectors.py
Normal file
@ -0,0 +1,122 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.models.detectors import Detector, build_detector
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_spectrogram() -> torch.Tensor:
|
||||
"""Provides a dummy spectrogram tensor (B, C, H, W)."""
|
||||
return torch.randn(2, 1, 256, 128)
|
||||
|
||||
|
||||
def test_build_detector_default():
|
||||
"""Test building the default detector without a config."""
|
||||
num_classes = 5
|
||||
model = build_detector(num_classes=num_classes)
|
||||
|
||||
assert isinstance(model, Detector)
|
||||
assert model.num_classes == num_classes
|
||||
assert isinstance(model.classifier_head, ClassifierHead)
|
||||
assert isinstance(model.bbox_head, BBoxHead)
|
||||
|
||||
|
||||
def test_build_detector_custom_config():
|
||||
"""Test building a detector with a custom BackboneConfig."""
|
||||
num_classes = 3
|
||||
config = BackboneConfig(in_channels=2, input_height=128)
|
||||
|
||||
model = build_detector(num_classes=num_classes, config=config)
|
||||
|
||||
assert isinstance(model, Detector)
|
||||
assert model.backbone.input_height == 128
|
||||
assert model.backbone.encoder.in_channels == 2
|
||||
|
||||
|
||||
def test_detector_forward_pass_shapes(dummy_spectrogram):
|
||||
"""Test that the forward pass produces correctly shaped outputs."""
|
||||
num_classes = 4
|
||||
# Build model matching the dummy input shape
|
||||
config = BackboneConfig(in_channels=1, input_height=256)
|
||||
model = build_detector(num_classes=num_classes, config=config)
|
||||
|
||||
# Process the spectrogram through the model
|
||||
# PyTorch expects shape (Batch, Channels, Height, Width)
|
||||
output = model(dummy_spectrogram)
|
||||
|
||||
# Verify the output is a NamedTuple ModelOutput
|
||||
assert isinstance(output, ModelOutput)
|
||||
|
||||
batch_size = dummy_spectrogram.size(0)
|
||||
input_height = dummy_spectrogram.size(2)
|
||||
input_width = dummy_spectrogram.size(3)
|
||||
|
||||
# Check detection probabilities shape: (B, 1, H, W)
|
||||
assert output.detection_probs.shape == (
|
||||
batch_size,
|
||||
1,
|
||||
input_height,
|
||||
input_width,
|
||||
)
|
||||
|
||||
# Check size predictions shape: (B, 2, H, W)
|
||||
assert output.size_preds.shape == (
|
||||
batch_size,
|
||||
2,
|
||||
input_height,
|
||||
input_width,
|
||||
)
|
||||
|
||||
# Check class probabilities shape: (B, num_classes, H, W)
|
||||
assert output.class_probs.shape == (
|
||||
batch_size,
|
||||
num_classes,
|
||||
input_height,
|
||||
input_width,
|
||||
)
|
||||
|
||||
# Check features shape: (B, out_channels, H, W)
|
||||
out_channels = model.backbone.out_channels
|
||||
assert output.features.shape == (
|
||||
batch_size,
|
||||
out_channels,
|
||||
input_height,
|
||||
input_width,
|
||||
)
|
||||
|
||||
|
||||
def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
|
||||
"""Test the full pipeline from audio to model output."""
|
||||
# Generate random audio: 1 second at 256kHz
|
||||
samplerate = 256000
|
||||
duration = 1.0
|
||||
audio = np.random.randn(int(samplerate * duration)).astype(np.float32)
|
||||
|
||||
# Create tensor: (Batch=1, Channels=1, Samples) - Preprocessor expects batched 1D waveforms
|
||||
audio_tensor = torch.from_numpy(audio).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Preprocess -> Output shape: (Batch=1, Channels=1, Height, Width)
|
||||
spec = sample_preprocessor(audio_tensor)
|
||||
|
||||
# Just to be safe, make sure it has 4 dimensions if the preprocessor didn't add batch
|
||||
if spec.ndim == 3:
|
||||
spec = spec.unsqueeze(0)
|
||||
|
||||
# Build model matching the preprocessor's output shape
|
||||
# The preprocessor output is (B, C, H, W) -> spec.shape[1] is C, spec.shape[2] is H
|
||||
config = BackboneConfig(
|
||||
in_channels=spec.shape[1], input_height=spec.shape[2]
|
||||
)
|
||||
model = build_detector(num_classes=3, config=config)
|
||||
|
||||
# Process
|
||||
output = model(spec)
|
||||
|
||||
# Assert
|
||||
assert isinstance(output, ModelOutput)
|
||||
assert output.detection_probs.shape[0] == 1 # Batch size 1
|
||||
assert output.class_probs.shape[1] == 3 # 3 classes
|
||||
@ -628,8 +628,5 @@ def test_build_roi_mapper_raises_error_for_unknown_name():
|
||||
name = "non_existent_mapper"
|
||||
|
||||
# Then
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
with pytest.raises(NotImplementedError):
|
||||
build_roi_mapper(DummyConfig()) # type: ignore
|
||||
|
||||
# Check that the error message is informative.
|
||||
assert "No ROI mapper of name 'non_existent_mapper'" in str(excinfo.value)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user