diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index df3c45cb..d26df966 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -659,22 +659,26 @@ def seq_len_to_mask(seq_len, max_len=None): >>> mask = seq_len_to_mask(seq_len) >>> print(mask.shape) (14, 15) + >>> seq_len = torch.arange(2, 16) + >>> mask = seq_len_to_mask(seq_len, max_len=100) + >>>print(mask.size()) + torch.Size([14, 100]) :param np.ndarray,torch.LongTensor seq_len: shape将是(B,) - :param int max_len: 将长度pad到这个长度. 默认使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 + :param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 :return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8 """ if isinstance(seq_len, np.ndarray): assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." - max_len = max(max_len, int(seq_len.max())) if max_len else int(seq_len.max()) + max_len = int(max_len) if max_len else int(seq_len.max()) broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) mask = broad_cast_seq_len < seq_len.reshape(-1, 1) elif isinstance(seq_len, torch.Tensor): assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." batch_size = seq_len.size(0) - max_len = max(max_len, seq_len.max().long()) if max_len else seq_len.max().long() + max_len = int(max_len) if max_len else seq_len.max().long() broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) else: diff --git a/test/core/test_utils.py b/test/core/test_utils.py index e3e019c6..a3e8bdf6 100644 --- a/test/core/test_utils.py +++ b/test/core/test_utils.py @@ -237,6 +237,11 @@ class TestSeqLenToMask(unittest.TestCase): with self.assertRaises(AssertionError): mask = seq_len_to_mask(seq_len) + # 3. pad到指定长度 + seq_len = np.random.randint(1, 10, size=(10,)) + mask = seq_len_to_mask(seq_len, 100) + self.assertEqual(100, mask.size(1)) + def test_pytorch_seq_len(self): # 1. 随机测试 @@ -250,3 +255,8 @@ class TestSeqLenToMask(unittest.TestCase): seq_len = torch.randn(3, 4) with self.assertRaises(AssertionError): mask = seq_len_to_mask(seq_len) + + # 3. pad到指定长度 + seq_len = torch.randint(1, 10, size=(10, )) + mask = seq_len_to_mask(seq_len, 100) + self.assertEqual(100, mask.size(1)) \ No newline at end of file