|
|
@@ -40,6 +40,8 @@ class TransformerEncoder(nn.Module): |
|
|
|
:param seq_mask: [batch, seq_len] |
|
|
|
:return: [batch, seq_len, model_size] |
|
|
|
""" |
|
|
|
if seq_mask is None: # 防止后续乘法时出错 |
|
|
|
seq_mask = 1 |
|
|
|
input = self.norm1(input) |
|
|
|
attention = self.atte(input, input, input, atte_mask_out) |
|
|
|
input = input + self.dropout(attention) |
|
|
|