@@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||
class Predictor(object): | |||
""" | |||
An interface for predicting outputs based on trained models. | |||
一个根据训练模型预测输出的预测器(Predictor) | |||
It does not care about evaluations of the model, which is different from Tester. | |||
This is a high-level model wrapper to be called by FastNLP. | |||
This class does not share any operations with Trainer and Tester. | |||
Currently, Predictor does not support GPU. | |||
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 | |||
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 | |||
:param torch.nn.Module network: 用来完成预测任务的模型 | |||
""" | |||
def __init__(self, network): | |||
@@ -30,18 +30,19 @@ class Predictor(object): | |||
self.batch_size = 1 | |||
self.batch_output = [] | |||
def predict(self, data, seq_len_field_name=None): | |||
"""Perform inference using the trained model. | |||
def predict(self, data: DataSet, seq_len_field_name=None): | |||
"""用已经训练好的模型进行inference. | |||
:param data: a DataSet object. | |||
:param str seq_len_field_name: field name indicating sequence lengths | |||
:return: list of batch outputs | |||
:param fastNLP.DataSet data: 待预测的数据集 | |||
:param str seq_len_field_name: 表示序列长度信息的field名字 | |||
:return: dict dict里面的内容为模型预测的结果 | |||
""" | |||
if not isinstance(data, DataSet): | |||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | |||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | |||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | |||
prev_training = self.network.training | |||
self.network.eval() | |||
network_device = _get_model_device(self.network) | |||
batch_output = defaultdict(list) | |||
@@ -74,4 +75,5 @@ class Predictor(object): | |||
else: | |||
batch_output[key].append(value) | |||
self.network.train(prev_training) | |||
return batch_output |
@@ -11,7 +11,7 @@ | |||
## Matching (自然语言推理/句子匹配) | |||
- still in progress | |||
- [Matching 任务复现](matching/) | |||
## Sequence Labeling (序列标注) | |||
@@ -1,32 +1,34 @@ | |||
# Matching任务模型复现 | |||
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。 | |||
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | |||
复现的模型有(按论文发表时间顺序排序): | |||
- CNTN:复现代码(still in progress)[](). | |||
- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | |||
- ESIM:[复现代码](model/esim.py). | |||
- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | |||
论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | |||
- DIIN:复现代码(still in progress)[](). | |||
- DIIN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf). | |||
- MwAN:复现代码(still in progress)[](). | |||
- MwAN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf). | |||
- BERT:[复现代码](model/bert.py). | |||
- BERT:[模型代码](model/bert.py); [训练代码](matching_bert.py). | |||
论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf). | |||
# 数据集及复现结果汇总 | |||
使用fastNLP复现的结果vs论文汇报结果 | |||
使用fastNLP复现的结果vs论文汇报结果,在前面的表示使用fastNLP复现的结果 | |||
'\-'表示我们仍未复现或者论文原文没有汇报 | |||
model name | SNLI | MNLI | RTE | QNLI | Quora | |||
:---: | :---: | :---: | :---: | :---: | :---: | |||
CNTN ; [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | - | - | - | - | - | | |||
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs - | 57.04(dev) / - | 76.97(dev) / - | - | | |||
DIIN [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | |||
MwAN [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.5 vs 88.3 | - vs 78.5/77.7 | - | - | vs 89.12 | | |||
CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | | |||
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | |||
DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | |||
MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | |||
BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - | | |||
*ESIM模型由MNLI官方复现的结果为72.4/72.1,ESIM原论文当中没有汇报MNLI数据集的结果。 | |||
# 数据集复现结果及其他主要模型对比 | |||
## SNLI | |||
[Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) | |||
@@ -42,7 +44,7 @@ Performance on Test set: | |||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | |||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | |||
__performance__ | - | 88.13 | - | - | 90.6 | 91.16 | |||
__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 | |||
## MNLI | |||
[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | |||
@@ -58,11 +60,13 @@ Performance on Test set(matched/mismatched): | |||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | |||
:---: | :---: | :---: | :---: | :---: | :---: | | |||
__performance__ | - | - | - | - | - | | |||
__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||
## RTE | |||
Still in progress. | |||
## QNLI | |||
### From GLUE baselines | |||
@@ -73,17 +77,24 @@ Performance on Test set: | |||
model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo | |||
:---: | :---: | :---: | :---: | :---: | | |||
__performance__ | 74.6 | 74.3 | 75.5 | 79.8 | | |||
*这些LSTM-based的baseline是由QNLI官方实现并测试的。 | |||
#### Transformer-based | |||
model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN | |||
:---: | :---: | :---: | :---: | :---: | | |||
__performance__ | 87.4 | 90.5 | 92.7 | 96.0 | | |||
### 基于fastNLP复现的结果 | |||
Performance on Dev set: | |||
Performance on __Dev__ set: | |||
model name | CNTN | ESIM | DIIN | MwAN | BERT | |||
:---: | :---: | :---: | :---: | :---: | :---: | |||
__performance__ | - | 76.97 | - | - | - | |||
__performance__ | - | 76.97 | - | 74.6 | - | |||
## Quora | |||
Still in progress. | |||
@@ -5,8 +5,8 @@ from typing import Union, Dict | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import DataInfo | |||
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader | |||
from fastNLP.io.base_loader import DataInfo, DataSetLoader | |||
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | |||
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from fastNLP.modules.encoder._bert import BertTokenizer | |||
@@ -348,6 +348,9 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||
'dev_mismatched': 'dev_mismatched.tsv', | |||
'test_matched': 'test_matched.tsv', | |||
'test_mismatched': 'test_mismatched.tsv', | |||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t') | |||
@@ -364,6 +367,10 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
if Const.TARGET in ds.get_field_names(): | |||
if ds[0][Const.TARGET] == 'hidden': | |||
ds.delete_field(Const.TARGET) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
@@ -0,0 +1,102 @@ | |||
import random | |||
import numpy as np | |||
import torch | |||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam | |||
from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ | |||
MNLILoader, QNLILoader, QuoraLoader | |||
from reproduction.matching.model.bert import BertForNLI | |||
# define hyper-parameters | |||
class BERTConfig: | |||
task = 'snli' | |||
batch_size_per_gpu = 6 | |||
n_epochs = 6 | |||
lr = 2e-5 | |||
seq_len_type = 'bert' | |||
seed = 42 | |||
train_dataset_name = 'train' | |||
dev_dataset_name = 'dev' | |||
test_dataset_name = 'test' | |||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||
bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹 | |||
arg = BERTConfig() | |||
# set random seed | |||
random.seed(arg.seed) | |||
np.random.seed(arg.seed) | |||
torch.manual_seed(arg.seed) | |||
n_gpu = torch.cuda.device_count() | |||
if n_gpu > 0: | |||
torch.cuda.manual_seed_all(arg.seed) | |||
# load data set | |||
if arg.task == 'snli': | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
elif arg.task == 'rte': | |||
data_info = RTELoader().process( | |||
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
elif arg.task == 'qnli': | |||
data_info = QNLILoader().process( | |||
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
elif arg.task == 'mnli': | |||
data_info = MNLILoader().process( | |||
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
elif arg.task == 'quora': | |||
data_info = QuoraLoader().process( | |||
paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
else: | |||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||
# define model | |||
model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir) | |||
# define trainer | |||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
n_epochs=arg.n_epochs, print_every=-1, | |||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||
metrics=AccuracyMetric(), metric_key='acc', | |||
device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1, | |||
save_path=arg.save_path) | |||
# train model | |||
trainer.train(load_best_model=True) | |||
# define tester | |||
tester = Tester( | |||
data=data_info.datasets[arg.test_dataset_name], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
device=[i for i in range(torch.cuda.device_count())], | |||
) | |||
# test model | |||
tester.test() | |||
@@ -1,47 +1,103 @@ | |||
import argparse | |||
import random | |||
import numpy as np | |||
import torch | |||
from torch.optim import Adamax | |||
from torch.optim.lr_scheduler import StepLR | |||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | |||
from fastNLP.core.callback import GradientClipCallback, LRScheduler | |||
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding | |||
from reproduction.matching.data.MatchingDataLoader import SNLILoader | |||
from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ | |||
MNLILoader, QNLILoader, QuoraLoader | |||
from reproduction.matching.model.esim import ESIMModel | |||
argument = argparse.ArgumentParser() | |||
argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove') | |||
argument.add_argument('--batch-size-per-gpu', type=int, default=128) | |||
argument.add_argument('--n-epochs', type=int, default=100) | |||
argument.add_argument('--lr', type=float, default=1e-4) | |||
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len') | |||
argument.add_argument('--save-dir', type=str, default=None) | |||
arg = argument.parse_args() | |||
bert_dirs = 'path/to/bert/dir' | |||
# define hyper-parameters | |||
class ESIMConfig: | |||
task = 'snli' | |||
embedding = 'glove' | |||
batch_size_per_gpu = 196 | |||
n_epochs = 30 | |||
lr = 2e-3 | |||
seq_len_type = 'seq_len' | |||
# seq_len表示在process的时候用len(words)来表示长度信息; | |||
# mask表示用0/1掩码矩阵来表示长度信息; | |||
seed = 42 | |||
train_dataset_name = 'train' | |||
dev_dataset_name = 'dev' | |||
test_dataset_name = 'test' | |||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||
arg = ESIMConfig() | |||
# set random seed | |||
random.seed(arg.seed) | |||
np.random.seed(arg.seed) | |||
torch.manual_seed(arg.seed) | |||
n_gpu = torch.cuda.device_count() | |||
if n_gpu > 0: | |||
torch.cuda.manual_seed_all(arg.seed) | |||
# load data set | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||
get_index=True, concat=False, | |||
) | |||
if arg.task == 'snli': | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
elif arg.task == 'rte': | |||
data_info = RTELoader().process( | |||
paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
elif arg.task == 'qnli': | |||
data_info = QNLILoader().process( | |||
paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
elif arg.task == 'mnli': | |||
data_info = MNLILoader().process( | |||
paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
elif arg.task == 'quora': | |||
data_info = QuoraLoader().process( | |||
paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
else: | |||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||
# load embedding | |||
if arg.embedding == 'elmo': | |||
embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||
elif arg.embedding == 'glove': | |||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False) | |||
else: | |||
raise ValueError(f'now we only support elmo or glove embedding for esim model!') | |||
raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') | |||
# define model | |||
model = ESIMModel(embedding) | |||
model = ESIMModel(embedding, num_labels=len(data_info.vocabs[Const.TARGET])) | |||
# define optimizer and callback | |||
optimizer = Adamax(lr=arg.lr, params=model.parameters()) | |||
scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍 | |||
callbacks = [ | |||
GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10) | |||
LRScheduler(scheduler), | |||
] | |||
# define trainer | |||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||
optimizer=optimizer, | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
n_epochs=arg.n_epochs, print_every=-1, | |||
dev_data=data_info.datasets['dev'], | |||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||
metrics=AccuracyMetric(), metric_key='acc', | |||
device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1, | |||
@@ -52,7 +108,7 @@ trainer.train(load_best_model=True) | |||
# define tester | |||
tester = Tester( | |||
data=data_info.datasets['test'], | |||
data=data_info.datasets[arg.test_dataset_name], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
@@ -81,6 +81,7 @@ class ESIMModel(BaseModel): | |||
out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | |||
logits = torch.tanh(self.classifier(out)) | |||
# logits = self.classifier(out) | |||
if target is not None: | |||
loss_fct = CrossEntropyLoss() | |||
@@ -91,7 +92,8 @@ class ESIMModel(BaseModel): | |||
return {Const.OUTPUT: logits} | |||
def predict(self, **kwargs): | |||
return self.forward(**kwargs) | |||
pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) | |||
return {Const.OUTPUT: pred} | |||
# input [batch_size, len , hidden] | |||
# mask [batch_size, len] (111...00) | |||
@@ -127,7 +129,7 @@ class BiRNN(nn.Module): | |||
def forward(self, x, x_mask): | |||
# Sort x | |||
lengths = x_mask.data.eq(1).long().sum(1).squeeze() | |||
lengths = x_mask.data.eq(1).long().sum(1) | |||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | |||
_, idx_unsort = torch.sort(idx_sort, dim=0) | |||
lengths = list(lengths[idx_sort]) | |||