Browse Source

ADD file via upload

pull/1/head
pfgqbl2ej 4 years ago
parent
commit
e3a02afa4b
1 changed files with 87 additions and 0 deletions
  1. +87
    -0
      examples/nlp/bert/train_hetu_bert.py

+ 87
- 0
examples/nlp/bert/train_hetu_bert.py View File

@@ -0,0 +1,87 @@
from tqdm import tqdm
import os
import math
import logging
import hetu as ht
from hetu_bert import BertForPreTraining
from bert_config import BertConfig
from load_data import DataLoader
import numpy as np
import time

''' Usage example:
In dir Hetu/examples/nlp/bert/: python train_hetu_bert.py
'''

device_id=6
executor_ctx = ht.gpu(device_id)

num_epochs = 1
lr = 1e-4

config = BertConfig(vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
max_position_embeddings=512,
#attention_probs_dropout_prob=0.0,
#hidden_dropout_prob=0.0,
batch_size=6)

model = BertForPreTraining(config=config)

batch_size = config.batch_size
seq_len = config.max_position_embeddings
vocab_size = config.vocab_size

dataloader = DataLoader(dataset='bookcorpus', doc_num=200, save_gap=200, batch_size = batch_size)
data_names = ['input_ids','token_type_ids','attention_mask','masked_lm_labels','next_sentence_label']

input_ids = ht.Variable(name='input_ids', trainable=False)
token_type_ids = ht.Variable(name='token_type_ids', trainable=False)
attention_mask = ht.Variable(name='attention_mask', trainable=False)

masked_lm_labels = ht.Variable(name='masked_lm_labels_one_hot', trainable=False)
next_sentence_label = ht.Variable(name='next_sentence_label_one_hot', trainable=False)

loss_position_sum = ht.Variable(name='loss_position_sum', trainable=False)

_,_, masked_lm_loss, next_sentence_loss = model(input_ids, token_type_ids, attention_mask, masked_lm_labels, next_sentence_label)

masked_lm_loss_mean = ht.div_op(ht.reduce_sum_op(masked_lm_loss, [0,1]), loss_position_sum)
next_sentence_loss_mean = ht.reduce_mean_op(next_sentence_loss, [0])

loss = masked_lm_loss_mean + next_sentence_loss_mean
#opt = optimizer.AdamOptimizer(learning_rate=lr, beta1=0.9, beta2=0.999, epsilon=1e-8)
opt = ht.optim.SGDOptimizer(learning_rate=lr)
train_op = opt.minimize(loss)

executor = ht.Executor([masked_lm_loss_mean, next_sentence_loss_mean, loss, train_op],ctx=executor_ctx,dynamic_memory=True)


dataloader.make_epoch_data()
for ep in range(num_epochs):
for i in range(dataloader.batch_num):
batch_data = dataloader.get_batch(i)

feed_dict = {
input_ids: batch_data['input_ids'],
token_type_ids: batch_data['token_type_ids'],
attention_mask: batch_data['attention_mask'],
masked_lm_labels: batch_data['masked_lm_labels'],
next_sentence_label: batch_data['next_sentence_label'],
loss_position_sum: np.array([np.where(batch_data['masked_lm_labels'].reshape(-1)!=-1)[0].shape[0]]),
}
start_time = time.time()
results = executor.run(feed_dict = feed_dict)
end_time = time.time()

masked_lm_loss_mean_out = results[0].asnumpy()
next_sentence_loss_mean_out = results[1].asnumpy()
loss_out = results[2].asnumpy()

print('[Epoch %d] (Iteration %d): Loss = %.3f, MLM_loss = %.3f, NSP_loss = %.6f, Time = %.3f'%(ep,i,loss_out, masked_lm_loss_mean_out, next_sentence_loss_mean_out, end_time-start_time))



Loading…
Cancel
Save