|
- import pytest
- import torch
-
- from mmdet.core.post_processing import mask_matrix_nms
-
-
- def _create_mask(N, h, w):
- masks = torch.rand((N, h, w)) > 0.5
- labels = torch.rand(N)
- scores = torch.rand(N)
- return masks, labels, scores
-
-
- def test_nms_input_errors():
- with pytest.raises(AssertionError):
- mask_matrix_nms(
- torch.rand((10, 28, 28)), torch.rand(11), torch.rand(11))
- with pytest.raises(AssertionError):
- masks = torch.rand((10, 28, 28))
- mask_matrix_nms(
- masks,
- torch.rand(11),
- torch.rand(11),
- mask_area=masks.sum((1, 2)).float()[:8])
- with pytest.raises(NotImplementedError):
- mask_matrix_nms(
- torch.rand((10, 28, 28)),
- torch.rand(10),
- torch.rand(10),
- kernel='None')
- # test an empty results
- masks, labels, scores = _create_mask(0, 28, 28)
- score, label, mask, keep_ind = \
- mask_matrix_nms(masks, labels, scores)
- assert len(score) == len(label) == \
- len(mask) == len(keep_ind) == 0
-
- # do not use update_thr, nms_pre and max_num
- masks, labels, scores = _create_mask(1000, 28, 28)
- score, label, mask, keep_ind = \
- mask_matrix_nms(masks, labels, scores)
- assert len(score) == len(label) == \
- len(mask) == len(keep_ind) == 1000
- # only use nms_pre
- score, label, mask, keep_ind = \
- mask_matrix_nms(masks, labels, scores, nms_pre=500)
- assert len(score) == len(label) == \
- len(mask) == len(keep_ind) == 500
- # use max_num
- score, label, mask, keep_ind = \
- mask_matrix_nms(masks, labels, scores,
- nms_pre=500, max_num=100)
- assert len(score) == len(label) == \
- len(mask) == len(keep_ind) == 100
-
- masks, labels, _ = _create_mask(1, 28, 28)
- scores = torch.Tensor([1.0])
- masks = masks.expand(1000, 28, 28)
- labels = labels.expand(1000)
- scores = scores.expand(1000)
-
- # assert scores is decayed and update_thr is worked
- # if with the same mask, label, and all scores = 1
- # the first score will set to 1, others will decay.
- score, label, mask, keep_ind = \
- mask_matrix_nms(masks,
- labels,
- scores,
- nms_pre=500,
- max_num=100,
- kernel='gaussian',
- sigma=2.0,
- filter_thr=0.5)
- assert len(score) == 1
- assert score[0] == 1
|