Browse Source

change the code to do with sentence with padding tokens.

tags/v0.1.0
2017alan 6 years ago
parent
commit
5960aba9cb
1 changed files with 40 additions and 11 deletions
  1. +40
    -11
      fastNLP/modules/aggregation/self_attention.py

+ 40
- 11
fastNLP/modules/aggregation/self_attention.py View File

@@ -1,8 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
import torch.nn.functional as F




from fastNLP.modules.utils import initial_parameter
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
""" """
Self Attention Module. Self Attention Module.
@@ -13,13 +15,18 @@ class SelfAttention(nn.Module):
num_vec: int, the number of encoded vectors num_vec: int, the number of encoded vectors
""" """


def __init__(self, input_size, dim=10, num_vec=10):
def __init__(self, input_size, dim=10, num_vec=10 ,drop = 0.5 ,initial_method =None):
super(SelfAttention, self).__init__() super(SelfAttention, self).__init__()
self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
# self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True)
# self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True)
self.attention_hops = num_vec

self.ws1 = nn.Linear(input_size, dim, bias=False)
self.ws2 = nn.Linear(dim, num_vec, bias=False)
self.drop = nn.Dropout(drop)
self.softmax = nn.Softmax(dim=2) self.softmax = nn.Softmax(dim=2)
self.tanh = nn.Tanh() self.tanh = nn.Tanh()

initial_parameter(self, initial_method)
def penalization(self, A): def penalization(self, A):
""" """
compute the penalization term for attention module compute the penalization term for attention module
@@ -32,11 +39,33 @@ class SelfAttention(nn.Module):
M = M.view(M.size(0), -1) M = M.view(M.size(0), -1)
return torch.sum(M ** 2, dim=1) return torch.sum(M ** 2, dim=1)
def forward(self, x):
inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2)))
A = self.softmax(torch.matmul(self.W_s2, inter))
out = torch.matmul(A, x)
out = out.view(out.size(0), -1)
penalty = self.penalization(A)
return out, penalty
def forward(self, outp ,inp):
# the following code can not be use because some word are padding ,these is not such module!

# inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) # []
# A = self.softmax(torch.matmul(self.W_s2, inter))
# out = torch.matmul(A, x)
# out = out.view(out.size(0), -1)
# penalty = self.penalization(A)
# return out, penalty
outp = outp.contiguous()
size = outp.size() # [bsz, len, nhid]

compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2]
transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len]
transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len]
concatenated_inp = [transformed_inp for i in range(self.attention_hops)]
concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len]

hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit]
attention = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop]
attention = torch.transpose(attention, 1, 2).contiguous() # [bsz, hop, len]
penalized_alphas = attention + (
-10000 * (concatenated_inp == 0).float())
# [bsz, hop, len] + [bsz, hop, len]
attention = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len]
attention = attention.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len]
return torch.bmm(attention, outp), attention # output --> [baz ,hop ,nhid]





Loading…
Cancel
Save