diff --git a/reproduction/matching/model/bert.py b/reproduction/matching/model/bert.py index 6b13ce2a..9b3a78b2 100644 --- a/reproduction/matching/model/bert.py +++ b/reproduction/matching/model/bert.py @@ -1,13 +1,41 @@ +import torch +import torch.nn as nn + +from fastNLP.core.const import Const from fastNLP.models import BaseModel from fastNLP.modules.encoder.bert import BertModel -class BertForSNLI(BaseModel): +class BertForNLI(BaseModel): # TODO: still in progress - def __init(self): - super(BertForSNLI, self).__init__() + def __init__(self, class_num=3, bert_dir=None): + super(BertForNLI, self).__init__() + if bert_dir is not None: + self.bert = BertModel.from_pretrained(bert_dir) + else: + self.bert = BertModel() + hidden_size = self.bert.pooler.dense._parameters['bias'].size(-1) + self.classifier = nn.Linear(hidden_size, class_num) + + def forward(self, words, seq_len1, seq_len2, target=None): + """ + :param torch.Tensor words: [batch_size, seq_len] input_ids + :param torch.Tensor seq_len1: [batch_size, seq_len] token_type_ids + :param torch.Tensor seq_len2: [batch_size, seq_len] attention_mask + :param torch.Tensor target: [batch] + :return: + """ + _, pooled_output = self.bert(words, seq_len1, seq_len2) + logits = self.classifier(pooled_output) + + if target is not None: + loss_func = torch.nn.CrossEntropyLoss() + loss = loss_func(logits, target) + return {Const.OUTPUT: logits, Const.LOSS: loss} + return {Const.OUTPUT: logits} + + def predict(self, words, seq_len1, seq_len2, target=None): + return self.forward(words, seq_len1, seq_len2) - def forward(self, words, segment_id, seq_len): - pass diff --git a/reproduction/matching/snli.py b/reproduction/matching/snli.py new file mode 100644 index 00000000..b389aa11 --- /dev/null +++ b/reproduction/matching/snli.py @@ -0,0 +1,97 @@ +import os + +import torch + +from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric + +from reproduction.matching.data.SNLIDataLoader import SNLILoader +from legacy.component.bert_tokenizer import BertTokenizer +from reproduction.matching.model.bert import BertForNLI + + +def preprocess_data(data: DataSet, bert_dir): + """ + preprocess data set to bert-need data set. + :param data: + :param bert_dir: + :return: + """ + tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt')) + + vocab = Vocabulary(padding=None, unknown=None) + with open(os.path.join(bert_dir, 'vocab.txt')) as f: + lines = f.readlines() + vocab_list = [] + for line in lines: + vocab_list.append(line.strip()) + vocab.add_word_lst(vocab_list) + vocab.build_vocab() + vocab.padding = '[PAD]' + vocab.unknown = '[UNK]' + + for i in range(2): + data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])), + new_field_name=Const.INPUTS(i)) + data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'], + new_field_name=Const.INPUT) + data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), + new_field_name=Const.INPUT_LENS(0)) + data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1)) + + max_len = 512 + data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT) + data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT) + data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0)) + data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1)) + + target_vocab = Vocabulary(padding=None, unknown=None) + target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment']) + target_vocab.build_vocab() + data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET) + + data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET) + data.set_target(Const.TARGET) + + return data + + +bert_dirs = 'path/to/bert/dir' + +# load raw data set +train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl') +dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl') +test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl') + +print('successfully load data sets!') + +train_data = preprocess_data(train_data, bert_dirs) +dev_data = preprocess_data(dev_data, bert_dirs) +test_data = preprocess_data(test_data, bert_dirs) + +model = BertForNLI(bert_dir=bert_dirs) + +trainer = Trainer( + train_data=train_data, + model=model, + optimizer=Adam(lr=2e-5, model_params=model.parameters()), + batch_size=torch.cuda.device_count() * 12, + n_epochs=4, + print_every=-1, + dev_data=dev_data, + metrics=AccuracyMetric(), + metric_key='acc', + device=[i for i in range(torch.cuda.device_count())], + check_code_level=-1 +) +trainer.train(load_best_model=True) + +tester = Tester( + data=test_data, + model=model, + metrics=AccuracyMetric(), + batch_size=torch.cuda.device_count() * 12, + device=[i for i in range(torch.cuda.device_count())], +) +tester.test() + +