@@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||||
class Predictor(object): | class Predictor(object): | ||||
""" | """ | ||||
An interface for predicting outputs based on trained models. | |||||
一个根据训练模型预测输出的预测器(Predictor) | |||||
It does not care about evaluations of the model, which is different from Tester. | |||||
This is a high-level model wrapper to be called by FastNLP. | |||||
This class does not share any operations with Trainer and Tester. | |||||
Currently, Predictor does not support GPU. | |||||
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 | |||||
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 | |||||
:param torch.nn.Module network: 用来完成预测任务的模型 | |||||
""" | """ | ||||
def __init__(self, network): | def __init__(self, network): | ||||
@@ -30,18 +30,19 @@ class Predictor(object): | |||||
self.batch_size = 1 | self.batch_size = 1 | ||||
self.batch_output = [] | self.batch_output = [] | ||||
def predict(self, data, seq_len_field_name=None): | |||||
"""Perform inference using the trained model. | |||||
def predict(self, data: DataSet, seq_len_field_name=None): | |||||
"""用已经训练好的模型进行inference. | |||||
:param data: a DataSet object. | |||||
:param str seq_len_field_name: field name indicating sequence lengths | |||||
:return: list of batch outputs | |||||
:param fastNLP.DataSet data: 待预测的数据集 | |||||
:param str seq_len_field_name: 表示序列长度信息的field名字 | |||||
:return: dict dict里面的内容为模型预测的结果 | |||||
""" | """ | ||||
if not isinstance(data, DataSet): | if not isinstance(data, DataSet): | ||||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | ||||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | ||||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | ||||
prev_training = self.network.training | |||||
self.network.eval() | self.network.eval() | ||||
network_device = _get_model_device(self.network) | network_device = _get_model_device(self.network) | ||||
batch_output = defaultdict(list) | batch_output = defaultdict(list) | ||||
@@ -74,4 +75,5 @@ class Predictor(object): | |||||
else: | else: | ||||
batch_output[key].append(value) | batch_output[key].append(value) | ||||
self.network.train(prev_training) | |||||
return batch_output | return batch_output |
@@ -11,7 +11,7 @@ | |||||
## Matching (自然语言推理/句子匹配) | ## Matching (自然语言推理/句子匹配) | ||||
- still in progress | |||||
- [Matching 任务复现](matching/) | |||||
## Sequence Labeling (序列标注) | ## Sequence Labeling (序列标注) | ||||
@@ -1,32 +1,34 @@ | |||||
# Matching任务模型复现 | # Matching任务模型复现 | ||||
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。 | |||||
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | |||||
复现的模型有(按论文发表时间顺序排序): | 复现的模型有(按论文发表时间顺序排序): | ||||
- CNTN:复现代码(still in progress)[](). | |||||
- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||||
论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | 论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | ||||
- ESIM:[复现代码](model/esim.py). | |||||
- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | |||||
论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | 论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | ||||
- DIIN:复现代码(still in progress)[](). | |||||
- DIIN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||||
论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf). | 论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf). | ||||
- MwAN:复现代码(still in progress)[](). | |||||
- MwAN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||||
论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf). | 论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf). | ||||
- BERT:[复现代码](model/bert.py). | |||||
- BERT:[模型代码](model/bert.py); [训练代码](matching_bert.py). | |||||
论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf). | 论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf). | ||||
# 数据集及复现结果汇总 | # 数据集及复现结果汇总 | ||||
使用fastNLP复现的结果vs论文汇报结果 | |||||
使用fastNLP复现的结果vs论文汇报结果,在前面的表示使用fastNLP复现的结果 | |||||
'\-'表示我们仍未复现或者论文原文没有汇报 | '\-'表示我们仍未复现或者论文原文没有汇报 | ||||
model name | SNLI | MNLI | RTE | QNLI | Quora | model name | SNLI | MNLI | RTE | QNLI | Quora | ||||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | ||||
CNTN ; [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | - | - | - | - | - | | |||||
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs - | 57.04(dev) / - | 76.97(dev) / - | - | | |||||
DIIN [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | |||||
MwAN [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.5 vs 88.3 | - vs 78.5/77.7 | - | - | vs 89.12 | | |||||
CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | | |||||
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | |||||
DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | |||||
MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | |||||
BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - | | BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - | | ||||
*ESIM模型由MNLI官方复现的结果为72.4/72.1,ESIM原论文当中没有汇报MNLI数据集的结果。 | |||||
# 数据集复现结果及其他主要模型对比 | # 数据集复现结果及其他主要模型对比 | ||||
## SNLI | ## SNLI | ||||
[Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) | [Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) | ||||
@@ -42,7 +44,7 @@ Performance on Test set: | |||||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | ||||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | ||||
__performance__ | - | 88.13 | - | - | 90.6 | 91.16 | |||||
__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 | |||||
## MNLI | ## MNLI | ||||
[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | [Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | ||||
@@ -58,11 +60,13 @@ Performance on Test set(matched/mismatched): | |||||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | ||||
:---: | :---: | :---: | :---: | :---: | :---: | | :---: | :---: | :---: | :---: | :---: | :---: | | ||||
__performance__ | - | - | - | - | - | | |||||
__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||||
## RTE | ## RTE | ||||
Still in progress. | |||||
## QNLI | ## QNLI | ||||
### From GLUE baselines | ### From GLUE baselines | ||||
@@ -73,17 +77,24 @@ Performance on Test set: | |||||
model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo | model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo | ||||
:---: | :---: | :---: | :---: | :---: | | :---: | :---: | :---: | :---: | :---: | | ||||
__performance__ | 74.6 | 74.3 | 75.5 | 79.8 | | __performance__ | 74.6 | 74.3 | 75.5 | 79.8 | | ||||
*这些LSTM-based的baseline是由QNLI官方实现并测试的。 | |||||
#### Transformer-based | #### Transformer-based | ||||
model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN | model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN | ||||
:---: | :---: | :---: | :---: | :---: | | :---: | :---: | :---: | :---: | :---: | | ||||
__performance__ | 87.4 | 90.5 | 92.7 | 96.0 | | __performance__ | 87.4 | 90.5 | 92.7 | 96.0 | | ||||
### 基于fastNLP复现的结果 | ### 基于fastNLP复现的结果 | ||||
Performance on Dev set: | |||||
Performance on __Dev__ set: | |||||
model name | CNTN | ESIM | DIIN | MwAN | BERT | model name | CNTN | ESIM | DIIN | MwAN | BERT | ||||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | ||||
__performance__ | - | 76.97 | - | - | - | |||||
__performance__ | - | 76.97 | - | 74.6 | - | |||||
## Quora | ## Quora | ||||
Still in progress. | |||||
@@ -5,8 +5,8 @@ from typing import Union, Dict | |||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.io.base_loader import DataInfo | |||||
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader | |||||
from fastNLP.io.base_loader import DataInfo, DataSetLoader | |||||
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | |||||
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | ||||
from fastNLP.modules.encoder._bert import BertTokenizer | from fastNLP.modules.encoder._bert import BertTokenizer | ||||
@@ -348,6 +348,9 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||||
'dev_mismatched': 'dev_mismatched.tsv', | 'dev_mismatched': 'dev_mismatched.tsv', | ||||
'test_matched': 'test_matched.tsv', | 'test_matched': 'test_matched.tsv', | ||||
'test_mismatched': 'test_mismatched.tsv', | 'test_mismatched': 'test_mismatched.tsv', | ||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
} | } | ||||
MatchingLoader.__init__(self, paths=paths) | MatchingLoader.__init__(self, paths=paths) | ||||
CSVLoader.__init__(self, sep='\t') | CSVLoader.__init__(self, sep='\t') | ||||
@@ -364,6 +367,10 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||||
if k in ds.get_field_names(): | if k in ds.get_field_names(): | ||||
ds.rename_field(k, v) | ds.rename_field(k, v) | ||||
if Const.TARGET in ds.get_field_names(): | |||||
if ds[0][Const.TARGET] == 'hidden': | |||||
ds.delete_field(Const.TARGET) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | parentheses_table = str.maketrans({'(': None, ')': None}) | ||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | ||||
@@ -0,0 +1,102 @@ | |||||
import random | |||||
import numpy as np | |||||
import torch | |||||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam | |||||
from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ | |||||
MNLILoader, QNLILoader, QuoraLoader | |||||
from reproduction.matching.model.bert import BertForNLI | |||||
# define hyper-parameters | |||||
class BERTConfig: | |||||
task = 'snli' | |||||
batch_size_per_gpu = 6 | |||||
n_epochs = 6 | |||||
lr = 2e-5 | |||||
seq_len_type = 'bert' | |||||
seed = 42 | |||||
train_dataset_name = 'train' | |||||
dev_dataset_name = 'dev' | |||||
test_dataset_name = 'test' | |||||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||||
bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹 | |||||
arg = BERTConfig() | |||||
# set random seed | |||||
random.seed(arg.seed) | |||||
np.random.seed(arg.seed) | |||||
torch.manual_seed(arg.seed) | |||||
n_gpu = torch.cuda.device_count() | |||||
if n_gpu > 0: | |||||
torch.cuda.manual_seed_all(arg.seed) | |||||
# load data set | |||||
if arg.task == 'snli': | |||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
elif arg.task == 'rte': | |||||
data_info = RTELoader().process( | |||||
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
elif arg.task == 'qnli': | |||||
data_info = QNLILoader().process( | |||||
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
elif arg.task == 'mnli': | |||||
data_info = MNLILoader().process( | |||||
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
elif arg.task == 'quora': | |||||
data_info = QuoraLoader().process( | |||||
paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
else: | |||||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||||
# define model | |||||
model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir) | |||||
# define trainer | |||||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||||
n_epochs=arg.n_epochs, print_every=-1, | |||||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||||
metrics=AccuracyMetric(), metric_key='acc', | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
check_code_level=-1, | |||||
save_path=arg.save_path) | |||||
# train model | |||||
trainer.train(load_best_model=True) | |||||
# define tester | |||||
tester = Tester( | |||||
data=data_info.datasets[arg.test_dataset_name], | |||||
model=model, | |||||
metrics=AccuracyMetric(), | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
) | |||||
# test model | |||||
tester.test() | |||||
@@ -1,47 +1,103 @@ | |||||
import argparse | |||||
import random | |||||
import numpy as np | |||||
import torch | import torch | ||||
from torch.optim import Adamax | |||||
from torch.optim.lr_scheduler import StepLR | |||||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | |||||
from fastNLP.core.callback import GradientClipCallback, LRScheduler | |||||
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding | from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding | ||||
from reproduction.matching.data.MatchingDataLoader import SNLILoader | |||||
from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \ | |||||
MNLILoader, QNLILoader, QuoraLoader | |||||
from reproduction.matching.model.esim import ESIMModel | from reproduction.matching.model.esim import ESIMModel | ||||
argument = argparse.ArgumentParser() | |||||
argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove') | |||||
argument.add_argument('--batch-size-per-gpu', type=int, default=128) | |||||
argument.add_argument('--n-epochs', type=int, default=100) | |||||
argument.add_argument('--lr', type=float, default=1e-4) | |||||
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len') | |||||
argument.add_argument('--save-dir', type=str, default=None) | |||||
arg = argument.parse_args() | |||||
bert_dirs = 'path/to/bert/dir' | |||||
# define hyper-parameters | |||||
class ESIMConfig: | |||||
task = 'snli' | |||||
embedding = 'glove' | |||||
batch_size_per_gpu = 196 | |||||
n_epochs = 30 | |||||
lr = 2e-3 | |||||
seq_len_type = 'seq_len' | |||||
# seq_len表示在process的时候用len(words)来表示长度信息; | |||||
# mask表示用0/1掩码矩阵来表示长度信息; | |||||
seed = 42 | |||||
train_dataset_name = 'train' | |||||
dev_dataset_name = 'dev' | |||||
test_dataset_name = 'test' | |||||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||||
arg = ESIMConfig() | |||||
# set random seed | |||||
random.seed(arg.seed) | |||||
np.random.seed(arg.seed) | |||||
torch.manual_seed(arg.seed) | |||||
n_gpu = torch.cuda.device_count() | |||||
if n_gpu > 0: | |||||
torch.cuda.manual_seed_all(arg.seed) | |||||
# load data set | # load data set | ||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||||
get_index=True, concat=False, | |||||
) | |||||
if arg.task == 'snli': | |||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
elif arg.task == 'rte': | |||||
data_info = RTELoader().process( | |||||
paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
elif arg.task == 'qnli': | |||||
data_info = QNLILoader().process( | |||||
paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
elif arg.task == 'mnli': | |||||
data_info = MNLILoader().process( | |||||
paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
elif arg.task == 'quora': | |||||
data_info = QuoraLoader().process( | |||||
paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
else: | |||||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||||
# load embedding | # load embedding | ||||
if arg.embedding == 'elmo': | if arg.embedding == 'elmo': | ||||
embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | ||||
elif arg.embedding == 'glove': | elif arg.embedding == 'glove': | ||||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False) | |||||
else: | else: | ||||
raise ValueError(f'now we only support elmo or glove embedding for esim model!') | |||||
raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') | |||||
# define model | # define model | ||||
model = ESIMModel(embedding) | |||||
model = ESIMModel(embedding, num_labels=len(data_info.vocabs[Const.TARGET])) | |||||
# define optimizer and callback | |||||
optimizer = Adamax(lr=arg.lr, params=model.parameters()) | |||||
scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍 | |||||
callbacks = [ | |||||
GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10) | |||||
LRScheduler(scheduler), | |||||
] | |||||
# define trainer | # define trainer | ||||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||||
optimizer=optimizer, | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | ||||
n_epochs=arg.n_epochs, print_every=-1, | n_epochs=arg.n_epochs, print_every=-1, | ||||
dev_data=data_info.datasets['dev'], | |||||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||||
metrics=AccuracyMetric(), metric_key='acc', | metrics=AccuracyMetric(), metric_key='acc', | ||||
device=[i for i in range(torch.cuda.device_count())], | device=[i for i in range(torch.cuda.device_count())], | ||||
check_code_level=-1, | check_code_level=-1, | ||||
@@ -52,7 +108,7 @@ trainer.train(load_best_model=True) | |||||
# define tester | # define tester | ||||
tester = Tester( | tester = Tester( | ||||
data=data_info.datasets['test'], | |||||
data=data_info.datasets[arg.test_dataset_name], | |||||
model=model, | model=model, | ||||
metrics=AccuracyMetric(), | metrics=AccuracyMetric(), | ||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | ||||
@@ -81,6 +81,7 @@ class ESIMModel(BaseModel): | |||||
out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | ||||
logits = torch.tanh(self.classifier(out)) | logits = torch.tanh(self.classifier(out)) | ||||
# logits = self.classifier(out) | |||||
if target is not None: | if target is not None: | ||||
loss_fct = CrossEntropyLoss() | loss_fct = CrossEntropyLoss() | ||||
@@ -91,7 +92,8 @@ class ESIMModel(BaseModel): | |||||
return {Const.OUTPUT: logits} | return {Const.OUTPUT: logits} | ||||
def predict(self, **kwargs): | def predict(self, **kwargs): | ||||
return self.forward(**kwargs) | |||||
pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) | |||||
return {Const.OUTPUT: pred} | |||||
# input [batch_size, len , hidden] | # input [batch_size, len , hidden] | ||||
# mask [batch_size, len] (111...00) | # mask [batch_size, len] (111...00) | ||||
@@ -127,7 +129,7 @@ class BiRNN(nn.Module): | |||||
def forward(self, x, x_mask): | def forward(self, x, x_mask): | ||||
# Sort x | # Sort x | ||||
lengths = x_mask.data.eq(1).long().sum(1).squeeze() | |||||
lengths = x_mask.data.eq(1).long().sum(1) | |||||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | _, idx_sort = torch.sort(lengths, dim=0, descending=True) | ||||
_, idx_unsort = torch.sort(idx_sort, dim=0) | _, idx_unsort = torch.sort(idx_sort, dim=0) | ||||
lengths = list(lengths[idx_sort]) | lengths = list(lengths[idx_sort]) | ||||