|
- # Copyright (c) OpenMMLab. All rights reserved.
- import pytest
- import torch
-
- from mmdet.models.utils import (LearnedPositionalEncoding,
- SinePositionalEncoding)
-
-
- def test_sine_positional_encoding(num_feats=16, batch_size=2):
- # test invalid type of scale
- with pytest.raises(AssertionError):
- module = SinePositionalEncoding(
- num_feats, scale=(3., ), normalize=True)
-
- module = SinePositionalEncoding(num_feats)
- h, w = 10, 6
- mask = (torch.rand(batch_size, h, w) > 0.5).to(torch.int)
- assert not module.normalize
- out = module(mask)
- assert out.shape == (batch_size, num_feats * 2, h, w)
-
- # set normalize
- module = SinePositionalEncoding(num_feats, normalize=True)
- assert module.normalize
- out = module(mask)
- assert out.shape == (batch_size, num_feats * 2, h, w)
-
-
- def test_learned_positional_encoding(num_feats=16,
- row_num_embed=10,
- col_num_embed=10,
- batch_size=2):
- module = LearnedPositionalEncoding(num_feats, row_num_embed, col_num_embed)
- assert module.row_embed.weight.shape == (row_num_embed, num_feats)
- assert module.col_embed.weight.shape == (col_num_embed, num_feats)
- h, w = 10, 6
- mask = torch.rand(batch_size, h, w) > 0.5
- out = module(mask)
- assert out.shape == (batch_size, num_feats * 2, h, w)
|