Browse Source

add Bi-Attention

tags/v0.3.0^2
xuyige 6 years ago
parent
commit
fcf1050512
1 changed files with 44 additions and 0 deletions
  1. +44
    -0
      fastNLP/modules/aggregator/attention.py

+ 44
- 0
fastNLP/modules/aggregator/attention.py View File

@@ -1,5 +1,6 @@
import torch
from torch import nn
import torch.nn.functional as F
import math
from fastNLP.modules.utils import mask_softmax

@@ -62,3 +63,46 @@ class MultiHeadAtte(nn.Module):
heads.append(headi)
output = torch.cat(heads, dim=2)
return self.out_linear(output)


class Bi_Attention(nn.Module):
def __init__(self):
super(Bi_Attention, self).__init__()
self.inf = 10e12

def forward(self, in_x1, in_x2, x1_len, x2_len):
# in_x1: [batch_size, x1_seq_len, hidden_size]
# in_x2: [batch_size, x2_seq_len, hidden_size]
# x1_len: [batch_size, x1_seq_len]
# x2_len: [batch_size, x2_seq_len]

assert in_x1.size()[0] == in_x2.size()[0]
assert in_x1.size()[2] == in_x2.size()[2]
# The batch size and hidden size must be equal.
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1]
# The seq len in in_x and x_len must be equal.
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0]

batch_size = in_x1.size()[0]
x1_max_len = in_x1.size()[1]
x2_max_len = in_x2.size()[1]

in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len]

attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len]

a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len]
a_mask = a_mask.view(batch_size, x1_max_len, -1)
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len]
b_mask = x2_len.le(0.5).float() * -self.inf
b_mask = b_mask.view(batch_size, -1, x2_max_len)
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len]

attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len]
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len]

out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size]
attention_b_t = torch.transpose(attention_b, 1, 2)
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size]

return out_x1, out_x2

Loading…
Cancel
Save