From 93620e76edf0162f8b9d8f844728dfd1c203e58d Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 18 Jun 2019 02:04:53 +0800 Subject: [PATCH] update framework of matching --- fastNLP/io/dataset_loader.py | 25 ++-- fastNLP/modules/encoder/bert.py | 4 +- fastNLP/modules/encoder/embedding.py | 11 +- reproduction/matching/matching.py | 26 ++-- reproduction/matching/model/esim.py | 182 +++++++++++++++++++++++++++ 5 files changed, 221 insertions(+), 27 deletions(-) create mode 100644 reproduction/matching/model/esim.py diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index c63ff2f4..b0bf2e60 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -269,7 +269,7 @@ class MatchingLoader(DataSetLoader): def _load(self, path: str) -> DataSet: raise NotImplementedError - def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: + def process(self, paths: Union[str, Dict[str, str]], input_field=None) -> DataInfo: if isinstance(paths, str): paths = {'train': paths} @@ -289,6 +289,13 @@ class MatchingLoader(DataSetLoader): raise RuntimeError(f'Your model is {self.data_format}, ' f'Please choose from [esim, bert]') + if input_field is not None: + if isinstance(input_field, str): + data.set_input(input_field) + elif isinstance(input_field, list): + for field in input_field: + data.set_input(field) + data_set[n] = data print(f'successfully load {n} set!') @@ -298,11 +305,11 @@ class MatchingLoader(DataSetLoader): raise RuntimeError(f'There is NOT label vocab attribute built!') if self.for_model != 'bert': - from fastNLP.modules.encoder.embedding import StaticEmbedding - embedding = StaticEmbedding(self.vocab, model_dir_or_name='en') + from fastNLP.modules.encoder.embedding import ElmoEmbedding + embedding = ElmoEmbedding(self.vocab, model_dir_or_name='en', requires_grad=True, layers='2') data_info = DataInfo(vocabs={'vocab': self.vocab, 'target_vocab': self.label_vocab}, - embeddings={'glove': embedding} if self.for_model != 'bert' else None, + embeddings={'elmo': embedding} if self.for_model != 'bert' else None, datasets=data_set) return data_info @@ -338,15 +345,17 @@ class MatchingLoader(DataSetLoader): raw_ds.drop(lambda x: x[Const.TARGET] == '-') if not hasattr(self, 'vocab'): - self.vocab = Vocabulary().from_dataset(raw_ds, [Const.INPUTS(0), Const.INPUTS(1)]) + self.vocab = Vocabulary().from_dataset(raw_ds, field_name=[Const.INPUTS(0), Const.INPUTS(1)]) if not hasattr(self, 'label_vocab'): self.label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(raw_ds, field_name=Const.TARGET) raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0)) raw_ds.apply(lambda ins: [self.vocab.to_index(w) for w in ins[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1)) - raw_ds.apply(lambda ins: self.label_vocab.to_index(Const.TARGET), new_field_name=Const.TARGET) + raw_ds.apply(lambda ins: self.label_vocab.to_index(ins[Const.TARGET]), new_field_name=Const.TARGET) + raw_ds.apply(lambda ins: len(ins[Const.INPUTS(0)]), new_field_name=Const.INPUT_LENS(0)) + raw_ds.apply(lambda ins: len(ins[Const.INPUTS(1)]), new_field_name=Const.INPUT_LENS(1)) - raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1)) + raw_ds.set_input(Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)) raw_ds.set_target(Const.TARGET) return raw_ds @@ -405,6 +414,8 @@ class MatchingLoader(DataSetLoader): raw_ds.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1)) raw_ds.set_target(Const.TARGET) + return raw_ds + class SNLILoader(JsonLoader): """ diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index e9739c28..4948d022 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -2,9 +2,9 @@ import os from torch import nn import torch -from ...core import Vocabulary +from ...core.vocabulary import Vocabulary from ...io.file_utils import _get_base_url, cached_path -from ._bert import _WordPieceBertModel +from ._bert import _WordPieceBertModel, BertModel class BertWordPieceEncoder(nn.Module): diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 7fd85578..9c1bf35f 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -152,6 +152,8 @@ class StaticEmbedding(TokenEmbedding): Example:: + >>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50') + :param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding @@ -311,8 +313,7 @@ class ElmoEmbedding(ContextualEmbedding): Example:: - >>> - >>> + >>> embedding = ElmoEmbedding(vocab, model_dir_or_name='en', layers='2', requires_grad=True) :param vocab: 词表 :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称, @@ -403,7 +404,7 @@ class BertEmbedding(ContextualEmbedding): Example:: - >>> + >>> embedding = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1') :param fastNLP.Vocabulary vocab: 词表 @@ -513,7 +514,7 @@ class CNNCharEmbedding(TokenEmbedding): Example:: - >>> + >>> cnn_char_embed = CNNCharEmbedding(vocab) :param vocab: 词表 @@ -647,7 +648,7 @@ class LSTMCharEmbedding(TokenEmbedding): Example:: - >>> + >>> lstm_char_embed = LSTMCharEmbedding(vocab) :param vocab: 词表 :param embed_size: embedding的大小。默认值为50. diff --git a/reproduction/matching/matching.py b/reproduction/matching/matching.py index 52c1c3b5..8251b3bc 100644 --- a/reproduction/matching/matching.py +++ b/reproduction/matching/matching.py @@ -2,31 +2,31 @@ import os import torch -from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric +from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const from fastNLP.io.dataset_loader import MatchingLoader from reproduction.matching.model.bert import BertForNLI +from reproduction.matching.model.esim import ESIMModel -# bert_dirs = 'path/to/bert/dir' -bert_dirs = '/remote-home/ygxu/BERT/BERT_English_uncased_L-12_H-768_A_12' +bert_dirs = 'path/to/bert/dir' # load data set -data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process( - {#'train': './data/snli/snli_1.0_train.jsonl', +# data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(... +data_info = MatchingLoader(data_format='snli', for_model='esim').process( + {'train': './data/snli/snli_1.0_train.jsonl', 'dev': './data/snli/snli_1.0_dev.jsonl', - 'test': './data/snli/snli_1.0_test.jsonl'} + 'test': './data/snli/snli_1.0_test.jsonl'}, + input_field=[Const.TARGET] ) -print('successfully load data sets!') +# model = BertForNLI(bert_dir=bert_dirs) +model = ESIMModel(data_info.embeddings['elmo'],) - -model = BertForNLI(bert_dir=bert_dirs) - -trainer = Trainer(train_data=data_info.datasets['dev'], model=model, - optimizer=Adam(lr=2e-5, model_params=model.parameters()), - batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, +trainer = Trainer(train_data=data_info.datasets['train'], model=model, + optimizer=Adam(lr=1e-4, model_params=model.parameters()), + batch_size=torch.cuda.device_count() * 24, n_epochs=20, print_every=-1, dev_data=data_info.datasets['dev'], metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], check_code_level=-1) diff --git a/reproduction/matching/model/esim.py b/reproduction/matching/model/esim.py new file mode 100644 index 00000000..0551bbdb --- /dev/null +++ b/reproduction/matching/model/esim.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import CrossEntropyLoss + +from fastNLP.models import BaseModel +from fastNLP.modules.encoder.embedding import TokenEmbedding +from fastNLP.modules.encoder.lstm import LSTM +from fastNLP.core.const import Const +from fastNLP.core.utils import seq_len_to_mask + + +class ESIMModel(BaseModel): + def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, + dropout_embed=0.1): + super(ESIMModel, self).__init__() + + self.embedding = init_embedding + self.dropout_embed = EmbedDropout(p=dropout_embed) + if hidden_size is None: + hidden_size = self.embedding.embed_size + self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) + # self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True) + + self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate), + nn.Linear(8 * hidden_size, hidden_size), + nn.ReLU()) + nn.init.xavier_uniform_(self.interfere[1].weight.data) + self.bi_attention = SoftmaxAttention() + + self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) + # self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True) + + self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), + nn.Linear(8 * hidden_size, hidden_size), + nn.Tanh(), + nn.Dropout(p=dropout_rate), + nn.Linear(hidden_size, num_labels)) + nn.init.xavier_uniform_(self.classifier[1].weight.data) + nn.init.xavier_uniform_(self.classifier[4].weight.data) + + def forward(self, words1, words2, seq_len1, seq_len2, target=None): + mask1 = seq_len_to_mask(seq_len1) + mask2 = seq_len_to_mask(seq_len2) + a0 = self.embedding(words1) # B * len * emb_dim + b0 = self.embedding(words2) + a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) + a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] + b = self.rnn(b0, mask2.byte()) + + ai, bi = self.bi_attention(a, mask1, b, mask2) + + a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H] + b_ = torch.cat((b, bi, b - bi, b * bi), dim=2) + a_f = self.interfere(a_) + b_f = self.interfere(b_) + + a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] + b_h = self.rnn_high(b_f, mask2.byte()) + + a_avg = self.mean_pooling(a_h, mask1, dim=1) + a_max, _ = self.max_pooling(a_h, mask1, dim=1) + b_avg = self.mean_pooling(b_h, mask2, dim=1) + b_max, _ = self.max_pooling(b_h, mask2, dim=1) + + out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] + logits = torch.tanh(self.classifier(out)) + + if target is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits, target) + + return {Const.LOSS: loss, Const.OUTPUT: logits} + else: + return {Const.OUTPUT: logits} + + def predict(self, **kwargs): + return self.forward(**kwargs) + + # input [batch_size, len , hidden] + # mask [batch_size, len] (111...00) + @staticmethod + def mean_pooling(input, mask, dim=1): + masks = mask.view(mask.size(0), mask.size(1), -1).float() + return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1) + + @staticmethod + def max_pooling(input, mask, dim=1): + my_inf = 10e12 + masks = mask.view(mask.size(0), mask.size(1), -1) + masks = masks.expand(-1, -1, input.size(2)).float() + return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim) + + +class EmbedDropout(nn.Dropout): + + def forward(self, sequences_batch): + ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1]) + dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False) + return dropout_mask.unsqueeze(1) * sequences_batch + + +class BiRNN(nn.Module): + def __init__(self, input_size, hidden_size, dropout_rate=0.3): + super(BiRNN, self).__init__() + self.dropout_rate = dropout_rate + self.rnn = nn.LSTM(input_size, hidden_size, + num_layers=1, + bidirectional=True, + batch_first=True) + + def forward(self, x, x_mask): + # Sort x + lengths = x_mask.data.eq(1).long().sum(1).squeeze() + _, idx_sort = torch.sort(lengths, dim=0, descending=True) + _, idx_unsort = torch.sort(idx_sort, dim=0) + lengths = list(lengths[idx_sort]) + + x = x.index_select(0, idx_sort) + # Pack it up + rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) + # Apply dropout to input + if self.dropout_rate > 0: + dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) + rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) + output = self.rnn(rnn_input)[0] + # Unpack everything + output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] + output = output.index_select(0, idx_unsort) + if output.size(1) != x_mask.size(1): + padding = torch.zeros(output.size(0), + x_mask.size(1) - output.size(1), + output.size(2)).type(output.data.type()) + output = torch.cat([output, padding], 1) + return output + + +def masked_softmax(tensor, mask): + tensor_shape = tensor.size() + reshaped_tensor = tensor.view(-1, tensor_shape[-1]) + + # Reshape the mask so it matches the size of the input tensor. + while mask.dim() < tensor.dim(): + mask = mask.unsqueeze(1) + mask = mask.expand_as(tensor).contiguous().float() + reshaped_mask = mask.view(-1, mask.size()[-1]) + result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) + result = result * reshaped_mask + # 1e-13 is added to avoid divisions by zero. + result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) + return result.view(*tensor_shape) + + +def weighted_sum(tensor, weights, mask): + w_sum = weights.bmm(tensor) + while mask.dim() < w_sum.dim(): + mask = mask.unsqueeze(1) + mask = mask.transpose(-1, -2) + mask = mask.expand_as(w_sum).contiguous().float() + return w_sum * mask + + +class SoftmaxAttention(nn.Module): + + def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): + similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) + .contiguous()) + + prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) + hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) + .contiguous(), + premise_mask) + + attended_premises = weighted_sum(hypothesis_batch, + prem_hyp_attn, + premise_mask) + attended_hypotheses = weighted_sum(premise_batch, + hyp_prem_attn, + hypothesis_mask) + + return attended_premises, attended_hypotheses \ No newline at end of file