diff --git a/test/core/test_utils.py b/test/core/test_utils.py index a3e8bdf6..363d5fa1 100644 --- a/test/core/test_utils.py +++ b/test/core/test_utils.py @@ -240,8 +240,7 @@ class TestSeqLenToMask(unittest.TestCase): # 3. pad到指定长度 seq_len = np.random.randint(1, 10, size=(10,)) mask = seq_len_to_mask(seq_len, 100) - self.assertEqual(100, mask.size(1)) - + self.assertEqual(100, mask.shape[1]) def test_pytorch_seq_len(self): # 1. 随机测试