You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dist_data_pipeline_mlp.py 2.7 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import hetu as ht
  2. import os
  3. import time
  4. import argparse
  5. import numpy as np
  6. import socket
  7. def fc(x, shape, name, with_relu=True):
  8. weight = ht.init.random_normal(shape, stddev=0.04, name=name+'_weight')
  9. bias = ht.init.random_normal(shape[-1:], stddev=0.04, name=name+'_bias')
  10. x = ht.matmul_op(x, weight)
  11. x = x + ht.broadcastto_op(bias, x)
  12. if with_relu:
  13. x = ht.relu_op(x)
  14. return x
  15. if __name__ == "__main__":
  16. # argument parser
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('--warmup', type=int, default=1,
  19. help='warm up steps excluded from timing')
  20. parser.add_argument('--batch-size', type=int,
  21. default=10000, help='batch size')
  22. parser.add_argument('--learning-rate', type=float,
  23. default=0.01, help='learning rate')
  24. args = parser.parse_args()
  25. datasets = ht.data.mnist()
  26. train_set_x, train_set_y = datasets[0]
  27. valid_set_x, valid_set_y = datasets[1]
  28. test_set_x, test_set_y = datasets[2]
  29. with ht.context([ht.rgpu('daim117', 0), ht.rgpu('daim117', 1)]):
  30. x = ht.Variable(name="dataloader_x", trainable=False)
  31. activation = fc(x, (784, 1024), 'mlp_fc0', with_relu=True)
  32. with ht.context([ht.rgpu('daim117', 2), ht.rgpu('daim117', 3)]):
  33. activation = fc(activation, (1024, 1024), 'mlp_fc1', with_relu=True)
  34. activation = fc(activation, (1024, 1024), 'mlp_fc11', with_relu=True)
  35. with ht.context([ht.rgpu('daim118', 0), ht.rgpu('daim118', 1)]):
  36. activation = fc(activation, (1024, 1024), 'mlp_fc2', with_relu=True)
  37. activation = fc(activation, (1024, 1024), 'mlp_fc22', with_relu=True)
  38. with ht.context([ht.rgpu('daim118', 2), ht.rgpu('daim118', 3)]):
  39. y_pred = fc(activation, (1024, 10), 'mlp_fc3', with_relu=True)
  40. y_ = ht.Variable(name="dataloader_y", trainable=False)
  41. loss = ht.softmaxcrossentropy_op(y_pred, y_)
  42. loss = ht.reduce_mean_op(loss, [0])
  43. opt = ht.optim.SGDOptimizer(learning_rate=args.learning_rate)
  44. train_op = opt.minimize(loss)
  45. executor = ht.Executor([loss, train_op])
  46. print_ranks = [2, 3]
  47. hostname = socket.gethostname()
  48. # training
  49. steps = train_set_x.shape[0] // args.batch_size
  50. for step in range(steps):
  51. start = step * args.batch_size
  52. end = start + args.batch_size
  53. loss_val, _ = executor.run(feed_dict={
  54. x: train_set_x[start:end], y_: train_set_y[start:end]}, convert_to_numpy_ret_vals=True)
  55. if executor.local_rank in print_ranks and hostname == 'daim118':
  56. print('[step {}]: loss: {}'.format(step, loss_val[0]))