Browse Source

update reproduction/matching/matching_bert.py

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
bf2920cba9
1 changed files with 5 additions and 6 deletions
  1. +5
    -6
      reproduction/matching/matching_bert.py

+ 5
- 6
reproduction/matching/matching_bert.py View File

@@ -8,8 +8,7 @@ from fastNLP.core.optimizer import AdamW
from fastNLP.embeddings import BertEmbedding from fastNLP.embeddings import BertEmbedding
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\ from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\
QNLIBertPipe, QuoraBertPipe QNLIBertPipe, QuoraBertPipe

from reproduction.matching.model.bert import BertForNLI
from fastNLP.models.bert import BertForSentenceMatching




# define hyper-parameters # define hyper-parameters
@@ -65,7 +64,7 @@ print(data_bundle) # print details in data_bundle
embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name) embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name)


# define model # define model
model = BertForNLI(embed, class_num=len(data_bundle.vocabs[Const.TARGET]))
model = BertForSentenceMatching(embed, num_labels=len(data_bundle.vocabs[Const.TARGET]))


# define optimizer and callback # define optimizer and callback
optimizer = AdamW(lr=arg.lr, params=model.parameters()) optimizer = AdamW(lr=arg.lr, params=model.parameters())
@@ -76,11 +75,11 @@ if arg.task in ['snli']:
# evaluate test set in every epoch if task is snli. # evaluate test set in every epoch if task is snli.


# define trainer # define trainer
trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model,
trainer = Trainer(train_data=data_bundle.get_dataset(arg.train_dataset_name), model=model,
optimizer=optimizer, optimizer=optimizer,
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1, n_epochs=arg.n_epochs, print_every=-1,
dev_data=data_bundle.datasets[arg.dev_dataset_name],
dev_data=data_bundle.get_dataset(arg.dev_dataset_name),
metrics=AccuracyMetric(), metric_key='acc', metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())], device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1, check_code_level=-1,
@@ -92,7 +91,7 @@ trainer.train(load_best_model=True)


# define tester # define tester
tester = Tester( tester = Tester(
data=data_bundle.datasets[arg.test_dataset_name],
data=data_bundle.get_dataset(arg.test_dataset_name),
model=model, model=model,
metrics=AccuracyMetric(), metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,


Loading…
Cancel
Save