mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add interfaces for encoder/decoder/bottleneck
This commit is contained in:
parent
54605ef269
commit
e393709258
@ -26,9 +26,14 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.models.bottleneck import build_bottleneck
|
from batdetect2.models.bottleneck import build_bottleneck
|
||||||
from batdetect2.models.config import BackboneConfig
|
from batdetect2.models.config import BackboneConfig
|
||||||
from batdetect2.models.decoder import Decoder, build_decoder
|
from batdetect2.models.decoder import build_decoder
|
||||||
from batdetect2.models.encoder import Encoder, build_encoder
|
from batdetect2.models.encoder import build_encoder
|
||||||
from batdetect2.typing.models import BackboneModel
|
from batdetect2.typing.models import (
|
||||||
|
BackboneModel,
|
||||||
|
BottleneckProtocol,
|
||||||
|
DecoderProtocol,
|
||||||
|
EncoderProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Backbone",
|
"Backbone",
|
||||||
@ -55,11 +60,11 @@ class Backbone(BackboneModel):
|
|||||||
Expected height of the input spectrogram.
|
Expected height of the input spectrogram.
|
||||||
out_channels : int
|
out_channels : int
|
||||||
Number of channels in the final output feature map.
|
Number of channels in the final output feature map.
|
||||||
encoder : Encoder
|
encoder : EncoderProtocol
|
||||||
The instantiated encoder module.
|
The instantiated encoder module.
|
||||||
decoder : Decoder
|
decoder : DecoderProtocol
|
||||||
The instantiated decoder module.
|
The instantiated decoder module.
|
||||||
bottleneck : nn.Module
|
bottleneck : BottleneckProtocol
|
||||||
The instantiated bottleneck module.
|
The instantiated bottleneck module.
|
||||||
final_conv : ConvBlock
|
final_conv : ConvBlock
|
||||||
Final convolutional block applied after the decoder.
|
Final convolutional block applied after the decoder.
|
||||||
@ -71,9 +76,9 @@ class Backbone(BackboneModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_height: int,
|
input_height: int,
|
||||||
encoder: Encoder,
|
encoder: EncoderProtocol,
|
||||||
decoder: Decoder,
|
decoder: DecoderProtocol,
|
||||||
bottleneck: nn.Module,
|
bottleneck: BottleneckProtocol,
|
||||||
):
|
):
|
||||||
"""Initialize the Backbone network.
|
"""Initialize the Backbone network.
|
||||||
|
|
||||||
@ -83,11 +88,11 @@ class Backbone(BackboneModel):
|
|||||||
Expected height of the input spectrogram.
|
Expected height of the input spectrogram.
|
||||||
out_channels : int
|
out_channels : int
|
||||||
Desired number of output channels for the backbone's feature map.
|
Desired number of output channels for the backbone's feature map.
|
||||||
encoder : Encoder
|
encoder : EncoderProtocol
|
||||||
An initialized Encoder module.
|
An initialized Encoder module.
|
||||||
decoder : Decoder
|
decoder : DecoderProtocol
|
||||||
An initialized Decoder module.
|
An initialized Decoder module.
|
||||||
bottleneck : nn.Module
|
bottleneck : BottleneckProtocol
|
||||||
An initialized Bottleneck module.
|
An initialized Bottleneck module.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
@ -185,7 +190,7 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
decoder = build_decoder(
|
decoder = build_decoder(
|
||||||
in_channels=bottleneck.get_output_channels(),
|
in_channels=bottleneck.out_channels,
|
||||||
input_height=encoder.output_height,
|
input_height=encoder.output_height,
|
||||||
config=config.decoder,
|
config=config.decoder,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -56,17 +56,14 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class BlockProtocol(Protocol):
|
class Block(nn.Module):
|
||||||
def get_output_channels(self) -> int:
|
in_channels: int
|
||||||
raise NotImplementedError
|
out_channels: int
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
def get_output_height(self, input_height: int) -> int:
|
||||||
return input_height
|
return input_height
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module, BlockProtocol): ...
|
|
||||||
|
|
||||||
|
|
||||||
block_registry: Registry[Block, [int, int]] = Registry("block")
|
block_registry: Registry[Block, [int, int]] = Registry("block")
|
||||||
|
|
||||||
|
|
||||||
@ -132,6 +129,8 @@ class SelfAttention(Block):
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels
|
||||||
|
|
||||||
# Note, does not encode position information (absolute or relative)
|
# Note, does not encode position information (absolute or relative)
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
@ -206,9 +205,6 @@ class SelfAttention(Block):
|
|||||||
att_weights = F.softmax(kk_qq, 1)
|
att_weights = F.softmax(kk_qq, 1)
|
||||||
return att_weights
|
return att_weights
|
||||||
|
|
||||||
def get_output_channels(self) -> int:
|
|
||||||
return self.output_channels
|
|
||||||
|
|
||||||
@block_registry.register(SelfAttentionConfig)
|
@block_registry.register(SelfAttentionConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
@ -267,6 +263,8 @@ class ConvBlock(Block):
|
|||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -290,9 +288,6 @@ class ConvBlock(Block):
|
|||||||
"""
|
"""
|
||||||
return F.relu_(self.batch_norm(self.conv(x)))
|
return F.relu_(self.batch_norm(self.conv(x)))
|
||||||
|
|
||||||
def get_output_channels(self) -> int:
|
|
||||||
return self.conv.out_channels
|
|
||||||
|
|
||||||
@block_registry.register(ConvConfig)
|
@block_registry.register(ConvConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
@ -342,6 +337,8 @@ class VerticalConv(Block):
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
@ -366,9 +363,6 @@ class VerticalConv(Block):
|
|||||||
"""
|
"""
|
||||||
return F.relu_(self.bn(self.conv(x)))
|
return F.relu_(self.bn(self.conv(x)))
|
||||||
|
|
||||||
def get_output_channels(self) -> int:
|
|
||||||
return self.conv.out_channels
|
|
||||||
|
|
||||||
@block_registry.register(VerticalConvConfig)
|
@block_registry.register(VerticalConvConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
@ -438,6 +432,8 @@ class FreqCoordConvDownBlock(Block):
|
|||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
self.coords = nn.Parameter(
|
self.coords = nn.Parameter(
|
||||||
torch.linspace(-1, 1, input_height)[None, None, ..., None],
|
torch.linspace(-1, 1, input_height)[None, None, ..., None],
|
||||||
@ -472,9 +468,6 @@ class FreqCoordConvDownBlock(Block):
|
|||||||
x = F.relu(self.batch_norm(x), inplace=True)
|
x = F.relu(self.batch_norm(x), inplace=True)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_output_channels(self) -> int:
|
|
||||||
return self.conv.out_channels
|
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
def get_output_height(self, input_height: int) -> int:
|
||||||
return input_height // 2
|
return input_height // 2
|
||||||
|
|
||||||
@ -538,6 +531,8 @@ class StandardConvDownBlock(Block):
|
|||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
):
|
):
|
||||||
super(StandardConvDownBlock, self).__init__()
|
super(StandardConvDownBlock, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -563,9 +558,6 @@ class StandardConvDownBlock(Block):
|
|||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
return F.relu(self.batch_norm(x), inplace=True)
|
return F.relu(self.batch_norm(x), inplace=True)
|
||||||
|
|
||||||
def get_output_channels(self) -> int:
|
|
||||||
return self.conv.out_channels
|
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
def get_output_height(self, input_height: int) -> int:
|
||||||
return input_height // 2
|
return input_height // 2
|
||||||
|
|
||||||
@ -652,6 +644,8 @@ class FreqCoordConvUpBlock(Block):
|
|||||||
up_scale: Tuple[int, int] = (2, 2),
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
self.up_scale = up_scale
|
self.up_scale = up_scale
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
@ -698,9 +692,6 @@ class FreqCoordConvUpBlock(Block):
|
|||||||
op = F.relu(self.batch_norm(op), inplace=True)
|
op = F.relu(self.batch_norm(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
def get_output_channels(self) -> int:
|
|
||||||
return self.conv.out_channels
|
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
def get_output_height(self, input_height: int) -> int:
|
||||||
return input_height * 2
|
return input_height * 2
|
||||||
|
|
||||||
@ -779,6 +770,8 @@ class StandardConvUpBlock(Block):
|
|||||||
up_scale: Tuple[int, int] = (2, 2),
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
):
|
):
|
||||||
super(StandardConvUpBlock, self).__init__()
|
super(StandardConvUpBlock, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
self.up_scale = up_scale
|
self.up_scale = up_scale
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
@ -815,9 +808,6 @@ class StandardConvUpBlock(Block):
|
|||||||
op = F.relu(self.batch_norm(op), inplace=True)
|
op = F.relu(self.batch_norm(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
def get_output_channels(self) -> int:
|
|
||||||
return self.conv.out_channels
|
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
def get_output_height(self, input_height: int) -> int:
|
||||||
return input_height * 2
|
return input_height * 2
|
||||||
|
|
||||||
@ -868,14 +858,15 @@ class LayerGroup(nn.Module):
|
|||||||
input_channels: int,
|
input_channels: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.in_channels = input_channels
|
||||||
|
self.out_channels = (
|
||||||
|
layers[-1].out_channels if layers else input_channels
|
||||||
|
)
|
||||||
self.layers = nn.Sequential(*layers)
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.layers(x)
|
return self.layers(x)
|
||||||
|
|
||||||
def get_output_channels(self) -> int:
|
|
||||||
return self.layers[-1].get_output_channels() # type: ignore
|
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
def get_output_height(self, input_height: int) -> int:
|
||||||
for block in self.layers:
|
for block in self.layers:
|
||||||
input_height = block.get_output_height(input_height) # type: ignore
|
input_height = block.get_output_height(input_height) # type: ignore
|
||||||
@ -898,7 +889,7 @@ class LayerGroup(nn.Module):
|
|||||||
)
|
)
|
||||||
layers.append(layer)
|
layers.append(layer)
|
||||||
input_height = layer.get_output_height(input_height)
|
input_height = layer.get_output_height(input_height)
|
||||||
input_channels = layer.get_output_channels()
|
input_channels = layer.out_channels
|
||||||
|
|
||||||
return LayerGroup(
|
return LayerGroup(
|
||||||
layers=layers,
|
layers=layers,
|
||||||
|
|||||||
@ -28,6 +28,8 @@ from batdetect2.models.blocks import (
|
|||||||
build_layer,
|
build_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from batdetect2.typing.models import BottleneckProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BottleneckConfig",
|
"BottleneckConfig",
|
||||||
"Bottleneck",
|
"Bottleneck",
|
||||||
@ -127,9 +129,6 @@ class Bottleneck(Block):
|
|||||||
|
|
||||||
return x.repeat([1, 1, self.input_height, 1])
|
return x.repeat([1, 1, self.input_height, 1])
|
||||||
|
|
||||||
def get_output_channels(self) -> int:
|
|
||||||
return self.layers[-1].get_output_channels() # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
BottleneckLayerConfig = Annotated[
|
BottleneckLayerConfig = Annotated[
|
||||||
SelfAttentionConfig,
|
SelfAttentionConfig,
|
||||||
@ -177,7 +176,7 @@ def build_bottleneck(
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
config: BottleneckConfig | None = None,
|
config: BottleneckConfig | None = None,
|
||||||
) -> Block:
|
) -> BottleneckProtocol:
|
||||||
"""Factory function to build the Bottleneck module from configuration.
|
"""Factory function to build the Bottleneck module from configuration.
|
||||||
|
|
||||||
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
||||||
@ -217,7 +216,7 @@ def build_bottleneck(
|
|||||||
config=layer_config,
|
config=layer_config,
|
||||||
)
|
)
|
||||||
current_height = layer.get_output_height(current_height)
|
current_height = layer.get_output_height(current_height)
|
||||||
current_channels = layer.get_output_channels()
|
current_channels = layer.out_channels
|
||||||
assert current_height == input_height, (
|
assert current_height == input_height, (
|
||||||
"Bottleneck layers should not change the spectrogram height"
|
"Bottleneck layers should not change the spectrogram height"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -265,7 +265,7 @@ def build_decoder(
|
|||||||
config=layer_config,
|
config=layer_config,
|
||||||
)
|
)
|
||||||
current_height = layer.get_output_height(current_height)
|
current_height = layer.get_output_height(current_height)
|
||||||
current_channels = layer.get_output_channels()
|
current_channels = layer.out_channels
|
||||||
layers.append(layer)
|
layers.append(layer)
|
||||||
|
|
||||||
return Decoder(
|
return Decoder(
|
||||||
|
|||||||
@ -307,7 +307,7 @@ def build_encoder(
|
|||||||
)
|
)
|
||||||
layers.append(layer)
|
layers.append(layer)
|
||||||
current_height = layer.get_output_height(current_height)
|
current_height = layer.get_output_height(current_height)
|
||||||
current_channels = layer.get_output_channels()
|
current_channels = layer.out_channels
|
||||||
|
|
||||||
return Encoder(
|
return Encoder(
|
||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
|
|||||||
@ -16,17 +16,80 @@ Key components:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import NamedTuple
|
from typing import List, NamedTuple, Protocol
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
"BackboneModel",
|
"BackboneModel",
|
||||||
|
"EncoderDecoderModel",
|
||||||
"DetectionModel",
|
"DetectionModel",
|
||||||
|
"BlockProtocol",
|
||||||
|
"EncoderProtocol",
|
||||||
|
"BottleneckProtocol",
|
||||||
|
"DecoderProtocol",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockProtocol(Protocol):
|
||||||
|
"""Interface for blocks of network layers."""
|
||||||
|
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Forward pass of the block."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_output_height(self, input_height: int) -> int:
|
||||||
|
"""Calculate the output height based on input height."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderProtocol(Protocol):
|
||||||
|
"""Interface for the downsampling path of a network."""
|
||||||
|
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
input_height: int
|
||||||
|
output_height: int
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
"""Forward pass must return intermediate tensors for skip connections."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckProtocol(Protocol):
|
||||||
|
"""Interface for the middle part of a U-Net-like network."""
|
||||||
|
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
input_height: int
|
||||||
|
|
||||||
|
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Processes the features from the encoder."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderProtocol(Protocol):
|
||||||
|
"""Interface for the upsampling reconstruction path."""
|
||||||
|
|
||||||
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
|
input_height: int
|
||||||
|
output_height: int
|
||||||
|
depth: int
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residuals: List[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Upsamples features while integrating skip connections."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class ModelOutput(NamedTuple):
|
class ModelOutput(NamedTuple):
|
||||||
"""Standard container for the outputs of a BatDetect2 detection model.
|
"""Standard container for the outputs of a BatDetect2 detection model.
|
||||||
|
|
||||||
|
|||||||
@ -53,7 +53,7 @@ def test_standard_block_protocol_methods(
|
|||||||
|
|
||||||
block = block_class(in_channels=in_channels, out_channels=out_channels)
|
block = block_class(in_channels=in_channels, out_channels=out_channels)
|
||||||
|
|
||||||
assert block.get_output_channels() == out_channels
|
assert block.out_channels == out_channels
|
||||||
assert block.get_output_height(input_height) == int(
|
assert block.get_output_height(input_height) == int(
|
||||||
input_height * expected_h_scale
|
input_height * expected_h_scale
|
||||||
)
|
)
|
||||||
@ -80,7 +80,7 @@ def test_coord_block_protocol_methods(
|
|||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert block.get_output_channels() == out_channels
|
assert block.out_channels == out_channels
|
||||||
assert block.get_output_height(input_height) == int(
|
assert block.get_output_height(input_height) == int(
|
||||||
input_height * expected_h_scale
|
input_height * expected_h_scale
|
||||||
)
|
)
|
||||||
@ -96,7 +96,7 @@ def test_vertical_conv_forward_shape(dummy_input):
|
|||||||
output = block(dummy_input)
|
output = block(dummy_input)
|
||||||
|
|
||||||
assert output.shape == (2, out_channels, 1, 32)
|
assert output.shape == (2, out_channels, 1, 32)
|
||||||
assert block.get_output_channels() == out_channels
|
assert block.out_channels == out_channels
|
||||||
|
|
||||||
|
|
||||||
def test_self_attention_forward_shape(dummy_bottleneck_input):
|
def test_self_attention_forward_shape(dummy_bottleneck_input):
|
||||||
@ -110,7 +110,7 @@ def test_self_attention_forward_shape(dummy_bottleneck_input):
|
|||||||
output = block(dummy_bottleneck_input)
|
output = block(dummy_bottleneck_input)
|
||||||
|
|
||||||
assert output.shape == dummy_bottleneck_input.shape
|
assert output.shape == dummy_bottleneck_input.shape
|
||||||
assert block.get_output_channels() == in_channels
|
assert block.out_channels == in_channels
|
||||||
|
|
||||||
|
|
||||||
def test_self_attention_weights(dummy_bottleneck_input):
|
def test_self_attention_weights(dummy_bottleneck_input):
|
||||||
@ -179,7 +179,7 @@ def test_layer_group_from_config_and_forward(dummy_input):
|
|||||||
assert len(layer_group.layers) == 2
|
assert len(layer_group.layers) == 2
|
||||||
|
|
||||||
# The group should report the output channels of the LAST block
|
# The group should report the output channels of the LAST block
|
||||||
assert layer_group.get_output_channels() == 64
|
assert layer_group.out_channels == 64
|
||||||
|
|
||||||
# The group should report the accumulated height changes
|
# The group should report the accumulated height changes
|
||||||
assert layer_group.get_output_height(input_height) == input_height // 2
|
assert layer_group.get_output_height(input_height) == input_height // 2
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user