From e57b8e4fd3f8b1a56d8011761c778611559da39b Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 21 Jun 2019 11:24:42 +0800 Subject: [PATCH] =?UTF-8?q?seq=5Flen=5Fto=5Fmask=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E5=A4=B1=E8=B4=A5=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/core/test_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/core/test_utils.py b/test/core/test_utils.py index a3e8bdf6..363d5fa1 100644 --- a/test/core/test_utils.py +++ b/test/core/test_utils.py @@ -240,8 +240,7 @@ class TestSeqLenToMask(unittest.TestCase): # 3. pad到指定长度 seq_len = np.random.randint(1, 10, size=(10,)) mask = seq_len_to_mask(seq_len, 100) - self.assertEqual(100, mask.size(1)) - + self.assertEqual(100, mask.shape[1]) def test_pytorch_seq_len(self): # 1. 随机测试