Browse Source

update documents in attention module

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
cee9fda6c7
3 changed files with 70 additions and 107 deletions
  1. +2
    -46
      fastNLP/models/snli.py
  2. +3
    -1
      fastNLP/modules/encoder/__init__.py
  3. +65
    -60
      fastNLP/modules/encoder/attention.py

+ 2
- 46
fastNLP/models/snli.py View File

@@ -15,6 +15,7 @@ from .base_model import BaseModel
from ..core.const import Const
from ..core.utils import seq_len_to_mask
from ..embeddings.embedding import TokenEmbedding, Embedding
from ..modules.encoder import BiAttention


class ESIM(BaseModel):
@@ -50,7 +51,7 @@ class ESIM(BaseModel):
nn.Linear(8 * hidden_size, hidden_size),
nn.ReLU())
nn.init.xavier_uniform_(self.interfere[1].weight.data)
self.bi_attention = SoftmaxAttention()
self.bi_attention = BiAttention()

self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate)
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True,)
@@ -174,48 +175,3 @@ class BiRNN(nn.Module):
output = torch.cat([output, padding], 1)
return output


def masked_softmax(tensor, mask):
tensor_shape = tensor.size()
reshaped_tensor = tensor.view(-1, tensor_shape[-1])

# Reshape the mask so it matches the size of the input tensor.
while mask.dim() < tensor.dim():
mask = mask.unsqueeze(1)
mask = mask.expand_as(tensor).contiguous().float()
reshaped_mask = mask.view(-1, mask.size()[-1])
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1)
result = result * reshaped_mask
# 1e-13 is added to avoid divisions by zero.
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
return result.view(*tensor_shape)


def weighted_sum(tensor, weights, mask):
w_sum = weights.bmm(tensor)
while mask.dim() < w_sum.dim():
mask = mask.unsqueeze(1)
mask = mask.transpose(-1, -2)
mask = mask.expand_as(w_sum).contiguous().float()
return w_sum * mask


class SoftmaxAttention(nn.Module):

def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask):
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1)
.contiguous())

prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask)
hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2)
.contiguous(),
premise_mask)

attended_premises = weighted_sum(hypothesis_batch,
prem_hyp_attn,
premise_mask)
attended_hypotheses = weighted_sum(premise_batch,
hyp_prem_attn,
hypothesis_mask)

return attended_premises, attended_hypotheses

+ 3
- 1
fastNLP/modules/encoder/__init__.py View File

@@ -27,9 +27,11 @@ __all__ = [
"AvgPoolWithMask",

"MultiHeadAttention",
"BiAttention",
"SelfAttention",
]

from .attention import MultiHeadAttention
from .attention import MultiHeadAttention, BiAttention, SelfAttention
from .bert import BertModel
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool


+ 65
- 60
fastNLP/modules/encoder/attention.py View File

@@ -1,7 +1,9 @@
"""undocumented"""

__all__ = [
"MultiHeadAttention"
"MultiHeadAttention",
"BiAttention",
"SelfAttention",
]

import math
@@ -15,8 +17,7 @@ from fastNLP.modules.utils import initial_parameter

class DotAttention(nn.Module):
"""
.. todo::
补上文档
Transformer当中的DotAttention
"""

def __init__(self, key_size, value_size, dropout=0.0):
@@ -45,7 +46,7 @@ class DotAttention(nn.Module):

class MultiHeadAttention(nn.Module):
"""
Transformer当中的MultiHeadAttention
"""

def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1):
@@ -104,74 +105,78 @@ class MultiHeadAttention(nn.Module):
return output


def _masked_softmax(tensor, mask):
tensor_shape = tensor.size()
reshaped_tensor = tensor.view(-1, tensor_shape[-1])

# Reshape the mask so it matches the size of the input tensor.
while mask.dim() < tensor.dim():
mask = mask.unsqueeze(1)
mask = mask.expand_as(tensor).contiguous().float()
reshaped_mask = mask.view(-1, mask.size()[-1])
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1)
result = result * reshaped_mask
# 1e-13 is added to avoid divisions by zero.
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
return result.view(*tensor_shape)


def _weighted_sum(tensor, weights, mask):
w_sum = weights.bmm(tensor)
while mask.dim() < w_sum.dim():
mask = mask.unsqueeze(1)
mask = mask.transpose(-1, -2)
mask = mask.expand_as(w_sum).contiguous().float()
return w_sum * mask


class BiAttention(nn.Module):
r"""Bi Attention module
.. todo::
这个模块的负责人来继续完善一下
Calculate Bi Attention matrix `e`
r"""
Bi Attention module

对于给定的两个向量序列 :math:`a_i` 和 :math:`b_j` , BiAttention模块将通过以下的公式来计算attention结果

.. math::

\begin{array}{ll} \\
e_ij = {a}^{\mathbf{T}}_{i}{b}_{j} \\
a_i =
b_j =
e_{ij} = {a}^{\mathrm{T}}_{i}{b}_{j} \\
{\hat{a}}_{i} = \sum_{j=1}^{\mathcal{l}_{b}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{b}}{\mathrm{exp}(e_{ik})}}}{b}_{j} \\
{\hat{b}}_{j} = \sum_{i=1}^{\mathcal{l}_{a}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{a}}{\mathrm{exp}(e_{ik})}}}{a}_{i} \\
\end{array}
"""

def __init__(self):
super(BiAttention, self).__init__()
self.inf = 10e12
"""

def forward(self, in_x1, in_x2, x1_len, x2_len):
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask):
"""
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示
:param torch.Tensor in_x2: [batch_size, x2_seq_len, hidden_size] 第二句的特征表示
:param torch.Tensor x1_len: [batch_size, x1_seq_len] 第一句的0/1mask矩阵
:param torch.Tensor x2_len: [batch_size, x2_seq_len] 第二句的0/1mask矩阵
:return: torch.Tensor out_x1: [batch_size, x1_seq_len, hidden_size] 第一句attend到的特征表示
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示
:param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size]
:param torch.Tensor premise_mask: [batch_size, a_seq_len]
:param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size]
:param torch.Tensor hypothesis_mask: [batch_size, b_seq_len]
:return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size]
torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size]
"""
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1)
.contiguous())

assert in_x1.size()[0] == in_x2.size()[0]
assert in_x1.size()[2] == in_x2.size()[2]
# The batch size and hidden size must be equal.
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1]
# The seq len in in_x and x_len must be equal.
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0]

batch_size = in_x1.size()[0]
x1_max_len = in_x1.size()[1]
x2_max_len = in_x2.size()[1]

in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len]

attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len]

a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len]
a_mask = a_mask.view(batch_size, x1_max_len, -1)
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len]
b_mask = x2_len.le(0.5).float() * -self.inf
b_mask = b_mask.view(batch_size, -1, x2_max_len)
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len]

attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len]
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len]
prem_hyp_attn = _masked_softmax(similarity_matrix, hypothesis_mask)
hyp_prem_attn = _masked_softmax(similarity_matrix.transpose(1, 2)
.contiguous(),
premise_mask)

out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size]
attention_b_t = torch.transpose(attention_b, 1, 2)
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size]
attended_premises = _weighted_sum(hypothesis_batch,
prem_hyp_attn,
premise_mask)
attended_hypotheses = _weighted_sum(premise_batch,
hyp_prem_attn,
hypothesis_mask)

return out_x1, out_x2
return attended_premises, attended_hypotheses


class SelfAttention(nn.Module):
"""
Self Attention Module.
这是一个基于论文 `A structured self-attentive sentence embedding <https://arxiv.org/pdf/1703.03130.pdf>`_
的Self Attention Module.
"""

def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ):
@@ -210,9 +215,9 @@ class SelfAttention(nn.Module):

def forward(self, input, input_origin):
"""
:param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵
:param torch.Tensor input_origin: [baz , senLen] 原始token的index组成的矩阵,含有pad部分内容
:return torch.Tensor output1: [baz, multi-head , h_dim] 经过attention操作后输入矩阵的结果
:param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵
:param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容
:return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果
:return torch.Tensor output2: [1] attention惩罚项,是一个标量
"""
input = input.contiguous()


Loading…
Cancel
Save