|
- import random
- import numpy as np
- import torch
-
- from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
- from fastNLP.core.callback import WarmupCallback, EvaluateCallback
- from fastNLP.core.optimizer import AdamW
- from fastNLP.embeddings import BertEmbedding
- from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\
- QNLIBertPipe, QuoraBertPipe
- from fastNLP.models.bert import BertForSentenceMatching
-
-
- # define hyper-parameters
- class BERTConfig:
-
- task = 'snli'
-
- batch_size_per_gpu = 6
- n_epochs = 6
- lr = 2e-5
- warm_up_rate = 0.1
- seed = 42
- save_path = None # 模型存储的位置,None表示不存储模型。
-
- train_dataset_name = 'train'
- dev_dataset_name = 'dev'
- test_dataset_name = 'test'
-
- to_lower = True # 忽略大小写
- tokenizer = 'spacy' # 使用spacy进行分词
-
- bert_model_dir_or_name = 'bert-base-uncased'
-
-
- arg = BERTConfig()
-
- # set random seed
- random.seed(arg.seed)
- np.random.seed(arg.seed)
- torch.manual_seed(arg.seed)
-
- n_gpu = torch.cuda.device_count()
- if n_gpu > 0:
- torch.cuda.manual_seed_all(arg.seed)
-
- # load data set
- if arg.task == 'snli':
- data_bundle = SNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
- elif arg.task == 'rte':
- data_bundle = RTEBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
- elif arg.task == 'qnli':
- data_bundle = QNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
- elif arg.task == 'mnli':
- data_bundle = MNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
- elif arg.task == 'quora':
- data_bundle = QuoraBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
- else:
- raise RuntimeError(f'NOT support {arg.task} task yet!')
-
- print(data_bundle) # print details in data_bundle
-
- # load embedding
- embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name)
-
- # define model
- model = BertForSentenceMatching(embed, num_labels=len(data_bundle.vocabs[Const.TARGET]))
-
- # define optimizer and callback
- optimizer = AdamW(lr=arg.lr, params=model.parameters())
- callbacks = [WarmupCallback(warmup=arg.warm_up_rate, schedule='linear'), ]
-
- if arg.task in ['snli']:
- callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name]))
- # evaluate test set in every epoch if task is snli.
-
- # define trainer
- trainer = Trainer(train_data=data_bundle.get_dataset(arg.train_dataset_name), model=model,
- optimizer=optimizer,
- batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
- n_epochs=arg.n_epochs, print_every=-1,
- dev_data=data_bundle.get_dataset(arg.dev_dataset_name),
- metrics=AccuracyMetric(), metric_key='acc',
- device=[i for i in range(torch.cuda.device_count())],
- check_code_level=-1,
- save_path=arg.save_path,
- callbacks=callbacks)
-
- # train model
- trainer.train(load_best_model=True)
-
- # define tester
- tester = Tester(
- data=data_bundle.get_dataset(arg.test_dataset_name),
- model=model,
- metrics=AccuracyMetric(),
- batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
- device=[i for i in range(torch.cuda.device_count())],
- )
-
- # test model
- tester.test()
-
|