You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

complex_pipeline_mlp.py 7.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import hetu as ht
  2. from hetu import stream
  3. from hetu import init
  4. import os
  5. import sys
  6. import json
  7. import time
  8. import argparse
  9. import numpy as np
  10. import logging
  11. np.random.seed(123)
  12. def convert_to_one_hot(vals, max_val=0):
  13. """Helper method to convert label array to one-hot array."""
  14. if max_val == 0:
  15. max_val = vals.max() + 1
  16. one_hot_vals = np.zeros((vals.size, max_val))
  17. one_hot_vals[np.arange(vals.size), vals] = 1
  18. return one_hot_vals
  19. def fc(x, shape, name, with_relu=True, ctx=None):
  20. weight = init.random_normal(
  21. shape=shape, stddev=0.04, name=name+'_weight', ctx=ctx)
  22. bias = init.random_normal(
  23. shape=shape[-1:], stddev=0.04, name=name+'_bias', ctx=ctx)
  24. x = ht.matmul_op(x, weight)
  25. x = x + ht.broadcastto_op(bias, x)
  26. if with_relu:
  27. x = ht.relu_op(x)
  28. return x
  29. if __name__ == "__main__":
  30. # argument parser
  31. parser = argparse.ArgumentParser()
  32. parser.add_argument('--steps', type=int, default=8, help='training steps')
  33. parser.add_argument('--warmup', type=int, default=2,
  34. help='warm up steps excluded from timing')
  35. parser.add_argument('--batch-size', type=int, default=8, help='batch size')
  36. parser.add_argument('--learning-rate', type=float,
  37. default=0.00001, help='learning rate')
  38. args = parser.parse_args()
  39. # init and opt for both ranks
  40. comm = ht.wrapped_mpi_nccl_init()
  41. device_id = comm.dev_id
  42. print("mpi_nccl init for gpu device: {}".format(device_id))
  43. executor_ctx = ht.gpu(device_id)
  44. opt = ht.optim.SGDOptimizer(learning_rate=args.learning_rate)
  45. # init logger
  46. logger = logging.getLogger()
  47. ch = logging.StreamHandler()
  48. formatter = logging.Formatter('[rank{}, PID{}]'.format(
  49. device_id, os.getpid()) + ' %(asctime)s: %(message)s')
  50. ch.setLevel(logging.DEBUG)
  51. ch.setFormatter(formatter)
  52. logger.addHandler(ch)
  53. log = logger.warning
  54. # nccl communicate stream for pipeline_send/receive
  55. communicate_stream = stream.create_stream_handle(executor_ctx)
  56. # dataset
  57. datasets = ht.data.mnist()
  58. train_set_x, train_set_y = datasets[0]
  59. valid_set_x, valid_set_y = datasets[1]
  60. test_set_x, test_set_y = datasets[2]
  61. batch_size = 10000
  62. batch_num = 5
  63. value_x_list = []
  64. value_y_list = []
  65. for i in range(batch_num):
  66. start = i * batch_size
  67. ending = (i+1) * batch_size
  68. value_x_list.append(train_set_x[start:ending])
  69. value_y_list.append(train_set_y[start:ending])
  70. x = ht.Variable(name="dataloader_x", trainable=False)
  71. y_ = ht.Variable(name="dataloader_y", trainable=False)
  72. # model parallel
  73. if comm.myRank.value == 0:
  74. # rank0
  75. # forward
  76. activation = fc(x, (784, 1024), 'mlp_fc1', with_relu=True,
  77. ctx=ht.gpu(comm.localRank.value))
  78. activation = fc(activation, (1024, 2048), 'mlp_fc2',
  79. with_relu=True, ctx=ht.gpu(comm.localRank.value))
  80. activation = fc(activation, (2048, 1024), 'mlp_fc3',
  81. with_relu=True, ctx=ht.gpu(comm.localRank.value))
  82. activation_send_op = ht.pipeline_send_op(
  83. activation, 1, comm, stream=communicate_stream)
  84. # backward
  85. gradient_receive_op = ht.pipeline_receive_op(
  86. 1, comm, ctx=executor_ctx, stream=communicate_stream)
  87. required_vars = opt.get_var_list(activation)
  88. opt.params = required_vars
  89. grads = ht.gradients(activation, required_vars,
  90. insert_grad=gradient_receive_op)
  91. train_op = ht.optim.OptimizerOp(grads, opt)
  92. executor = ht.Executor(
  93. [activation_send_op, train_op], ctx=executor_ctx)
  94. elif comm.myRank.value != 7:
  95. # from rank1 to rank6
  96. previous_rank = comm.myRank.value - 1
  97. next_rank = comm.myRank.value + 1
  98. # 1. receive activation from previous rank
  99. activation_receive_op = ht.pipeline_receive_op(
  100. previous_rank, comm, ctx=executor_ctx, stream=communicate_stream)
  101. # forward
  102. activation = fc(activation_receive_op, (1024, 2048), 'mlp_fc1',
  103. with_relu=True, ctx=ht.gpu(comm.localRank.value))
  104. activation = fc(activation, (2048, 2048), 'mlp_fc2',
  105. with_relu=True, ctx=ht.gpu(comm.localRank.value))
  106. activation = fc(activation, (2048, 1024), 'mlp_fc3',
  107. with_relu=True, ctx=ht.gpu(comm.localRank.value))
  108. # 2. send activation to next rank
  109. activation_send_op = ht.pipeline_send_op(
  110. activation, next_rank, comm, ctx=executor_ctx, stream=communicate_stream)
  111. # 3. receive gradients from next rank
  112. gradient_receive_op = ht.pipeline_receive_op(
  113. next_rank, comm, ctx=executor_ctx, stream=communicate_stream)
  114. # backward
  115. required_vars = opt.get_var_list(activation)
  116. opt.params = required_vars
  117. required_vars = [activation_receive_op] + required_vars
  118. grads = ht.gradients(activation, required_vars,
  119. insert_grad=gradient_receive_op)
  120. train_op = ht.optim.OptimizerOp(grads[1:], opt)
  121. # 4. send gradients to previous rank
  122. sendback_grad_op = ht.pipeline_send_op(
  123. grads[0], previous_rank, comm, stream=communicate_stream)
  124. executor = ht.Executor(
  125. [activation_send_op, sendback_grad_op, train_op], ctx=executor_ctx)
  126. else:
  127. # rank7
  128. activation_receive_op = ht.pipeline_receive_op(
  129. 6, comm, ctx=executor_ctx, stream=communicate_stream)
  130. # forward
  131. activation = fc(activation_receive_op, (1024, 2048), 'mlp_fc1',
  132. with_relu=True, ctx=ht.gpu(comm.localRank.value))
  133. activation = fc(activation, (2048, 1024), 'mlp_fc2',
  134. with_relu=True, ctx=ht.gpu(comm.localRank.value))
  135. y_pred = fc(activation, (1024, 10), 'mlp_fc3', with_relu=False)
  136. loss = ht.softmaxcrossentropy_op(y_pred, y_)
  137. loss = ht.reduce_mean_op(loss, [0])
  138. # backward
  139. required_vars = opt.get_var_list(loss)
  140. opt.params = required_vars
  141. required_vars = [activation_receive_op] + required_vars
  142. grads = ht.gradients(loss, required_vars)
  143. train_op = ht.optim.OptimizerOp(grads[1:], opt)
  144. sendback_grad_op = ht.pipeline_send_op(
  145. grads[0], 6, comm, stream=communicate_stream)
  146. executor = ht.Executor(
  147. [loss, sendback_grad_op, train_op], ctx=executor_ctx)
  148. # training
  149. for step in range(args.steps):
  150. if step == args.warmup:
  151. start = time.time()
  152. if comm.myRank.value == 0:
  153. log("step {}:".format(step))
  154. if comm.myRank.value == 0:
  155. executor.run(feed_dict={x: value_x_list[step % batch_num]})
  156. log("gpu0 ok")
  157. elif comm.myRank.value == 7:
  158. loss, _, _ = executor.run(
  159. feed_dict={y_: value_y_list[step % batch_num]}, convert_to_numpy_ret_vals=True)
  160. log("gpu7 ok, loss: {}".format(loss[0]))
  161. else:
  162. executor.run()
  163. log("gpu{} ok".format(comm.myRank.value))
  164. # comm.stream.sync()
  165. if communicate_stream:
  166. communicate_stream.sync()
  167. end = time.time()
  168. log("time elapsed for {} steps: {}s".format(
  169. args.steps-args.warmup, round(end-start, 3)))