You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 8.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import hetu as ht
  2. import models
  3. import os
  4. import numpy as np
  5. import argparse
  6. import json
  7. import logging
  8. from time import time
  9. logging.basicConfig(level=logging.INFO,
  10. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  11. logger = logging.getLogger(__name__)
  12. def print_rank0(msg):
  13. if device_id == 0:
  14. logger.info(msg)
  15. if __name__ == "__main__":
  16. # argument parser
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('--model', type=str, required=True,
  19. help='model to be tested')
  20. parser.add_argument('--dataset', type=str, required=True,
  21. help='dataset to be trained on')
  22. parser.add_argument('--batch-size', type=int,
  23. default=128, help='batch size')
  24. parser.add_argument('--learning-rate', type=float,
  25. default=0.1, help='learning rate')
  26. parser.add_argument('--opt', type=str, default='sgd',
  27. help='optimizer to be used, default sgd; sgd / momentum / adagrad / adam')
  28. parser.add_argument('--num-epochs', type=int,
  29. default=10, help='epoch number')
  30. parser.add_argument('--gpu', type=int, default=0,
  31. help='gpu to be used, -1 means cpu')
  32. parser.add_argument('--validate', action='store_true',
  33. help='whether to use validation')
  34. parser.add_argument('--timing', action='store_true',
  35. help='whether to time the training phase')
  36. parser.add_argument('--comm-mode', default=None, help='communication mode')
  37. args = parser.parse_args()
  38. global device_id
  39. device_id = 0
  40. print_rank0("Training {} on HETU".format(args.model))
  41. if args.comm_mode in ('AllReduce', 'Hybrid'):
  42. comm, device_id = ht.mpi_nccl_init()
  43. executor_ctx = ht.gpu(device_id % 8) if args.gpu >= 0 else ht.cpu(0)
  44. else:
  45. if args.gpu == -1:
  46. executor_ctx = ht.cpu(0)
  47. print_rank0('Use CPU.')
  48. else:
  49. executor_ctx = ht.gpu(args.gpu)
  50. print_rank0('Use GPU %d.' % args.gpu)
  51. if args.comm_mode in ('PS', 'Hybrid'):
  52. settings_file = open(os.path.join(os.path.abspath(
  53. os.path.dirname(__file__)), 'worker_conf%d.json' % args.gpu))
  54. settings = json.load(settings_file)
  55. for key in settings:
  56. if type(settings[key]) == str:
  57. os.environ[key] = settings[key]
  58. else:
  59. os.environ[key] = str(settings[key]) # type is str
  60. assert args.model in ['alexnet', 'cnn_3_layers', 'lenet', 'logreg', 'lstm', 'mlp', 'resnet18', 'resnet34', 'rnn', 'vgg16', 'vgg19'], \
  61. 'Model not supported!'
  62. model = eval('models.' + args.model)
  63. assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
  64. dataset = args.dataset
  65. assert args.opt in ['sgd', 'momentum', 'nesterov',
  66. 'adagrad', 'adam'], 'Optimizer not supported!'
  67. if args.opt == 'sgd':
  68. print_rank0('Use SGD Optimizer.')
  69. opt = ht.optim.SGDOptimizer(learning_rate=args.learning_rate)
  70. elif args.opt == 'momentum':
  71. print_rank0('Use Momentum Optimizer.')
  72. opt = ht.optim.MomentumOptimizer(learning_rate=args.learning_rate)
  73. elif args.opt == 'nesterov':
  74. print_rank0('Use Nesterov Momentum Optimizer.')
  75. opt = ht.optim.MomentumOptimizer(
  76. learning_rate=args.learning_rate, nesterov=True)
  77. elif args.opt == 'adagrad':
  78. print_rank0('Use AdaGrad Optimizer.')
  79. opt = ht.optim.AdaGradOptimizer(
  80. learning_rate=args.learning_rate, initial_accumulator_value=0.1)
  81. else:
  82. print_rank0('Use Adam Optimizer.')
  83. opt = ht.optim.AdamOptimizer(learning_rate=args.learning_rate)
  84. # data loading
  85. print_rank0('Loading %s data...' % dataset)
  86. if dataset == 'MNIST':
  87. datasets = ht.data.mnist()
  88. train_set_x, train_set_y = datasets[0]
  89. valid_set_x, valid_set_y = datasets[1]
  90. test_set_x, test_set_y = datasets[2]
  91. # train_set_x: (50000, 784), train_set_y: (50000, 10)
  92. # valid_set_x: (10000, 784), valid_set_y: (10000, 10)
  93. # x_shape = (args.batch_size, 784)
  94. # y_shape = (args.batch_size, 10)
  95. elif dataset == 'CIFAR10':
  96. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.normalize_cifar(
  97. num_class=10)
  98. if args.model == "mlp":
  99. train_set_x = train_set_x.reshape(train_set_x.shape[0], -1)
  100. valid_set_x = valid_set_x.reshape(valid_set_x.shape[0], -1)
  101. # train_set_x: (50000, 3, 32, 32), train_set_y: (50000, 10)
  102. # valid_set_x: (10000, 3, 32, 32), valid_set_y: (10000, 10)
  103. # x_shape = (args.batch_size, 3, 32, 32)
  104. # y_shape = (args.batch_size, 10)
  105. elif dataset == 'CIFAR100':
  106. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.normalize_cifar(
  107. num_class=100)
  108. # train_set_x: (50000, 3, 32, 32), train_set_y: (50000, 100)
  109. # valid_set_x: (10000, 3, 32, 32), valid_set_y: (10000, 100)
  110. else:
  111. raise NotImplementedError
  112. # model definition
  113. print_rank0('Building model {}'.format(args.model))
  114. x = ht.dataloader_op([
  115. ht.Dataloader(train_set_x, args.batch_size, 'train'),
  116. ht.Dataloader(valid_set_x, args.batch_size, 'validate'),
  117. ])
  118. y_ = ht.dataloader_op([
  119. ht.Dataloader(train_set_y, args.batch_size, 'train'),
  120. ht.Dataloader(valid_set_y, args.batch_size, 'validate'),
  121. ])
  122. if args.model in ['resnet18', 'resnet34', 'vgg16', 'vgg19'] and args.dataset == 'CIFAR100':
  123. loss, y = model(x, y_, 100)
  124. else:
  125. loss, y = model(x, y_)
  126. train_op = opt.minimize(loss)
  127. eval_nodes = {'train': [loss, y, y_, train_op], 'validate': [loss, y, y_]}
  128. executor = ht.Executor(eval_nodes, ctx=executor_ctx,
  129. comm_mode=args.comm_mode)
  130. n_train_batches = executor.get_batch_num('train')
  131. n_valid_batches = executor.get_batch_num('validate')
  132. # training
  133. print_rank0("Start training loop...")
  134. running_time = 0
  135. for i in range(args.num_epochs + 1):
  136. print_rank0("Epoch %d" % i)
  137. loss_all = 0
  138. batch_num = 0
  139. if args.timing:
  140. start = time()
  141. correct_predictions = []
  142. for minibatch_index in range(n_train_batches):
  143. loss_val, predict_y, y_val, _ = executor.run(
  144. 'train', eval_node_list=[loss, y, y_, train_op])
  145. # Loss for this minibatch
  146. predict_y = predict_y.asnumpy()
  147. y_val = y_val.asnumpy()
  148. loss_all += loss_val.asnumpy()
  149. batch_num += 1
  150. # Predict accuracy for this minibatch
  151. correct_prediction = np.equal(
  152. np.argmax(y_val, 1),
  153. np.argmax(predict_y, 1)).astype(np.float32)
  154. correct_predictions.extend(correct_prediction)
  155. loss_all /= batch_num
  156. accuracy = np.mean(correct_predictions)
  157. print_rank0("Train loss = %f" % loss_all)
  158. print_rank0("Train accuracy = %f" % accuracy)
  159. if args.timing:
  160. end = time()
  161. during_time = end - start
  162. print_rank0("Running time of current epoch = %fs" % (during_time))
  163. if i != 0:
  164. running_time += during_time
  165. if args.validate:
  166. val_loss_all = 0
  167. batch_num = 0
  168. correct_predictions = []
  169. for minibatch_index in range(n_valid_batches):
  170. loss_val, valid_y_predicted, y_val = executor.run(
  171. 'validate', eval_node_list=[loss, y, y_], convert_to_numpy_ret_vals=True)
  172. val_loss_all += loss_val
  173. batch_num += 1
  174. correct_prediction = np.equal(
  175. np.argmax(y_val, 1),
  176. np.argmax(valid_y_predicted, 1)).astype(np.float32)
  177. correct_predictions.extend(correct_prediction)
  178. val_loss_all /= batch_num
  179. accuracy = np.mean(correct_predictions)
  180. print_rank0("Validation loss = %f" % val_loss_all)
  181. print_rank0("Validation accuracy = %f" % accuracy)
  182. print_rank0("*"*50)
  183. print_rank0("Running time of total %d epoch = %fs" %
  184. (args.num_epochs, running_time))
  185. if args.comm_mode in ('AllReduce', 'Hybrid'):
  186. ht.mpi_nccl_finish(comm)

分布式深度学习系统

Contributors (1)