@@ -1,5 +1,6 @@ | |||||
import torch | import torch | ||||
from torch import nn | |||||
import math | |||||
from fastNLP.modules.utils import mask_softmax | from fastNLP.modules.utils import mask_softmax | ||||
@@ -17,3 +18,44 @@ class Attention(torch.nn.Module): | |||||
def _atten_forward(self, query, memory): | def _atten_forward(self, query, memory): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class DotAtte(nn.Module): | |||||
def __init__(self, key_size, value_size): | |||||
super(DotAtte, self).__init__() | |||||
self.key_size = key_size | |||||
self.value_size = value_size | |||||
self.scale = math.sqrt(key_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
""" | |||||
:param Q: [batch, seq_len, key_size] | |||||
:param K: [batch, seq_len, key_size] | |||||
:param V: [batch, seq_len, value_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
""" | |||||
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale | |||||
if seq_mask is not None: | |||||
output.masked_fill_(seq_mask.lt(1), -float('inf')) | |||||
output = nn.functional.softmax(output, dim=2) | |||||
return torch.matmul(output, V) | |||||
class MultiHeadAtte(nn.Module): | |||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
super(MultiHeadAtte, self).__init__() | |||||
self.in_linear = nn.ModuleList() | |||||
for i in range(num_atte * 3): | |||||
out_feat = key_size if (i % 3) != 2 else value_size | |||||
self.in_linear.append(nn.Linear(input_size, out_feat)) | |||||
self.attes = nn.ModuleList([DotAtte(key_size, value_size) for _ in range(num_atte)]) | |||||
self.out_linear = nn.Linear(value_size * num_atte, output_size) | |||||
def forward(self, Q, K, V, seq_mask=None): | |||||
heads = [] | |||||
for i in range(len(self.attes)): | |||||
j = i * 3 | |||||
qi, ki, vi = self.in_linear[j](Q), self.in_linear[j+1](K), self.in_linear[j+2](V) | |||||
headi = self.attes[i](qi, ki, vi, seq_mask) | |||||
heads.append(headi) | |||||
output = torch.cat(heads, dim=2) | |||||
return self.out_linear(output) |
@@ -0,0 +1,32 @@ | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from ..aggregator.attention import MultiHeadAtte | |||||
from ..other_modules import LayerNormalization | |||||
class TransformerEncoder(nn.Module): | |||||
class SubLayer(nn.Module): | |||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | |||||
super(TransformerEncoder.SubLayer, self).__init__() | |||||
self.atte = MultiHeadAtte(input_size, output_size, key_size, value_size, num_atte) | |||||
self.norm1 = LayerNormalization(output_size) | |||||
self.ffn = nn.Sequential(nn.Linear(output_size, output_size), | |||||
nn.ReLU(), | |||||
nn.Linear(output_size, output_size)) | |||||
self.norm2 = LayerNormalization(output_size) | |||||
def forward(self, input, seq_mask): | |||||
attention = self.atte(input) | |||||
norm_atte = self.norm1(attention + input) | |||||
output = self.ffn(norm_atte) | |||||
return self.norm2(output + norm_atte) | |||||
def __init__(self, num_layers, **kargs): | |||||
super(TransformerEncoder, self).__init__() | |||||
self.layers = nn.Sequential(*[self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
def forward(self, x, seq_mask=None): | |||||
return self.layers(x, seq_mask) | |||||
@@ -31,12 +31,12 @@ class GroupNorm(nn.Module): | |||||
class LayerNormalization(nn.Module): | class LayerNormalization(nn.Module): | ||||
""" Layer normalization module """ | """ Layer normalization module """ | ||||
def __init__(self, d_hid, eps=1e-3): | |||||
def __init__(self, layer_size, eps=1e-3): | |||||
super(LayerNormalization, self).__init__() | super(LayerNormalization, self).__init__() | ||||
self.eps = eps | self.eps = eps | ||||
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) | |||||
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) | |||||
self.a_2 = nn.Parameter(torch.ones(1, layer_size, requires_grad=True)) | |||||
self.b_2 = nn.Parameter(torch.zeros(1, layer_size, requires_grad=True)) | |||||
def forward(self, z): | def forward(self, z): | ||||
if z.size(1) == 1: | if z.size(1) == 1: | ||||
@@ -44,9 +44,8 @@ class LayerNormalization(nn.Module): | |||||
mu = torch.mean(z, keepdim=True, dim=-1) | mu = torch.mean(z, keepdim=True, dim=-1) | ||||
sigma = torch.std(z, keepdim=True, dim=-1) | sigma = torch.std(z, keepdim=True, dim=-1) | ||||
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) | |||||
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out) | |||||
ln_out = (z - mu) / (sigma + self.eps) | |||||
ln_out = ln_out * self.a_2 + self.b_2 | |||||
return ln_out | return ln_out | ||||
@@ -1,5 +1,5 @@ | |||||
[train] | [train] | ||||
epochs = 50 | |||||
epochs = -1 | |||||
batch_size = 16 | batch_size = 16 | ||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||