From d978b0c1259b5bc39ee6eccc40e145b01b385501 Mon Sep 17 00:00:00 2001 From: xuyige Date: Wed, 9 Jan 2019 00:11:05 +0800 Subject: [PATCH] update ESIM model and rename it --- fastNLP/models/__init__.py | 2 +- fastNLP/models/snli.py | 121 ++++++++++++------------------------- 2 files changed, 38 insertions(+), 85 deletions(-) diff --git a/fastNLP/models/__init__.py b/fastNLP/models/__init__.py index 5bb2bc3d..a83c3936 100644 --- a/fastNLP/models/__init__.py +++ b/fastNLP/models/__init__.py @@ -3,4 +3,4 @@ from .biaffine_parser import BiaffineParser, GraphParser from .char_language_model import CharLM from .cnn_text_classification import CNNText from .sequence_modeling import SeqLabeling, AdvSeqLabel -from .snli import SNLI +from .snli import ESIM diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index 9be7f43d..6a7d8d84 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -3,29 +3,35 @@ import torch.nn as nn import torch.nn.functional as F from fastNLP.models.base_model import BaseModel -from fastNLP.modules import decoder as Decoder, encoder as Encoder +from fastNLP.modules import decoder as Decoder +from fastNLP.modules import encoder as Encoder +from fastNLP.modules import aggregator as Aggregator my_inf = 10e12 -class SNLI(BaseModel): +class ESIM(BaseModel): """ - PyTorch Network for SNLI. + PyTorch Network for SNLI task using ESIM model. """ - def __init__(self, args, init_embedding=None): - super(SNLI, self).__init__() - self.vocab_size = args["vocab_size"] - self.embed_dim = args["embed_dim"] - self.hidden_size = args["hidden_size"] - self.batch_first = args["batch_first"] - self.dropout = args["dropout"] - self.n_labels = args["num_classes"] - self.gpu = args["gpu"] and torch.cuda.is_available() - - self.embedding = Encoder.embedding.Embedding(self.vocab_size, self.embed_dim, init_emb=init_embedding, - dropout=self.dropout) + def __init__(self, **kwargs): + super(ESIM, self).__init__() + self.vocab_size = kwargs["vocab_size"] + self.embed_dim = kwargs["embed_dim"] + self.hidden_size = kwargs["hidden_size"] + self.batch_first = kwargs["batch_first"] + self.dropout = kwargs["dropout"] + self.n_labels = kwargs["num_classes"] + self.gpu = kwargs["gpu"] and torch.cuda.is_available() + + self.drop = nn.Dropout(self.dropout) + + self.embedding = Encoder.Embedding( + self.vocab_size, self.embed_dim, dropout=self.dropout, + init_emb=kwargs["init_embedding"] if "inin_embedding" in kwargs.keys() else None, + ) self.embedding_layer = Encoder.Linear(self.embed_dim, self.hidden_size) @@ -34,6 +40,10 @@ class SNLI(BaseModel): batch_first=self.batch_first, bidirectional=True ) + self.bi_attention = Aggregator.Bi_Attention() + self.mean_pooling = Aggregator.MeanPoolWithMask() + self.max_pooling = Aggregator.MaxPoolWithMask() + self.inference_layer = Encoder.Linear(self.hidden_size * 4, self.hidden_size) self.decoder = Encoder.LSTM( @@ -41,16 +51,16 @@ class SNLI(BaseModel): batch_first=self.batch_first, bidirectional=True ) - self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh') + self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) def forward(self, premise, hypothesis, premise_len, hypothesis_len): """ Forward function - :param premise: A Tensor represents premise: [batch size(B), premise seq len(PL), hidden size(H)]. - :param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL), H]. + :param premise: A Tensor represents premise: [batch size(B), premise seq len(PL)]. + :param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. :param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL]. :param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL]. - :return: prediction: A Tensor of classification result: [B, n_labels(N)]. + :return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. """ premise0 = self.embedding_layer(self.embedding(premise)) @@ -68,16 +78,13 @@ class SNLI(BaseModel): B, PL, H = premise0.size() B, HL, H = hypothesis0.size() - # a0, (ah0, ac0) = self.encoder(premise) # a0: [B, PL, H * 2], ah0: [2, B, H] - # b0, (bh0, bc0) = self.encoder(hypothesis) # b0: [B, HL, H * 2] - - a0 = self.encoder(premise0) # a0: [B, PL, H * 2] - b0 = self.encoder(hypothesis0) # b0: [B, HL, H * 2] + a0 = self.encoder(self.drop(premise0)) # a0: [B, PL, H * 2] + b0 = self.encoder(self.drop(hypothesis0)) # b0: [B, HL, H * 2] a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] - ai, bi = self.calc_bi_attention(a, b, premise_len, hypothesis_len) + ai, bi = self.bi_attention(a, b, premise_len, hypothesis_len) ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] @@ -85,17 +92,12 @@ class SNLI(BaseModel): f_ma = self.inference_layer(ma) f_mb = self.inference_layer(mb) - vat = self.decoder(f_ma) - vbt = self.decoder(f_mb) + vat = self.decoder(self.drop(f_ma)) + vbt = self.decoder(self.drop(f_mb)) va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] - # va_ave = torch.mean(va, dim=1) # va_ave: [B, H] - # va_max, va_arg_max = torch.max(va, dim=1) # va_max: [B, H] - # vb_ave = torch.mean(vb, dim=1) # vb_ave: [B, H] - # vb_max, vb_arg_max = torch.max(vb, dim=1) # vb_max: [B, H] - va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H] va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H] vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H] @@ -103,59 +105,10 @@ class SNLI(BaseModel): v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] - # v_mlp = F.tanh(self.mlp_layer1(v)) # v_mlp: [B, H] - # prediction = self.mlp_layer2(v_mlp) # prediction: [B, N] - prediction = F.tanh(self.output(v)) # prediction: [B, N] - return prediction - - @staticmethod - def calc_bi_attention(in_x1, in_x2, x1_len, x2_len): - - # in_x1: [batch_size, x1_seq_len, hidden_size] - # in_x2: [batch_size, x2_seq_len, hidden_size] - # x1_len: [batch_size, x1_seq_len] - # x2_len: [batch_size, x2_seq_len] - - assert in_x1.size()[0] == in_x2.size()[0] - assert in_x1.size()[2] == in_x2.size()[2] - # The batch size and hidden size must be equal. - assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1] - # The seq len in in_x and x_len must be equal. - assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0] - - batch_size = in_x1.size()[0] - x1_max_len = in_x1.size()[1] - x2_max_len = in_x2.size()[1] - - in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len] - - attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len] - - a_mask = x1_len.le(0.5).float() * -my_inf # [batch_size, x1_seq_len] - a_mask = a_mask.view(batch_size, x1_max_len, -1) - a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len] - b_mask = x2_len.le(0.5).float() * -my_inf - b_mask = b_mask.view(batch_size, -1, x2_max_len) - b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len] - - attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len] - attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len] - - out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size] - attention_b_t = torch.transpose(attention_b, 1, 2) - out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size] - - return out_x1, out_x2 + return {'pred': prediction} - @staticmethod - def mean_pooling(tensor, mask, dim=0): - masks = mask.view(mask.size(0), mask.size(1), -1).float() - return torch.sum(tensor * masks, dim=dim) / torch.sum(masks, dim=1) + def predict(self, premise, hypothesis, premise_len, hypothesis_len): + return self.forward(premise, hypothesis, premise_len, hypothesis_len) - @staticmethod - def max_pooling(tensor, mask, dim=0): - masks = mask.view(mask.size(0), mask.size(1), -1) - masks = masks.expand(-1, -1, tensor.size(2)).float() - return torch.max(tensor + masks.le(0.5).float() * -my_inf, dim=dim)