|
- from tqdm import tqdm
- import os
- import math
- import logging
- import hetu as ht
- from hetu_bert import BertForPreTraining
- from bert_config import BertConfig
- from load_data import DataLoader
- import numpy as np
- import time
-
- ''' Usage example:
- In dir Hetu/examples/nlp/bert/: python train_hetu_bert.py
- '''
-
- device_id=6
- executor_ctx = ht.gpu(device_id)
-
- num_epochs = 1
- lr = 1e-4
-
- config = BertConfig(vocab_size=30522,
- hidden_size=768,
- num_hidden_layers=12,
- num_attention_heads=12,
- intermediate_size=3072,
- max_position_embeddings=512,
- #attention_probs_dropout_prob=0.0,
- #hidden_dropout_prob=0.0,
- batch_size=6)
-
- model = BertForPreTraining(config=config)
-
- batch_size = config.batch_size
- seq_len = config.max_position_embeddings
- vocab_size = config.vocab_size
-
- dataloader = DataLoader(dataset='bookcorpus', doc_num=200, save_gap=200, batch_size = batch_size)
- data_names = ['input_ids','token_type_ids','attention_mask','masked_lm_labels','next_sentence_label']
-
- input_ids = ht.Variable(name='input_ids', trainable=False)
- token_type_ids = ht.Variable(name='token_type_ids', trainable=False)
- attention_mask = ht.Variable(name='attention_mask', trainable=False)
-
- masked_lm_labels = ht.Variable(name='masked_lm_labels_one_hot', trainable=False)
- next_sentence_label = ht.Variable(name='next_sentence_label_one_hot', trainable=False)
-
- loss_position_sum = ht.Variable(name='loss_position_sum', trainable=False)
-
- _,_, masked_lm_loss, next_sentence_loss = model(input_ids, token_type_ids, attention_mask, masked_lm_labels, next_sentence_label)
-
- masked_lm_loss_mean = ht.div_op(ht.reduce_sum_op(masked_lm_loss, [0,1]), loss_position_sum)
- next_sentence_loss_mean = ht.reduce_mean_op(next_sentence_loss, [0])
-
- loss = masked_lm_loss_mean + next_sentence_loss_mean
- #opt = optimizer.AdamOptimizer(learning_rate=lr, beta1=0.9, beta2=0.999, epsilon=1e-8)
- opt = ht.optim.SGDOptimizer(learning_rate=lr)
- train_op = opt.minimize(loss)
-
- executor = ht.Executor([masked_lm_loss_mean, next_sentence_loss_mean, loss, train_op],ctx=executor_ctx,dynamic_memory=True)
-
-
- dataloader.make_epoch_data()
- for ep in range(num_epochs):
- for i in range(dataloader.batch_num):
- batch_data = dataloader.get_batch(i)
-
- feed_dict = {
- input_ids: batch_data['input_ids'],
- token_type_ids: batch_data['token_type_ids'],
- attention_mask: batch_data['attention_mask'],
- masked_lm_labels: batch_data['masked_lm_labels'],
- next_sentence_label: batch_data['next_sentence_label'],
- loss_position_sum: np.array([np.where(batch_data['masked_lm_labels'].reshape(-1)!=-1)[0].shape[0]]),
- }
-
- start_time = time.time()
- results = executor.run(feed_dict = feed_dict)
- end_time = time.time()
-
- masked_lm_loss_mean_out = results[0].asnumpy()
- next_sentence_loss_mean_out = results[1].asnumpy()
- loss_out = results[2].asnumpy()
-
- print('[Epoch %d] (Iteration %d): Loss = %.3f, MLM_loss = %.3f, NSP_loss = %.6f, Time = %.3f'%(ep,i,loss_out, masked_lm_loss_mean_out, next_sentence_loss_mean_out, end_time-start_time))
-
|