From 76e2330a2ee54db35457bbe65fdc2db2c9680bb3 Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 19 Jun 2019 23:13:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0seq=5Flen=5Fto=5Fmask?= =?UTF-8?q?=E5=AF=B9=E5=A4=9A=E5=8D=A1=E5=9C=BA=E6=99=AF=E7=9A=84=E6=94=AF?= =?UTF-8?q?=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 1eb2b70e..df3c45cb 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -643,7 +643,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): warnings.warn(message=_unused_warn) -def seq_len_to_mask(seq_len): +def seq_len_to_mask(seq_len, max_len=None): """ 将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 @@ -661,18 +661,20 @@ def seq_len_to_mask(seq_len): (14, 15) :param np.ndarray,torch.LongTensor seq_len: shape将是(B,) + :param int max_len: 将长度pad到这个长度. 默认使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 + 区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 :return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8 """ if isinstance(seq_len, np.ndarray): assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." - max_len = int(seq_len.max()) + max_len = max(max_len, int(seq_len.max())) if max_len else int(seq_len.max()) broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) mask = broad_cast_seq_len < seq_len.reshape(-1, 1) elif isinstance(seq_len, torch.Tensor): assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." batch_size = seq_len.size(0) - max_len = seq_len.max().long() + max_len = max(max_len, seq_len.max().long()) if max_len else seq_len.max().long() broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) else: