Browse Source

seq_len_to_mask修复测试失败的问题

tags/v0.4.10
yh_cc 5 years ago
parent
commit
e57b8e4fd3
1 changed files with 1 additions and 2 deletions
  1. +1
    -2
      test/core/test_utils.py

+ 1
- 2
test/core/test_utils.py View File

@@ -240,8 +240,7 @@ class TestSeqLenToMask(unittest.TestCase):
# 3. pad到指定长度 # 3. pad到指定长度
seq_len = np.random.randint(1, 10, size=(10,)) seq_len = np.random.randint(1, 10, size=(10,))
mask = seq_len_to_mask(seq_len, 100) mask = seq_len_to_mask(seq_len, 100)
self.assertEqual(100, mask.size(1))

self.assertEqual(100, mask.shape[1])


def test_pytorch_seq_len(self): def test_pytorch_seq_len(self):
# 1. 随机测试 # 1. 随机测试


Loading…
Cancel
Save