You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 8.4 kB

4 years ago
3 years ago
4 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import hetu as ht
  2. import models
  3. import os
  4. import numpy as np
  5. import argparse
  6. import json
  7. import logging
  8. from time import time
  9. logging.basicConfig(level=logging.INFO,
  10. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  11. logger = logging.getLogger(__name__)
  12. def print_rank0(msg):
  13. if device_id == 0:
  14. logger.info(msg)
  15. if __name__ == "__main__":
  16. # argument parser
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('--model', type=str, required=True,
  19. help='model to be tested')
  20. parser.add_argument('--dataset', type=str, required=True,
  21. help='dataset to be trained on')
  22. parser.add_argument('--batch-size', type=int,
  23. default=128, help='batch size')
  24. parser.add_argument('--learning-rate', type=float,
  25. default=0.1, help='learning rate')
  26. parser.add_argument('--opt', type=str, default='sgd',
  27. help='optimizer to be used, default sgd; sgd / momentum / adagrad / adam')
  28. parser.add_argument('--num-epochs', type=int,
  29. default=10, help='epoch number')
  30. parser.add_argument('--gpu', type=int, default=0,
  31. help='gpu to be used, -1 means cpu')
  32. parser.add_argument('--validate', action='store_true',
  33. help='whether to use validation')
  34. parser.add_argument('--timing', action='store_true',
  35. help='whether to time the training phase')
  36. parser.add_argument('--comm-mode', default=None, help='communication mode')
  37. args = parser.parse_args()
  38. global device_id
  39. device_id = 0
  40. print_rank0("Training {} on HETU".format(args.model))
  41. if args.comm_mode in ('AllReduce', 'Hybrid'):
  42. comm = ht.wrapped_mpi_nccl_init()
  43. device_id = comm.dev_id
  44. rank = comm.rank
  45. executor_ctx = ht.gpu(device_id % 8) if args.gpu >= 0 else ht.cpu(0)
  46. else:
  47. if args.gpu == -1:
  48. executor_ctx = ht.cpu(0)
  49. print_rank0('Use CPU.')
  50. else:
  51. executor_ctx = ht.gpu(args.gpu)
  52. print_rank0('Use GPU %d.' % args.gpu)
  53. if args.comm_mode in ('PS', 'Hybrid'):
  54. settings_file = open(os.path.join(os.path.abspath(
  55. os.path.dirname(__file__)), 'worker_conf%d.json' % args.gpu))
  56. settings = json.load(settings_file)
  57. for key in settings:
  58. if type(settings[key]) == str:
  59. os.environ[key] = settings[key]
  60. else:
  61. os.environ[key] = str(settings[key]) # type is str
  62. assert args.model in ['alexnet', 'cnn_3_layers', 'lenet', 'logreg', 'lstm', 'mlp', 'resnet18', 'resnet34', 'rnn', 'vgg16', 'vgg19'], \
  63. 'Model not supported!'
  64. model = eval('models.' + args.model)
  65. assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
  66. dataset = args.dataset
  67. assert args.opt in ['sgd', 'momentum', 'nesterov',
  68. 'adagrad', 'adam'], 'Optimizer not supported!'
  69. if args.opt == 'sgd':
  70. print_rank0('Use SGD Optimizer.')
  71. opt = ht.optim.SGDOptimizer(learning_rate=args.learning_rate)
  72. elif args.opt == 'momentum':
  73. print_rank0('Use Momentum Optimizer.')
  74. opt = ht.optim.MomentumOptimizer(learning_rate=args.learning_rate)
  75. elif args.opt == 'nesterov':
  76. print_rank0('Use Nesterov Momentum Optimizer.')
  77. opt = ht.optim.MomentumOptimizer(
  78. learning_rate=args.learning_rate, nesterov=True)
  79. elif args.opt == 'adagrad':
  80. print_rank0('Use AdaGrad Optimizer.')
  81. opt = ht.optim.AdaGradOptimizer(
  82. learning_rate=args.learning_rate, initial_accumulator_value=0.1)
  83. else:
  84. print_rank0('Use Adam Optimizer.')
  85. opt = ht.optim.AdamOptimizer(learning_rate=args.learning_rate)
  86. # data loading
  87. print_rank0('Loading %s data...' % dataset)
  88. if dataset == 'MNIST':
  89. datasets = ht.data.mnist()
  90. train_set_x, train_set_y = datasets[0]
  91. valid_set_x, valid_set_y = datasets[1]
  92. test_set_x, test_set_y = datasets[2]
  93. # train_set_x: (50000, 784), train_set_y: (50000, 10)
  94. # valid_set_x: (10000, 784), valid_set_y: (10000, 10)
  95. # x_shape = (args.batch_size, 784)
  96. # y_shape = (args.batch_size, 10)
  97. elif dataset == 'CIFAR10':
  98. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.normalize_cifar(
  99. num_class=10)
  100. if args.model == "mlp":
  101. train_set_x = train_set_x.reshape(train_set_x.shape[0], -1)
  102. valid_set_x = valid_set_x.reshape(valid_set_x.shape[0], -1)
  103. # train_set_x: (50000, 3, 32, 32), train_set_y: (50000, 10)
  104. # valid_set_x: (10000, 3, 32, 32), valid_set_y: (10000, 10)
  105. # x_shape = (args.batch_size, 3, 32, 32)
  106. # y_shape = (args.batch_size, 10)
  107. elif dataset == 'CIFAR100':
  108. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.normalize_cifar(
  109. num_class=100)
  110. # train_set_x: (50000, 3, 32, 32), train_set_y: (50000, 100)
  111. # valid_set_x: (10000, 3, 32, 32), valid_set_y: (10000, 100)
  112. else:
  113. raise NotImplementedError
  114. # model definition
  115. print_rank0('Building model {}'.format(args.model))
  116. x = ht.dataloader_op([
  117. ht.Dataloader(train_set_x, args.batch_size, 'train'),
  118. ht.Dataloader(valid_set_x, args.batch_size, 'validate'),
  119. ])
  120. y_ = ht.dataloader_op([
  121. ht.Dataloader(train_set_y, args.batch_size, 'train'),
  122. ht.Dataloader(valid_set_y, args.batch_size, 'validate'),
  123. ])
  124. if args.model in ['resnet18', 'resnet34', 'vgg16', 'vgg19'] and args.dataset == 'CIFAR100':
  125. loss, y = model(x, y_, 100)
  126. else:
  127. loss, y = model(x, y_)
  128. train_op = opt.minimize(loss)
  129. eval_nodes = {'train': [loss, y, y_, train_op], 'validate': [loss, y, y_]}
  130. executor = ht.Executor(eval_nodes, ctx=executor_ctx,
  131. comm_mode=args.comm_mode)
  132. n_train_batches = executor.get_batch_num('train')
  133. n_valid_batches = executor.get_batch_num('validate')
  134. # training
  135. print_rank0("Start training loop...")
  136. running_time = 0
  137. for i in range(args.num_epochs + 1):
  138. print_rank0("Epoch %d" % i)
  139. loss_all = 0
  140. batch_num = 0
  141. if args.timing:
  142. start = time()
  143. correct_predictions = []
  144. for minibatch_index in range(n_train_batches):
  145. loss_val, predict_y, y_val, _ = executor.run(
  146. 'train', eval_node_list=[loss, y, y_, train_op])
  147. # Loss for this minibatch
  148. predict_y = predict_y.asnumpy()
  149. y_val = y_val.asnumpy()
  150. loss_all += loss_val.asnumpy()
  151. batch_num += 1
  152. # Predict accuracy for this minibatch
  153. correct_prediction = np.equal(
  154. np.argmax(y_val, 1),
  155. np.argmax(predict_y, 1)).astype(np.float32)
  156. correct_predictions.extend(correct_prediction)
  157. loss_all /= batch_num
  158. accuracy = np.mean(correct_predictions)
  159. print_rank0("Train loss = %f" % loss_all)
  160. print_rank0("Train accuracy = %f" % accuracy)
  161. if args.timing:
  162. end = time()
  163. during_time = end - start
  164. print_rank0("Running time of current epoch = %fs" % (during_time))
  165. if i != 0:
  166. running_time += during_time
  167. if args.validate:
  168. val_loss_all = 0
  169. batch_num = 0
  170. correct_predictions = []
  171. for minibatch_index in range(n_valid_batches):
  172. loss_val, valid_y_predicted, y_val = executor.run(
  173. 'validate', eval_node_list=[loss, y, y_], convert_to_numpy_ret_vals=True)
  174. val_loss_all += loss_val
  175. batch_num += 1
  176. correct_prediction = np.equal(
  177. np.argmax(y_val, 1),
  178. np.argmax(valid_y_predicted, 1)).astype(np.float32)
  179. correct_predictions.extend(correct_prediction)
  180. val_loss_all /= batch_num
  181. accuracy = np.mean(correct_predictions)
  182. print_rank0("Validation loss = %f" % val_loss_all)
  183. print_rank0("Validation accuracy = %f" % accuracy)
  184. print_rank0("*"*50)
  185. print_rank0("Running time of total %d epoch = %fs" %
  186. (args.num_epochs, running_time))