You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_tf_horovod.py 7.2 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. import time
  5. import argparse
  6. from tqdm import tqdm
  7. from sklearn import metrics
  8. import horovod.tensorflow as hvd
  9. def pop_env():
  10. for k in ['https_proxy', 'http_proxy']:
  11. if k in os.environ:
  12. os.environ.pop(k)
  13. pop_env()
  14. # horovodrun -np 8 -H localhost:8 python run_tf_horovod.py --model
  15. # horovodrun -np 8 --start-timeout 300 -H daim116:4,daim117:4 python run_tf_horovod.py --model
  16. # if using multi nodes setting in conda, need to modify /etc/bash.bashrc
  17. # we can also use mpirun (default gloo):
  18. # ../build/_deps/openmpi-build/bin/mpirun -mca btl_tcp_if_include enp97s0f0 --bind-to none --map-by slot\
  19. # -x NCCL_SOCKET_IFNAME=enp97s0f0 -H daim117:8,daim118:8 --allow-run-as-root python run_tf_horovod.py --model
  20. def train_criteo(model, args):
  21. hvd.init()
  22. def get_current_shard(data):
  23. part_size = data.shape[0] // hvd.size()
  24. start = part_size * hvd.rank()
  25. end = start + part_size if hvd.rank() != hvd.size() - \
  26. 1 else data.shape[0]
  27. return data[start:end]
  28. if args.all:
  29. from models.load_data import process_all_criteo_data
  30. dense, sparse, all_labels = process_all_criteo_data()
  31. dense_feature = get_current_shard(dense[0])
  32. sparse_feature = get_current_shard(sparse[0])
  33. labels = get_current_shard(all_labels[0])
  34. val_dense = get_current_shard(dense[1])
  35. val_sparse = get_current_shard(sparse[1])
  36. val_labels = get_current_shard(all_labels[1])
  37. else:
  38. from models.load_data import process_sampled_criteo_data
  39. dense_feature, sparse_feature, labels = process_sampled_criteo_data()
  40. dense_feature = get_current_shard(dense_feature)
  41. sparse_feature = get_current_shard(sparse_feature)
  42. labels = get_current_shard(labels)
  43. batch_size = 128
  44. dense_input = tf.compat.v1.placeholder(tf.float32, [batch_size, 13])
  45. sparse_input = tf.compat.v1.placeholder(tf.int32, [batch_size, 26])
  46. y_ = y_ = tf.compat.v1.placeholder(tf.float32, [batch_size, 1])
  47. loss, y, opt = model(dense_input, sparse_input, y_)
  48. global_step = tf.train.get_or_create_global_step()
  49. # here in DistributedOptimizer by default all tensor are reduced on GPU
  50. # can use device_sparse=xxx, device_dense=xxx to modify
  51. # if using device_sparse='/cpu:0', the performance degrades
  52. train_op = hvd.DistributedOptimizer(
  53. opt).minimize(loss, global_step=global_step)
  54. gpu_options = tf.compat.v1.GPUOptions(
  55. allow_growth=True, visible_device_list=str(hvd.local_rank()))
  56. # here horovod default use gpu to initialize, which will cause OOM
  57. hooks = [hvd.BroadcastGlobalVariablesHook(0, device='/cpu:0')]
  58. sess = tf.compat.v1.train.MonitoredTrainingSession(
  59. hooks=hooks, config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
  60. my_feed_dict = {
  61. dense_input: np.empty(shape=(batch_size, 13)),
  62. sparse_input: np.empty(shape=(batch_size, 26)),
  63. y_: np.empty(shape=(batch_size, 1)),
  64. }
  65. if args.all:
  66. raw_log_file = './logs/tf_hvd_%s_%d.log' % (
  67. args.model, hvd.local_rank())
  68. print('Processing all data, log to', raw_log_file)
  69. log_file = open(raw_log_file, 'w')
  70. iterations = dense_feature.shape[0] // batch_size
  71. total_epoch = 400
  72. start_index = 0
  73. for ep in range(total_epoch):
  74. print("epoch %d" % ep)
  75. st_time = time.time()
  76. train_loss, train_acc, train_auc = [], [], []
  77. for it in tqdm(range(iterations // 10 + (ep % 10 == 9) * (iterations % 10))):
  78. my_feed_dict[dense_input][:] = dense_feature[start_index: start_index + batch_size]
  79. my_feed_dict[sparse_input][:] = sparse_feature[start_index: start_index + batch_size]
  80. my_feed_dict[y_][:] = labels[start_index: start_index+batch_size]
  81. start_index += batch_size
  82. if start_index + batch_size > dense_feature.shape[0]:
  83. start_index = 0
  84. loss_val = sess.run([loss, y, y_, train_op],
  85. feed_dict=my_feed_dict)
  86. pred_val = loss_val[1]
  87. true_val = loss_val[2]
  88. acc_val = np.equal(
  89. true_val,
  90. pred_val > 0.5)
  91. train_loss.append(loss_val[0])
  92. train_acc.append(acc_val)
  93. train_auc.append(metrics.roc_auc_score(true_val, pred_val))
  94. tra_accuracy = np.mean(train_acc)
  95. tra_loss = np.mean(train_loss)
  96. tra_auc = np.mean(train_auc)
  97. en_time = time.time()
  98. train_time = en_time - st_time
  99. printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, train_time: %.4f"\
  100. % (tra_loss, tra_accuracy, tra_auc, train_time)
  101. print(printstr)
  102. log_file.write(printstr + '\n')
  103. log_file.flush()
  104. else:
  105. iterations = dense_feature.shape[0] // batch_size
  106. epoch = 50
  107. for ep in range(epoch):
  108. print('epoch', ep)
  109. if ep == 5:
  110. start = time.time()
  111. ep_st = time.time()
  112. train_loss = []
  113. train_acc = []
  114. for idx in range(iterations):
  115. start_index = idx * batch_size
  116. my_feed_dict[dense_input][:] = dense_feature[start_index: start_index + batch_size]
  117. my_feed_dict[sparse_input][:] = sparse_feature[start_index: start_index + batch_size]
  118. my_feed_dict[y_][:] = labels[start_index: start_index+batch_size]
  119. loss_val = sess.run([loss, y, y_, train_op],
  120. feed_dict=my_feed_dict)
  121. pred_val = loss_val[1]
  122. true_val = loss_val[2]
  123. if pred_val.shape[1] == 1: # for criteo case
  124. acc_val = np.equal(
  125. true_val,
  126. pred_val > 0.5)
  127. else:
  128. acc_val = np.equal(
  129. np.argmax(pred_val, 1),
  130. np.argmax(true_val, 1)).astype(np.float32)
  131. train_loss.append(loss_val[0])
  132. train_acc.append(acc_val)
  133. tra_accuracy = np.mean(train_acc)
  134. tra_loss = np.mean(train_loss)
  135. ep_en = time.time()
  136. print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f"
  137. % (tra_loss, tra_accuracy, ep_en - ep_st))
  138. print('all time:', (time.time() - start))
  139. def main():
  140. parser = argparse.ArgumentParser()
  141. parser.add_argument("--model", type=str, required=True,
  142. help="model to be tested")
  143. parser.add_argument("--all", action="store_true",
  144. help="whether to use all data")
  145. args = parser.parse_args()
  146. raw_model = args.model
  147. import tf_models
  148. model = eval('tf_models.' + raw_model)
  149. dataset = raw_model.split('_')[-1]
  150. print('Model:', raw_model)
  151. train_criteo(model, args)
  152. if __name__ == '__main__':
  153. main()