You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_launch_worker.py 10 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import tensorflow as tf
  2. import tf_models
  3. import hetu as ht
  4. import numpy as np
  5. import argparse
  6. import json
  7. from time import time
  8. import os
  9. import logging
  10. logging.basicConfig(level=logging.INFO,
  11. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  12. logger = logging.getLogger(__name__)
  13. def print_rank0(msg):
  14. if task_id % 8 == 0:
  15. logger.info(msg)
  16. def pop_env():
  17. for k in ['https_proxy', 'http_proxy']:
  18. if k in os.environ:
  19. os.environ.pop(k)
  20. pop_env()
  21. if __name__ == "__main__":
  22. # argument parser
  23. parser = argparse.ArgumentParser()
  24. parser.add_argument('--model', type=str, required=True,
  25. help='model to be tested')
  26. parser.add_argument('--dataset', type=str, required=True,
  27. help='dataset to be trained on')
  28. parser.add_argument('--batch-size', type=int,
  29. default=128, help='batch size')
  30. parser.add_argument('--learning-rate', type=float,
  31. default=0.1, help='learning rate')
  32. parser.add_argument('--opt', type=str, default='sgd',
  33. help='optimizer to be used, default sgd; sgd / momentum / adagrad / adam')
  34. parser.add_argument('--num-epochs', type=int,
  35. default=20, help='epoch number')
  36. parser.add_argument('--gpu', type=int, default=0,
  37. help='gpu to be used, -1 means cpu')
  38. parser.add_argument('--validate', action='store_true',
  39. help='whether to use validation')
  40. parser.add_argument('--timing', action='store_true',
  41. help='whether to time the training phase')
  42. parser.add_argument("--rank", type=int, required=True,
  43. help="rank of process")
  44. parser.add_argument(
  45. "--config", type=str, default='./settings/tf_dist_s1_w2.json', help="config file path")
  46. args = parser.parse_args()
  47. global task_id
  48. task_id = int(args.rank)
  49. print_rank0("task id %d" % (task_id))
  50. raw_config = args.config
  51. if args.gpu == -1:
  52. device = '/job:worker/task:%d/cpu:0' % (task_id)
  53. print_rank0('Use CPU.')
  54. else:
  55. device = "/job:worker/task:%d/gpu:%d" % (task_id, args.gpu)
  56. print_rank0('Use GPU %d.' % args.gpu)
  57. config = json.load(open(raw_config))
  58. cluster = tf.train.ClusterSpec(config)
  59. assert args.model in ['tf_cnn_3_layers', 'tf_lenet', 'tf_logreg', 'tf_lstm', 'tf_mlp', 'tf_resnet18', 'tf_resnet34', 'tf_rnn', 'tf_vgg16', 'tf_vgg19'], \
  60. 'Model not supported now.'
  61. model = eval('tf_models.' + args.model)
  62. assert args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100', 'ImageNet']
  63. dataset = args.dataset
  64. assert args.opt in ['sgd', 'momentum', 'nesterov',
  65. 'adagrad', 'adam'], 'Optimizer not supported!'
  66. if args.opt == 'sgd':
  67. print_rank0('Use SGD Optimizer.')
  68. opt = tf.train.GradientDescentOptimizer(
  69. learning_rate=args.learning_rate)
  70. elif args.opt == 'momentum':
  71. print_rank0('Use Momentum Optimizer.')
  72. opt = tf.train.MomentumOptimizer(
  73. learning_rate=args.learning_rate, momentum=0.9)
  74. elif args.opt == 'nesterov':
  75. print_rank0('Use Nesterov Momentum Optimizer.')
  76. opt = tf.train.MomentumOptimizer(
  77. learning_rate=args.learning_rate, momentum=0.9, use_nesterov=True)
  78. elif args.opt == 'adagrad':
  79. print_rank0('Use AdaGrad Optimizer.')
  80. opt = tf.train.AdagradOptimizer(learning_rate=args.learning_rate)
  81. else:
  82. print_rank0('Use Adam Optimizer.')
  83. opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
  84. with tf.device(
  85. tf.compat.v1.train.replica_device_setter(
  86. worker_device=device,
  87. cluster=cluster)):
  88. # data loading
  89. print_rank0('Loading %s data...' % dataset)
  90. if dataset == 'MNIST':
  91. datasets = ht.data.mnist()
  92. train_set_x, train_set_y = datasets[0]
  93. valid_set_x, valid_set_y = datasets[1]
  94. test_set_x, test_set_y = datasets[2]
  95. n_train_batches = train_set_x.shape[0] // args.batch_size
  96. n_valid_batches = valid_set_x.shape[0] // args.batch_size
  97. # train_set_x: (50000, 784), train_set_y: (50000,)
  98. # valid_set_x: (10000, 784), valid_set_y: (10000,)
  99. elif dataset == 'CIFAR10':
  100. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.tf_normalize_cifar(
  101. num_class=10)
  102. n_train_batches = train_set_x.shape[0] // args.batch_size
  103. n_valid_batches = valid_set_x.shape[0] // args.batch_size
  104. if args.model == "tf_mlp":
  105. train_set_x = train_set_x.reshape(train_set_x.shape[0], -1)
  106. valid_set_x = valid_set_x.reshape(valid_set_x.shape[0], -1)
  107. # train_set_x: (50000, 32, 32, 3), train_set_y: (50000,)
  108. # valid_set_x: (10000, 32, 32, 3), valid_set_y: (10000,)
  109. elif dataset == 'CIFAR100':
  110. train_set_x, train_set_y, valid_set_x, valid_set_y = ht.data.tf_normalize_cifar(
  111. num_class=100)
  112. n_train_batches = train_set_x.shape[0] // args.batch_size
  113. n_valid_batches = valid_set_x.shape[0] // args.batch_size
  114. # train_set_x: (50000, 32, 32, 3), train_set_y: (50000,)
  115. # valid_set_x: (10000, 32, 32, 3), valid_set_y: (10000,)
  116. else:
  117. raise NotImplementedError
  118. if dataset == 'MNIST':
  119. x = tf.placeholder(dtype=tf.float32, shape=(None, 784), name='x')
  120. y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10), name='y_')
  121. loss, y = model(x, y_)
  122. elif dataset == 'CIFAR10':
  123. if args.model == "tf_mlp":
  124. x = tf.placeholder(
  125. dtype=tf.float32, shape=(None, 3072), name='x')
  126. y_ = tf.placeholder(
  127. dtype=tf.float32, shape=(None, 10), name='y_')
  128. else:
  129. x = tf.placeholder(dtype=tf.float32, shape=(
  130. None, 32, 32, 3), name='x')
  131. y_ = tf.placeholder(
  132. dtype=tf.float32, shape=(None, 10), name='y_')
  133. loss, y = model(x, y_, 10)
  134. elif dataset == 'CIFAR100':
  135. x = tf.placeholder(dtype=tf.float32, shape=(
  136. None, 32, 32, 3), name='x')
  137. y_ = tf.placeholder(dtype=tf.float32, shape=(None, 100), name='y_')
  138. loss, y = model(x, y_, 100)
  139. train_op = opt.minimize(loss)
  140. server = tf.train.Server(
  141. cluster, job_name="worker", task_index=task_id)
  142. init = tf.compat.v1.global_variables_initializer()
  143. sv = tf.train.Supervisor(
  144. is_chief=(task_id == 0),
  145. init_op=init,
  146. recovery_wait_secs=1)
  147. sess_config = tf.compat.v1.ConfigProto(
  148. allow_soft_placement=True,
  149. log_device_placement=False,
  150. device_filters=["/job:ps",
  151. "/job:worker/task:%d" % task_id])
  152. sess = sv.prepare_or_wait_for_session(
  153. server.target, config=sess_config)
  154. sess.run(init)
  155. # training
  156. print_rank0("Start training loop...")
  157. running_time = 0
  158. for i in range(args.num_epochs + 1):
  159. print_rank0("Epoch %d" % i)
  160. loss_all = 0
  161. batch_num = 0
  162. if args.timing:
  163. start = time()
  164. correct_predictions = []
  165. for minibatch_index in range(n_train_batches):
  166. minibatch_start = minibatch_index * args.batch_size
  167. minibatch_end = (minibatch_index + 1) * args.batch_size
  168. x_val = train_set_x[minibatch_start:minibatch_end]
  169. y_val = train_set_y[minibatch_start:minibatch_end]
  170. loss_val, predict_y, _ = sess.run([loss, y, train_op],
  171. feed_dict={x: x_val, y_: y_val})
  172. correct_prediction = np.equal(
  173. np.argmax(y_val, 1),
  174. np.argmax(predict_y, 1)).astype(np.float32)
  175. correct_predictions.extend(correct_prediction)
  176. batch_num += 1
  177. loss_all += loss_val
  178. loss_all /= batch_num
  179. accuracy = np.mean(correct_predictions)
  180. print_rank0("Train loss = %f" % loss_all)
  181. print_rank0("Train accuracy = %f" % accuracy)
  182. if args.timing:
  183. end = time()
  184. print_rank0("Running time of current epoch = %fs" %
  185. (end - start))
  186. if i != 0:
  187. running_time += (end - start)
  188. if args.validate:
  189. val_loss_all = 0
  190. batch_num = 0
  191. correct_predictions = []
  192. for minibatch_index in range(n_valid_batches):
  193. minibatch_start = minibatch_index * args.batch_size
  194. minibatch_end = (minibatch_index + 1) * args.batch_size
  195. valid_x_val = valid_set_x[minibatch_start:minibatch_end]
  196. valid_y_val = valid_set_y[minibatch_start:minibatch_end]
  197. loss_val, valid_y_predicted = sess.run([loss, y],
  198. feed_dict={x: valid_x_val, y_: valid_y_val})
  199. correct_prediction = np.equal(
  200. np.argmax(valid_y_val, 1),
  201. np.argmax(valid_y_predicted, 1)).astype(np.float32)
  202. correct_predictions.extend(correct_prediction)
  203. val_loss_all += loss_all
  204. batch_num += 1
  205. val_loss_all /= batch_num
  206. accuracy = np.mean(correct_predictions)
  207. print_rank0("Validation loss = %f" % val_loss_all)
  208. print_rank0("Validation accuracy = %f" % accuracy)
  209. print_rank0("*"*50)
  210. print_rank0("Running time of total %d epoch = %fs" %
  211. (args.num_epochs, running_time))

分布式深度学习系统

Contributors (1)