You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_model_pipeline_mlp.py 4.0 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import hetu as ht
  2. import time
  3. import argparse
  4. def fc(x, shape, name, with_relu=True, ctx=None):
  5. weight = ht.init.random_normal(
  6. shape=shape, stddev=0.04, name=name+'_weight', ctx=ctx)
  7. bias = ht.init.random_normal(
  8. shape=shape[-1:], stddev=0.04, name=name+'_bias', ctx=ctx)
  9. x = ht.matmul_op(x, weight)
  10. x = x + ht.broadcastto_op(bias, x)
  11. if with_relu:
  12. x = ht.relu_op(x)
  13. return x
  14. if __name__ == "__main__":
  15. # argument parser
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument('--steps', type=int, default=8, help='training steps')
  18. parser.add_argument('--warmup', type=int, default=2,
  19. help='warm up steps excluded from timing')
  20. parser.add_argument('--batch-size', type=int, default=8, help='batch size')
  21. parser.add_argument('--learning-rate', type=float,
  22. default=0.00001, help='learning rate')
  23. parser.add_argument('--split', type=str, default='left',
  24. help='left, middle, right')
  25. args = parser.parse_args()
  26. assert args.split in ('left', 'middle', 'right')
  27. # dataset
  28. datasets = ht.data.mnist()
  29. train_set_x, train_set_y = datasets[0]
  30. valid_set_x, valid_set_y = datasets[1]
  31. test_set_x, test_set_y = datasets[2]
  32. batch_size = 10000
  33. batch_num = 5
  34. value_x_list = []
  35. value_y_list = []
  36. for i in range(batch_num):
  37. start = i * batch_size
  38. ending = (i+1) * batch_size
  39. value_x_list.append(train_set_x[start:ending])
  40. value_y_list.append(train_set_y[start:ending])
  41. # model parallel
  42. with ht.context([ht.gpu(0), ht.gpu(4)]):
  43. x = ht.Variable(name="dataloader_x", trainable=False)
  44. activation = fc(x, (784, 1024), 'mlp_fc1', with_relu=True)
  45. activation = fc(activation, (1024, 2048), 'mlp_fc2', with_relu=True)
  46. activation = fc(activation, (2048, 1024), 'mlp_fc3', with_relu=True)
  47. if args.split == 'left':
  48. activation = ht.dispatch(activation, (2, 1))
  49. weight = ht.dispatch(ht.init.random_normal(
  50. shape=(1024, 2048), stddev=0.04, name='mlp_fc1_weight'), (1, 1), duplicate=2)
  51. elif args.split == 'right':
  52. activation = ht.dispatch(activation, (1, 1), duplicate=2)
  53. weight = ht.dispatch(ht.init.random_normal(
  54. shape=(1024, 2048), stddev=0.04, name='mlp_fc1_weight'), (1, 2))
  55. else:
  56. activation = ht.dispatch(activation, (1, 2))
  57. weight = ht.dispatch(ht.init.random_normal(
  58. shape=(1024, 2048), stddev=0.04, name='mlp_fc1_weight'), (2, 1))
  59. with ht.context([(ht.gpu(1), ht.gpu(2)), (ht.gpu(5), ht.gpu(6))]):
  60. activation = ht.matmul_op(activation, weight)
  61. activation = ht.dispatch(activation, (1, 1))
  62. with ht.context([ht.gpu(3), ht.gpu(7)]):
  63. activation = ht.relu_op(activation)
  64. activation = fc(activation, (2048, 2048), 'mlp_fc2', with_relu=True)
  65. activation = fc(activation, (2048, 1024), 'mlp_fc3', with_relu=True)
  66. y_pred = fc(activation, (1024, 10), 'mlp_fc3', with_relu=False)
  67. y_ = ht.Variable(name="dataloader_y", trainable=False)
  68. loss = ht.softmaxcrossentropy_op(y_pred, y_)
  69. loss = ht.reduce_mean_op(loss, [0])
  70. opt = ht.optim.SGDOptimizer(learning_rate=args.learning_rate)
  71. train_op = opt.minimize(loss)
  72. executor = ht.Executor([loss, train_op])
  73. # training
  74. for step in range(args.steps):
  75. if step == args.warmup:
  76. start = time.time()
  77. loss_val, _ = executor.run(feed_dict={
  78. x: value_x_list[step % batch_num], y_: value_y_list[step % batch_num]}, convert_to_numpy_ret_vals=True)
  79. if executor.rank == 3:
  80. print('step:', step, 'loss:', loss_val)
  81. end = time.time()
  82. if executor.rank == 3:
  83. print("time elapsed for {} steps: {}s".format(
  84. args.steps-args.warmup, round(end-start, 3)))