Browse Source

update attention

tags/v0.4.10
xuyige 6 years ago
parent
commit
9d43239fc1
4 changed files with 47 additions and 33 deletions
  1. +19
    -19
      fastNLP/models/snli.py
  2. +1
    -1
      fastNLP/modules/aggregator/__init__.py
  3. +25
    -11
      fastNLP/modules/aggregator/attention.py
  4. +2
    -2
      fastNLP/modules/encoder/transformer.py

+ 19
- 19
fastNLP/models/snli.py View File

@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder as Decoder
@@ -40,7 +39,7 @@ class ESIM(BaseModel):
batch_first=self.batch_first, bidirectional=True
)

self.bi_attention = Aggregator.Bi_Attention()
self.bi_attention = Aggregator.BiAttention()
self.mean_pooling = Aggregator.MeanPoolWithMask()
self.max_pooling = Aggregator.MaxPoolWithMask()

@@ -53,23 +52,23 @@ class ESIM(BaseModel):

self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout)

def forward(self, premise, hypothesis, premise_len, hypothesis_len):
def forward(self, words1, words2, seq_len1, seq_len2):
""" Forward function

:param premise: A Tensor represents premise: [batch size(B), premise seq len(PL)].
:param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL)].
:param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL].
:param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL].
:param words1: A Tensor represents premise: [batch size(B), premise seq len(PL)].
:param words2: A Tensor represents hypothesis: [B, hypothesis seq len(HL)].
:param seq_len1: A Tensor record which is a real word and which is a padding word in premise: [B].
:param seq_len2: A Tensor record which is a real word and which is a padding word in hypothesis: [B].
:return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)].
"""

premise0 = self.embedding_layer(self.embedding(premise))
hypothesis0 = self.embedding_layer(self.embedding(hypothesis))
premise0 = self.embedding_layer(self.embedding(words1))
hypothesis0 = self.embedding_layer(self.embedding(words2))

_BP, _PSL, _HP = premise0.size()
_BH, _HSL, _HH = hypothesis0.size()
_BPL, _PLL = premise_len.size()
_HPL, _HLL = hypothesis_len.size()
_BPL, _PLL = seq_len1.size()
_HPL, _HLL = seq_len2.size()

assert _BP == _BH and _BPL == _HPL and _BP == _BPL
assert _HP == _HH
@@ -84,7 +83,7 @@ class ESIM(BaseModel):
a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H]
b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H]

ai, bi = self.bi_attention(a, b, premise_len, hypothesis_len)
ai, bi = self.bi_attention(a, b, seq_len1, seq_len2)

ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H]
mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H]
@@ -98,17 +97,18 @@ class ESIM(BaseModel):
va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H]
vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H]

va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H]
va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H]
vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H]
vb_max, vb_arg_max = self.max_pooling(vb, hypothesis_len, dim=1) # vb_max: [B, H]
va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H]
va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H]
vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H]
vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H]

v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H]

prediction = F.tanh(self.output(v)) # prediction: [B, N]
prediction = torch.tanh(self.output(v)) # prediction: [B, N]

return {'pred': prediction}

def predict(self, premise, hypothesis, premise_len, hypothesis_len):
return self.forward(premise, hypothesis, premise_len, hypothesis_len)
def predict(self, words1, words2, seq_len1, seq_len2):
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred']
return torch.argmax(prediction, dim=-1)


+ 1
- 1
fastNLP/modules/aggregator/__init__.py View File

@@ -5,6 +5,6 @@ from .avg_pool import MeanPoolWithMask
from .kmax_pool import KMaxPool

from .attention import Attention
from .attention import Bi_Attention
from .attention import BiAttention
from .self_attention import SelfAttention


+ 25
- 11
fastNLP/modules/aggregator/attention.py View File

@@ -23,9 +23,9 @@ class Attention(torch.nn.Module):
raise NotImplementedError


class DotAtte(nn.Module):
class DotAttention(nn.Module):
def __init__(self, key_size, value_size, dropout=0.1):
super(DotAtte, self).__init__()
super(DotAttention, self).__init__()
self.key_size = key_size
self.value_size = value_size
self.scale = math.sqrt(key_size)
@@ -48,7 +48,7 @@ class DotAtte(nn.Module):
return torch.matmul(output, V)


class MultiHeadAtte(nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1):
"""

@@ -58,7 +58,7 @@ class MultiHeadAtte(nn.Module):
:param num_head: int,head的数量。
:param dropout: float。
"""
super(MultiHeadAtte, self).__init__()
super(MultiHeadAttention, self).__init__()
self.input_size = input_size
self.key_size = key_size
self.value_size = value_size
@@ -68,7 +68,7 @@ class MultiHeadAtte(nn.Module):
self.q_in = nn.Linear(input_size, in_size)
self.k_in = nn.Linear(input_size, in_size)
self.v_in = nn.Linear(input_size, in_size)
self.attention = DotAtte(key_size=key_size, value_size=value_size)
self.attention = DotAttention(key_size=key_size, value_size=value_size)
self.out = nn.Linear(value_size * num_head, input_size)
self.drop = TimestepDropout(dropout)
self.reset_parameters()
@@ -109,16 +109,30 @@ class MultiHeadAtte(nn.Module):
return output


class Bi_Attention(nn.Module):
class BiAttention(nn.Module):
"""Bi Attention module
Calculate Bi Attention matrix `e`
.. math::
\begin{array}{ll} \\
e_ij = {a}^{\mathbf{T}}_{i}{b}_{j} \\
a_i =
b_j =
\end{array}
"""

def __init__(self):
super(Bi_Attention, self).__init__()
super(BiAttention, self).__init__()
self.inf = 10e12

def forward(self, in_x1, in_x2, x1_len, x2_len):
# in_x1: [batch_size, x1_seq_len, hidden_size]
# in_x2: [batch_size, x2_seq_len, hidden_size]
# x1_len: [batch_size, x1_seq_len]
# x2_len: [batch_size, x2_seq_len]
"""
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示
:param torch.Tensor in_x2: [batch_size, x2_seq_len, hidden_size] 第二句的特征表示
:param torch.Tensor x1_len: [batch_size, x1_seq_len] 第一句的0/1mask矩阵
:param torch.Tensor x2_len: [batch_size, x2_seq_len] 第二句的0/1mask矩阵
:return: torch.Tensor out_x1: [batch_size, x1_seq_len, hidden_size] 第一句attend到的特征表示
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示
"""

assert in_x1.size()[0] == in_x2.size()[0]
assert in_x1.size()[2] == in_x2.size()[2]


+ 2
- 2
fastNLP/modules/encoder/transformer.py View File

@@ -1,6 +1,6 @@
from torch import nn

from ..aggregator.attention import MultiHeadAtte
from ..aggregator.attention import MultiHeadAttention
from ..dropout import TimestepDropout


@@ -18,7 +18,7 @@ class TransformerEncoder(nn.Module):
class SubLayer(nn.Module):
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1):
super(TransformerEncoder.SubLayer, self).__init__()
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout)
self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout)
self.norm1 = nn.LayerNorm(model_size)
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size),
nn.ReLU(),


Loading…
Cancel
Save