Browse Source

add transformer

tags/v0.2.0
yunfan 6 years ago
parent
commit
830d223344
4 changed files with 81 additions and 8 deletions
  1. +43
    -1
      fastNLP/modules/aggregator/attention.py
  2. +32
    -0
      fastNLP/modules/encoder/transformer.py
  3. +5
    -6
      fastNLP/modules/other_modules.py
  4. +1
    -1
      reproduction/Biaffine_parser/cfg.cfg

+ 43
- 1
fastNLP/modules/aggregator/attention.py View File

@@ -1,5 +1,6 @@
import torch

from torch import nn
import math
from fastNLP.modules.utils import mask_softmax


@@ -17,3 +18,44 @@ class Attention(torch.nn.Module):

def _atten_forward(self, query, memory):
raise NotImplementedError

class DotAtte(nn.Module):
def __init__(self, key_size, value_size):
super(DotAtte, self).__init__()
self.key_size = key_size
self.value_size = value_size
self.scale = math.sqrt(key_size)

def forward(self, Q, K, V, seq_mask=None):
"""

:param Q: [batch, seq_len, key_size]
:param K: [batch, seq_len, key_size]
:param V: [batch, seq_len, value_size]
:param seq_mask: [batch, seq_len]
"""
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale
if seq_mask is not None:
output.masked_fill_(seq_mask.lt(1), -float('inf'))
output = nn.functional.softmax(output, dim=2)
return torch.matmul(output, V)

class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList()
for i in range(num_atte * 3):
out_feat = key_size if (i % 3) != 2 else value_size
self.in_linear.append(nn.Linear(input_size, out_feat))
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)])
self.out_linear = nn.Linear(value_size * num_atte, output_size)

def forward(self, Q, K, V, seq_mask=None):
heads = []
for i in range(len(self.attes)):
j = i * 3
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V)
headi = self.attes[i](qi, ki, vi, seq_mask)
heads.append(headi)
output = torch.cat(heads, dim=2)
return self.out_linear(output)

+ 32
- 0
fastNLP/modules/encoder/transformer.py View File

@@ -0,0 +1,32 @@
import torch
from torch import nn
import torch.nn.functional as F

from ..aggregator.attention import MultiHeadAtte
from ..other_modules import LayerNormalization

class TransformerEncoder(nn.Module):
class SubLayer(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
super(TransformerEncoder.SubLayer, self).__init__()
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte)
self.norm1 = LayerNormalization(output_size)
self.ffn = nn.Sequential(nn.Linear(output_size, output_size),
nn.ReLU(),
nn.Linear(output_size, output_size))
self.norm2 = LayerNormalization(output_size)

def forward(self, input, seq_mask):
attention = self.atte(input)
norm_atte = self.norm1(attention + input)
output = self.ffn(norm_atte)
return self.norm2(output + norm_atte)

def __init__(self, num_layers, **kargs):
super(TransformerEncoder, self).__init__()
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)])

def forward(self, x, seq_mask=None):
return self.layers(x, seq_mask)



+ 5
- 6
fastNLP/modules/other_modules.py View File

@@ -31,12 +31,12 @@ class GroupNorm(nn.Module):
class LayerNormalization(nn.Module):
""" Layer normalization module """

def __init__(self, d_hid, eps=1e-3):
def __init__(self, layer_size, eps=1e-3):
super(LayerNormalization, self).__init__()

self.eps = eps
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True))
self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True))

def forward(self, z):
if z.size(1) == 1:
@@ -44,9 +44,8 @@ class LayerNormalization(nn.Module):

mu = torch.mean(z, keepdim=True, dim=-1)
sigma = torch.std(z, keepdim=True, dim=-1)
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)

ln_out = (z - mu) / (sigma + self.eps)
ln_out = ln_out * self.a_2 + self.b_2
return ln_out




+ 1
- 1
reproduction/Biaffine_parser/cfg.cfg View File

@@ -1,5 +1,5 @@
[train]
epochs = 50
epochs = -1
batch_size = 16
pickle_path = "./save/"
validate = true


Loading…
Cancel
Save