|
- # Copyright (c) OpenMMLab. All rights reserved.
- import pytest
- import torch
- from mmcv.utils import ConfigDict
-
- from mmdet.models.utils.transformer import (AdaptivePadding,
- DetrTransformerDecoder,
- DetrTransformerEncoder, PatchEmbed,
- PatchMerging, Transformer)
-
-
- def test_adaptive_padding():
-
- for padding in ('same', 'corner'):
- kernel_size = 16
- stride = 16
- dilation = 1
- input = torch.rand(1, 1, 15, 17)
- pool = AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=padding)
- out = pool(input)
- # padding to divisible by 16
- assert (out.shape[2], out.shape[3]) == (16, 32)
- input = torch.rand(1, 1, 16, 17)
- out = pool(input)
- # padding to divisible by 16
- assert (out.shape[2], out.shape[3]) == (16, 32)
-
- kernel_size = (2, 2)
- stride = (2, 2)
- dilation = (1, 1)
-
- adap_pad = AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=padding)
- input = torch.rand(1, 1, 11, 13)
- out = adap_pad(input)
- # padding to divisible by 2
- assert (out.shape[2], out.shape[3]) == (12, 14)
-
- kernel_size = (2, 2)
- stride = (10, 10)
- dilation = (1, 1)
-
- adap_pad = AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=padding)
- input = torch.rand(1, 1, 10, 13)
- out = adap_pad(input)
- # no padding
- assert (out.shape[2], out.shape[3]) == (10, 13)
-
- kernel_size = (11, 11)
- adap_pad = AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=padding)
- input = torch.rand(1, 1, 11, 13)
- out = adap_pad(input)
- # all padding
- assert (out.shape[2], out.shape[3]) == (21, 21)
-
- # test padding as kernel is (7,9)
- input = torch.rand(1, 1, 11, 13)
- stride = (3, 4)
- kernel_size = (4, 5)
- dilation = (2, 2)
- # actually (7, 9)
- adap_pad = AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=padding)
- dilation_out = adap_pad(input)
- assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21)
- kernel_size = (7, 9)
- dilation = (1, 1)
- adap_pad = AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=padding)
- kernel79_out = adap_pad(input)
- assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21)
- assert kernel79_out.shape == dilation_out.shape
-
- # assert only support "same" "corner"
- with pytest.raises(AssertionError):
- AdaptivePadding(
- kernel_size=kernel_size,
- stride=stride,
- dilation=dilation,
- padding=1)
-
-
- def test_patch_embed():
- B = 2
- H = 3
- W = 4
- C = 3
- embed_dims = 10
- kernel_size = 3
- stride = 1
- dummy_input = torch.rand(B, C, H, W)
- patch_merge_1 = PatchEmbed(
- in_channels=C,
- embed_dims=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=0,
- dilation=1,
- norm_cfg=None)
-
- x1, shape = patch_merge_1(dummy_input)
- # test out shape
- assert x1.shape == (2, 2, 10)
- # test outsize is correct
- assert shape == (1, 2)
- # test L = out_h * out_w
- assert shape[0] * shape[1] == x1.shape[1]
-
- B = 2
- H = 10
- W = 10
- C = 3
- embed_dims = 10
- kernel_size = 5
- stride = 2
- dummy_input = torch.rand(B, C, H, W)
- # test dilation
- patch_merge_2 = PatchEmbed(
- in_channels=C,
- embed_dims=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=0,
- dilation=2,
- norm_cfg=None,
- )
-
- x2, shape = patch_merge_2(dummy_input)
- # test out shape
- assert x2.shape == (2, 1, 10)
- # test outsize is correct
- assert shape == (1, 1)
- # test L = out_h * out_w
- assert shape[0] * shape[1] == x2.shape[1]
-
- stride = 2
- input_size = (10, 10)
-
- dummy_input = torch.rand(B, C, H, W)
- # test stride and norm
- patch_merge_3 = PatchEmbed(
- in_channels=C,
- embed_dims=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=0,
- dilation=2,
- norm_cfg=dict(type='LN'),
- input_size=input_size)
-
- x3, shape = patch_merge_3(dummy_input)
- # test out shape
- assert x3.shape == (2, 1, 10)
- # test outsize is correct
- assert shape == (1, 1)
- # test L = out_h * out_w
- assert shape[0] * shape[1] == x3.shape[1]
-
- # test the init_out_size with nn.Unfold
- assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 -
- 1) // 2 + 1
- assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 -
- 1) // 2 + 1
- H = 11
- W = 12
- input_size = (H, W)
- dummy_input = torch.rand(B, C, H, W)
- # test stride and norm
- patch_merge_3 = PatchEmbed(
- in_channels=C,
- embed_dims=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=0,
- dilation=2,
- norm_cfg=dict(type='LN'),
- input_size=input_size)
-
- _, shape = patch_merge_3(dummy_input)
- # when input_size equal to real input
- # the out_size should be equal to `init_out_size`
- assert shape == patch_merge_3.init_out_size
-
- input_size = (H, W)
- dummy_input = torch.rand(B, C, H, W)
- # test stride and norm
- patch_merge_3 = PatchEmbed(
- in_channels=C,
- embed_dims=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=0,
- dilation=2,
- norm_cfg=dict(type='LN'),
- input_size=input_size)
-
- _, shape = patch_merge_3(dummy_input)
- # when input_size equal to real input
- # the out_size should be equal to `init_out_size`
- assert shape == patch_merge_3.init_out_size
-
- # test adap padding
- for padding in ('same', 'corner'):
- in_c = 2
- embed_dims = 3
- B = 2
-
- # test stride is 1
- input_size = (5, 5)
- kernel_size = (5, 5)
- stride = (1, 1)
- dilation = 1
- bias = False
-
- x = torch.rand(B, in_c, *input_size)
- patch_embed = PatchEmbed(
- in_channels=in_c,
- embed_dims=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
-
- x_out, out_size = patch_embed(x)
- assert x_out.size() == (B, 25, 3)
- assert out_size == (5, 5)
- assert x_out.size(1) == out_size[0] * out_size[1]
-
- # test kernel_size == stride
- input_size = (5, 5)
- kernel_size = (5, 5)
- stride = (5, 5)
- dilation = 1
- bias = False
-
- x = torch.rand(B, in_c, *input_size)
- patch_embed = PatchEmbed(
- in_channels=in_c,
- embed_dims=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
-
- x_out, out_size = patch_embed(x)
- assert x_out.size() == (B, 1, 3)
- assert out_size == (1, 1)
- assert x_out.size(1) == out_size[0] * out_size[1]
-
- # test kernel_size == stride
- input_size = (6, 5)
- kernel_size = (5, 5)
- stride = (5, 5)
- dilation = 1
- bias = False
-
- x = torch.rand(B, in_c, *input_size)
- patch_embed = PatchEmbed(
- in_channels=in_c,
- embed_dims=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
-
- x_out, out_size = patch_embed(x)
- assert x_out.size() == (B, 2, 3)
- assert out_size == (2, 1)
- assert x_out.size(1) == out_size[0] * out_size[1]
-
- # test different kernel_size with different stride
- input_size = (6, 5)
- kernel_size = (6, 2)
- stride = (6, 2)
- dilation = 1
- bias = False
-
- x = torch.rand(B, in_c, *input_size)
- patch_embed = PatchEmbed(
- in_channels=in_c,
- embed_dims=embed_dims,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
-
- x_out, out_size = patch_embed(x)
- assert x_out.size() == (B, 3, 3)
- assert out_size == (1, 3)
- assert x_out.size(1) == out_size[0] * out_size[1]
-
-
- def test_patch_merging():
-
- # Test the model with int padding
- in_c = 3
- out_c = 4
- kernel_size = 3
- stride = 3
- padding = 1
- dilation = 1
- bias = False
- # test the case `pad_to_stride` is False
- patch_merge = PatchMerging(
- in_channels=in_c,
- out_channels=out_c,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
- B, L, C = 1, 100, 3
- input_size = (10, 10)
- x = torch.rand(B, L, C)
- x_out, out_size = patch_merge(x, input_size)
- assert x_out.size() == (1, 16, 4)
- assert out_size == (4, 4)
- # assert out size is consistent with real output
- assert x_out.size(1) == out_size[0] * out_size[1]
- in_c = 4
- out_c = 5
- kernel_size = 6
- stride = 3
- padding = 2
- dilation = 2
- bias = False
- patch_merge = PatchMerging(
- in_channels=in_c,
- out_channels=out_c,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
- B, L, C = 1, 100, 4
- input_size = (10, 10)
- x = torch.rand(B, L, C)
- x_out, out_size = patch_merge(x, input_size)
- assert x_out.size() == (1, 4, 5)
- assert out_size == (2, 2)
- # assert out size is consistent with real output
- assert x_out.size(1) == out_size[0] * out_size[1]
-
- # Test with adaptive padding
- for padding in ('same', 'corner'):
- in_c = 2
- out_c = 3
- B = 2
-
- # test stride is 1
- input_size = (5, 5)
- kernel_size = (5, 5)
- stride = (1, 1)
- dilation = 1
- bias = False
- L = input_size[0] * input_size[1]
-
- x = torch.rand(B, L, in_c)
- patch_merge = PatchMerging(
- in_channels=in_c,
- out_channels=out_c,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
-
- x_out, out_size = patch_merge(x, input_size)
- assert x_out.size() == (B, 25, 3)
- assert out_size == (5, 5)
- assert x_out.size(1) == out_size[0] * out_size[1]
-
- # test kernel_size == stride
- input_size = (5, 5)
- kernel_size = (5, 5)
- stride = (5, 5)
- dilation = 1
- bias = False
- L = input_size[0] * input_size[1]
-
- x = torch.rand(B, L, in_c)
- patch_merge = PatchMerging(
- in_channels=in_c,
- out_channels=out_c,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
-
- x_out, out_size = patch_merge(x, input_size)
- assert x_out.size() == (B, 1, 3)
- assert out_size == (1, 1)
- assert x_out.size(1) == out_size[0] * out_size[1]
-
- # test kernel_size == stride
- input_size = (6, 5)
- kernel_size = (5, 5)
- stride = (5, 5)
- dilation = 1
- bias = False
- L = input_size[0] * input_size[1]
-
- x = torch.rand(B, L, in_c)
- patch_merge = PatchMerging(
- in_channels=in_c,
- out_channels=out_c,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
-
- x_out, out_size = patch_merge(x, input_size)
- assert x_out.size() == (B, 2, 3)
- assert out_size == (2, 1)
- assert x_out.size(1) == out_size[0] * out_size[1]
-
- # test different kernel_size with different stride
- input_size = (6, 5)
- kernel_size = (6, 2)
- stride = (6, 2)
- dilation = 1
- bias = False
- L = input_size[0] * input_size[1]
-
- x = torch.rand(B, L, in_c)
- patch_merge = PatchMerging(
- in_channels=in_c,
- out_channels=out_c,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
-
- x_out, out_size = patch_merge(x, input_size)
- assert x_out.size() == (B, 3, 3)
- assert out_size == (1, 3)
- assert x_out.size(1) == out_size[0] * out_size[1]
-
-
- def test_detr_transformer_dencoder_encoder_layer():
- config = ConfigDict(
- dict(
- return_intermediate=True,
- num_layers=6,
- transformerlayers=dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=(
- 'norm',
- 'self_attn',
- 'norm',
- 'cross_attn',
- 'norm',
- 'ffn',
- ))))
- assert DetrTransformerDecoder(**config).layers[0].pre_norm
- assert len(DetrTransformerDecoder(**config).layers) == 6
-
- DetrTransformerDecoder(**config)
- with pytest.raises(AssertionError):
- config = ConfigDict(
- dict(
- return_intermediate=True,
- num_layers=6,
- transformerlayers=[
- dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'cross_attn',
- 'norm', 'ffn', 'norm'))
- ] * 5))
- DetrTransformerDecoder(**config)
-
- config = ConfigDict(
- dict(
- num_layers=6,
- transformerlayers=dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('norm', 'self_attn', 'norm', 'cross_attn',
- 'norm', 'ffn', 'norm'))))
-
- with pytest.raises(AssertionError):
- # len(operation_order) == 6
- DetrTransformerEncoder(**config)
-
-
- def test_transformer():
- config = ConfigDict(
- dict(
- encoder=dict(
- type='DetrTransformerEncoder',
- num_layers=6,
- transformerlayers=dict(
- type='BaseTransformerLayer',
- attn_cfgs=[
- dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1)
- ],
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
- decoder=dict(
- type='DetrTransformerDecoder',
- return_intermediate=True,
- num_layers=6,
- transformerlayers=dict(
- type='DetrTransformerDecoderLayer',
- attn_cfgs=dict(
- type='MultiheadAttention',
- embed_dims=256,
- num_heads=8,
- dropout=0.1),
- feedforward_channels=2048,
- ffn_dropout=0.1,
- operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
- 'ffn', 'norm')),
- )))
- transformer = Transformer(**config)
- transformer.init_weights()
|