You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_hetu_bert.py 3.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from tqdm import tqdm
  2. import os
  3. import math
  4. import logging
  5. import hetu as ht
  6. from hetu_bert import BertForPreTraining
  7. from bert_config import BertConfig
  8. from load_data import DataLoader
  9. import numpy as np
  10. import time
  11. ''' Usage example:
  12. In dir Hetu/examples/nlp/bert/: python train_hetu_bert.py
  13. '''
  14. device_id=6
  15. executor_ctx = ht.gpu(device_id)
  16. num_epochs = 1
  17. lr = 1e-4
  18. config = BertConfig(vocab_size=30522,
  19. hidden_size=768,
  20. num_hidden_layers=12,
  21. num_attention_heads=12,
  22. intermediate_size=3072,
  23. max_position_embeddings=512,
  24. #attention_probs_dropout_prob=0.0,
  25. #hidden_dropout_prob=0.0,
  26. batch_size=6)
  27. model = BertForPreTraining(config=config)
  28. batch_size = config.batch_size
  29. seq_len = config.max_position_embeddings
  30. vocab_size = config.vocab_size
  31. dataloader = DataLoader(dataset='bookcorpus', doc_num=200, save_gap=200, batch_size = batch_size)
  32. data_names = ['input_ids','token_type_ids','attention_mask','masked_lm_labels','next_sentence_label']
  33. input_ids = ht.Variable(name='input_ids', trainable=False)
  34. token_type_ids = ht.Variable(name='token_type_ids', trainable=False)
  35. attention_mask = ht.Variable(name='attention_mask', trainable=False)
  36. masked_lm_labels = ht.Variable(name='masked_lm_labels_one_hot', trainable=False)
  37. next_sentence_label = ht.Variable(name='next_sentence_label_one_hot', trainable=False)
  38. loss_position_sum = ht.Variable(name='loss_position_sum', trainable=False)
  39. _,_, masked_lm_loss, next_sentence_loss = model(input_ids, token_type_ids, attention_mask, masked_lm_labels, next_sentence_label)
  40. masked_lm_loss_mean = ht.div_op(ht.reduce_sum_op(masked_lm_loss, [0,1]), loss_position_sum)
  41. next_sentence_loss_mean = ht.reduce_mean_op(next_sentence_loss, [0])
  42. loss = masked_lm_loss_mean + next_sentence_loss_mean
  43. #opt = optimizer.AdamOptimizer(learning_rate=lr, beta1=0.9, beta2=0.999, epsilon=1e-8)
  44. opt = ht.optim.SGDOptimizer(learning_rate=lr)
  45. train_op = opt.minimize(loss)
  46. executor = ht.Executor([masked_lm_loss_mean, next_sentence_loss_mean, loss, train_op],ctx=executor_ctx,dynamic_memory=True)
  47. dataloader.make_epoch_data()
  48. for ep in range(num_epochs):
  49. for i in range(dataloader.batch_num):
  50. batch_data = dataloader.get_batch(i)
  51. feed_dict = {
  52. input_ids: batch_data['input_ids'],
  53. token_type_ids: batch_data['token_type_ids'],
  54. attention_mask: batch_data['attention_mask'],
  55. masked_lm_labels: batch_data['masked_lm_labels'],
  56. next_sentence_label: batch_data['next_sentence_label'],
  57. loss_position_sum: np.array([np.where(batch_data['masked_lm_labels'].reshape(-1)!=-1)[0].shape[0]]),
  58. }
  59. start_time = time.time()
  60. results = executor.run(feed_dict = feed_dict)
  61. end_time = time.time()
  62. masked_lm_loss_mean_out = results[0].asnumpy()
  63. next_sentence_loss_mean_out = results[1].asnumpy()
  64. loss_out = results[2].asnumpy()
  65. print('[Epoch %d] (Iteration %d): Loss = %.3f, MLM_loss = %.3f, NSP_loss = %.6f, Time = %.3f'%(ep,i,loss_out, masked_lm_loss_mean_out, next_sentence_loss_mean_out, end_time-start_time))

分布式深度学习系统