From 23625fe954c996b91a5c6a7e856a33ba21ca103d Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 23 Apr 2019 21:05:24 +0800 Subject: [PATCH] update documents in snli --- fastNLP/models/snli.py | 65 ++++++++++++++++++--------- fastNLP/modules/aggregator/pooling.py | 10 ++--- 2 files changed, 50 insertions(+), 25 deletions(-) diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index 901f2dd4..7ead5c18 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -5,38 +5,45 @@ from fastNLP.models.base_model import BaseModel from fastNLP.modules import decoder as Decoder from fastNLP.modules import encoder as Encoder from fastNLP.modules import aggregator as Aggregator +from fastNLP.modules.utils import seq_mask my_inf = 10e12 class ESIM(BaseModel): - """ - PyTorch Network for SNLI task using ESIM model. + """ESIM模型的一个PyTorch实现。 + ESIM模型的论文: Enhanced LSTM for Natural Language Inference (arXiv: 1609.06038) """ - def __init__(self, **kwargs): + def __init__(self, vocab_size, embed_dim, hidden_size, dropout=0.0, num_classes=3, init_embedding=None): + """ + :param int vocab_size: 词表大小 + :param int embed_dim: 词嵌入维度 + :param int hidden_size: LSTM隐层大小 + :param float dropout: dropout大小,默认为0 + :param int num_classes: 标签数目,默认为3 + :param numpy.array init_embedding: 初始词嵌入矩阵,形状为(vocab_size, embed_dim),默认为None,即随机初始化词嵌入矩阵 + """ super(ESIM, self).__init__() - self.vocab_size = kwargs["vocab_size"] - self.embed_dim = kwargs["embed_dim"] - self.hidden_size = kwargs["hidden_size"] - self.batch_first = kwargs["batch_first"] - self.dropout = kwargs["dropout"] - self.n_labels = kwargs["num_classes"] - self.gpu = kwargs["gpu"] and torch.cuda.is_available() + self.vocab_size = vocab_size + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.dropout = dropout + self.n_labels = num_classes self.drop = nn.Dropout(self.dropout) self.embedding = Encoder.Embedding( self.vocab_size, self.embed_dim, dropout=self.dropout, - init_emb=kwargs["init_embedding"] if "inin_embedding" in kwargs.keys() else None, + init_emb=init_embedding, ) self.embedding_layer = Encoder.Linear(self.embed_dim, self.hidden_size) self.encoder = Encoder.LSTM( input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, - batch_first=self.batch_first, bidirectional=True + batch_first=True, bidirectional=True ) self.bi_attention = Aggregator.BiAttention() @@ -47,24 +54,34 @@ class ESIM(BaseModel): self.decoder = Encoder.LSTM( input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, - batch_first=self.batch_first, bidirectional=True + batch_first=True, bidirectional=True ) self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) - def forward(self, words1, words2, seq_len1, seq_len2): + def forward(self, words1, words2, seq_len1=None, seq_len2=None): """ Forward function - - :param words1: A Tensor represents premise: [batch size(B), premise seq len(PL)]. - :param words2: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. - :param seq_len1: A Tensor record which is a real word and which is a padding word in premise: [B]. - :param seq_len2: A Tensor record which is a real word and which is a padding word in hypothesis: [B]. - :return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. + :param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 + :param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 + :param torch.LongTensor seq_len1: [B] premise的长度 + :param torch.LongTensor seq_len2: [B] hypothesis的长度 + :return: dict prediction: [B, n_labels(N)] 预测结果 """ premise0 = self.embedding_layer(self.embedding(words1)) hypothesis0 = self.embedding_layer(self.embedding(words2)) + if seq_len1 is not None: + seq_len1 = seq_mask(seq_len1, premise0.size(1)) + else: + seq_len1 = torch.ones(premise0.size(0), premise0.size(1)) + seq_len1 = (seq_len1.long()).to(device=premise0.device) + if seq_len2 is not None: + seq_len2 = seq_mask(seq_len2, hypothesis0.size(1)) + else: + seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) + seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) + _BP, _PSL, _HP = premise0.size() _BH, _HSL, _HH = hypothesis0.size() _BPL, _PLL = seq_len1.size() @@ -109,6 +126,14 @@ class ESIM(BaseModel): return {'pred': prediction} def predict(self, words1, words2, seq_len1, seq_len2): + """ Predict function + + :param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 + :param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 + :param torch.LongTensor seq_len1: [B] premise的长度 + :param torch.LongTensor seq_len2: [B] hypothesis的长度 + :return: dict prediction: [B, n_labels(N)] 预测结果 + """ prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] return {'pred': torch.argmax(prediction, dim=-1)} diff --git a/fastNLP/modules/aggregator/pooling.py b/fastNLP/modules/aggregator/pooling.py index 876f5fb1..9961b87f 100644 --- a/fastNLP/modules/aggregator/pooling.py +++ b/fastNLP/modules/aggregator/pooling.py @@ -63,8 +63,8 @@ class MaxPoolWithMask(nn.Module): def forward(self, tensor, mask, dim=1): """ - :param torch.Tensor tensor: [batch_size, seq_len, channels] 初始tensor - :param torch.Tensor mask: [batch_size, seq_len] 0/1的mask矩阵 + :param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor + :param torch.LongTensor mask: [batch_size, seq_len] 0/1的mask矩阵 :param int dim: 需要进行max pooling的维度 :return: """ @@ -120,13 +120,13 @@ class MeanPoolWithMask(nn.Module): def forward(self, tensor, mask, dim=1): """ - :param torch.Tensor tensor: [batch_size, seq_len, channels] 初始tensor - :param torch.Tensor mask: [batch_size, seq_len] 0/1的mask矩阵 + :param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor + :param torch.LongTensor mask: [batch_size, seq_len] 0/1的mask矩阵 :param int dim: 需要进行max pooling的维度 :return: """ masks = mask.view(mask.size(0), mask.size(1), -1).float() - return torch.sum(tensor * masks, dim=dim) / torch.sum(masks, dim=1) + return torch.sum(tensor * masks.float(), dim=dim) / torch.sum(masks.float(), dim=1)