|
- # Copyright (c) OpenMMLab. All rights reserved.
- import pytest
- import torch
-
- from mmdet.models.plugins import DropBlock
-
-
- def test_dropblock():
- feat = torch.rand(1, 1, 11, 11)
- drop_prob = 1.0
- dropblock = DropBlock(drop_prob, block_size=11, warmup_iters=0)
- out_feat = dropblock(feat)
- assert (out_feat == 0).all() and out_feat.shape == feat.shape
- drop_prob = 0.5
- dropblock = DropBlock(drop_prob, block_size=5, warmup_iters=0)
- out_feat = dropblock(feat)
- assert out_feat.shape == feat.shape
-
- # drop_prob must be (0,1]
- with pytest.raises(AssertionError):
- DropBlock(1.5, 3)
-
- # block_size cannot be an even number
- with pytest.raises(AssertionError):
- DropBlock(0.5, 2)
-
- # warmup_iters cannot be less than 0
- with pytest.raises(AssertionError):
- DropBlock(0.5, 3, -1)
|