From 82b5726686dcbac9f9a2032537f53c3eb77f7698 Mon Sep 17 00:00:00 2001 From: yunfan Date: Sat, 24 Aug 2019 13:59:30 +0800 Subject: [PATCH] update transformer --- fastNLP/modules/encoder/transformer.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index ce9172d5..70b82bde 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -32,9 +32,10 @@ class TransformerEncoder(nn.Module): self.norm1 = nn.LayerNorm(model_size) self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), nn.ReLU(), - nn.Linear(inner_size, model_size), - TimestepDropout(dropout), ) + nn.Dropout(dropout), + nn.Linear(inner_size, model_size)) self.norm2 = nn.LayerNorm(model_size) + self.dropout = nn.Dropout(dropout) def forward(self, input, seq_mask=None, atte_mask_out=None): """ @@ -43,17 +44,20 @@ class TransformerEncoder(nn.Module): :param seq_mask: [batch, seq_len] :return: [batch, seq_len, model_size] """ + input = self.norm1(input) attention = self.atte(input, input, input, atte_mask_out) - norm_atte = self.norm1(attention + input) - attention *= seq_mask - output = self.ffn(norm_atte) - output = self.norm2(output + norm_atte) - output *= seq_mask + input = input + self.dropout(attention) + # attention *= seq_mask + input = self.norm2(input) + output = self.ffn(input) + input = input + self.dropout(output) + # output *= seq_mask return output def __init__(self, num_layers, **kargs): super(TransformerEncoder, self).__init__() self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(kargs['model_size']) def forward(self, x, seq_mask=None): """ @@ -70,4 +74,4 @@ class TransformerEncoder(nn.Module): seq_mask = seq_mask[:, :, None] for layer in self.layers: output = layer(output, seq_mask, atte_mask_out) - return output + return self.norm(output)