From 5960aba9cb1fc3106c51200b965cd6579e04d2ab Mon Sep 17 00:00:00 2001 From: 2017alan <17210240044@fudan.edu.cn> Date: Sat, 15 Sep 2018 17:16:36 +0800 Subject: [PATCH] change the code to do with sentence with padding tokens. --- fastNLP/modules/aggregation/self_attention.py | 51 +++++++++++++++---- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/fastNLP/modules/aggregation/self_attention.py b/fastNLP/modules/aggregation/self_attention.py index aeaef4db..4155d708 100644 --- a/fastNLP/modules/aggregation/self_attention.py +++ b/fastNLP/modules/aggregation/self_attention.py @@ -1,8 +1,10 @@ import torch import torch.nn as nn from torch.autograd import Variable +import torch.nn.functional as F +from fastNLP.modules.utils import initial_parameter class SelfAttention(nn.Module): """ Self Attention Module. @@ -13,13 +15,18 @@ class SelfAttention(nn.Module): num_vec: int, the number of encoded vectors """ - def __init__(self, input_size, dim=10, num_vec=10): + def __init__(self, input_size, dim=10, num_vec=10 ,drop = 0.5 ,initial_method =None): super(SelfAttention, self).__init__() - self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) - self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) + # self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) + # self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) + self.attention_hops = num_vec + + self.ws1 = nn.Linear(input_size, dim, bias=False) + self.ws2 = nn.Linear(dim, num_vec, bias=False) + self.drop = nn.Dropout(drop) self.softmax = nn.Softmax(dim=2) self.tanh = nn.Tanh() - + initial_parameter(self, initial_method) def penalization(self, A): """ compute the penalization term for attention module @@ -32,11 +39,33 @@ class SelfAttention(nn.Module): M = M.view(M.size(0), -1) return torch.sum(M ** 2, dim=1) - def forward(self, x): - inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) - A = self.softmax(torch.matmul(self.W_s2, inter)) - out = torch.matmul(A, x) - out = out.view(out.size(0), -1) - penalty = self.penalization(A) - return out, penalty + def forward(self, outp ,inp): + # the following code can not be use because some word are padding ,these is not such module! + + # inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) # [] + # A = self.softmax(torch.matmul(self.W_s2, inter)) + # out = torch.matmul(A, x) + # out = out.view(out.size(0), -1) + # penalty = self.penalization(A) + # return out, penalty + outp = outp.contiguous() + size = outp.size() # [bsz, len, nhid] + + compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2] + transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len] + transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len] + concatenated_inp = [transformed_inp for i in range(self.attention_hops)] + concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len] + + hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit] + attention = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop] + attention = torch.transpose(attention, 1, 2).contiguous() # [bsz, hop, len] + penalized_alphas = attention + ( + -10000 * (concatenated_inp == 0).float()) + # [bsz, hop, len] + [bsz, hop, len] + attention = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len] + attention = attention.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len] + return torch.bmm(attention, outp), attention # output --> [baz ,hop ,nhid] + +