You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_tf_parallax.py 8.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. import time
  5. import argparse
  6. from tqdm import tqdm
  7. from sklearn import metrics
  8. from autodist import AutoDist
  9. from autodist.resource_spec import ResourceSpec
  10. from autodist.strategy import PS, PSLoadBalancing, PartitionedPS, AllReduce, Parallax
  11. from autodist.strategy.base import Strategy
  12. from autodist.kernel.common.utils import get_op_name
  13. from tensorflow.python.framework import ops
  14. def pop_env():
  15. for k in ['https_proxy', 'http_proxy']:
  16. if k in os.environ:
  17. os.environ.pop(k)
  18. pop_env()
  19. # Please DO NOT modify /etc/bash.bashrc to activate conda environment.
  20. # Use python_venv in spec yml file instead.
  21. # Use absolute path of python file.
  22. # Here we use the tf native partitioner instead of autodist's PartitionPS.
  23. class Parallaxx(PSLoadBalancing, AllReduce):
  24. """
  25. Modify original parallax to remove replica on CPUs.
  26. """
  27. def __init__(self, chunk_size=128, local_proxy_variable=False, sync=True, staleness=0):
  28. PSLoadBalancing.__init__(self, local_proxy_variable, sync, staleness)
  29. AllReduce.__init__(self, chunk_size)
  30. # pylint: disable=attribute-defined-outside-init
  31. def build(self, graph_item, resource_spec):
  32. """Generate the strategy."""
  33. expr = Strategy()
  34. # For each variable, generate variable synchronizer config
  35. expr.graph_config.replicas.extend(
  36. [k for k, v in resource_spec.gpu_devices])
  37. reduction_device_names = [k for k, _ in resource_spec.cpu_devices]
  38. self.loads = {ps: 0.0 for ps in reduction_device_names}
  39. # Generate node config
  40. node_config = []
  41. for idx, var in enumerate(graph_item.trainable_var_op_to_var.values()):
  42. var_op_name = get_op_name(var.name)
  43. grad, _, _ = graph_item.var_op_name_to_grad_info[var_op_name]
  44. if isinstance(grad, ops.Tensor): # this is a dense variable
  45. group_id = idx // self.chunk_size
  46. config = self._gen_all_reduce_node_config(
  47. var.name, group=group_id)
  48. else: # sparse updates
  49. # For Parallax Strategy, all PS vars are sparse so we don't use a proxy.
  50. # Sparse variables are likely larger, so keeping copies would be costlier,
  51. # and usually each device only requires a small part of the overall variable.
  52. config = self._gen_ps_node_config(
  53. var,
  54. # For Parallax Strategy, all PS vars are sparse which does not need proxy.
  55. False,
  56. self._sync,
  57. self._staleness
  58. )
  59. node_config.append(config)
  60. expr.node_config.extend(node_config)
  61. return expr
  62. def train_criteo(model, args):
  63. resource_spec_file = os.path.join(os.path.dirname(
  64. __file__), 'settings', 'plx_local_spec.yml')
  65. autodist = AutoDist(resource_spec_file, Parallaxx())
  66. respec = ResourceSpec(resource_spec_file)
  67. if args.all:
  68. from models.load_data import process_all_criteo_data
  69. dense, sparse, all_labels = process_all_criteo_data()
  70. dense_feature, val_dense = dense
  71. sparse_feature, val_sparse = sparse
  72. labels, val_labels = all_labels
  73. else:
  74. from models.load_data import process_sampled_criteo_data
  75. dense_feature, sparse_feature, labels = process_sampled_criteo_data()
  76. # autodist will split the feeding data
  77. batch_size = 128
  78. with tf.Graph().as_default() as g, autodist.scope():
  79. dense_input = tf.compat.v1.placeholder(tf.float32, [batch_size, 13])
  80. sparse_input = tf.compat.v1.placeholder(tf.int32, [batch_size, 26])
  81. y_ = y_ = tf.compat.v1.placeholder(tf.float32, [batch_size, 1])
  82. embed_partitioner = tf.fixed_size_partitioner(
  83. len(respec.nodes), 0) if len(respec.nodes) > 1 else None
  84. loss, y, opt = model(dense_input, sparse_input,
  85. y_, embed_partitioner, False)
  86. train_op = opt.minimize(loss)
  87. sess = autodist.create_distributed_session()
  88. my_feed_dict = {
  89. dense_input: np.empty(shape=(batch_size, 13)),
  90. sparse_input: np.empty(shape=(batch_size, 26)),
  91. y_: np.empty(shape=(batch_size, 1)),
  92. }
  93. if args.all:
  94. raw_log_file = os.path.join(os.path.split(os.path.abspath(__file__))[
  95. 0], 'logs', 'tf_plx_%s.log' % (args.model))
  96. print('Processing all data, log to', raw_log_file)
  97. log_file = open(raw_log_file, 'w')
  98. iterations = dense_feature.shape[0] // batch_size
  99. total_epoch = 11
  100. start_index = 0
  101. for ep in range(total_epoch):
  102. print("epoch %d" % ep)
  103. st_time = time.time()
  104. train_loss, train_acc, train_auc = [], [], []
  105. for it in tqdm(range(iterations // 10 + (ep % 10 == 9) * (iterations % 10))):
  106. my_feed_dict[dense_input][:] = dense_feature[start_index: start_index + batch_size]
  107. my_feed_dict[sparse_input][:] = sparse_feature[start_index: start_index + batch_size]
  108. my_feed_dict[y_][:] = labels[start_index: start_index+batch_size]
  109. start_index += batch_size
  110. if start_index + batch_size > dense_feature.shape[0]:
  111. start_index = 0
  112. loss_val = sess.run(
  113. [loss, y, y_, train_op], feed_dict=my_feed_dict)
  114. pred_val = loss_val[1]
  115. true_val = loss_val[2]
  116. acc_val = np.equal(
  117. true_val,
  118. pred_val > 0.5)
  119. train_loss.append(loss_val[0])
  120. train_acc.append(acc_val)
  121. train_auc.append(metrics.roc_auc_score(true_val, pred_val))
  122. tra_accuracy = np.mean(train_acc)
  123. tra_loss = np.mean(train_loss)
  124. tra_auc = np.mean(train_auc)
  125. en_time = time.time()
  126. train_time = en_time - st_time
  127. printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, train_time: %.4f"\
  128. % (tra_loss, tra_accuracy, tra_auc, train_time)
  129. print(printstr)
  130. log_file.write(printstr + '\n')
  131. log_file.flush()
  132. else:
  133. iteration = dense_feature.shape[0] // batch_size
  134. epoch = 50
  135. for ep in range(epoch):
  136. print('epoch', ep)
  137. if ep == 5:
  138. start = time.time()
  139. ep_st = time.time()
  140. train_loss = []
  141. train_acc = []
  142. for idx in range(iteration):
  143. start_index = idx * batch_size
  144. my_feed_dict[dense_input][:] = dense_feature[start_index: start_index + batch_size]
  145. my_feed_dict[sparse_input][:] = sparse_feature[start_index: start_index + batch_size]
  146. my_feed_dict[y_][:] = labels[start_index: start_index+batch_size]
  147. loss_val = sess.run(
  148. [loss, y, y_, train_op], feed_dict=my_feed_dict)
  149. pred_val = loss_val[1]
  150. true_val = loss_val[2]
  151. if pred_val.shape[1] == 1: # for criteo case
  152. acc_val = np.equal(
  153. true_val,
  154. pred_val > 0.5)
  155. else:
  156. acc_val = np.equal(
  157. np.argmax(pred_val, 1),
  158. np.argmax(true_val, 1)).astype(np.float32)
  159. train_loss.append(loss_val[0])
  160. train_acc.append(acc_val)
  161. tra_accuracy = np.mean(train_acc)
  162. tra_loss = np.mean(train_loss)
  163. ep_en = time.time()
  164. print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f"
  165. % (tra_loss, tra_accuracy, ep_en - ep_st))
  166. print('all time:', (time.time() - start))
  167. def main():
  168. parser = argparse.ArgumentParser()
  169. parser.add_argument("--model", type=str, required=True,
  170. help="model to be tested")
  171. parser.add_argument("--all", action="store_true",
  172. help="whether to use all data")
  173. args = parser.parse_args()
  174. raw_model = args.model
  175. import tf_models
  176. model = eval('tf_models.' + raw_model)
  177. dataset = raw_model.split('_')[-1]
  178. print('Model:', raw_model)
  179. if dataset == 'criteo':
  180. train_criteo(model, args)
  181. else:
  182. raise NotImplementedError
  183. if __name__ == '__main__':
  184. main()