| @@ -1,5 +1,6 @@ | |||
| import torch | |||
| from torch import nn | |||
| import math | |||
| from fastNLP.modules.utils import mask_softmax | |||
| @@ -17,3 +18,44 @@ class Attention(torch.nn.Module): | |||
| def _atten_forward(self, query, memory): | |||
| raise NotImplementedError | |||
| class DotAtte(nn.Module): | |||
| def __init__(self, key_size, value_size): | |||
| super(DotAtte, self).__init__() | |||
| self.key_size = key_size | |||
| self.value_size = value_size | |||
| self.scale = math.sqrt(key_size) | |||
| def forward(self, Q, K, V, seq_mask=None): | |||
| """ | |||
| :param Q: [batch, seq_len, key_size] | |||
| :param K: [batch, seq_len, key_size] | |||
| :param V: [batch, seq_len, value_size] | |||
| :param seq_mask: [batch, seq_len] | |||
| """ | |||
| output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | |||
| if seq_mask is not None: | |||
| output.masked_fill_(seq_mask.lt(1), -float('inf')) | |||
| output = nn.functional.softmax(output, dim=2) | |||
| return torch.matmul(output, V) | |||
| class MultiHeadAtte(nn.Module): | |||
| def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||
| super(MultiHeadAtte, self).__init__() | |||
| self.in_linear = nn.ModuleList() | |||
| for i in range(num_atte * 3): | |||
| out_feat = key_size if (i % 3) != 2 else value_size | |||
| self.in_linear.append(nn.Linear(input_size, out_feat)) | |||
| self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) | |||
| self.out_linear = nn.Linear(value_size * num_atte, output_size) | |||
| def forward(self, Q, K, V, seq_mask=None): | |||
| heads = [] | |||
| for i in range(len(self.attes)): | |||
| j = i * 3 | |||
| qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) | |||
| headi = self.attes[i](qi, ki, vi, seq_mask) | |||
| heads.append(headi) | |||
| output = torch.cat(heads, dim=2) | |||
| return self.out_linear(output) | |||
| @@ -0,0 +1,32 @@ | |||
| import torch | |||
| from torch import nn | |||
| import torch.nn.functional as F | |||
| from ..aggregator.attention import MultiHeadAtte | |||
| from ..other_modules import LayerNormalization | |||
| class TransformerEncoder(nn.Module): | |||
| class SubLayer(nn.Module): | |||
| def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||
| super(TransformerEncoder.SubLayer, self).__init__() | |||
| self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) | |||
| self.norm1 = LayerNormalization(output_size) | |||
| self.ffn = nn.Sequential(nn.Linear(output_size, output_size), | |||
| nn.ReLU(), | |||
| nn.Linear(output_size, output_size)) | |||
| self.norm2 = LayerNormalization(output_size) | |||
| def forward(self, input, seq_mask): | |||
| attention = self.atte(input) | |||
| norm_atte = self.norm1(attention + input) | |||
| output = self.ffn(norm_atte) | |||
| return self.norm2(output + norm_atte) | |||
| def __init__(self, num_layers, **kargs): | |||
| super(TransformerEncoder, self).__init__() | |||
| self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) | |||
| def forward(self, x, seq_mask=None): | |||
| return self.layers(x, seq_mask) | |||
| @@ -31,12 +31,12 @@ class GroupNorm(nn.Module): | |||
| class LayerNormalization(nn.Module): | |||
| """ Layer normalization module """ | |||
| def __init__(self, d_hid, eps=1e-3): | |||
| def __init__(self, layer_size, eps=1e-3): | |||
| super(LayerNormalization, self).__init__() | |||
| self.eps = eps | |||
| self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) | |||
| self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) | |||
| self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True)) | |||
| self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True)) | |||
| def forward(self, z): | |||
| if z.size(1) == 1: | |||
| @@ -44,9 +44,8 @@ class LayerNormalization(nn.Module): | |||
| mu = torch.mean(z, keepdim=True, dim=-1) | |||
| sigma = torch.std(z, keepdim=True, dim=-1) | |||
| ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) | |||
| ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out) | |||
| ln_out = (z - mu) / (sigma + self.eps) | |||
| ln_out = ln_out * self.a_2 + self.b_2 | |||
| return ln_out | |||
| @@ -1,5 +1,5 @@ | |||
| [train] | |||
| epochs = 50 | |||
| epochs = -1 | |||
| batch_size = 16 | |||
| pickle_path = "./save/" | |||
| validate = true | |||