|
- # Copyright (c) OpenMMLab. All rights reserved.
- import pytest
- import torch
-
- from mmdet.models.backbones.hrnet import HRModule, HRNet
- from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
-
-
- @pytest.mark.parametrize('block', [BasicBlock, Bottleneck])
- def test_hrmodule(block):
- # Test multiscale forward
- num_channles = (32, 64)
- in_channels = [c * block.expansion for c in num_channles]
- hrmodule = HRModule(
- num_branches=2,
- blocks=block,
- in_channels=in_channels,
- num_blocks=(4, 4),
- num_channels=num_channles,
- )
-
- feats = [
- torch.randn(1, in_channels[0], 64, 64),
- torch.randn(1, in_channels[1], 32, 32)
- ]
- feats = hrmodule(feats)
-
- assert len(feats) == 2
- assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
- assert feats[1].shape == torch.Size([1, in_channels[1], 32, 32])
-
- # Test single scale forward
- num_channles = (32, 64)
- in_channels = [c * block.expansion for c in num_channles]
- hrmodule = HRModule(
- num_branches=2,
- blocks=block,
- in_channels=in_channels,
- num_blocks=(4, 4),
- num_channels=num_channles,
- multiscale_output=False,
- )
-
- feats = [
- torch.randn(1, in_channels[0], 64, 64),
- torch.randn(1, in_channels[1], 32, 32)
- ]
- feats = hrmodule(feats)
-
- assert len(feats) == 1
- assert feats[0].shape == torch.Size([1, in_channels[0], 64, 64])
-
-
- def test_hrnet_backbone():
- # only have 3 stages
- extra = dict(
- stage1=dict(
- num_modules=1,
- num_branches=1,
- block='BOTTLENECK',
- num_blocks=(4, ),
- num_channels=(64, )),
- stage2=dict(
- num_modules=1,
- num_branches=2,
- block='BASIC',
- num_blocks=(4, 4),
- num_channels=(32, 64)),
- stage3=dict(
- num_modules=4,
- num_branches=3,
- block='BASIC',
- num_blocks=(4, 4, 4),
- num_channels=(32, 64, 128)))
-
- with pytest.raises(AssertionError):
- # HRNet now only support 4 stages
- HRNet(extra=extra)
- extra['stage4'] = dict(
- num_modules=3,
- num_branches=3, # should be 4
- block='BASIC',
- num_blocks=(4, 4, 4, 4),
- num_channels=(32, 64, 128, 256))
-
- with pytest.raises(AssertionError):
- # len(num_blocks) should equal num_branches
- HRNet(extra=extra)
-
- extra['stage4']['num_branches'] = 4
-
- # Test hrnetv2p_w32
- model = HRNet(extra=extra)
- model.init_weights()
- model.train()
-
- imgs = torch.randn(1, 3, 256, 256)
- feats = model(imgs)
- assert len(feats) == 4
- assert feats[0].shape == torch.Size([1, 32, 64, 64])
- assert feats[3].shape == torch.Size([1, 256, 8, 8])
-
- # Test single scale output
- model = HRNet(extra=extra, multiscale_output=False)
- model.init_weights()
- model.train()
-
- imgs = torch.randn(1, 3, 256, 256)
- feats = model(imgs)
- assert len(feats) == 1
- assert feats[0].shape == torch.Size([1, 32, 64, 64])
|