From f22cb585593dfdef01e547381904a7ca2a4e4de0 Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 23 Apr 2019 21:24:19 +0800 Subject: [PATCH] combine self attention module to attention.py --- fastNLP/modules/aggregator/__init__.py | 2 +- fastNLP/modules/aggregator/attention.py | 59 +++++++++++++++++ fastNLP/modules/aggregator/self_attention.py | 68 -------------------- 3 files changed, 60 insertions(+), 69 deletions(-) delete mode 100644 fastNLP/modules/aggregator/self_attention.py diff --git a/fastNLP/modules/aggregator/__init__.py b/fastNLP/modules/aggregator/__init__.py index c0a63fd3..51106a76 100644 --- a/fastNLP/modules/aggregator/__init__.py +++ b/fastNLP/modules/aggregator/__init__.py @@ -6,5 +6,5 @@ from .pooling import KMaxPool from .attention import Attention from .attention import BiAttention -from .self_attention import SelfAttention +from .attention import SelfAttention diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 4155fdd6..f2f2ac68 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -7,6 +7,8 @@ from torch import nn from fastNLP.modules.dropout import TimestepDropout from fastNLP.modules.utils import mask_softmax +from fastNLP.modules.utils import initial_parameter + class Attention(torch.nn.Module): def __init__(self, normalize=False): @@ -168,3 +170,60 @@ class BiAttention(nn.Module): out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] return out_x1, out_x2 + +class SelfAttention(nn.Module): + """Self Attention Module. + :param int input_size: 输入tensor的hidden维度 + :param int attention_unit: 输出tensor的hidden维度 + :param int attention_hops: + :param float drop: dropout概率,默认值为0.5 + :param str initial_method: 初始化参数方法 + """ + + def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None,): + super(SelfAttention, self).__init__() + + self.attention_hops = attention_hops + self.ws1 = nn.Linear(input_size, attention_unit, bias=False) + self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) + self.I = torch.eye(attention_hops, requires_grad=False) + self.I_origin = self.I + self.drop = nn.Dropout(drop) + self.tanh = nn.Tanh() + initial_parameter(self, initial_method) + + def _penalization(self, attention): + """ + compute the penalization term for attention module + """ + baz = attention.size(0) + size = self.I.size() + if len(size) != 3 or size[0] != baz: + self.I = self.I_origin.expand(baz, -1, -1) + self.I = self.I.to(device=attention.device) + attention_t = torch.transpose(attention, 1, 2).contiguous() + mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)] + ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 + return torch.sum(ret) / size[0] + + def forward(self, input, input_origin): + """ + :param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵 + :param torch.Tensor input_origin: [baz , senLen] 原始token的index组成的矩阵,含有pad部分内容 + :return torch.Tensor output1: [baz, multi-head , h_dim] 经过attention操作后输入矩阵的结果 + :return torch.Tensor output2: [1] attention惩罚项,是一个标量 + """ + input = input.contiguous() + size = input.size() # [bsz, len, nhid] + + input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] + input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] + + y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] + attention = self.ws2(y1).transpose(1, 2).contiguous() + # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] + + attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. + attention = F.softmax(attention, 2) # [baz ,hop, len] + return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid] + diff --git a/fastNLP/modules/aggregator/self_attention.py b/fastNLP/modules/aggregator/self_attention.py deleted file mode 100644 index b0f03791..00000000 --- a/fastNLP/modules/aggregator/self_attention.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable - -from fastNLP.modules.utils import initial_parameter - - -class SelfAttention(nn.Module): - """Self Attention Module. - - :param int input_size: - :param int attention_unit: - :param int attention_hops: - :param float drop: - :param str initial_method: - :param bool use_cuda: - """ - - def __init__(self, input_size, attention_unit=350, attention_hops=10, drop=0.5, initial_method=None, - use_cuda=False): - super(SelfAttention, self).__init__() - - self.attention_hops = attention_hops - self.ws1 = nn.Linear(input_size, attention_unit, bias=False) - self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) - if use_cuda: - self.I = Variable(torch.eye(attention_hops).cuda(), requires_grad=False) - else: - self.I = Variable(torch.eye(attention_hops), requires_grad=False) - self.I_origin = self.I - self.drop = nn.Dropout(drop) - self.tanh = nn.Tanh() - initial_parameter(self, initial_method) - - def penalization(self, attention): - """ - compute the penalization term for attention module - """ - baz = attention.size(0) - size = self.I.size() - if len(size) != 3 or size[0] != baz: - self.I = self.I_origin.expand(baz, -1, -1) - attentionT = torch.transpose(attention, 1, 2).contiguous() - mat = torch.bmm(attention, attentionT) - self.I[:attention.size(0)] - ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 - return torch.sum(ret) / size[0] - - def forward(self, input, input_origin): - """ - :param input: the matrix to do attention. [baz, senLen, h_dim] - :param inp: then token index include pad token( 0 ) [baz , senLen] - :return output1: the input matrix after attention operation [baz, multi-head , h_dim] - :return output2: the attention penalty term, a scalar [1] - """ - input = input.contiguous() - size = input.size() # [bsz, len, nhid] - - input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] - input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] - - y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] - attention = self.ws2(y1).transpose(1, 2).contiguous() - # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] - - attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. - attention = F.softmax(attention, 2) # [baz ,hop, len] - return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid]