|
@@ -8,8 +8,7 @@ from fastNLP.core.optimizer import AdamW |
|
|
from fastNLP.embeddings import BertEmbedding |
|
|
from fastNLP.embeddings import BertEmbedding |
|
|
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\ |
|
|
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\ |
|
|
QNLIBertPipe, QuoraBertPipe |
|
|
QNLIBertPipe, QuoraBertPipe |
|
|
|
|
|
|
|
|
from reproduction.matching.model.bert import BertForNLI |
|
|
|
|
|
|
|
|
from fastNLP.models.bert import BertForSentenceMatching |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# define hyper-parameters |
|
|
# define hyper-parameters |
|
@@ -65,7 +64,7 @@ print(data_bundle) # print details in data_bundle |
|
|
embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name) |
|
|
embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name) |
|
|
|
|
|
|
|
|
# define model |
|
|
# define model |
|
|
model = BertForNLI(embed, class_num=len(data_bundle.vocabs[Const.TARGET])) |
|
|
|
|
|
|
|
|
model = BertForSentenceMatching(embed, num_labels=len(data_bundle.vocabs[Const.TARGET])) |
|
|
|
|
|
|
|
|
# define optimizer and callback |
|
|
# define optimizer and callback |
|
|
optimizer = AdamW(lr=arg.lr, params=model.parameters()) |
|
|
optimizer = AdamW(lr=arg.lr, params=model.parameters()) |
|
@@ -76,11 +75,11 @@ if arg.task in ['snli']: |
|
|
# evaluate test set in every epoch if task is snli. |
|
|
# evaluate test set in every epoch if task is snli. |
|
|
|
|
|
|
|
|
# define trainer |
|
|
# define trainer |
|
|
trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model, |
|
|
|
|
|
|
|
|
trainer = Trainer(train_data=data_bundle.get_dataset(arg.train_dataset_name), model=model, |
|
|
optimizer=optimizer, |
|
|
optimizer=optimizer, |
|
|
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, |
|
|
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, |
|
|
n_epochs=arg.n_epochs, print_every=-1, |
|
|
n_epochs=arg.n_epochs, print_every=-1, |
|
|
dev_data=data_bundle.datasets[arg.dev_dataset_name], |
|
|
|
|
|
|
|
|
dev_data=data_bundle.get_dataset(arg.dev_dataset_name), |
|
|
metrics=AccuracyMetric(), metric_key='acc', |
|
|
metrics=AccuracyMetric(), metric_key='acc', |
|
|
device=[i for i in range(torch.cuda.device_count())], |
|
|
device=[i for i in range(torch.cuda.device_count())], |
|
|
check_code_level=-1, |
|
|
check_code_level=-1, |
|
@@ -92,7 +91,7 @@ trainer.train(load_best_model=True) |
|
|
|
|
|
|
|
|
# define tester |
|
|
# define tester |
|
|
tester = Tester( |
|
|
tester = Tester( |
|
|
data=data_bundle.datasets[arg.test_dataset_name], |
|
|
|
|
|
|
|
|
data=data_bundle.get_dataset(arg.test_dataset_name), |
|
|
model=model, |
|
|
model=model, |
|
|
metrics=AccuracyMetric(), |
|
|
metrics=AccuracyMetric(), |
|
|
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, |
|
|
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, |
|
|