From 8e7a604b29c8a7df69d1388e4f787ca52150ad83 Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 2 Jul 2019 13:34:26 +0800 Subject: [PATCH 1/2] update documents in predictor --- fastNLP/core/predictor.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index ce016bb6..2d6a7380 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device class Predictor(object): """ - An interface for predicting outputs based on trained models. + 一个根据训练模型预测输出的预测器(Predictor) - It does not care about evaluations of the model, which is different from Tester. - This is a high-level model wrapper to be called by FastNLP. - This class does not share any operations with Trainer and Tester. - Currently, Predictor does not support GPU. + 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 + 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 + + :param torch.nn.Module network: 用来完成预测任务的模型 """ def __init__(self, network): @@ -30,18 +30,19 @@ class Predictor(object): self.batch_size = 1 self.batch_output = [] - def predict(self, data, seq_len_field_name=None): - """Perform inference using the trained model. + def predict(self, data: DataSet, seq_len_field_name=None): + """用已经训练好的模型进行inference. - :param data: a DataSet object. - :param str seq_len_field_name: field name indicating sequence lengths - :return: list of batch outputs + :param fastNLP.DataSet data: 待预测的数据集 + :param str seq_len_field_name: 表示序列长度信息的field名字 + :return: dict dict里面的内容为模型预测的结果 """ if not isinstance(data, DataSet): raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) + prev_training = self.network.training self.network.eval() network_device = _get_model_device(self.network) batch_output = defaultdict(list) @@ -74,4 +75,5 @@ class Predictor(object): else: batch_output[key].append(value) + self.network.train(prev_training) return batch_output From 1bc780a2bd06c50ed703589bc7bf7a22ccea1774 Mon Sep 17 00:00:00 2001 From: xuyige Date: Tue, 2 Jul 2019 13:38:05 +0800 Subject: [PATCH 2/2] update framework in matching tasks --- reproduction/README.md | 2 +- reproduction/matching/README.md | 41 ++++--- .../matching/data/MatchingDataLoader.py | 11 +- reproduction/matching/matching_bert.py | 102 ++++++++++++++++++ reproduction/matching/matching_esim.py | 102 ++++++++++++++---- reproduction/matching/model/esim.py | 6 +- 6 files changed, 221 insertions(+), 43 deletions(-) create mode 100644 reproduction/matching/matching_bert.py diff --git a/reproduction/README.md b/reproduction/README.md index bb21c067..92652fb4 100644 --- a/reproduction/README.md +++ b/reproduction/README.md @@ -11,7 +11,7 @@ ## Matching (自然语言推理/句子匹配) -- still in progress +- [Matching 任务复现](matching/) ## Sequence Labeling (序列标注) diff --git a/reproduction/matching/README.md b/reproduction/matching/README.md index 899d7e9b..056b0212 100644 --- a/reproduction/matching/README.md +++ b/reproduction/matching/README.md @@ -1,32 +1,34 @@ # Matching任务模型复现 -这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。 +这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). 复现的模型有(按论文发表时间顺序排序): -- CNTN:复现代码(still in progress)[](). +- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). 论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). -- ESIM:[复现代码](model/esim.py). +- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). 论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). -- DIIN:复现代码(still in progress)[](). +- DIIN:模型代码(still in progress)[](); 训练代码(still in progress)[](). 论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf). -- MwAN:复现代码(still in progress)[](). +- MwAN:模型代码(still in progress)[](); 训练代码(still in progress)[](). 论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf). -- BERT:[复现代码](model/bert.py). +- BERT:[模型代码](model/bert.py); [训练代码](matching_bert.py). 论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf). # 数据集及复现结果汇总 -使用fastNLP复现的结果vs论文汇报结果 +使用fastNLP复现的结果vs论文汇报结果,在前面的表示使用fastNLP复现的结果 '\-'表示我们仍未复现或者论文原文没有汇报 model name | SNLI | MNLI | RTE | QNLI | Quora :---: | :---: | :---: | :---: | :---: | :---: -CNTN ; [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | - | - | - | - | - | -ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs - | 57.04(dev) / - | 76.97(dev) / - | - | -DIIN [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | -MwAN [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.5 vs 88.3 | - vs 78.5/77.7 | - | - | vs 89.12 | +CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | +ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | +DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | +MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - | +*ESIM模型由MNLI官方复现的结果为72.4/72.1,ESIM原论文当中没有汇报MNLI数据集的结果。 + # 数据集复现结果及其他主要模型对比 ## SNLI [Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) @@ -42,7 +44,7 @@ Performance on Test set: model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large :---: | :---: | :---: | :---: | :---: | :---: | :---: -__performance__ | - | 88.13 | - | - | 90.6 | 91.16 +__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 ## MNLI [Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) @@ -58,11 +60,13 @@ Performance on Test set(matched/mismatched): model name | CNTN | ESIM | DIIN | MwAN | BERT-Base :---: | :---: | :---: | :---: | :---: | :---: | -__performance__ | - | - | - | - | - | +__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | ## RTE +Still in progress. + ## QNLI ### From GLUE baselines @@ -73,17 +77,24 @@ Performance on Test set: model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo :---: | :---: | :---: | :---: | :---: | __performance__ | 74.6 | 74.3 | 75.5 | 79.8 | + +*这些LSTM-based的baseline是由QNLI官方实现并测试的。 + #### Transformer-based model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN :---: | :---: | :---: | :---: | :---: | __performance__ | 87.4 | 90.5 | 92.7 | 96.0 | + ### 基于fastNLP复现的结果 -Performance on Dev set: +Performance on __Dev__ set: model name | CNTN | ESIM | DIIN | MwAN | BERT :---: | :---: | :---: | :---: | :---: | :---: -__performance__ | - | 76.97 | - | - | - +__performance__ | - | 76.97 | - | 74.6 | - ## Quora + +Still in progress. + diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py index 749b16c8..20a63d75 100644 --- a/reproduction/matching/data/MatchingDataLoader.py +++ b/reproduction/matching/data/MatchingDataLoader.py @@ -5,8 +5,8 @@ from typing import Union, Dict from fastNLP.core.const import Const from fastNLP.core.vocabulary import Vocabulary -from fastNLP.io.base_loader import DataInfo -from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader +from fastNLP.io.base_loader import DataInfo, DataSetLoader +from fastNLP.io.dataset_loader import JsonLoader, CSVLoader from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from fastNLP.modules.encoder._bert import BertTokenizer @@ -348,6 +348,9 @@ class MNLILoader(MatchingLoader, CSVLoader): 'dev_mismatched': 'dev_mismatched.tsv', 'test_matched': 'test_matched.tsv', 'test_mismatched': 'test_mismatched.tsv', + # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', + # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', + # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) } MatchingLoader.__init__(self, paths=paths) CSVLoader.__init__(self, sep='\t') @@ -364,6 +367,10 @@ class MNLILoader(MatchingLoader, CSVLoader): if k in ds.get_field_names(): ds.rename_field(k, v) + if Const.TARGET in ds.get_field_names(): + if ds[0][Const.TARGET] == 'hidden': + ds.delete_field(Const.TARGET) + parentheses_table = str.maketrans({'(': None, ')': None}) ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), diff --git a/reproduction/matching/matching_bert.py b/reproduction/matching/matching_bert.py new file mode 100644 index 00000000..75112d5a --- /dev/null +++ b/reproduction/matching/matching_bert.py @@ -0,0 +1,102 @@ +import random +import numpy as np +import torch + +from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam + +from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ + MNLILoader, QNLILoader, QuoraLoader +from reproduction.matching.model.bert import BertForNLI + + +# define hyper-parameters +class BERTConfig: + + task = 'snli' + batch_size_per_gpu = 6 + n_epochs = 6 + lr = 2e-5 + seq_len_type = 'bert' + seed = 42 + train_dataset_name = 'train' + dev_dataset_name = 'dev' + test_dataset_name = 'test' + save_path = None # 模型存储的位置,None表示不存储模型。 + bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹 + + +arg = BERTConfig() + +# set random seed +random.seed(arg.seed) +np.random.seed(arg.seed) +torch.manual_seed(arg.seed) + +n_gpu = torch.cuda.device_count() +if n_gpu > 0: + torch.cuda.manual_seed_all(arg.seed) + +# load data set +if arg.task == 'snli': + data_info = SNLILoader().process( + paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, + bert_tokenizer=arg.bert_dir, cut_text=512, + get_index=True, concat='bert', + ) +elif arg.task == 'rte': + data_info = RTELoader().process( + paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, + bert_tokenizer=arg.bert_dir, cut_text=512, + get_index=True, concat='bert', + ) +elif arg.task == 'qnli': + data_info = QNLILoader().process( + paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, + bert_tokenizer=arg.bert_dir, cut_text=512, + get_index=True, concat='bert', + ) +elif arg.task == 'mnli': + data_info = MNLILoader().process( + paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, + bert_tokenizer=arg.bert_dir, cut_text=512, + get_index=True, concat='bert', + ) +elif arg.task == 'quora': + data_info = QuoraLoader().process( + paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type, + bert_tokenizer=arg.bert_dir, cut_text=512, + get_index=True, concat='bert', + ) +else: + raise RuntimeError(f'NOT support {arg.task} task yet!') + +# define model +model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir) + +# define trainer +trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, + optimizer=Adam(lr=arg.lr, model_params=model.parameters()), + batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, + n_epochs=arg.n_epochs, print_every=-1, + dev_data=data_info.datasets[arg.dev_dataset_name], + metrics=AccuracyMetric(), metric_key='acc', + device=[i for i in range(torch.cuda.device_count())], + check_code_level=-1, + save_path=arg.save_path) + +# train model +trainer.train(load_best_model=True) + +# define tester +tester = Tester( + data=data_info.datasets[arg.test_dataset_name], + model=model, + metrics=AccuracyMetric(), + batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, + device=[i for i in range(torch.cuda.device_count())], +) + +# test model +tester.test() + + diff --git a/reproduction/matching/matching_esim.py b/reproduction/matching/matching_esim.py index 3da6141f..d878608f 100644 --- a/reproduction/matching/matching_esim.py +++ b/reproduction/matching/matching_esim.py @@ -1,47 +1,103 @@ -import argparse +import random +import numpy as np import torch +from torch.optim import Adamax +from torch.optim.lr_scheduler import StepLR -from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const +from fastNLP.core import Trainer, Tester, AccuracyMetric, Const +from fastNLP.core.callback import GradientClipCallback, LRScheduler from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding -from reproduction.matching.data.MatchingDataLoader import SNLILoader +from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ + MNLILoader, QNLILoader, QuoraLoader from reproduction.matching.model.esim import ESIMModel -argument = argparse.ArgumentParser() -argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove') -argument.add_argument('--batch-size-per-gpu', type=int, default=128) -argument.add_argument('--n-epochs', type=int, default=100) -argument.add_argument('--lr', type=float, default=1e-4) -argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len') -argument.add_argument('--save-dir', type=str, default=None) -arg = argument.parse_args() -bert_dirs = 'path/to/bert/dir' +# define hyper-parameters +class ESIMConfig: + + task = 'snli' + embedding = 'glove' + batch_size_per_gpu = 196 + n_epochs = 30 + lr = 2e-3 + seq_len_type = 'seq_len' + # seq_len表示在process的时候用len(words)来表示长度信息; + # mask表示用0/1掩码矩阵来表示长度信息; + seed = 42 + train_dataset_name = 'train' + dev_dataset_name = 'dev' + test_dataset_name = 'test' + save_path = None # 模型存储的位置,None表示不存储模型。 + + +arg = ESIMConfig() + +# set random seed +random.seed(arg.seed) +np.random.seed(arg.seed) +torch.manual_seed(arg.seed) + +n_gpu = torch.cuda.device_count() +if n_gpu > 0: + torch.cuda.manual_seed_all(arg.seed) # load data set -data_info = SNLILoader().process( - paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, - get_index=True, concat=False, -) +if arg.task == 'snli': + data_info = SNLILoader().process( + paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type, + get_index=True, concat=False, + ) +elif arg.task == 'rte': + data_info = RTELoader().process( + paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type, + get_index=True, concat=False, + ) +elif arg.task == 'qnli': + data_info = QNLILoader().process( + paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type, + get_index=True, concat=False, + ) +elif arg.task == 'mnli': + data_info = MNLILoader().process( + paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type, + get_index=True, concat=False, + ) +elif arg.task == 'quora': + data_info = QuoraLoader().process( + paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type, + get_index=True, concat=False, + ) +else: + raise RuntimeError(f'NOT support {arg.task} task yet!') # load embedding if arg.embedding == 'elmo': embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) elif arg.embedding == 'glove': - embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) + embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False) else: - raise ValueError(f'now we only support elmo or glove embedding for esim model!') + raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') # define model -model = ESIMModel(embedding) +model = ESIMModel(embedding, num_labels=len(data_info.vocabs[Const.TARGET])) + +# define optimizer and callback +optimizer = Adamax(lr=arg.lr, params=model.parameters()) +scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍 + +callbacks = [ + GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10) + LRScheduler(scheduler), +] # define trainer -trainer = Trainer(train_data=data_info.datasets['train'], model=model, - optimizer=Adam(lr=arg.lr, model_params=model.parameters()), +trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, + optimizer=optimizer, batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, n_epochs=arg.n_epochs, print_every=-1, - dev_data=data_info.datasets['dev'], + dev_data=data_info.datasets[arg.dev_dataset_name], metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], check_code_level=-1, @@ -52,7 +108,7 @@ trainer.train(load_best_model=True) # define tester tester = Tester( - data=data_info.datasets['test'], + data=data_info.datasets[arg.test_dataset_name], model=model, metrics=AccuracyMetric(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, diff --git a/reproduction/matching/model/esim.py b/reproduction/matching/model/esim.py index d55034e7..187e565d 100644 --- a/reproduction/matching/model/esim.py +++ b/reproduction/matching/model/esim.py @@ -81,6 +81,7 @@ class ESIMModel(BaseModel): out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] logits = torch.tanh(self.classifier(out)) + # logits = self.classifier(out) if target is not None: loss_fct = CrossEntropyLoss() @@ -91,7 +92,8 @@ class ESIMModel(BaseModel): return {Const.OUTPUT: logits} def predict(self, **kwargs): - return self.forward(**kwargs) + pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) + return {Const.OUTPUT: pred} # input [batch_size, len , hidden] # mask [batch_size, len] (111...00) @@ -127,7 +129,7 @@ class BiRNN(nn.Module): def forward(self, x, x_mask): # Sort x - lengths = x_mask.data.eq(1).long().sum(1).squeeze() + lengths = x_mask.data.eq(1).long().sum(1) _, idx_sort = torch.sort(lengths, dim=0, descending=True) _, idx_unsort = torch.sort(idx_sort, dim=0) lengths = list(lengths[idx_sort])