From bbda73c14f2352583f1a89bafdd1ff7471543cc4 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 30 Aug 2019 21:48:00 +0800 Subject: [PATCH] [update] transformer --- fastNLP/modules/encoder/attention.py | 39 +++++++++++--------------- fastNLP/modules/encoder/transformer.py | 17 ++++++----- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/fastNLP/modules/encoder/attention.py b/fastNLP/modules/encoder/attention.py index 02bd078a..6a973864 100644 --- a/fastNLP/modules/encoder/attention.py +++ b/fastNLP/modules/encoder/attention.py @@ -30,14 +30,14 @@ class DotAttention(nn.Module): def forward(self, Q, K, V, mask_out=None): """ - :param Q: [batch, seq_len_q, key_size] - :param K: [batch, seq_len_k, key_size] - :param V: [batch, seq_len_k, value_size] - :param mask_out: [batch, 1, seq_len] or [batch, seq_len_q, seq_len_k] + :param Q: [..., seq_len_q, key_size] + :param K: [..., seq_len_k, key_size] + :param V: [..., seq_len_k, value_size] + :param mask_out: [..., 1, seq_len] or [..., seq_len_q, seq_len_k] """ - output = torch.matmul(Q, K.transpose(1, 2)) / self.scale + output = torch.matmul(Q, K.transpose(-1, -2)) / self.scale if mask_out is not None: - output.masked_fill_(mask_out, -1e18) + output.masked_fill_(mask_out, -1e9) output = self.softmax(output) output = self.drop(output) return torch.matmul(output, V) @@ -65,17 +65,16 @@ class MultiHeadAttention(nn.Module): self.q_in = nn.Linear(input_size, in_size) self.k_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size) - # follow the paper, do not apply dropout within dot-product self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) self.out = nn.Linear(value_size * num_head, input_size) self.reset_parameters() def reset_parameters(self): sqrt = math.sqrt - nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) - nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size))) - nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size))) - nn.init.xavier_normal_(self.out.weight) + nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) + nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) + nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) + nn.init.normal_(self.out.weight, mean=0, std=sqrt(1.0 / self.input_size)) def forward(self, Q, K, V, atte_mask_out=None): """ @@ -89,20 +88,16 @@ class MultiHeadAttention(nn.Module): sk = K.size(1) d_k, d_v, n_head = self.key_size, self.value_size, self.num_head # input linear - q = self.q_in(Q).view(batch, sq, n_head, d_k) - k = self.k_in(K).view(batch, sk, n_head, d_k) - v = self.v_in(V).view(batch, sk, n_head, d_v) - - # transpose q, k and v to do batch attention - q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k) - k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k) - v = v.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_v) + q = self.q_in(Q).view(batch, sq, n_head, d_k).transpose(1, 2) + k = self.k_in(K).view(batch, sk, n_head, d_k).transpose(1, 2) + v = self.v_in(V).view(batch, sk, n_head, d_v).transpose(1, 2) + if atte_mask_out is not None: - atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) - atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v) + atte_mask_out = atte_mask_out[:,None,:,:] # [bsz,1,1,len] + atte = self.attention(q, k, v, atte_mask_out).view(batch, n_head, sq, d_v) # concat all heads, do output linear - atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) + atte = atte.transpose(1, 2).contiguous().view(batch, sq, -1) output = self.out(atte) return output diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index 70b82bde..d8a612a0 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -5,8 +5,7 @@ __all__ = [ ] from torch import nn -from fastNLP.modules.encoder.attention import MultiHeadAttention -from ..dropout import TimestepDropout +from .attention import MultiHeadAttention class TransformerEncoder(nn.Module): @@ -29,12 +28,12 @@ class TransformerEncoder(nn.Module): def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): super(TransformerEncoder.SubLayer, self).__init__() self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) - self.norm1 = nn.LayerNorm(model_size) + self.norm1 = nn.LayerNorm(model_size, eps=1e-6) self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), nn.ReLU(), nn.Dropout(dropout), nn.Linear(inner_size, model_size)) - self.norm2 = nn.LayerNorm(model_size) + self.norm2 = nn.LayerNorm(model_size, eps=1e-6) self.dropout = nn.Dropout(dropout) def forward(self, input, seq_mask=None, atte_mask_out=None): @@ -47,17 +46,17 @@ class TransformerEncoder(nn.Module): input = self.norm1(input) attention = self.atte(input, input, input, atte_mask_out) input = input + self.dropout(attention) - # attention *= seq_mask + attention *= seq_mask input = self.norm2(input) output = self.ffn(input) input = input + self.dropout(output) - # output *= seq_mask - return output + input *= seq_mask + return input def __init__(self, num_layers, **kargs): super(TransformerEncoder, self).__init__() self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) - self.norm = nn.LayerNorm(kargs['model_size']) + self.norm = nn.LayerNorm(kargs['model_size'], eps=1e-6) def forward(self, x, seq_mask=None): """ @@ -70,7 +69,7 @@ class TransformerEncoder(nn.Module): if seq_mask is None: atte_mask_out = None else: - atte_mask_out = (seq_mask < 1)[:, None, :] + atte_mask_out = (seq_mask == 0)[:, None, :] seq_mask = seq_mask[:, :, None] for layer in self.layers: output = layer(output, seq_mask, atte_mask_out)