Browse Source

Update transformer.py

tags/v0.4.10
ZikaiGuo GitHub 5 years ago
parent
commit
1caa83d0ca
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions
  1. +2
    -0
      fastNLP/modules/encoder/transformer.py

+ 2
- 0
fastNLP/modules/encoder/transformer.py View File

@@ -40,6 +40,8 @@ class TransformerEncoder(nn.Module):
:param seq_mask: [batch, seq_len]
:return: [batch, seq_len, model_size]
"""
if seq_mask is None: # 防止后续乘法时出错
seq_mask = 1
input = self.norm1(input)
attention = self.atte(input, input, input, atte_mask_out)
input = input + self.dropout(attention)


Loading…
Cancel
Save