|
@@ -1,5 +1,6 @@ |
|
|
import torch |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch import nn |
|
|
|
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
import math |
|
|
from fastNLP.modules.utils import mask_softmax |
|
|
from fastNLP.modules.utils import mask_softmax |
|
|
|
|
|
|
|
@@ -62,3 +63,46 @@ class MultiHeadAtte(nn.Module): |
|
|
heads.append(headi) |
|
|
heads.append(headi) |
|
|
output = torch.cat(heads, dim=2) |
|
|
output = torch.cat(heads, dim=2) |
|
|
return self.out_linear(output) |
|
|
return self.out_linear(output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Bi_Attention(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
super(Bi_Attention, self).__init__() |
|
|
|
|
|
self.inf = 10e12 |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, in_x1, in_x2, x1_len, x2_len): |
|
|
|
|
|
# in_x1: [batch_size, x1_seq_len, hidden_size] |
|
|
|
|
|
# in_x2: [batch_size, x2_seq_len, hidden_size] |
|
|
|
|
|
# x1_len: [batch_size, x1_seq_len] |
|
|
|
|
|
# x2_len: [batch_size, x2_seq_len] |
|
|
|
|
|
|
|
|
|
|
|
assert in_x1.size()[0] == in_x2.size()[0] |
|
|
|
|
|
assert in_x1.size()[2] == in_x2.size()[2] |
|
|
|
|
|
# The batch size and hidden size must be equal. |
|
|
|
|
|
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] |
|
|
|
|
|
# The seq len in in_x and x_len must be equal. |
|
|
|
|
|
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] |
|
|
|
|
|
|
|
|
|
|
|
batch_size = in_x1.size()[0] |
|
|
|
|
|
x1_max_len = in_x1.size()[1] |
|
|
|
|
|
x2_max_len = in_x2.size()[1] |
|
|
|
|
|
|
|
|
|
|
|
in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] |
|
|
|
|
|
|
|
|
|
|
|
attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] |
|
|
|
|
|
|
|
|
|
|
|
a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len] |
|
|
|
|
|
a_mask = a_mask.view(batch_size, x1_max_len, -1) |
|
|
|
|
|
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] |
|
|
|
|
|
b_mask = x2_len.le(0.5).float() * -self.inf |
|
|
|
|
|
b_mask = b_mask.view(batch_size, -1, x2_max_len) |
|
|
|
|
|
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] |
|
|
|
|
|
|
|
|
|
|
|
attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] |
|
|
|
|
|
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] |
|
|
|
|
|
|
|
|
|
|
|
out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] |
|
|
|
|
|
attention_b_t = torch.transpose(attention_b, 1, 2) |
|
|
|
|
|
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] |
|
|
|
|
|
|
|
|
|
|
|
return out_x1, out_x2 |