Browse Source

update reproduction/matching/matching_bert.py

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
bf2920cba9
1 changed files with 5 additions and 6 deletions
  1. +5
    -6
      reproduction/matching/matching_bert.py

+ 5
- 6
reproduction/matching/matching_bert.py View File

@@ -8,8 +8,7 @@ from fastNLP.core.optimizer import AdamW
from fastNLP.embeddings import BertEmbedding
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\
QNLIBertPipe, QuoraBertPipe

from reproduction.matching.model.bert import BertForNLI
from fastNLP.models.bert import BertForSentenceMatching


# define hyper-parameters
@@ -65,7 +64,7 @@ print(data_bundle) # print details in data_bundle
embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name)

# define model
model = BertForNLI(embed, class_num=len(data_bundle.vocabs[Const.TARGET]))
model = BertForSentenceMatching(embed, num_labels=len(data_bundle.vocabs[Const.TARGET]))

# define optimizer and callback
optimizer = AdamW(lr=arg.lr, params=model.parameters())
@@ -76,11 +75,11 @@ if arg.task in ['snli']:
# evaluate test set in every epoch if task is snli.

# define trainer
trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model,
trainer = Trainer(train_data=data_bundle.get_dataset(arg.train_dataset_name), model=model,
optimizer=optimizer,
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1,
dev_data=data_bundle.datasets[arg.dev_dataset_name],
dev_data=data_bundle.get_dataset(arg.dev_dataset_name),
metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1,
@@ -92,7 +91,7 @@ trainer.train(load_best_model=True)

# define tester
tester = Tester(
data=data_bundle.datasets[arg.test_dataset_name],
data=data_bundle.get_dataset(arg.test_dataset_name),
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,


Loading…
Cancel
Save