From 1caa83d0cafbb5df6470627fab8dea86b56df36a Mon Sep 17 00:00:00 2001 From: ZikaiGuo <634500098@qq.com> Date: Sun, 8 Sep 2019 14:54:31 +0200 Subject: [PATCH] Update transformer.py --- fastNLP/modules/encoder/transformer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index d29a10c3..3d97c306 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -40,6 +40,8 @@ class TransformerEncoder(nn.Module): :param seq_mask: [batch, seq_len] :return: [batch, seq_len, model_size] """ + if seq_mask is None: # 防止后续乘法时出错 + seq_mask = 1 input = self.norm1(input) attention = self.atte(input, input, input, atte_mask_out) input = input + self.dropout(attention)