From bf2920cba98ff6c534ccbc5c16fe0942411cb360 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Mon, 16 Sep 2019 15:09:20 +0800 Subject: [PATCH] update reproduction/matching/matching_bert.py --- reproduction/matching/matching_bert.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/reproduction/matching/matching_bert.py b/reproduction/matching/matching_bert.py index 323d81a3..05377dff 100644 --- a/reproduction/matching/matching_bert.py +++ b/reproduction/matching/matching_bert.py @@ -8,8 +8,7 @@ from fastNLP.core.optimizer import AdamW from fastNLP.embeddings import BertEmbedding from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\ QNLIBertPipe, QuoraBertPipe - -from reproduction.matching.model.bert import BertForNLI +from fastNLP.models.bert import BertForSentenceMatching # define hyper-parameters @@ -65,7 +64,7 @@ print(data_bundle) # print details in data_bundle embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name) # define model -model = BertForNLI(embed, class_num=len(data_bundle.vocabs[Const.TARGET])) +model = BertForSentenceMatching(embed, num_labels=len(data_bundle.vocabs[Const.TARGET])) # define optimizer and callback optimizer = AdamW(lr=arg.lr, params=model.parameters()) @@ -76,11 +75,11 @@ if arg.task in ['snli']: # evaluate test set in every epoch if task is snli. # define trainer -trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model, +trainer = Trainer(train_data=data_bundle.get_dataset(arg.train_dataset_name), model=model, optimizer=optimizer, batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, n_epochs=arg.n_epochs, print_every=-1, - dev_data=data_bundle.datasets[arg.dev_dataset_name], + dev_data=data_bundle.get_dataset(arg.dev_dataset_name), metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], check_code_level=-1, @@ -92,7 +91,7 @@ trainer.train(load_best_model=True) # define tester tester = Tester( - data=data_bundle.datasets[arg.test_dataset_name], + data=data_bundle.get_dataset(arg.test_dataset_name), model=model, metrics=AccuracyMetric(), batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,