From fcf105051259b8b10d5e4603ebbeba8590606964 Mon Sep 17 00:00:00 2001 From: xuyige Date: Wed, 9 Jan 2019 00:09:22 +0800 Subject: [PATCH] add Bi-Attention --- fastNLP/modules/aggregator/attention.py | 44 +++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 882807f8..5fc8a091 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -1,5 +1,6 @@ import torch from torch import nn +import torch.nn.functional as F import math from fastNLP.modules.utils import mask_softmax @@ -62,3 +63,46 @@ class MultiHeadAtte(nn.Module): heads.append(headi) output = torch.cat(heads, dim=2) return self.out_linear(output) + + +class Bi_Attention(nn.Module): + def __init__(self): + super(Bi_Attention, self).__init__() + self.inf = 10e12 + + def forward(self, in_x1, in_x2, x1_len, x2_len): + # in_x1: [batch_size, x1_seq_len, hidden_size] + # in_x2: [batch_size, x2_seq_len, hidden_size] + # x1_len: [batch_size, x1_seq_len] + # x2_len: [batch_size, x2_seq_len] + + assert in_x1.size()[0] == in_x2.size()[0] + assert in_x1.size()[2] == in_x2.size()[2] + # The batch size and hidden size must be equal. + assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] + # The seq len in in_x and x_len must be equal. + assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] + + batch_size = in_x1.size()[0] + x1_max_len = in_x1.size()[1] + x2_max_len = in_x2.size()[1] + + in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] + + attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] + + a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] + a_mask = a_mask.view(batch_size, x1_max_len, -1) + a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] + b_mask = x2_len.le(0.5).float() * -self.inf + b_mask = b_mask.view(batch_size, -1, x2_max_len) + b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] + + attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] + attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] + + out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] + attention_b_t = torch.transpose(attention_b, 1, 2) + out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] + + return out_x1, out_x2