diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index 9a48f967..1661d191 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -15,6 +15,7 @@ from .base_model import BaseModel from ..core.const import Const from ..core.utils import seq_len_to_mask from ..embeddings.embedding import TokenEmbedding, Embedding +from ..modules.encoder import BiAttention class ESIM(BaseModel): @@ -50,7 +51,7 @@ class ESIM(BaseModel): nn.Linear(8 * hidden_size, hidden_size), nn.ReLU()) nn.init.xavier_uniform_(self.interfere[1].weight.data) - self.bi_attention = SoftmaxAttention() + self.bi_attention = BiAttention() self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) # self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True,) @@ -174,48 +175,3 @@ class BiRNN(nn.Module): output = torch.cat([output, padding], 1) return output - -def masked_softmax(tensor, mask): - tensor_shape = tensor.size() - reshaped_tensor = tensor.view(-1, tensor_shape[-1]) - - # Reshape the mask so it matches the size of the input tensor. - while mask.dim() < tensor.dim(): - mask = mask.unsqueeze(1) - mask = mask.expand_as(tensor).contiguous().float() - reshaped_mask = mask.view(-1, mask.size()[-1]) - result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) - result = result * reshaped_mask - # 1e-13 is added to avoid divisions by zero. - result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) - return result.view(*tensor_shape) - - -def weighted_sum(tensor, weights, mask): - w_sum = weights.bmm(tensor) - while mask.dim() < w_sum.dim(): - mask = mask.unsqueeze(1) - mask = mask.transpose(-1, -2) - mask = mask.expand_as(w_sum).contiguous().float() - return w_sum * mask - - -class SoftmaxAttention(nn.Module): - - def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): - similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) - .contiguous()) - - prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) - hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) - .contiguous(), - premise_mask) - - attended_premises = weighted_sum(hypothesis_batch, - prem_hyp_attn, - premise_mask) - attended_hypotheses = weighted_sum(premise_batch, - hyp_prem_attn, - hypothesis_mask) - - return attended_premises, attended_hypotheses diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index 0dfc18de..7fbc4b71 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -27,9 +27,11 @@ __all__ = [ "AvgPoolWithMask", "MultiHeadAttention", + "BiAttention", + "SelfAttention", ] -from .attention import MultiHeadAttention +from .attention import MultiHeadAttention, BiAttention, SelfAttention from .bert import BertModel from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder from .conv_maxpool import ConvMaxpool diff --git a/fastNLP/modules/encoder/attention.py b/fastNLP/modules/encoder/attention.py index 32f59c22..fdfcf0fd 100644 --- a/fastNLP/modules/encoder/attention.py +++ b/fastNLP/modules/encoder/attention.py @@ -1,7 +1,9 @@ """undocumented""" __all__ = [ - "MultiHeadAttention" + "MultiHeadAttention", + "BiAttention", + "SelfAttention", ] import math @@ -15,8 +17,7 @@ from fastNLP.modules.utils import initial_parameter class DotAttention(nn.Module): """ - .. todo:: - 补上文档 + Transformer当中的DotAttention """ def __init__(self, key_size, value_size, dropout=0.0): @@ -45,7 +46,7 @@ class DotAttention(nn.Module): class MultiHeadAttention(nn.Module): """ - + Transformer当中的MultiHeadAttention """ def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): @@ -104,74 +105,78 @@ class MultiHeadAttention(nn.Module): return output +def _masked_softmax(tensor, mask): + tensor_shape = tensor.size() + reshaped_tensor = tensor.view(-1, tensor_shape[-1]) + + # Reshape the mask so it matches the size of the input tensor. + while mask.dim() < tensor.dim(): + mask = mask.unsqueeze(1) + mask = mask.expand_as(tensor).contiguous().float() + reshaped_mask = mask.view(-1, mask.size()[-1]) + result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) + result = result * reshaped_mask + # 1e-13 is added to avoid divisions by zero. + result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) + return result.view(*tensor_shape) + + +def _weighted_sum(tensor, weights, mask): + w_sum = weights.bmm(tensor) + while mask.dim() < w_sum.dim(): + mask = mask.unsqueeze(1) + mask = mask.transpose(-1, -2) + mask = mask.expand_as(w_sum).contiguous().float() + return w_sum * mask + + class BiAttention(nn.Module): - r"""Bi Attention module - - .. todo:: - 这个模块的负责人来继续完善一下 - - Calculate Bi Attention matrix `e` - + r""" + Bi Attention module + + 对于给定的两个向量序列 :math:`a_i` 和 :math:`b_j` , BiAttention模块将通过以下的公式来计算attention结果 + .. math:: - + \begin{array}{ll} \\ - e_ij = {a}^{\mathbf{T}}_{i}{b}_{j} \\ - a_i = - b_j = + e_{ij} = {a}^{\mathrm{T}}_{i}{b}_{j} \\ + {\hat{a}}_{i} = \sum_{j=1}^{\mathcal{l}_{b}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{b}}{\mathrm{exp}(e_{ik})}}}{b}_{j} \\ + {\hat{b}}_{j} = \sum_{i=1}^{\mathcal{l}_{a}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{a}}{\mathrm{exp}(e_{ik})}}}{a}_{i} \\ \end{array} - - """ - def __init__(self): - super(BiAttention, self).__init__() - self.inf = 10e12 + """ - def forward(self, in_x1, in_x2, x1_len, x2_len): + def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): """ - :param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 - :param torch.Tensor in_x2: [batch_size, x2_seq_len, hidden_size] 第二句的特征表示 - :param torch.Tensor x1_len: [batch_size, x1_seq_len] 第一句的0/1mask矩阵 - :param torch.Tensor x2_len: [batch_size, x2_seq_len] 第二句的0/1mask矩阵 - :return: torch.Tensor out_x1: [batch_size, x1_seq_len, hidden_size] 第一句attend到的特征表示 - torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 - + :param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size] + :param torch.Tensor premise_mask: [batch_size, a_seq_len] + :param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] + :param torch.Tensor hypothesis_mask: [batch_size, b_seq_len] + :return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size] + torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size] """ + similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) + .contiguous()) - assert in_x1.size()[0] == in_x2.size()[0] - assert in_x1.size()[2] == in_x2.size()[2] - # The batch size and hidden size must be equal. - assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] - # The seq len in in_x and x_len must be equal. - assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] - - batch_size = in_x1.size()[0] - x1_max_len = in_x1.size()[1] - x2_max_len = in_x2.size()[1] - - in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] - - attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] - - a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] - a_mask = a_mask.view(batch_size, x1_max_len, -1) - a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] - b_mask = x2_len.le(0.5).float() * -self.inf - b_mask = b_mask.view(batch_size, -1, x2_max_len) - b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] - - attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] - attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] + prem_hyp_attn = _masked_softmax(similarity_matrix, hypothesis_mask) + hyp_prem_attn = _masked_softmax(similarity_matrix.transpose(1, 2) + .contiguous(), + premise_mask) - out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] - attention_b_t = torch.transpose(attention_b, 1, 2) - out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] + attended_premises = _weighted_sum(hypothesis_batch, + prem_hyp_attn, + premise_mask) + attended_hypotheses = _weighted_sum(premise_batch, + hyp_prem_attn, + hypothesis_mask) - return out_x1, out_x2 + return attended_premises, attended_hypotheses class SelfAttention(nn.Module): """ - Self Attention Module. + 这是一个基于论文 `A structured self-attentive sentence embedding `_ + 的Self Attention Module. """ def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ): @@ -210,9 +215,9 @@ class SelfAttention(nn.Module): def forward(self, input, input_origin): """ - :param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵 - :param torch.Tensor input_origin: [baz , senLen] 原始token的index组成的矩阵,含有pad部分内容 - :return torch.Tensor output1: [baz, multi-head , h_dim] 经过attention操作后输入矩阵的结果 + :param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵 + :param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容 + :return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果 :return torch.Tensor output2: [1] attention惩罚项,是一个标量 """ input = input.contiguous()