Browse Source

update transformer

tags/v0.4.10
yunfan 6 years ago
parent
commit
82b5726686
1 changed files with 12 additions and 8 deletions
  1. +12
    -8
      fastNLP/modules/encoder/transformer.py

+ 12
- 8
fastNLP/modules/encoder/transformer.py View File

@@ -32,9 +32,10 @@ class TransformerEncoder(nn.Module):
self.norm1 = nn.LayerNorm(model_size)
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size),
nn.ReLU(),
nn.Linear(inner_size, model_size),
TimestepDropout(dropout), )
nn.Dropout(dropout),
nn.Linear(inner_size, model_size))
self.norm2 = nn.LayerNorm(model_size)
self.dropout = nn.Dropout(dropout)

def forward(self, input, seq_mask=None, atte_mask_out=None):
"""
@@ -43,17 +44,20 @@ class TransformerEncoder(nn.Module):
:param seq_mask: [batch, seq_len]
:return: [batch, seq_len, model_size]
"""
input = self.norm1(input)
attention = self.atte(input, input, input, atte_mask_out)
norm_atte = self.norm1(attention + input)
attention *= seq_mask
output = self.ffn(norm_atte)
output = self.norm2(output + norm_atte)
output *= seq_mask
input = input + self.dropout(attention)
# attention *= seq_mask
input = self.norm2(input)
output = self.ffn(input)
input = input + self.dropout(output)
# output *= seq_mask
return output

def __init__(self, num_layers, **kargs):
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)])
self.norm = nn.LayerNorm(kargs['model_size'])

def forward(self, x, seq_mask=None):
"""
@@ -70,4 +74,4 @@ class TransformerEncoder(nn.Module):
seq_mask = seq_mask[:, :, None]
for layer in self.layers:
output = layer(output, seq_mask, atte_mask_out)
return output
return self.norm(output)

Loading…
Cancel
Save