diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index d29a10c3..3d97c306 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -40,6 +40,8 @@ class TransformerEncoder(nn.Module): :param seq_mask: [batch, seq_len] :return: [batch, seq_len, model_size] """ + if seq_mask is None: # 防止后续乘法时出错 + seq_mask = 1 input = self.norm1(input) attention = self.atte(input, input, input, atte_mask_out) input = input + self.dropout(attention)