You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_model_mlp_base.py 3.4 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import hetu as ht
  2. import time
  3. import argparse
  4. import os
  5. import numpy as np
  6. def fc(x, shape, name, with_relu=True, rank=-1):
  7. weight_save = np.random.normal(0, 0.04, size=shape)
  8. bias_save = np.random.normal(0, 0.04, size=shape[-1:])
  9. weight = ht.Variable(value=weight_save, name=name+'_weight')
  10. bias = ht.Variable(value=bias_save, name=name+'_bias')
  11. global args
  12. if args.save and args.rank == rank:
  13. np.save('std/' + name + '_weight.npy', weight_save)
  14. np.save('std/' + name + '_bias.npy', bias_save)
  15. x = ht.matmul_op(x, weight)
  16. x = x + ht.broadcastto_op(bias, x)
  17. if with_relu:
  18. x = ht.relu_op(x)
  19. return x
  20. if __name__ == "__main__":
  21. # argument parser
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument('--steps', type=int, default=8, help='training steps')
  24. parser.add_argument('--warmup', type=int, default=2,
  25. help='warm up steps excluded from timing')
  26. parser.add_argument('--batch-size', type=int, default=8, help='batch size')
  27. parser.add_argument('--learning-rate', type=float,
  28. default=0.00001, help='learning rate')
  29. parser.add_argument('--save', action='store_true')
  30. global args
  31. args = parser.parse_args()
  32. if args.save:
  33. comm = ht.wrapped_mpi_nccl_init()
  34. args.rank = comm.rank
  35. if args.rank == 0 and not os.path.exists('std'):
  36. os.mkdir('std')
  37. # dataset
  38. datasets = ht.data.mnist()
  39. train_set_x, train_set_y = datasets[0]
  40. valid_set_x, valid_set_y = datasets[1]
  41. test_set_x, test_set_y = datasets[2]
  42. batch_size = 10000
  43. batch_num = 5
  44. value_x_list = []
  45. value_y_list = []
  46. for i in range(batch_num):
  47. start = i * batch_size
  48. ending = (i+1) * batch_size
  49. value_x_list.append(train_set_x[start:ending])
  50. value_y_list.append(train_set_y[start:ending])
  51. # model parallel
  52. with ht.context(ht.gpu(0)):
  53. x = ht.Variable(name="dataloader_x", trainable=False)
  54. activation = fc(x, (784, 1024), 'mlp_fc1', with_relu=True, rank=0)
  55. with ht.context(ht.gpu(1)):
  56. weight_save = np.random.normal(0, 0.04, size=(1024, 2048))
  57. if args.save and args.rank == 1:
  58. np.save('std/' + 'special_weight.npy', weight_save)
  59. weight = ht.Variable(value=weight_save, name='mlp_fc1_weight')
  60. activation = ht.matmul_op(activation, weight)
  61. with ht.context(ht.gpu(2)):
  62. activation = ht.relu_op(activation)
  63. y_pred = fc(activation, (2048, 10), 'mlp_fc2', with_relu=False, rank=2)
  64. y_ = ht.Variable(name="dataloader_y", trainable=False)
  65. loss = ht.softmaxcrossentropy_op(y_pred, y_)
  66. loss = ht.reduce_mean_op(loss, [0])
  67. opt = ht.optim.SGDOptimizer(learning_rate=args.learning_rate)
  68. train_op = opt.minimize(loss)
  69. executor = ht.Executor([loss, train_op])
  70. # training
  71. for step in range(args.steps):
  72. if step == args.warmup:
  73. start = time.time()
  74. loss_val, _ = executor.run(feed_dict={
  75. x: value_x_list[step % batch_num], y_: value_y_list[step % batch_num]}, convert_to_numpy_ret_vals=True)
  76. if executor.rank == 2:
  77. print('step:', step, 'loss:', loss_val)
  78. end = time.time()
  79. if executor.rank == 2:
  80. print("time elapsed for {} steps: {}s".format(
  81. args.steps-args.warmup, round(end-start, 3)))