You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_mlp.py 4.9 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import hetu as ht
  2. from models import MLP
  3. import os
  4. import numpy as np
  5. import argparse
  6. import json
  7. from time import time
  8. if __name__ == "__main__":
  9. # argument parser
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--config', type=str, default='local',
  12. help='[local, lps(localps), lar(localallreduce), rps(remoteps), rar]')
  13. parser.add_argument('--batch-size', type=int,
  14. default=128, help='batch size')
  15. parser.add_argument('--learning-rate', type=float,
  16. default=0.1, help='learning rate')
  17. parser.add_argument('--opt', type=str, default='sgd',
  18. help='optimizer to be used, default sgd; sgd / momentum / adagrad / adam')
  19. parser.add_argument('--num-epochs', type=int,
  20. default=10, help='epoch number')
  21. parser.add_argument('--validate', action='store_true',
  22. help='whether to use validation')
  23. parser.add_argument('--timing', action='store_true',
  24. help='whether to time the training phase')
  25. args = parser.parse_args()
  26. dataset = 'MNIST'
  27. assert args.opt in ['sgd', 'momentum', 'nesterov',
  28. 'adagrad', 'adam'], 'Optimizer not supported!'
  29. if args.opt == 'sgd':
  30. print('Use SGD Optimizer.')
  31. opt = ht.optim.SGDOptimizer(learning_rate=args.learning_rate)
  32. elif args.opt == 'momentum':
  33. print('Use Momentum Optimizer.')
  34. opt = ht.optim.MomentumOptimizer(learning_rate=args.learning_rate)
  35. elif args.opt == 'nesterov':
  36. print('Use Nesterov Momentum Optimizer.')
  37. opt = ht.optim.MomentumOptimizer(
  38. learning_rate=args.learning_rate, nesterov=True)
  39. elif args.opt == 'adagrad':
  40. print('Use AdaGrad Optimizer.')
  41. opt = ht.optim.AdaGradOptimizer(
  42. learning_rate=args.learning_rate, initial_accumulator_value=0.1)
  43. else:
  44. print('Use Adam Optimizer.')
  45. opt = ht.optim.AdamOptimizer(learning_rate=args.learning_rate)
  46. # data loading
  47. print('Loading %s data...' % dataset)
  48. if dataset == 'MNIST':
  49. datasets = ht.data.mnist()
  50. train_set_x, train_set_y = datasets[0]
  51. valid_set_x, valid_set_y = datasets[1]
  52. test_set_x, test_set_y = datasets[2]
  53. # train_set_x: (50000, 784), train_set_y: (50000,)
  54. # valid_set_x: (10000, 784), valid_set_y: (10000,)
  55. # x_shape = (args.batch_size, 784)
  56. # y_shape = (args.batch_size, 10)
  57. # model definition
  58. ctx = {
  59. 'local': ht.gpu(0),
  60. 'lps': [ht.cpu(0), ht.gpu(0), ht.gpu(1), ht.gpu(4), ht.gpu(5)],
  61. 'lar': [ht.gpu(1), ht.gpu(2), ht.gpu(3), ht.gpu(6)],
  62. 'rps': ['cpu:0', 'daim118:gpu:0', 'daim118:gpu:2', 'daim118:gpu:4', 'daim118:gpu:6', 'daim117:gpu:1', 'daim117:gpu:3'],
  63. 'rar': ['daim118:gpu:0', 'daim118:gpu:2', 'daim118:gpu:4', 'daim118:gpu:6', 'daim117:gpu:1', 'daim117:gpu:3']
  64. }[args.config]
  65. with ht.context(ctx):
  66. print('Building model...')
  67. x = ht.dataloader_op([
  68. ht.Dataloader(train_set_x, args.batch_size, 'train'),
  69. ht.Dataloader(valid_set_x, args.batch_size, 'validate'),
  70. ])
  71. y_ = ht.dataloader_op([
  72. ht.Dataloader(train_set_y, args.batch_size, 'train'),
  73. ht.Dataloader(valid_set_y, args.batch_size, 'validate'),
  74. ])
  75. loss, y = MLP.mlp(x, y_)
  76. train_op = opt.minimize(loss)
  77. executor = ht.Executor(
  78. {'train': [loss, y, train_op], 'validate': [loss, y, y_]})
  79. n_train_batches = executor.get_batch_num('train')
  80. n_valid_batches = executor.get_batch_num('validate')
  81. # training
  82. print("Start training loop...")
  83. for i in range(args.num_epochs):
  84. print("Epoch %d" % i)
  85. loss_all = 0
  86. if args.timing:
  87. start = time()
  88. for minibatch_index in range(n_train_batches):
  89. loss_val, predict_y, _ = executor.run('train')
  90. loss_val = loss_val.asnumpy()
  91. loss_all += loss_val * x.dataloaders['train'].last_batch_size
  92. loss_all /= len(train_set_x)
  93. print("Loss = %f" % loss_all)
  94. if args.timing:
  95. end = time()
  96. print("Time = %f" % (end - start))
  97. if args.validate:
  98. correct_predictions = []
  99. for minibatch_index in range(n_valid_batches):
  100. loss_val, valid_y_predicted, y_val = executor.run(
  101. 'validate', convert_to_numpy_ret_vals=True)
  102. correct_prediction = np.equal(
  103. np.argmax(y_val, 1),
  104. np.argmax(valid_y_predicted, 1)).astype(np.float32)
  105. correct_predictions.extend(correct_prediction)
  106. accuracy = np.mean(correct_predictions)
  107. print("Validation accuracy = %f" % accuracy)