|
- # Copyright (c) OpenMMLab. All rights reserved.
- import pytest
- import torch
-
- from mmdet.models.backbones import TridentResNet
- from mmdet.models.backbones.trident_resnet import TridentBottleneck
-
-
- def test_trident_resnet_bottleneck():
- trident_dilations = (1, 2, 3)
- test_branch_idx = 1
- concat_output = True
- trident_build_config = (trident_dilations, test_branch_idx, concat_output)
-
- with pytest.raises(AssertionError):
- # Style must be in ['pytorch', 'caffe']
- TridentBottleneck(
- *trident_build_config, inplanes=64, planes=64, style='tensorflow')
-
- with pytest.raises(AssertionError):
- # Allowed positions are 'after_conv1', 'after_conv2', 'after_conv3'
- plugins = [
- dict(
- cfg=dict(type='ContextBlock', ratio=1. / 16),
- position='after_conv4')
- ]
- TridentBottleneck(
- *trident_build_config, inplanes=64, planes=16, plugins=plugins)
-
- with pytest.raises(AssertionError):
- # Need to specify different postfix to avoid duplicate plugin name
- plugins = [
- dict(
- cfg=dict(type='ContextBlock', ratio=1. / 16),
- position='after_conv3'),
- dict(
- cfg=dict(type='ContextBlock', ratio=1. / 16),
- position='after_conv3')
- ]
- TridentBottleneck(
- *trident_build_config, inplanes=64, planes=16, plugins=plugins)
-
- with pytest.raises(KeyError):
- # Plugin type is not supported
- plugins = [dict(cfg=dict(type='WrongPlugin'), position='after_conv3')]
- TridentBottleneck(
- *trident_build_config, inplanes=64, planes=16, plugins=plugins)
-
- # Test Bottleneck with checkpoint forward
- block = TridentBottleneck(
- *trident_build_config, inplanes=64, planes=16, with_cp=True)
- assert block.with_cp
- x = torch.randn(1, 64, 56, 56)
- x_out = block(x)
- assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
-
- # Test Bottleneck style
- block = TridentBottleneck(
- *trident_build_config,
- inplanes=64,
- planes=64,
- stride=2,
- style='pytorch')
- assert block.conv1.stride == (1, 1)
- assert block.conv2.stride == (2, 2)
- block = TridentBottleneck(
- *trident_build_config, inplanes=64, planes=64, stride=2, style='caffe')
- assert block.conv1.stride == (2, 2)
- assert block.conv2.stride == (1, 1)
-
- # Test Bottleneck forward
- block = TridentBottleneck(*trident_build_config, inplanes=64, planes=16)
- x = torch.randn(1, 64, 56, 56)
- x_out = block(x)
- assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
-
- # Test Bottleneck with 1 ContextBlock after conv3
- plugins = [
- dict(
- cfg=dict(type='ContextBlock', ratio=1. / 16),
- position='after_conv3')
- ]
- block = TridentBottleneck(
- *trident_build_config, inplanes=64, planes=16, plugins=plugins)
- assert block.context_block.in_channels == 64
- x = torch.randn(1, 64, 56, 56)
- x_out = block(x)
- assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
-
- # Test Bottleneck with 1 GeneralizedAttention after conv2
- plugins = [
- dict(
- cfg=dict(
- type='GeneralizedAttention',
- spatial_range=-1,
- num_heads=8,
- attention_type='0010',
- kv_stride=2),
- position='after_conv2')
- ]
- block = TridentBottleneck(
- *trident_build_config, inplanes=64, planes=16, plugins=plugins)
- assert block.gen_attention_block.in_channels == 16
- x = torch.randn(1, 64, 56, 56)
- x_out = block(x)
- assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
-
- # Test Bottleneck with 1 GeneralizedAttention after conv2, 1 NonLocal2D
- # after conv2, 1 ContextBlock after conv3
- plugins = [
- dict(
- cfg=dict(
- type='GeneralizedAttention',
- spatial_range=-1,
- num_heads=8,
- attention_type='0010',
- kv_stride=2),
- position='after_conv2'),
- dict(cfg=dict(type='NonLocal2d'), position='after_conv2'),
- dict(
- cfg=dict(type='ContextBlock', ratio=1. / 16),
- position='after_conv3')
- ]
- block = TridentBottleneck(
- *trident_build_config, inplanes=64, planes=16, plugins=plugins)
- assert block.gen_attention_block.in_channels == 16
- assert block.nonlocal_block.in_channels == 16
- assert block.context_block.in_channels == 64
- x = torch.randn(1, 64, 56, 56)
- x_out = block(x)
- assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
-
- # Test Bottleneck with 1 ContextBlock after conv2, 2 ContextBlock after
- # conv3
- plugins = [
- dict(
- cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=1),
- position='after_conv2'),
- dict(
- cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=2),
- position='after_conv3'),
- dict(
- cfg=dict(type='ContextBlock', ratio=1. / 16, postfix=3),
- position='after_conv3')
- ]
- block = TridentBottleneck(
- *trident_build_config, inplanes=64, planes=16, plugins=plugins)
- assert block.context_block1.in_channels == 16
- assert block.context_block2.in_channels == 64
- assert block.context_block3.in_channels == 64
- x = torch.randn(1, 64, 56, 56)
- x_out = block(x)
- assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
-
-
- def test_trident_resnet_backbone():
- tridentresnet_config = dict(
- num_branch=3,
- test_branch_idx=1,
- strides=(1, 2, 2),
- dilations=(1, 1, 1),
- trident_dilations=(1, 2, 3),
- out_indices=(2, ),
- )
- """Test tridentresnet backbone."""
- with pytest.raises(AssertionError):
- # TridentResNet depth should be in [50, 101, 152]
- TridentResNet(18, **tridentresnet_config)
-
- with pytest.raises(AssertionError):
- # In TridentResNet: num_stages == 3
- TridentResNet(50, num_stages=4, **tridentresnet_config)
-
- model = TridentResNet(50, num_stages=3, **tridentresnet_config)
- model.train()
-
- imgs = torch.randn(1, 3, 32, 32)
- feat = model(imgs)
- assert len(feat) == 1
- assert feat[0].shape == torch.Size([3, 1024, 2, 2])
|