Browse Source

增加seq_len_to_mask对多卡场景的支持

tags/v0.4.10
yh 6 years ago
parent
commit
76e2330a2e
1 changed files with 5 additions and 3 deletions
  1. +5
    -3
      fastNLP/core/utils.py

+ 5
- 3
fastNLP/core/utils.py View File

@@ -643,7 +643,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
warnings.warn(message=_unused_warn)


def seq_len_to_mask(seq_len):
def seq_len_to_mask(seq_len, max_len=None):
"""

将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。
@@ -661,18 +661,20 @@ def seq_len_to_mask(seq_len):
(14, 15)

:param np.ndarray,torch.LongTensor seq_len: shape将是(B,)
:param int max_len: 将长度pad到这个长度. 默认使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。
:return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8
"""
if isinstance(seq_len, np.ndarray):
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
max_len = int(seq_len.max())
max_len = max(max_len, int(seq_len.max())) if max_len else int(seq_len.max())
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
mask = broad_cast_seq_len < seq_len.reshape(-1, 1)
elif isinstance(seq_len, torch.Tensor):
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}."
batch_size = seq_len.size(0)
max_len = seq_len.max().long()
max_len = max(max_len, seq_len.max().long()) if max_len else seq_len.max().long()
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len)
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
else:


Loading…
Cancel
Save