| @@ -1,8 +1,10 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from torch.autograd import Variable | from torch.autograd import Variable | ||||
| import torch.nn.functional as F | |||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class SelfAttention(nn.Module): | class SelfAttention(nn.Module): | ||||
| """ | """ | ||||
| Self Attention Module. | Self Attention Module. | ||||
| @@ -13,13 +15,18 @@ class SelfAttention(nn.Module): | |||||
| num_vec: int, the number of encoded vectors | num_vec: int, the number of encoded vectors | ||||
| """ | """ | ||||
| def __init__(self, input_size, dim=10, num_vec=10): | |||||
| def __init__(self, input_size, dim=10, num_vec=10 ,drop = 0.5 ,initial_method =None): | |||||
| super(SelfAttention, self).__init__() | super(SelfAttention, self).__init__() | ||||
| self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) | |||||
| self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) | |||||
| # self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) | |||||
| # self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) | |||||
| self.attention_hops = num_vec | |||||
| self.ws1 = nn.Linear(input_size, dim, bias=False) | |||||
| self.ws2 = nn.Linear(dim, num_vec, bias=False) | |||||
| self.drop = nn.Dropout(drop) | |||||
| self.softmax = nn.Softmax(dim=2) | self.softmax = nn.Softmax(dim=2) | ||||
| self.tanh = nn.Tanh() | self.tanh = nn.Tanh() | ||||
| initial_parameter(self, initial_method) | |||||
| def penalization(self, A): | def penalization(self, A): | ||||
| """ | """ | ||||
| compute the penalization term for attention module | compute the penalization term for attention module | ||||
| @@ -32,11 +39,33 @@ class SelfAttention(nn.Module): | |||||
| M = M.view(M.size(0), -1) | M = M.view(M.size(0), -1) | ||||
| return torch.sum(M ** 2, dim=1) | return torch.sum(M ** 2, dim=1) | ||||
| def forward(self, x): | |||||
| inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) | |||||
| A = self.softmax(torch.matmul(self.W_s2, inter)) | |||||
| out = torch.matmul(A, x) | |||||
| out = out.view(out.size(0), -1) | |||||
| penalty = self.penalization(A) | |||||
| return out, penalty | |||||
| def forward(self, outp ,inp): | |||||
| # the following code can not be use because some word are padding ,these is not such module! | |||||
| # inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) # [] | |||||
| # A = self.softmax(torch.matmul(self.W_s2, inter)) | |||||
| # out = torch.matmul(A, x) | |||||
| # out = out.view(out.size(0), -1) | |||||
| # penalty = self.penalization(A) | |||||
| # return out, penalty | |||||
| outp = outp.contiguous() | |||||
| size = outp.size() # [bsz, len, nhid] | |||||
| compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2] | |||||
| transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len] | |||||
| transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len] | |||||
| concatenated_inp = [transformed_inp for i in range(self.attention_hops)] | |||||
| concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len] | |||||
| hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit] | |||||
| attention = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop] | |||||
| attention = torch.transpose(attention, 1, 2).contiguous() # [bsz, hop, len] | |||||
| penalized_alphas = attention + ( | |||||
| -10000 * (concatenated_inp == 0).float()) | |||||
| # [bsz, hop, len] + [bsz, hop, len] | |||||
| attention = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len] | |||||
| attention = attention.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len] | |||||
| return torch.bmm(attention, outp), attention # output --> [baz ,hop ,nhid] | |||||