Browse Source

[update] transformer

tags/v0.4.10
yunfan 5 years ago
parent
commit
bbda73c14f
2 changed files with 25 additions and 31 deletions
  1. +17
    -22
      fastNLP/modules/encoder/attention.py
  2. +8
    -9
      fastNLP/modules/encoder/transformer.py

+ 17
- 22
fastNLP/modules/encoder/attention.py View File

@@ -30,14 +30,14 @@ class DotAttention(nn.Module):
def forward(self, Q, K, V, mask_out=None): def forward(self, Q, K, V, mask_out=None):
""" """


:param Q: [batch, seq_len_q, key_size]
:param K: [batch, seq_len_k, key_size]
:param V: [batch, seq_len_k, value_size]
:param mask_out: [batch, 1, seq_len] or [batch, seq_len_q, seq_len_k]
:param Q: [..., seq_len_q, key_size]
:param K: [..., seq_len_k, key_size]
:param V: [..., seq_len_k, value_size]
:param mask_out: [..., 1, seq_len] or [..., seq_len_q, seq_len_k]
""" """
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale
output = torch.matmul(Q, K.transpose(-1, -2)) / self.scale
if mask_out is not None: if mask_out is not None:
output.masked_fill_(mask_out, -1e18)
output.masked_fill_(mask_out, -1e9)
output = self.softmax(output) output = self.softmax(output)
output = self.drop(output) output = self.drop(output)
return torch.matmul(output, V) return torch.matmul(output, V)
@@ -65,17 +65,16 @@ class MultiHeadAttention(nn.Module):
self.q_in = nn.Linear(input_size, in_size) self.q_in = nn.Linear(input_size, in_size)
self.k_in = nn.Linear(input_size, in_size) self.k_in = nn.Linear(input_size, in_size)
self.v_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size)
# follow the paper, do not apply dropout within dot-product
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout)
self.out = nn.Linear(value_size * num_head, input_size) self.out = nn.Linear(value_size * num_head, input_size)
self.reset_parameters() self.reset_parameters()


def reset_parameters(self): def reset_parameters(self):
sqrt = math.sqrt sqrt = math.sqrt
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size)))
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size)))
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size)))
nn.init.xavier_normal_(self.out.weight)
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(1.0 / self.input_size))
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(1.0 / self.input_size))
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(1.0 / self.input_size))
nn.init.normal_(self.out.weight, mean=0, std=sqrt(1.0 / self.input_size))


def forward(self, Q, K, V, atte_mask_out=None): def forward(self, Q, K, V, atte_mask_out=None):
""" """
@@ -89,20 +88,16 @@ class MultiHeadAttention(nn.Module):
sk = K.size(1) sk = K.size(1)
d_k, d_v, n_head = self.key_size, self.value_size, self.num_head d_k, d_v, n_head = self.key_size, self.value_size, self.num_head
# input linear # input linear
q = self.q_in(Q).view(batch, sq, n_head, d_k)
k = self.k_in(K).view(batch, sk, n_head, d_k)
v = self.v_in(V).view(batch, sk, n_head, d_v)

# transpose q, k and v to do batch attention
q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k)
k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k)
v = v.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_v)
q = self.q_in(Q).view(batch, sq, n_head, d_k).transpose(1, 2)
k = self.k_in(K).view(batch, sk, n_head, d_k).transpose(1, 2)
v = self.v_in(V).view(batch, sk, n_head, d_v).transpose(1, 2)

if atte_mask_out is not None: if atte_mask_out is not None:
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1)
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v)
atte_mask_out = atte_mask_out[:,None,:,:] # [bsz,1,1,len]
atte = self.attention(q, k, v, atte_mask_out).view(batch, n_head, sq, d_v)


# concat all heads, do output linear # concat all heads, do output linear
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1)
atte = atte.transpose(1, 2).contiguous().view(batch, sq, -1)
output = self.out(atte) output = self.out(atte)
return output return output




+ 8
- 9
fastNLP/modules/encoder/transformer.py View File

@@ -5,8 +5,7 @@ __all__ = [
] ]
from torch import nn from torch import nn


from fastNLP.modules.encoder.attention import MultiHeadAttention
from ..dropout import TimestepDropout
from .attention import MultiHeadAttention




class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
@@ -29,12 +28,12 @@ class TransformerEncoder(nn.Module):
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1):
super(TransformerEncoder.SubLayer, self).__init__() super(TransformerEncoder.SubLayer, self).__init__()
self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout)
self.norm1 = nn.LayerNorm(model_size)
self.norm1 = nn.LayerNorm(model_size, eps=1e-6)
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), self.ffn = nn.Sequential(nn.Linear(model_size, inner_size),
nn.ReLU(), nn.ReLU(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(inner_size, model_size)) nn.Linear(inner_size, model_size))
self.norm2 = nn.LayerNorm(model_size)
self.norm2 = nn.LayerNorm(model_size, eps=1e-6)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)


def forward(self, input, seq_mask=None, atte_mask_out=None): def forward(self, input, seq_mask=None, atte_mask_out=None):
@@ -47,17 +46,17 @@ class TransformerEncoder(nn.Module):
input = self.norm1(input) input = self.norm1(input)
attention = self.atte(input, input, input, atte_mask_out) attention = self.atte(input, input, input, atte_mask_out)
input = input + self.dropout(attention) input = input + self.dropout(attention)
# attention *= seq_mask
attention *= seq_mask
input = self.norm2(input) input = self.norm2(input)
output = self.ffn(input) output = self.ffn(input)
input = input + self.dropout(output) input = input + self.dropout(output)
# output *= seq_mask
return output
input *= seq_mask
return input


def __init__(self, num_layers, **kargs): def __init__(self, num_layers, **kargs):
super(TransformerEncoder, self).__init__() super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)])
self.norm = nn.LayerNorm(kargs['model_size'])
self.norm = nn.LayerNorm(kargs['model_size'], eps=1e-6)


def forward(self, x, seq_mask=None): def forward(self, x, seq_mask=None):
""" """
@@ -70,7 +69,7 @@ class TransformerEncoder(nn.Module):
if seq_mask is None: if seq_mask is None:
atte_mask_out = None atte_mask_out = None
else: else:
atte_mask_out = (seq_mask < 1)[:, None, :]
atte_mask_out = (seq_mask == 0)[:, None, :]
seq_mask = seq_mask[:, :, None] seq_mask = seq_mask[:, :, None]
for layer in self.layers: for layer in self.layers:
output = layer(output, seq_mask, atte_mask_out) output = layer(output, seq_mask, atte_mask_out)


Loading…
Cancel
Save