Browse Source

update bert for nli in reproduction/matching

tags/v0.4.10
xuyige 5 years ago
parent
commit
8e82c91751
2 changed files with 130 additions and 5 deletions
  1. +33
    -5
      reproduction/matching/model/bert.py
  2. +97
    -0
      reproduction/matching/snli.py

+ 33
- 5
reproduction/matching/model/bert.py View File

@@ -1,13 +1,41 @@

import torch
import torch.nn as nn

from fastNLP.core.const import Const
from fastNLP.models import BaseModel
from fastNLP.modules.encoder.bert import BertModel


class BertForSNLI(BaseModel):
class BertForNLI(BaseModel):
# TODO: still in progress

def __init(self):
super(BertForSNLI, self).__init__()
def __init__(self, class_num=3, bert_dir=None):
super(BertForNLI, self).__init__()
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
else:
self.bert = BertModel()
hidden_size = self.bert.pooler.dense._parameters['bias'].size(-1)
self.classifier = nn.Linear(hidden_size, class_num)

def forward(self, words, seq_len1, seq_len2, target=None):
"""
:param torch.Tensor words: [batch_size, seq_len] input_ids
:param torch.Tensor seq_len1: [batch_size, seq_len] token_type_ids
:param torch.Tensor seq_len2: [batch_size, seq_len] attention_mask
:param torch.Tensor target: [batch]
:return:
"""
_, pooled_output = self.bert(words, seq_len1, seq_len2)
logits = self.classifier(pooled_output)

if target is not None:
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits, target)
return {Const.OUTPUT: logits, Const.LOSS: loss}
return {Const.OUTPUT: logits}

def predict(self, words, seq_len1, seq_len2, target=None):
return self.forward(words, seq_len1, seq_len2)

def forward(self, words, segment_id, seq_len):
pass

+ 97
- 0
reproduction/matching/snli.py View File

@@ -0,0 +1,97 @@
import os

import torch

from fastNLP.core import Vocabulary, DataSet, Trainer, Tester, Const, Adam, AccuracyMetric

from reproduction.matching.data.SNLIDataLoader import SNLILoader
from legacy.component.bert_tokenizer import BertTokenizer
from reproduction.matching.model.bert import BertForNLI


def preprocess_data(data: DataSet, bert_dir):
"""
preprocess data set to bert-need data set.
:param data:
:param bert_dir:
:return:
"""
tokenizer = BertTokenizer.from_pretrained(os.path.join(bert_dir, 'vocab.txt'))

vocab = Vocabulary(padding=None, unknown=None)
with open(os.path.join(bert_dir, 'vocab.txt')) as f:
lines = f.readlines()
vocab_list = []
for line in lines:
vocab_list.append(line.strip())
vocab.add_word_lst(vocab_list)
vocab.build_vocab()
vocab.padding = '[PAD]'
vocab.unknown = '[UNK]'

for i in range(2):
data.apply(lambda x: tokenizer.tokenize(" ".join(x[Const.INPUTS(i)])),
new_field_name=Const.INPUTS(i))
data.apply(lambda x: ['[CLS]'] + x[Const.INPUTS(0)] + ['[SEP]'] + x[Const.INPUTS(1)] + ['[SEP]'],
new_field_name=Const.INPUT)
data.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
new_field_name=Const.INPUT_LENS(0))
data.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), new_field_name=Const.INPUT_LENS(1))

max_len = 512
data.apply(lambda x: x[Const.INPUT][: max_len], new_field_name=Const.INPUT)
data.apply(lambda x: [vocab.to_index(w) for w in x[Const.INPUT]], new_field_name=Const.INPUT)
data.apply(lambda x: x[Const.INPUT_LENS(0)][: max_len], new_field_name=Const.INPUT_LENS(0))
data.apply(lambda x: x[Const.INPUT_LENS(1)][: max_len], new_field_name=Const.INPUT_LENS(1))

target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.add_word_lst(['neutral', 'contradiction', 'entailment'])
target_vocab.build_vocab()
data.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET)

data.set_input(Const.INPUT, Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET)
data.set_target(Const.TARGET)

return data


bert_dirs = 'path/to/bert/dir'

# load raw data set
train_data = SNLILoader().load('./data/snli/snli_1.0_train.jsonl')
dev_data = SNLILoader().load('./data/snli/snli_1.0_dev.jsonl')
test_data = SNLILoader().load('./data/snli/snli_1.0_test.jsonl')

print('successfully load data sets!')

train_data = preprocess_data(train_data, bert_dirs)
dev_data = preprocess_data(dev_data, bert_dirs)
test_data = preprocess_data(test_data, bert_dirs)

model = BertForNLI(bert_dir=bert_dirs)

trainer = Trainer(
train_data=train_data,
model=model,
optimizer=Adam(lr=2e-5, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 12,
n_epochs=4,
print_every=-1,
dev_data=dev_data,
metrics=AccuracyMetric(),
metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1
)
trainer.train(load_best_model=True)

tester = Tester(
data=test_data,
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * 12,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()



Loading…
Cancel
Save