diff --git a/fastNLP/io/data_loader/mnli.py b/fastNLP/io/data_loader/mnli.py index 48923736..5d857533 100644 --- a/fastNLP/io/data_loader/mnli.py +++ b/fastNLP/io/data_loader/mnli.py @@ -1,5 +1,5 @@ -from ...core import Const +from ...core.const import Const from .matching import MatchingLoader from ..dataset_loader import CSVLoader diff --git a/fastNLP/io/data_loader/qnli.py b/fastNLP/io/data_loader/qnli.py index 650c6be7..ff6302b2 100644 --- a/fastNLP/io/data_loader/qnli.py +++ b/fastNLP/io/data_loader/qnli.py @@ -1,5 +1,5 @@ -from ...core import Const +from ...core.const import Const from .matching import MatchingLoader from ..dataset_loader import CSVLoader diff --git a/fastNLP/io/data_loader/quora.py b/fastNLP/io/data_loader/quora.py index 2c466a24..12cc42ce 100644 --- a/fastNLP/io/data_loader/quora.py +++ b/fastNLP/io/data_loader/quora.py @@ -1,5 +1,5 @@ -from ...core import Const +from ...core.const import Const from .matching import MatchingLoader from ..dataset_loader import CSVLoader diff --git a/fastNLP/io/data_loader/rte.py b/fastNLP/io/data_loader/rte.py index 9bf05d60..c6c64ef8 100644 --- a/fastNLP/io/data_loader/rte.py +++ b/fastNLP/io/data_loader/rte.py @@ -1,5 +1,5 @@ -from ...core import Const +from ...core.const import Const from .matching import MatchingLoader from ..dataset_loader import CSVLoader diff --git a/fastNLP/io/data_loader/snli.py b/fastNLP/io/data_loader/snli.py index 7c91ca86..8334fcfd 100644 --- a/fastNLP/io/data_loader/snli.py +++ b/fastNLP/io/data_loader/snli.py @@ -1,5 +1,5 @@ -from ...core import Const +from ...core.const import Const from .matching import MatchingLoader from ..dataset_loader import JsonLoader diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index 395a9bbf..d12524cc 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -4,149 +4,209 @@ __all__ = [ import torch import torch.nn as nn +import torch.nn.functional as F -from .base_model import BaseModel -from ..core.const import Const -from ..modules import decoder as Decoder -from ..modules import encoder as Encoder -from ..modules import aggregator as Aggregator -from ..core.utils import seq_len_to_mask +from torch.nn import CrossEntropyLoss -my_inf = 10e12 +from fastNLP.models import BaseModel +from fastNLP.modules.encoder.embedding import TokenEmbedding +from fastNLP.modules.encoder.lstm import LSTM +from fastNLP.core.const import Const +from fastNLP.core.utils import seq_len_to_mask class ESIM(BaseModel): - """ - 别名::class:`fastNLP.models.ESIM` :class:`fastNLP.models.snli.ESIM` - - ESIM模型的一个PyTorch实现。 - ESIM模型的论文: Enhanced LSTM for Natural Language Inference (arXiv: 1609.06038) + """ESIM model的一个PyTorch实现 + 论文参见: https://arxiv.org/pdf/1609.06038.pdf - :param int vocab_size: 词表大小 - :param int embed_dim: 词嵌入维度 - :param int hidden_size: LSTM隐层大小 - :param float dropout: dropout大小,默认为0 - :param int num_classes: 标签数目,默认为3 - :param numpy.array init_embedding: 初始词嵌入矩阵,形状为(vocab_size, embed_dim),默认为None,即随机初始化词嵌入矩阵 + :param fastNLP.TokenEmbedding init_embedding: 初始化的TokenEmbedding + :param int hidden_size: 隐藏层大小,默认值为Embedding的维度 + :param int num_labels: 目标标签种类数量,默认值为3 + :param float dropout_rate: dropout的比率,默认值为0.3 + :param float dropout_embed: 对Embedding的dropout比率,默认值为0.1 """ - - def __init__(self, vocab_size, embed_dim, hidden_size, dropout=0.0, num_classes=3, init_embedding=None): - + + def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, + dropout_embed=0.1): super(ESIM, self).__init__() - self.vocab_size = vocab_size - self.embed_dim = embed_dim - self.hidden_size = hidden_size - self.dropout = dropout - self.n_labels = num_classes - - self.drop = nn.Dropout(self.dropout) - - self.embedding = Encoder.Embedding( - (self.vocab_size, self.embed_dim), dropout=self.dropout, - ) - - self.embedding_layer = nn.Linear(self.embed_dim, self.hidden_size) - - self.encoder = Encoder.LSTM( - input_size=self.embed_dim, hidden_size=self.hidden_size, num_layers=1, bias=True, - batch_first=True, bidirectional=True - ) - - self.bi_attention = Aggregator.BiAttention() - self.mean_pooling = Aggregator.AvgPoolWithMask() - self.max_pooling = Aggregator.MaxPoolWithMask() - - self.inference_layer = nn.Linear(self.hidden_size * 4, self.hidden_size) - - self.decoder = Encoder.LSTM( - input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=1, bias=True, - batch_first=True, bidirectional=True - ) - - self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) - - def forward(self, words1, words2, seq_len1=None, seq_len2=None, target=None): - """ Forward function - - :param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 - :param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 - :param torch.LongTensor seq_len1: [B] premise的长度 - :param torch.LongTensor seq_len2: [B] hypothesis的长度 - :param torch.LongTensor target: [B] 真实目标值 - :return: dict prediction: [B, n_labels(N)] 预测结果 + + self.embedding = init_embedding + self.dropout_embed = EmbedDropout(p=dropout_embed) + if hidden_size is None: + hidden_size = self.embedding.embed_size + self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) + # self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True) + + self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate), + nn.Linear(8 * hidden_size, hidden_size), + nn.ReLU()) + nn.init.xavier_uniform_(self.interfere[1].weight.data) + self.bi_attention = SoftmaxAttention() + + self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) + # self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True,) + + self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), + nn.Linear(8 * hidden_size, hidden_size), + nn.Tanh(), + nn.Dropout(p=dropout_rate), + nn.Linear(hidden_size, num_labels)) + + self.dropout_rnn = nn.Dropout(p=dropout_rate) + + nn.init.xavier_uniform_(self.classifier[1].weight.data) + nn.init.xavier_uniform_(self.classifier[4].weight.data) + + def forward(self, words1, words2, seq_len1, seq_len2, target=None): """ - - premise0 = self.embedding_layer(self.embedding(words1)) - hypothesis0 = self.embedding_layer(self.embedding(words2)) - - if seq_len1 is not None: - seq_len1 = seq_len_to_mask(seq_len1) - else: - seq_len1 = torch.ones(premise0.size(0), premise0.size(1)) - seq_len1 = (seq_len1.long()).to(device=premise0.device) - if seq_len2 is not None: - seq_len2 = seq_len_to_mask(seq_len2) - else: - seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) - seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) - - _BP, _PSL, _HP = premise0.size() - _BH, _HSL, _HH = hypothesis0.size() - _BPL, _PLL = seq_len1.size() - _HPL, _HLL = seq_len2.size() - - assert _BP == _BH and _BPL == _HPL and _BP == _BPL - assert _HP == _HH - assert _PSL == _PLL and _HSL == _HLL - - B, PL, H = premise0.size() - B, HL, H = hypothesis0.size() - - a0 = self.encoder(self.drop(premise0)) # a0: [B, PL, H * 2] - b0 = self.encoder(self.drop(hypothesis0)) # b0: [B, HL, H * 2] - - a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] - b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] - - ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) - - ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] - mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] - - f_ma = self.inference_layer(ma) - f_mb = self.inference_layer(mb) - - vat = self.decoder(self.drop(f_ma)) - vbt = self.decoder(self.drop(f_mb)) - - va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] - vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] - - va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] - va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] - vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] - vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] - - v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] - - prediction = torch.tanh(self.output(v)) # prediction: [B, N] - - if target is not None: - func = nn.CrossEntropyLoss() - loss = func(prediction, target) - return {Const.OUTPUT: prediction, Const.LOSS: loss} - - return {Const.OUTPUT: prediction} - - def predict(self, words1, words2, seq_len1=None, seq_len2=None, target=None): - """ Predict function - - :param torch.Tensor words1: [batch size(B), premise seq len(PL)] premise的token表示 - :param torch.Tensor words2: [B, hypothesis seq len(HL)] hypothesis的token表示 - :param torch.LongTensor seq_len1: [B] premise的长度 - :param torch.LongTensor seq_len2: [B] hypothesis的长度 - :param torch.LongTensor target: [B] 真实目标值 - :return: dict prediction: [B, n_labels(N)] 预测结果 + :param words1: [batch, seq_len] + :param words2: [batch, seq_len] + :param seq_len1: [batch] + :param seq_len2: [batch] + :param target: + :return: """ - prediction = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT] - return {Const.OUTPUT: torch.argmax(prediction, dim=-1)} + mask1 = seq_len_to_mask(seq_len1, words1.size(1)) + mask2 = seq_len_to_mask(seq_len2, words2.size(1)) + a0 = self.embedding(words1) # B * len * emb_dim + b0 = self.embedding(words2) + a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) + a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] + b = self.rnn(b0, mask2.byte()) + # a = self.dropout_rnn(self.rnn(a0, seq_len1)[0]) # a: [B, PL, 2 * H] + # b = self.dropout_rnn(self.rnn(b0, seq_len2)[0]) + + ai, bi = self.bi_attention(a, mask1, b, mask2) + + a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H] + b_ = torch.cat((b, bi, b - bi, b * bi), dim=2) + a_f = self.interfere(a_) + b_f = self.interfere(b_) + + a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] + b_h = self.rnn_high(b_f, mask2.byte()) + # a_h = self.dropout_rnn(self.rnn_high(a_f, seq_len1)[0]) # ma: [B, PL, 2 * H] + # b_h = self.dropout_rnn(self.rnn_high(b_f, seq_len2)[0]) + + a_avg = self.mean_pooling(a_h, mask1, dim=1) + a_max, _ = self.max_pooling(a_h, mask1, dim=1) + b_avg = self.mean_pooling(b_h, mask2, dim=1) + b_max, _ = self.max_pooling(b_h, mask2, dim=1) + + out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] + logits = torch.tanh(self.classifier(out)) + + if target is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits, target) + + return {Const.LOSS: loss, Const.OUTPUT: logits} + else: + return {Const.OUTPUT: logits} + + def predict(self, **kwargs): + pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) + return {Const.OUTPUT: pred} + + # input [batch_size, len , hidden] + # mask [batch_size, len] (111...00) + @staticmethod + def mean_pooling(input, mask, dim=1): + masks = mask.view(mask.size(0), mask.size(1), -1).float() + return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1) + + @staticmethod + def max_pooling(input, mask, dim=1): + my_inf = 10e12 + masks = mask.view(mask.size(0), mask.size(1), -1) + masks = masks.expand(-1, -1, input.size(2)).float() + return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim) + + +class EmbedDropout(nn.Dropout): + + def forward(self, sequences_batch): + ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1]) + dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False) + return dropout_mask.unsqueeze(1) * sequences_batch + + +class BiRNN(nn.Module): + def __init__(self, input_size, hidden_size, dropout_rate=0.3): + super(BiRNN, self).__init__() + self.dropout_rate = dropout_rate + self.rnn = nn.LSTM(input_size, hidden_size, + num_layers=1, + bidirectional=True, + batch_first=True) + + def forward(self, x, x_mask): + # Sort x + lengths = x_mask.data.eq(1).long().sum(1) + _, idx_sort = torch.sort(lengths, dim=0, descending=True) + _, idx_unsort = torch.sort(idx_sort, dim=0) + lengths = list(lengths[idx_sort]) + + x = x.index_select(0, idx_sort) + # Pack it up + rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) + # Apply dropout to input + if self.dropout_rate > 0: + dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) + rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) + output = self.rnn(rnn_input)[0] + # Unpack everything + output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] + output = output.index_select(0, idx_unsort) + if output.size(1) != x_mask.size(1): + padding = torch.zeros(output.size(0), + x_mask.size(1) - output.size(1), + output.size(2)).type(output.data.type()) + output = torch.cat([output, padding], 1) + return output + + +def masked_softmax(tensor, mask): + tensor_shape = tensor.size() + reshaped_tensor = tensor.view(-1, tensor_shape[-1]) + + # Reshape the mask so it matches the size of the input tensor. + while mask.dim() < tensor.dim(): + mask = mask.unsqueeze(1) + mask = mask.expand_as(tensor).contiguous().float() + reshaped_mask = mask.view(-1, mask.size()[-1]) + result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) + result = result * reshaped_mask + # 1e-13 is added to avoid divisions by zero. + result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) + return result.view(*tensor_shape) + + +def weighted_sum(tensor, weights, mask): + w_sum = weights.bmm(tensor) + while mask.dim() < w_sum.dim(): + mask = mask.unsqueeze(1) + mask = mask.transpose(-1, -2) + mask = mask.expand_as(w_sum).contiguous().float() + return w_sum * mask + + +class SoftmaxAttention(nn.Module): + + def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): + similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) + .contiguous()) + + prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) + hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) + .contiguous(), + premise_mask) + + attended_premises = weighted_sum(hypothesis_batch, + prem_hyp_attn, + premise_mask) + attended_hypotheses = weighted_sum(premise_batch, + hyp_prem_attn, + hypothesis_mask) + + return attended_premises, attended_hypotheses