Browse Source

修改seq_len_to_mask的jittor实现及测试例中的一处传参错误

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
a6cfc4086f
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      fastNLP/core/utils/seq_len_to_mask.py
  2. +1
    -1
      tests/core/utils/test_seq_len_to_mask.py

+ 1
- 1
fastNLP/core/utils/seq_len_to_mask.py View File

@@ -74,7 +74,7 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]=None):
if isinstance(seq_len, jittor.Var):
assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}."
batch_size = seq_len.shape[0]
broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1)
broad_cast_seq_len = jittor.arange(max_len).reshape(1, max_len).expand(batch_size, -1)
mask = broad_cast_seq_len < seq_len.unsqueeze(1)
return mask
except NameError as e:


+ 1
- 1
tests/core/utils/test_seq_len_to_mask.py View File

@@ -78,7 +78,7 @@ class TestSeqLenToMask:
mask = seq_len_to_mask(seq_len)

# 3. pad到指定长度
seq_len = paddle.randint(1, 10, size=(10,))
seq_len = paddle.randint(1, 10, shape=(10,))
mask = seq_len_to_mask(seq_len, 100)
assert 100 == mask.shape[1]



Loading…
Cancel
Save