You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_model_mlp.py 3.5 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import hetu as ht
  2. import time
  3. import argparse
  4. import numpy as np
  5. def fc(x, shape, name, with_relu=True, ctx=None):
  6. weight_save = np.load('std/' + name + '_weight.npy')
  7. bias_save = np.load('std/' + name + '_bias.npy')
  8. weight = ht.Variable(value=weight_save, name=name+'_weight', ctx=ctx)
  9. bias = ht.Variable(value=bias_save, name=name+'_bias', ctx=ctx)
  10. x = ht.matmul_op(x, weight)
  11. x = x + ht.broadcastto_op(bias, x)
  12. if with_relu:
  13. x = ht.relu_op(x)
  14. return x
  15. if __name__ == "__main__":
  16. # argument parser
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('--steps', type=int, default=8, help='training steps')
  19. parser.add_argument('--warmup', type=int, default=2,
  20. help='warm up steps excluded from timing')
  21. parser.add_argument('--batch-size', type=int, default=8, help='batch size')
  22. parser.add_argument('--learning-rate', type=float,
  23. default=0.00001, help='learning rate')
  24. parser.add_argument('--split', type=str, default='left')
  25. args = parser.parse_args()
  26. assert args.split in ('left', 'right', 'middle')
  27. # dataset
  28. datasets = ht.data.mnist()
  29. train_set_x, train_set_y = datasets[0]
  30. valid_set_x, valid_set_y = datasets[1]
  31. test_set_x, test_set_y = datasets[2]
  32. batch_size = 10000
  33. batch_num = 5
  34. value_x_list = []
  35. value_y_list = []
  36. for i in range(batch_num):
  37. start = i * batch_size
  38. ending = (i+1) * batch_size
  39. value_x_list.append(train_set_x[start:ending])
  40. value_y_list.append(train_set_y[start:ending])
  41. # model parallel
  42. with ht.context(ht.gpu(0)):
  43. x = ht.Variable(name="dataloader_x", trainable=False)
  44. activation = fc(x, (784, 1024), 'mlp_fc1', with_relu=True)
  45. weight_save = np.load('std/' + 'special_weight.npy')
  46. weight = ht.Variable(value=weight_save, name='mlp_fc1_weight')
  47. if args.split == 'left':
  48. activation = ht.dispatch(activation, (2, 1))
  49. weight = ht.dispatch(weight, (1, 1), duplicate=2)
  50. elif args.split == 'right':
  51. activation = ht.dispatch(activation, (1, 1), duplicate=2)
  52. weight = ht.dispatch(weight, (1, 2))
  53. else:
  54. activation = ht.dispatch(activation, (1, 2))
  55. weight = ht.dispatch(weight, (2, 1))
  56. with ht.context((ht.gpu(1), ht.gpu(2))):
  57. activation = ht.matmul_op(activation, weight)
  58. activation = ht.dispatch(activation, (1, 1))
  59. with ht.context(ht.gpu(3)):
  60. activation = ht.relu_op(activation)
  61. y_pred = fc(activation, (2048, 10), 'mlp_fc2', with_relu=False)
  62. y_ = ht.Variable(name="dataloader_y", trainable=False)
  63. loss = ht.softmaxcrossentropy_op(y_pred, y_)
  64. loss = ht.reduce_mean_op(loss, [0])
  65. opt = ht.optim.SGDOptimizer(learning_rate=args.learning_rate)
  66. train_op = opt.minimize(loss)
  67. executor = ht.Executor([loss, train_op])
  68. # training
  69. for step in range(args.steps):
  70. if step == args.warmup:
  71. start = time.time()
  72. loss_val, _ = executor.run(feed_dict={
  73. x: value_x_list[step % batch_num], y_: value_y_list[step % batch_num]}, convert_to_numpy_ret_vals=True)
  74. if executor.rank == 3:
  75. print('step:', step, 'loss:', loss_val)
  76. end = time.time()
  77. if executor.rank == 3:
  78. print("time elapsed for {} steps: {}s".format(
  79. args.steps-args.warmup, round(end-start, 3)))