diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index 6a7d8d84..5816d2af 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F from fastNLP.models.base_model import BaseModel from fastNLP.modules import decoder as Decoder @@ -40,7 +39,7 @@ class ESIM(BaseModel): batch_first=self.batch_first, bidirectional=True ) - self.bi_attention = Aggregator.Bi_Attention() + self.bi_attention = Aggregator.BiAttention() self.mean_pooling = Aggregator.MeanPoolWithMask() self.max_pooling = Aggregator.MaxPoolWithMask() @@ -53,23 +52,23 @@ class ESIM(BaseModel): self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) - def forward(self, premise, hypothesis, premise_len, hypothesis_len): + def forward(self, words1, words2, seq_len1, seq_len2): """ Forward function - :param premise: A Tensor represents premise: [batch size(B), premise seq len(PL)]. - :param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. - :param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL]. - :param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL]. + :param words1: A Tensor represents premise: [batch size(B), premise seq len(PL)]. + :param words2: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. + :param seq_len1: A Tensor record which is a real word and which is a padding word in premise: [B]. + :param seq_len2: A Tensor record which is a real word and which is a padding word in hypothesis: [B]. :return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. """ - premise0 = self.embedding_layer(self.embedding(premise)) - hypothesis0 = self.embedding_layer(self.embedding(hypothesis)) + premise0 = self.embedding_layer(self.embedding(words1)) + hypothesis0 = self.embedding_layer(self.embedding(words2)) _BP, _PSL, _HP = premise0.size() _BH, _HSL, _HH = hypothesis0.size() - _BPL, _PLL = premise_len.size() - _HPL, _HLL = hypothesis_len.size() + _BPL, _PLL = seq_len1.size() + _HPL, _HLL = seq_len2.size() assert _BP == _BH and _BPL == _HPL and _BP == _BPL assert _HP == _HH @@ -84,7 +83,7 @@ class ESIM(BaseModel): a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] - ai, bi = self.bi_attention(a, b, premise_len, hypothesis_len) + ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] @@ -98,17 +97,18 @@ class ESIM(BaseModel): va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] - va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H] - va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H] - vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H] - vb_max, vb_arg_max = self.max_pooling(vb, hypothesis_len, dim=1) # vb_max: [B, H] + va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] + va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] + vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] + vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] - prediction = F.tanh(self.output(v)) # prediction: [B, N] + prediction = torch.tanh(self.output(v)) # prediction: [B, N] return {'pred': prediction} - def predict(self, premise, hypothesis, premise_len, hypothesis_len): - return self.forward(premise, hypothesis, premise_len, hypothesis_len) + def predict(self, words1, words2, seq_len1, seq_len2): + prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] + return torch.argmax(prediction, dim=-1) diff --git a/fastNLP/modules/aggregator/__init__.py b/fastNLP/modules/aggregator/__init__.py index 2fabb89e..43d60cac 100644 --- a/fastNLP/modules/aggregator/__init__.py +++ b/fastNLP/modules/aggregator/__init__.py @@ -5,6 +5,6 @@ from .avg_pool import MeanPoolWithMask from .kmax_pool import KMaxPool from .attention import Attention -from .attention import Bi_Attention +from .attention import BiAttention from .self_attention import SelfAttention diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index ef9d159d..33d73a07 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -23,9 +23,9 @@ class Attention(torch.nn.Module): raise NotImplementedError -class DotAtte(nn.Module): +class DotAttention(nn.Module): def __init__(self, key_size, value_size, dropout=0.1): - super(DotAtte, self).__init__() + super(DotAttention, self).__init__() self.key_size = key_size self.value_size = value_size self.scale = math.sqrt(key_size) @@ -48,7 +48,7 @@ class DotAtte(nn.Module): return torch.matmul(output, V) -class MultiHeadAtte(nn.Module): +class MultiHeadAttention(nn.Module): def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): """ @@ -58,7 +58,7 @@ class MultiHeadAtte(nn.Module): :param num_head: int,head的数量。 :param dropout: float。 """ - super(MultiHeadAtte, self).__init__() + super(MultiHeadAttention, self).__init__() self.input_size = input_size self.key_size = key_size self.value_size = value_size @@ -68,7 +68,7 @@ class MultiHeadAtte(nn.Module): self.q_in = nn.Linear(input_size, in_size) self.k_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size) - self.attention = DotAtte(key_size=key_size, value_size=value_size) + self.attention = DotAttention(key_size=key_size, value_size=value_size) self.out = nn.Linear(value_size * num_head, input_size) self.drop = TimestepDropout(dropout) self.reset_parameters() @@ -109,16 +109,30 @@ class MultiHeadAtte(nn.Module): return output -class Bi_Attention(nn.Module): +class BiAttention(nn.Module): + """Bi Attention module + Calculate Bi Attention matrix `e` + .. math:: + \begin{array}{ll} \\ + e_ij = {a}^{\mathbf{T}}_{i}{b}_{j} \\ + a_i = + b_j = + \end{array} + """ + def __init__(self): - super(Bi_Attention, self).__init__() + super(BiAttention, self).__init__() self.inf = 10e12 def forward(self, in_x1, in_x2, x1_len, x2_len): - # in_x1: [batch_size, x1_seq_len, hidden_size] - # in_x2: [batch_size, x2_seq_len, hidden_size] - # x1_len: [batch_size, x1_seq_len] - # x2_len: [batch_size, x2_seq_len] + """ + :param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 + :param torch.Tensor in_x2: [batch_size, x2_seq_len, hidden_size] 第二句的特征表示 + :param torch.Tensor x1_len: [batch_size, x1_seq_len] 第一句的0/1mask矩阵 + :param torch.Tensor x2_len: [batch_size, x2_seq_len] 第二句的0/1mask矩阵 + :return: torch.Tensor out_x1: [batch_size, x1_seq_len, hidden_size] 第一句attend到的特征表示 + torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 + """ assert in_x1.size()[0] == in_x2.size()[0] assert in_x1.size()[2] == in_x2.size()[2] diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index d7b8c544..d1262141 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -1,6 +1,6 @@ from torch import nn -from ..aggregator.attention import MultiHeadAtte +from ..aggregator.attention import MultiHeadAttention from ..dropout import TimestepDropout @@ -18,7 +18,7 @@ class TransformerEncoder(nn.Module): class SubLayer(nn.Module): def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): super(TransformerEncoder.SubLayer, self).__init__() - self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) + self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) self.norm1 = nn.LayerNorm(model_size) self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), nn.ReLU(),