You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_tf_horovod.py 13 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. import tf_models
  5. import time
  6. import argparse
  7. from tqdm import tqdm
  8. from sklearn import metrics
  9. import horovod.tensorflow as hvd
  10. import hetu as ht
  11. import logging
  12. logging.basicConfig(level=logging.INFO,
  13. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  14. logger = logging.getLogger(__name__)
  15. def print_rank0(msg):
  16. if rank % 8 == 0:
  17. logger.info(msg)
  18. def pop_env():
  19. for k in ['https_proxy', 'http_proxy']:
  20. if k in os.environ:
  21. os.environ.pop(k)
  22. pop_env()
  23. # horovodrun -np 8 -H localhost:8 python run_tf_horovod.py --model
  24. # horovodrun -np 8 --start-timeout 300 -H daim116:4,daim117:4 python run_tf_horovod.py --model
  25. # horovodrun -np 16 --start-timeout 3000 -H daim116:8,daim117:8
  26. # python /home/public/nxn/Athena-master/examples/cnn/run_tf_horovod.py --model tf_rnn
  27. # if using multi nodes setting in conda, need to modify /etc/bash.bashrc
  28. # we can also use mpirun (default gloo):
  29. # ../build/_deps/openmpi-build/bin/mpirun -mca btl_tcp_if_include enp97s0f0 --bind-to none --map-by slot\
  30. # -x NCCL_SOCKET_IFNAME=enp97s0f0 -H daim117:8,daim118:8 --allow-run-as-root python run_tf_horovod.py --model
  31. '''
  32. def train(model, args):
  33. hvd.init()
  34. def get_current_shard(data):
  35. part_size = data.shape[0] // hvd.size()
  36. start = part_size * hvd.rank()
  37. end = start + part_size if hvd.rank() != hvd.size() - 1 else data.shape[0]
  38. return data[start:end]
  39. batch_size = 128
  40. if args.model == 'tf_resnet34':
  41. train_images, train_labels, test_images,\
  42. test_labels = ht.data.tf_normalize_cifar10()
  43. x = tf.compat.v1.placeholder(tf.float32, [batch_size, 32, 32, 3])
  44. y_ = y_ = tf.compat.v1.placeholder(tf.float32, [batch_size, 10])
  45. else:
  46. datasets = ht.data.mnist()
  47. train_images, train_labels = datasets[0]
  48. test_images, test_labels = datasets[2]
  49. x = tf.compat.v1.placeholder(tf.float32, [batch_size, 784])
  50. y_ = y_ = tf.compat.v1.placeholder(tf.float32, [batch_size, 10])
  51. n_train_batches = train_images.shape[0] // batch_size
  52. loss, y = model(x, y_)
  53. opt = tf.train.GradientDescentOptimizer(learning_rate=0.01)
  54. global_step = tf.train.get_or_create_global_step()
  55. # here in DistributedOptimizer by default all tensor are reduced on GPU
  56. # can use device_sparse=xxx, device_dense=xxx to modify
  57. # if using device_sparse='/cpu:0', the performance degrades
  58. train_op = hvd.DistributedOptimizer(opt).minimize(loss, global_step=global_step)
  59. gpu_options = tf.compat.v1.GPUOptions(allow_growth=True, visible_device_list=str(hvd.local_rank()))
  60. # here horovod default use gpu to initialize, which will cause OOM
  61. hooks = [hvd.BroadcastGlobalVariablesHook(0, device='/cpu:0')]
  62. sess = tf.compat.v1.train.MonitoredTrainingSession(hooks=hooks, config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
  63. iterations = train_images.shape[0] // batch_size
  64. total_epoch = 10
  65. start_index = 0
  66. total_time = 0
  67. for ep in range(total_epoch + 1):
  68. print("epoch %d" % ep)
  69. st_time = time.time()
  70. train_loss, train_acc = [], []
  71. for it in range(n_train_batches):
  72. x_val = train_images[start_index: start_index + batch_size]
  73. y_val = train_labels[start_index : start_index+batch_size]
  74. start_index += batch_size
  75. if start_index + batch_size > train_images.shape[0]:
  76. start_index = 0
  77. loss_val = sess.run([loss, y, y_, train_op], feed_dict={x:x_val, y_:y_val})
  78. pred_val = loss_val[1]
  79. true_val = loss_val[2]
  80. acc_val = np.equal(
  81. true_val,
  82. pred_val > 0.5)
  83. train_loss.append(loss_val[0])
  84. train_acc.append(acc_val)
  85. tra_accuracy = np.mean(train_acc)
  86. tra_loss = np.mean(train_loss)
  87. en_time = time.time()
  88. train_time = en_time - st_time
  89. if ep != 0:
  90. total_time += train_time
  91. printstr = "train_loss: %.4f, train_acc: %.4f, train_time: %.4f"\
  92. % (tra_loss, tra_accuracy, train_time)
  93. print("training time:", total_time)
  94. def main():
  95. parser = argparse.ArgumentParser()
  96. parser.add_argument("--model", type=str, required=True, help="model to be tested")
  97. parser.add_argument("--all", action="store_true", help="whether to use all data")
  98. args = parser.parse_args()
  99. raw_model = args.model
  100. import tf_models
  101. model = eval('tf_models.' + raw_model)
  102. print('Model:', raw_model)
  103. train(model, args)
  104. if __name__ == '__main__':
  105. main()
  106. '''
  107. if __name__ == "__main__":
  108. parser = argparse.ArgumentParser()
  109. parser.add_argument('--model', type=str, required=True,
  110. help='model to be tested')
  111. parser.add_argument('--dataset', type=str, required=True,
  112. help='dataset to be trained on')
  113. parser.add_argument('--batch-size', type=int,
  114. default=128, help='batch size')
  115. parser.add_argument('--learning-rate', type=float,
  116. default=0.1, help='learning rate')
  117. parser.add_argument('--opt', type=str, default='sgd',
  118. help='optimizer to be used, default sgd; sgd / momentum / adagrad / adam')
  119. parser.add_argument('--num-epochs', type=int,
  120. default=20, help='epoch number')
  121. parser.add_argument('--validate', action='store_true',
  122. help='whether to use validation')
  123. parser.add_argument('--timing', action='store_true',
  124. help='whether to time the training phase')
  125. args = parser.parse_args()
  126. hvd.init()
  127. global rank
  128. rank = hvd.rank()
  129. assert args.model in ['tf_cnn_3_layers', 'tf_lenet', 'tf_logreg', 'tf_lstm', 'tf_mlp', 'tf_resnet18', 'tf_resnet34', 'tf_rnn', 'tf_vgg16', 'tf_vgg19'], \
  130. 'Model not supported now.'
  131. model = eval('tf_models.' + args.model)
  132. assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
  133. dataset = args.dataset
  134. assert args.opt in ['sgd', 'momentum', 'nesterov',
  135. 'adagrad', 'adam'], 'Optimizer not supported!'
  136. if args.opt == 'sgd':
  137. print_rank0('Use SGD Optimizer.')
  138. opt = tf.train.GradientDescentOptimizer(
  139. learning_rate=args.learning_rate)
  140. elif args.opt == 'momentum':
  141. print_rank0('Use Momentum Optimizer.')
  142. opt = tf.train.MomentumOptimizer(
  143. learning_rate=args.learning_rate, momentum=0.9)
  144. elif args.opt == 'nesterov':
  145. print_rank0('Use Nesterov Momentum Optimizer.')
  146. opt = tf.train.MomentumOptimizer(
  147. learning_rate=args.learning_rate, momentum=0.9, use_nesterov=True)
  148. elif args.opt == 'adagrad':
  149. print_rank0('Use AdaGrad Optimizer.')
  150. opt = tf.train.AdagradOptimizer(learning_rate=args.learning_rate)
  151. else:
  152. print_rank0('Use Adam Optimizer.')
  153. opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
  154. if dataset == 'MNIST':
  155. datasets = ht.data.mnist()
  156. train_set_x, train_set_y = datasets[0]
  157. valid_set_x, valid_set_y = datasets[1]
  158. test_set_x, test_set_y = datasets[2]
  159. n_train_batches = train_set_x.shape[0] // args.batch_size
  160. n_valid_batches = valid_set_x.shape[0] // args.batch_size
  161. # train_set_x: (50000, 784), train_set_y: (50000,)
  162. # valid_set_x: (10000, 784), valid_set_y: (10000,)
  163. elif dataset == 'CIFAR10':
  164. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.tf_normalize_cifar(
  165. num_class=10)
  166. n_train_batches = train_set_x.shape[0] // args.batch_size
  167. n_valid_batches = valid_set_x.shape[0] // args.batch_size
  168. if args.model == "tf_mlp":
  169. train_set_x = train_set_x.reshape(train_set_x.shape[0], -1)
  170. valid_set_x = valid_set_x.reshape(valid_set_x.shape[0], -1)
  171. # train_set_x: (50000, 32, 32, 3), train_set_y: (50000,)
  172. # valid_set_x: (10000, 32, 32, 3), valid_set_y: (10000,)
  173. elif dataset == 'CIFAR100':
  174. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.tf_normalize_cifar(
  175. num_class=100)
  176. n_train_batches = train_set_x.shape[0] // args.batch_size
  177. n_valid_batches = valid_set_x.shape[0] // args.batch_size
  178. # train_set_x: (50000, 32, 32, 3), train_set_y: (50000,)
  179. # valid_set_x: (10000, 32, 32, 3), valid_set_y: (10000,)
  180. else:
  181. raise NotImplementedError
  182. if dataset == 'MNIST':
  183. x = tf.compat.v1.placeholder(
  184. dtype=tf.float32, shape=(None, 784), name='x')
  185. y_ = tf.compat.v1.placeholder(
  186. dtype=tf.float32, shape=(None, 10), name='y_')
  187. loss, y = model(x, y_)
  188. elif dataset == 'CIFAR10':
  189. if args.model == "tf_mlp":
  190. x = tf.compat.v1.placeholder(
  191. dtype=tf.float32, shape=(None, 3072), name='x')
  192. y_ = tf.compat.v1.placeholder(
  193. dtype=tf.float32, shape=(None, 10), name='y_')
  194. else:
  195. x = tf.compat.v1.placeholder(
  196. dtype=tf.float32, shape=(None, 32, 32, 3), name='x')
  197. y_ = tf.compat.v1.placeholder(
  198. dtype=tf.float32, shape=(None, 10), name='y_')
  199. loss, y = model(x, y_, 10)
  200. elif dataset == 'CIFAR100':
  201. x = tf.compat.v1.placeholder(
  202. dtype=tf.float32, shape=(None, 32, 32, 3), name='x')
  203. y_ = tf.compat.v1.placeholder(
  204. dtype=tf.float32, shape=(None, 100), name='y_')
  205. loss, y = model(x, y_, 100)
  206. global_step = tf.train.get_or_create_global_step()
  207. # here in DistributedOptimizer by default all tensor are reduced on GPU
  208. # can use device_sparse=xxx, device_dense=xxx to modify
  209. # if using device_sparse='/cpu:0', the performance degrades
  210. train_op = hvd.DistributedOptimizer(
  211. opt).minimize(loss, global_step=global_step)
  212. gpu_options = tf.compat.v1.GPUOptions(
  213. allow_growth=True, visible_device_list=str(hvd.local_rank()))
  214. # here horovod default use gpu to initialize, which will cause OOM
  215. hooks = [hvd.BroadcastGlobalVariablesHook(0, device='/cpu:0')]
  216. sess = tf.compat.v1.train.MonitoredTrainingSession(
  217. hooks=hooks, config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
  218. # sess.run(tf.compat.v1.global_variables_initializer())
  219. # training
  220. print_rank0("Start training loop...")
  221. running_time = 0
  222. for i in range(args.num_epochs + 1):
  223. print_rank0("Epoch %d" % i)
  224. loss_all = 0
  225. batch_num = 0
  226. if args.timing:
  227. start = time.time()
  228. correct_predictions = []
  229. for minibatch_index in range(n_train_batches):
  230. minibatch_start = minibatch_index * args.batch_size
  231. minibatch_end = (minibatch_index + 1) * args.batch_size
  232. x_val = train_set_x[minibatch_start:minibatch_end]
  233. y_val = train_set_y[minibatch_start:minibatch_end]
  234. loss_val, predict_y, _ = sess.run([loss, y, train_op],
  235. feed_dict={x: x_val, y_: y_val})
  236. correct_prediction = np.equal(
  237. np.argmax(y_val, 1),
  238. np.argmax(predict_y, 1)).astype(np.float32)
  239. correct_predictions.extend(correct_prediction)
  240. batch_num += 1
  241. loss_all += loss_val
  242. loss_all /= batch_num
  243. accuracy = np.mean(correct_predictions)
  244. print_rank0("Train loss = %f" % loss_all)
  245. print_rank0("Train accuracy = %f" % accuracy)
  246. if args.timing:
  247. end = time.time()
  248. print_rank0("Running time of current epoch = %fs" % (end - start))
  249. if i != 0:
  250. running_time += (end - start)
  251. if args.validate:
  252. val_loss_all = 0
  253. batch_num = 0
  254. correct_predictions = []
  255. for minibatch_index in range(n_valid_batches):
  256. minibatch_start = minibatch_index * args.batch_size
  257. minibatch_end = (minibatch_index + 1) * args.batch_size
  258. valid_x_val = valid_set_x[minibatch_start:minibatch_end]
  259. valid_y_val = valid_set_y[minibatch_start:minibatch_end]
  260. loss_val, valid_y_predicted = sess.run([loss, y],
  261. feed_dict={x: valid_x_val, y_: valid_y_val})
  262. correct_prediction = np.equal(
  263. np.argmax(valid_y_val, 1),
  264. np.argmax(valid_y_predicted, 1)).astype(np.float32)
  265. correct_predictions.extend(correct_prediction)
  266. val_loss_all += loss_all
  267. batch_num += 1
  268. val_loss_all /= batch_num
  269. accuracy = np.mean(correct_predictions)
  270. print_rank0("Validation loss = %f" % val_loss_all)
  271. print_rank0("Validation accuracy = %f" % accuracy)
  272. print_rank0("*"*50)
  273. print_rank0("Running time of total %d epoch = %fs" %
  274. (args.num_epochs, running_time))