diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 5cdc77c9..69c5fdf6 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -1,5 +1,6 @@ import torch - +from torch import nn +import math from fastNLP.modules.utils import mask_softmax @@ -17,3 +18,44 @@ class Attention(torch.nn.Module): def _atten_forward(self, query, memory): raise NotImplementedError + +class DotAtte(nn.Module): + def __init__(self, key_size, value_size): + super(DotAtte, self).__init__() + self.key_size = key_size + self.value_size = value_size + self.scale = math.sqrt(key_size) + + def forward(self, Q, K, V, seq_mask=None): + """ + + :param Q: [batch, seq_len, key_size] + :param K: [batch, seq_len, key_size] + :param V: [batch, seq_len, value_size] + :param seq_mask: [batch, seq_len] + """ + output = torch.matmul(Q, K.transpose(1, 2)) / self.scale + if seq_mask is not None: + output.masked_fill_(seq_mask.lt(1), -float('inf')) + output = nn.functional.softmax(output, dim=2) + return torch.matmul(output, V) + +class MultiHeadAtte(nn.Module): + def __init__(self, input_size, output_size, key_size, value_size, num_atte): + super(MultiHeadAtte, self).__init__() + self.in_linear = nn.ModuleList() + for i in range(num_atte * 3): + out_feat = key_size if (i % 3) != 2 else value_size + self.in_linear.append(nn.Linear(input_size, out_feat)) + self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) + self.out_linear = nn.Linear(value_size * num_atte, output_size) + + def forward(self, Q, K, V, seq_mask=None): + heads = [] + for i in range(len(self.attes)): + j = i * 3 + qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) + headi = self.attes[i](qi, ki, vi, seq_mask) + heads.append(headi) + output = torch.cat(heads, dim=2) + return self.out_linear(output) diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py new file mode 100644 index 00000000..46badcfe --- /dev/null +++ b/fastNLP/modules/encoder/transformer.py @@ -0,0 +1,32 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from ..aggregator.attention import MultiHeadAtte +from ..other_modules import LayerNormalization + +class TransformerEncoder(nn.Module): + class SubLayer(nn.Module): + def __init__(self, input_size, output_size, key_size, value_size, num_atte): + super(TransformerEncoder.SubLayer, self).__init__() + self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) + self.norm1 = LayerNormalization(output_size) + self.ffn = nn.Sequential(nn.Linear(output_size, output_size), + nn.ReLU(), + nn.Linear(output_size, output_size)) + self.norm2 = LayerNormalization(output_size) + + def forward(self, input, seq_mask): + attention = self.atte(input) + norm_atte = self.norm1(attention + input) + output = self.ffn(norm_atte) + return self.norm2(output + norm_atte) + + def __init__(self, num_layers, **kargs): + super(TransformerEncoder, self).__init__() + self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) + + def forward(self, x, seq_mask=None): + return self.layers(x, seq_mask) + + diff --git a/fastNLP/modules/other_modules.py b/fastNLP/modules/other_modules.py index ea1423be..5cd10e7e 100644 --- a/fastNLP/modules/other_modules.py +++ b/fastNLP/modules/other_modules.py @@ -31,12 +31,12 @@ class GroupNorm(nn.Module): class LayerNormalization(nn.Module): """ Layer normalization module """ - def __init__(self, d_hid, eps=1e-3): + def __init__(self, layer_size, eps=1e-3): super(LayerNormalization, self).__init__() self.eps = eps - self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) - self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) + self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True)) + self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True)) def forward(self, z): if z.size(1) == 1: @@ -44,9 +44,8 @@ class LayerNormalization(nn.Module): mu = torch.mean(z, keepdim=True, dim=-1) sigma = torch.std(z, keepdim=True, dim=-1) - ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) - ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out) - + ln_out = (z - mu) / (sigma + self.eps) + ln_out = ln_out * self.a_2 + self.b_2 return ln_out diff --git a/reproduction/Biaffine_parser/cfg.cfg b/reproduction/Biaffine_parser/cfg.cfg index 946e4c51..84e0f288 100644 --- a/reproduction/Biaffine_parser/cfg.cfg +++ b/reproduction/Biaffine_parser/cfg.cfg @@ -1,5 +1,5 @@ [train] -epochs = 50 +epochs = -1 batch_size = 16 pickle_path = "./save/" validate = true