|
|
@@ -1,7 +1,9 @@ |
|
|
|
"""undocumented""" |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
"MultiHeadAttention" |
|
|
|
"MultiHeadAttention", |
|
|
|
"BiAttention", |
|
|
|
"SelfAttention", |
|
|
|
] |
|
|
|
|
|
|
|
import math |
|
|
@@ -15,8 +17,7 @@ from fastNLP.modules.utils import initial_parameter |
|
|
|
|
|
|
|
class DotAttention(nn.Module): |
|
|
|
""" |
|
|
|
.. todo:: |
|
|
|
补上文档 |
|
|
|
Transformer当中的DotAttention |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, key_size, value_size, dropout=0.0): |
|
|
@@ -45,7 +46,7 @@ class DotAttention(nn.Module): |
|
|
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
|
""" |
|
|
|
|
|
|
|
Transformer当中的MultiHeadAttention |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): |
|
|
@@ -104,74 +105,78 @@ class MultiHeadAttention(nn.Module): |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def _masked_softmax(tensor, mask): |
|
|
|
tensor_shape = tensor.size() |
|
|
|
reshaped_tensor = tensor.view(-1, tensor_shape[-1]) |
|
|
|
|
|
|
|
# Reshape the mask so it matches the size of the input tensor. |
|
|
|
while mask.dim() < tensor.dim(): |
|
|
|
mask = mask.unsqueeze(1) |
|
|
|
mask = mask.expand_as(tensor).contiguous().float() |
|
|
|
reshaped_mask = mask.view(-1, mask.size()[-1]) |
|
|
|
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) |
|
|
|
result = result * reshaped_mask |
|
|
|
# 1e-13 is added to avoid divisions by zero. |
|
|
|
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) |
|
|
|
return result.view(*tensor_shape) |
|
|
|
|
|
|
|
|
|
|
|
def _weighted_sum(tensor, weights, mask): |
|
|
|
w_sum = weights.bmm(tensor) |
|
|
|
while mask.dim() < w_sum.dim(): |
|
|
|
mask = mask.unsqueeze(1) |
|
|
|
mask = mask.transpose(-1, -2) |
|
|
|
mask = mask.expand_as(w_sum).contiguous().float() |
|
|
|
return w_sum * mask |
|
|
|
|
|
|
|
|
|
|
|
class BiAttention(nn.Module): |
|
|
|
r"""Bi Attention module |
|
|
|
|
|
|
|
.. todo:: |
|
|
|
这个模块的负责人来继续完善一下 |
|
|
|
|
|
|
|
Calculate Bi Attention matrix `e` |
|
|
|
|
|
|
|
r""" |
|
|
|
Bi Attention module |
|
|
|
|
|
|
|
对于给定的两个向量序列 :math:`a_i` 和 :math:`b_j` , BiAttention模块将通过以下的公式来计算attention结果 |
|
|
|
|
|
|
|
.. math:: |
|
|
|
|
|
|
|
|
|
|
|
\begin{array}{ll} \\ |
|
|
|
e_ij = {a}^{\mathbf{T}}_{i}{b}_{j} \\ |
|
|
|
a_i = |
|
|
|
b_j = |
|
|
|
e_{ij} = {a}^{\mathrm{T}}_{i}{b}_{j} \\ |
|
|
|
{\hat{a}}_{i} = \sum_{j=1}^{\mathcal{l}_{b}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{b}}{\mathrm{exp}(e_{ik})}}}{b}_{j} \\ |
|
|
|
{\hat{b}}_{j} = \sum_{i=1}^{\mathcal{l}_{a}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{a}}{\mathrm{exp}(e_{ik})}}}{a}_{i} \\ |
|
|
|
\end{array} |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super(BiAttention, self).__init__() |
|
|
|
self.inf = 10e12 |
|
|
|
""" |
|
|
|
|
|
|
|
def forward(self, in_x1, in_x2, x1_len, x2_len): |
|
|
|
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): |
|
|
|
""" |
|
|
|
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 |
|
|
|
:param torch.Tensor in_x2: [batch_size, x2_seq_len, hidden_size] 第二句的特征表示 |
|
|
|
:param torch.Tensor x1_len: [batch_size, x1_seq_len] 第一句的0/1mask矩阵 |
|
|
|
:param torch.Tensor x2_len: [batch_size, x2_seq_len] 第二句的0/1mask矩阵 |
|
|
|
:return: torch.Tensor out_x1: [batch_size, x1_seq_len, hidden_size] 第一句attend到的特征表示 |
|
|
|
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 |
|
|
|
|
|
|
|
:param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size] |
|
|
|
:param torch.Tensor premise_mask: [batch_size, a_seq_len] |
|
|
|
:param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] |
|
|
|
:param torch.Tensor hypothesis_mask: [batch_size, b_seq_len] |
|
|
|
:return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size] |
|
|
|
torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size] |
|
|
|
""" |
|
|
|
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) |
|
|
|
.contiguous()) |
|
|
|
|
|
|
|
assert in_x1.size()[0] == in_x2.size()[0] |
|
|
|
assert in_x1.size()[2] == in_x2.size()[2] |
|
|
|
# The batch size and hidden size must be equal. |
|
|
|
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] |
|
|
|
# The seq len in in_x and x_len must be equal. |
|
|
|
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] |
|
|
|
|
|
|
|
batch_size = in_x1.size()[0] |
|
|
|
x1_max_len = in_x1.size()[1] |
|
|
|
x2_max_len = in_x2.size()[1] |
|
|
|
|
|
|
|
in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] |
|
|
|
|
|
|
|
attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] |
|
|
|
|
|
|
|
a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] |
|
|
|
a_mask = a_mask.view(batch_size, x1_max_len, -1) |
|
|
|
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] |
|
|
|
b_mask = x2_len.le(0.5).float() * -self.inf |
|
|
|
b_mask = b_mask.view(batch_size, -1, x2_max_len) |
|
|
|
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] |
|
|
|
|
|
|
|
attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] |
|
|
|
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] |
|
|
|
prem_hyp_attn = _masked_softmax(similarity_matrix, hypothesis_mask) |
|
|
|
hyp_prem_attn = _masked_softmax(similarity_matrix.transpose(1, 2) |
|
|
|
.contiguous(), |
|
|
|
premise_mask) |
|
|
|
|
|
|
|
out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] |
|
|
|
attention_b_t = torch.transpose(attention_b, 1, 2) |
|
|
|
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] |
|
|
|
attended_premises = _weighted_sum(hypothesis_batch, |
|
|
|
prem_hyp_attn, |
|
|
|
premise_mask) |
|
|
|
attended_hypotheses = _weighted_sum(premise_batch, |
|
|
|
hyp_prem_attn, |
|
|
|
hypothesis_mask) |
|
|
|
|
|
|
|
return out_x1, out_x2 |
|
|
|
return attended_premises, attended_hypotheses |
|
|
|
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
|
""" |
|
|
|
Self Attention Module. |
|
|
|
这是一个基于论文 `A structured self-attentive sentence embedding <https://arxiv.org/pdf/1703.03130.pdf>`_ |
|
|
|
的Self Attention Module. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ): |
|
|
@@ -210,9 +215,9 @@ class SelfAttention(nn.Module): |
|
|
|
|
|
|
|
def forward(self, input, input_origin): |
|
|
|
""" |
|
|
|
:param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵 |
|
|
|
:param torch.Tensor input_origin: [baz , senLen] 原始token的index组成的矩阵,含有pad部分内容 |
|
|
|
:return torch.Tensor output1: [baz, multi-head , h_dim] 经过attention操作后输入矩阵的结果 |
|
|
|
:param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵 |
|
|
|
:param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容 |
|
|
|
:return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果 |
|
|
|
:return torch.Tensor output2: [1] attention惩罚项,是一个标量 |
|
|
|
""" |
|
|
|
input = input.contiguous() |
|
|
|