mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add blocks and detector tests
This commit is contained in:
parent
b8b8a68f49
commit
54605ef269
@ -231,7 +231,7 @@ def _pad_adjust(
|
|||||||
- The amount of padding added to height (`h_pad`).
|
- The amount of padding added to height (`h_pad`).
|
||||||
- The amount of padding added to width (`w_pad`).
|
- The amount of padding added to width (`w_pad`).
|
||||||
"""
|
"""
|
||||||
h, w = spec.shape[2:]
|
h, w = spec.shape[-2:]
|
||||||
h_pad = -h % factor
|
h_pad = -h % factor
|
||||||
w_pad = -w % factor
|
w_pad = -w % factor
|
||||||
|
|
||||||
@ -263,9 +263,9 @@ def _restore_pad(
|
|||||||
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
|
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
|
||||||
"""
|
"""
|
||||||
if h_pad > 0:
|
if h_pad > 0:
|
||||||
x = x[:, :, :-h_pad, :]
|
x = x[..., :-h_pad, :]
|
||||||
|
|
||||||
if w_pad > 0:
|
if w_pad > 0:
|
||||||
x = x[:, :, :, :-w_pad]
|
x = x[..., :-w_pad]
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -23,7 +23,7 @@ research:
|
|||||||
These blocks can be used directly in custom PyTorch model definitions or
|
These blocks can be used directly in custom PyTorch model definitions or
|
||||||
assembled into larger architectures.
|
assembled into larger architectures.
|
||||||
|
|
||||||
A unified factory function `build_layer_from_config` allows creating instances
|
A unified factory function `build_layer` allows creating instances
|
||||||
of these blocks based on configuration objects.
|
of these blocks based on configuration objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -57,7 +57,8 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class BlockProtocol(Protocol):
|
class BlockProtocol(Protocol):
|
||||||
def get_output_channels(self) -> int: ...
|
def get_output_channels(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
def get_output_height(self, input_height: int) -> int:
|
||||||
return input_height
|
return input_height
|
||||||
@ -867,26 +868,25 @@ class LayerGroup(nn.Module):
|
|||||||
input_channels: int,
|
input_channels: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.blocks = layers
|
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.layers(x)
|
return self.layers(x)
|
||||||
|
|
||||||
def get_output_channels(self) -> int:
|
def get_output_channels(self) -> int:
|
||||||
return self.blocks[-1].get_output_channels()
|
return self.layers[-1].get_output_channels() # type: ignore
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
def get_output_height(self, input_height: int) -> int:
|
||||||
for block in self.blocks:
|
for block in self.layers:
|
||||||
input_height = block.get_output_height(input_height)
|
input_height = block.get_output_height(input_height) # type: ignore
|
||||||
return input_height
|
return input_height
|
||||||
|
|
||||||
@block_registry.register(LayerGroupConfig)
|
@block_registry.register(LayerGroupConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
config: LayerGroupConfig,
|
config: LayerGroupConfig,
|
||||||
input_height: int,
|
|
||||||
input_channels: int,
|
input_channels: int,
|
||||||
|
input_height: int,
|
||||||
):
|
):
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
|
|||||||
@ -127,6 +127,9 @@ class Bottleneck(Block):
|
|||||||
|
|
||||||
return x.repeat([1, 1, self.input_height, 1])
|
return x.repeat([1, 1, self.input_height, 1])
|
||||||
|
|
||||||
|
def get_output_channels(self) -> int:
|
||||||
|
return self.layers[-1].get_output_channels() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
BottleneckLayerConfig = Annotated[
|
BottleneckLayerConfig = Annotated[
|
||||||
SelfAttentionConfig,
|
SelfAttentionConfig,
|
||||||
|
|||||||
190
tests/test_models/test_blocks.py
Normal file
190
tests/test_models/test_blocks.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.models.blocks import (
|
||||||
|
ConvBlock,
|
||||||
|
ConvConfig,
|
||||||
|
FreqCoordConvDownBlock,
|
||||||
|
FreqCoordConvDownConfig,
|
||||||
|
FreqCoordConvUpBlock,
|
||||||
|
FreqCoordConvUpConfig,
|
||||||
|
LayerGroup,
|
||||||
|
LayerGroupConfig,
|
||||||
|
SelfAttention,
|
||||||
|
SelfAttentionConfig,
|
||||||
|
StandardConvDownBlock,
|
||||||
|
StandardConvDownConfig,
|
||||||
|
StandardConvUpBlock,
|
||||||
|
StandardConvUpConfig,
|
||||||
|
VerticalConv,
|
||||||
|
VerticalConvConfig,
|
||||||
|
build_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_input() -> torch.Tensor:
|
||||||
|
"""Provides a standard (B, C, H, W) tensor for testing blocks."""
|
||||||
|
batch_size, in_channels, height, width = 2, 16, 32, 32
|
||||||
|
return torch.randn(batch_size, in_channels, height, width)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_bottleneck_input() -> torch.Tensor:
|
||||||
|
"""Provides an input typical for the Bottleneck/SelfAttention (H=1)."""
|
||||||
|
return torch.randn(2, 64, 1, 32)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"block_class, expected_h_scale",
|
||||||
|
[
|
||||||
|
(ConvBlock, 1.0),
|
||||||
|
(StandardConvDownBlock, 0.5),
|
||||||
|
(StandardConvUpBlock, 2.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_standard_block_protocol_methods(
|
||||||
|
block_class, expected_h_scale, dummy_input
|
||||||
|
):
|
||||||
|
"""Test get_output_channels and get_output_height for standard blocks."""
|
||||||
|
in_channels = dummy_input.size(1)
|
||||||
|
input_height = dummy_input.size(2)
|
||||||
|
out_channels = 32
|
||||||
|
|
||||||
|
block = block_class(in_channels=in_channels, out_channels=out_channels)
|
||||||
|
|
||||||
|
assert block.get_output_channels() == out_channels
|
||||||
|
assert block.get_output_height(input_height) == int(
|
||||||
|
input_height * expected_h_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"block_class, expected_h_scale",
|
||||||
|
[
|
||||||
|
(FreqCoordConvDownBlock, 0.5),
|
||||||
|
(FreqCoordConvUpBlock, 2.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_coord_block_protocol_methods(
|
||||||
|
block_class, expected_h_scale, dummy_input
|
||||||
|
):
|
||||||
|
"""Test get_output_channels and get_output_height for coord blocks."""
|
||||||
|
in_channels = dummy_input.size(1)
|
||||||
|
input_height = dummy_input.size(2)
|
||||||
|
out_channels = 32
|
||||||
|
|
||||||
|
block = block_class(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert block.get_output_channels() == out_channels
|
||||||
|
assert block.get_output_height(input_height) == int(
|
||||||
|
input_height * expected_h_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertical_conv_forward_shape(dummy_input):
|
||||||
|
"""Test that VerticalConv correctly collapses the height dimension to 1."""
|
||||||
|
in_channels = dummy_input.size(1)
|
||||||
|
input_height = dummy_input.size(2)
|
||||||
|
out_channels = 32
|
||||||
|
|
||||||
|
block = VerticalConv(in_channels, out_channels, input_height)
|
||||||
|
output = block(dummy_input)
|
||||||
|
|
||||||
|
assert output.shape == (2, out_channels, 1, 32)
|
||||||
|
assert block.get_output_channels() == out_channels
|
||||||
|
|
||||||
|
|
||||||
|
def test_self_attention_forward_shape(dummy_bottleneck_input):
|
||||||
|
"""Test that SelfAttention maintains the exact shape."""
|
||||||
|
in_channels = dummy_bottleneck_input.size(1)
|
||||||
|
attention_channels = 32
|
||||||
|
|
||||||
|
block = SelfAttention(
|
||||||
|
in_channels=in_channels, attention_channels=attention_channels
|
||||||
|
)
|
||||||
|
output = block(dummy_bottleneck_input)
|
||||||
|
|
||||||
|
assert output.shape == dummy_bottleneck_input.shape
|
||||||
|
assert block.get_output_channels() == in_channels
|
||||||
|
|
||||||
|
|
||||||
|
def test_self_attention_weights(dummy_bottleneck_input):
|
||||||
|
"""Test that attention weights sum to 1 over the time sequence."""
|
||||||
|
in_channels = dummy_bottleneck_input.size(1)
|
||||||
|
block = SelfAttention(in_channels=in_channels, attention_channels=32)
|
||||||
|
|
||||||
|
weights = block.compute_attention_weights(dummy_bottleneck_input)
|
||||||
|
|
||||||
|
# Weights shape should be (B, T, T) where T is time (width)
|
||||||
|
batch_size = dummy_bottleneck_input.size(0)
|
||||||
|
time_steps = dummy_bottleneck_input.size(3)
|
||||||
|
|
||||||
|
assert weights.shape == (batch_size, time_steps, time_steps)
|
||||||
|
|
||||||
|
# Summing across the keys (dim=1) for each query should equal 1.0
|
||||||
|
sum_weights = weights.sum(dim=1)
|
||||||
|
assert torch.allclose(sum_weights, torch.ones_like(sum_weights), atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"layer_config, expected_type",
|
||||||
|
[
|
||||||
|
(ConvConfig(out_channels=32), ConvBlock),
|
||||||
|
(StandardConvDownConfig(out_channels=32), StandardConvDownBlock),
|
||||||
|
(StandardConvUpConfig(out_channels=32), StandardConvUpBlock),
|
||||||
|
(FreqCoordConvDownConfig(out_channels=32), FreqCoordConvDownBlock),
|
||||||
|
(FreqCoordConvUpConfig(out_channels=32), FreqCoordConvUpBlock),
|
||||||
|
(SelfAttentionConfig(attention_channels=32), SelfAttention),
|
||||||
|
(VerticalConvConfig(channels=32), VerticalConv),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_build_layer_factory(layer_config, expected_type):
|
||||||
|
"""Test that the factory dynamically builds the correct block."""
|
||||||
|
input_height = 32
|
||||||
|
in_channels = 16
|
||||||
|
|
||||||
|
layer = build_layer(
|
||||||
|
input_height=input_height,
|
||||||
|
in_channels=in_channels,
|
||||||
|
config=layer_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(layer, expected_type)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_group_from_config_and_forward(dummy_input):
|
||||||
|
"""Test that LayerGroup successfully chains multiple blocks."""
|
||||||
|
in_channels = dummy_input.size(1)
|
||||||
|
input_height = dummy_input.size(2)
|
||||||
|
|
||||||
|
config = LayerGroupConfig(
|
||||||
|
layers=[
|
||||||
|
ConvConfig(out_channels=32),
|
||||||
|
StandardConvDownConfig(out_channels=64),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_group = build_layer(
|
||||||
|
input_height=input_height,
|
||||||
|
in_channels=in_channels,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(layer_group, LayerGroup)
|
||||||
|
assert len(layer_group.layers) == 2
|
||||||
|
|
||||||
|
# The group should report the output channels of the LAST block
|
||||||
|
assert layer_group.get_output_channels() == 64
|
||||||
|
|
||||||
|
# The group should report the accumulated height changes
|
||||||
|
assert layer_group.get_output_height(input_height) == input_height // 2
|
||||||
|
|
||||||
|
output = layer_group(dummy_input)
|
||||||
|
|
||||||
|
# Shape should reflect: Conv (stays 32x32) -> DownConv (halves to 16x16)
|
||||||
|
assert output.shape == (2, 64, 16, 16)
|
||||||
122
tests/test_models/test_detectors.py
Normal file
122
tests/test_models/test_detectors.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.models.config import BackboneConfig
|
||||||
|
from batdetect2.models.detectors import Detector, build_detector
|
||||||
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
|
from batdetect2.typing.models import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_spectrogram() -> torch.Tensor:
|
||||||
|
"""Provides a dummy spectrogram tensor (B, C, H, W)."""
|
||||||
|
return torch.randn(2, 1, 256, 128)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_detector_default():
|
||||||
|
"""Test building the default detector without a config."""
|
||||||
|
num_classes = 5
|
||||||
|
model = build_detector(num_classes=num_classes)
|
||||||
|
|
||||||
|
assert isinstance(model, Detector)
|
||||||
|
assert model.num_classes == num_classes
|
||||||
|
assert isinstance(model.classifier_head, ClassifierHead)
|
||||||
|
assert isinstance(model.bbox_head, BBoxHead)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_detector_custom_config():
|
||||||
|
"""Test building a detector with a custom BackboneConfig."""
|
||||||
|
num_classes = 3
|
||||||
|
config = BackboneConfig(in_channels=2, input_height=128)
|
||||||
|
|
||||||
|
model = build_detector(num_classes=num_classes, config=config)
|
||||||
|
|
||||||
|
assert isinstance(model, Detector)
|
||||||
|
assert model.backbone.input_height == 128
|
||||||
|
assert model.backbone.encoder.in_channels == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_detector_forward_pass_shapes(dummy_spectrogram):
|
||||||
|
"""Test that the forward pass produces correctly shaped outputs."""
|
||||||
|
num_classes = 4
|
||||||
|
# Build model matching the dummy input shape
|
||||||
|
config = BackboneConfig(in_channels=1, input_height=256)
|
||||||
|
model = build_detector(num_classes=num_classes, config=config)
|
||||||
|
|
||||||
|
# Process the spectrogram through the model
|
||||||
|
# PyTorch expects shape (Batch, Channels, Height, Width)
|
||||||
|
output = model(dummy_spectrogram)
|
||||||
|
|
||||||
|
# Verify the output is a NamedTuple ModelOutput
|
||||||
|
assert isinstance(output, ModelOutput)
|
||||||
|
|
||||||
|
batch_size = dummy_spectrogram.size(0)
|
||||||
|
input_height = dummy_spectrogram.size(2)
|
||||||
|
input_width = dummy_spectrogram.size(3)
|
||||||
|
|
||||||
|
# Check detection probabilities shape: (B, 1, H, W)
|
||||||
|
assert output.detection_probs.shape == (
|
||||||
|
batch_size,
|
||||||
|
1,
|
||||||
|
input_height,
|
||||||
|
input_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check size predictions shape: (B, 2, H, W)
|
||||||
|
assert output.size_preds.shape == (
|
||||||
|
batch_size,
|
||||||
|
2,
|
||||||
|
input_height,
|
||||||
|
input_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check class probabilities shape: (B, num_classes, H, W)
|
||||||
|
assert output.class_probs.shape == (
|
||||||
|
batch_size,
|
||||||
|
num_classes,
|
||||||
|
input_height,
|
||||||
|
input_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check features shape: (B, out_channels, H, W)
|
||||||
|
out_channels = model.backbone.out_channels
|
||||||
|
assert output.features.shape == (
|
||||||
|
batch_size,
|
||||||
|
out_channels,
|
||||||
|
input_height,
|
||||||
|
input_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
|
||||||
|
"""Test the full pipeline from audio to model output."""
|
||||||
|
# Generate random audio: 1 second at 256kHz
|
||||||
|
samplerate = 256000
|
||||||
|
duration = 1.0
|
||||||
|
audio = np.random.randn(int(samplerate * duration)).astype(np.float32)
|
||||||
|
|
||||||
|
# Create tensor: (Batch=1, Channels=1, Samples) - Preprocessor expects batched 1D waveforms
|
||||||
|
audio_tensor = torch.from_numpy(audio).unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
# Preprocess -> Output shape: (Batch=1, Channels=1, Height, Width)
|
||||||
|
spec = sample_preprocessor(audio_tensor)
|
||||||
|
|
||||||
|
# Just to be safe, make sure it has 4 dimensions if the preprocessor didn't add batch
|
||||||
|
if spec.ndim == 3:
|
||||||
|
spec = spec.unsqueeze(0)
|
||||||
|
|
||||||
|
# Build model matching the preprocessor's output shape
|
||||||
|
# The preprocessor output is (B, C, H, W) -> spec.shape[1] is C, spec.shape[2] is H
|
||||||
|
config = BackboneConfig(
|
||||||
|
in_channels=spec.shape[1], input_height=spec.shape[2]
|
||||||
|
)
|
||||||
|
model = build_detector(num_classes=3, config=config)
|
||||||
|
|
||||||
|
# Process
|
||||||
|
output = model(spec)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert isinstance(output, ModelOutput)
|
||||||
|
assert output.detection_probs.shape[0] == 1 # Batch size 1
|
||||||
|
assert output.class_probs.shape[1] == 3 # 3 classes
|
||||||
@ -628,8 +628,5 @@ def test_build_roi_mapper_raises_error_for_unknown_name():
|
|||||||
name = "non_existent_mapper"
|
name = "non_existent_mapper"
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
with pytest.raises(NotImplementedError) as excinfo:
|
with pytest.raises(NotImplementedError):
|
||||||
build_roi_mapper(DummyConfig()) # type: ignore
|
build_roi_mapper(DummyConfig()) # type: ignore
|
||||||
|
|
||||||
# Check that the error message is informative.
|
|
||||||
assert "No ROI mapper of name 'non_existent_mapper'" in str(excinfo.value)
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user