Browse Source

Merge pull request #217 from ZikaiGuo/dev0.5.0

[bugfix]Update transformer.py
tags/v0.4.10
Yige Xu GitHub 5 years ago
parent
commit
2fc14790c7
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions
  1. +2
    -0
      fastNLP/modules/encoder/transformer.py

+ 2
- 0
fastNLP/modules/encoder/transformer.py View File

@@ -40,6 +40,8 @@ class TransformerEncoder(nn.Module):
:param seq_mask: [batch, seq_len] :param seq_mask: [batch, seq_len]
:return: [batch, seq_len, model_size] :return: [batch, seq_len, model_size]
""" """
if seq_mask is None: # 防止后续乘法时出错
seq_mask = 1
input = self.norm1(input) input = self.norm1(input)
attention = self.atte(input, input, input, atte_mask_out) attention = self.atte(input, input, input, atte_mask_out)
input = input + self.dropout(attention) input = input + self.dropout(attention)


Loading…
Cancel
Save