You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_parallax.py 6.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. import time
  5. import argparse
  6. from tqdm import tqdm
  7. from tf_ncf import neural_mf
  8. import heapq # for retrieval topK
  9. import math
  10. from autodist import AutoDist
  11. from autodist.resource_spec import ResourceSpec
  12. from autodist.strategy import PS, PSLoadBalancing, PartitionedPS, AllReduce, Parallax
  13. from autodist.strategy.base import Strategy
  14. from autodist.kernel.common.utils import get_op_name
  15. from tensorflow.python.framework import ops
  16. def pop_env():
  17. for k in ['https_proxy', 'http_proxy']:
  18. if k in os.environ:
  19. os.environ.pop(k)
  20. pop_env()
  21. # Please DO NOT modify /etc/bash.bashrc to activate conda environment.
  22. # Use python_venv in spec yml file instead.
  23. # Use absolute path of python file.
  24. # Here we use the tf native partitioner instead of autodist's PartitionPS.
  25. class Parallaxx(PSLoadBalancing, AllReduce):
  26. """
  27. Modify original parallax to remove replica on CPUs.
  28. """
  29. def __init__(self, chunk_size=128, local_proxy_variable=False, sync=True, staleness=0):
  30. PSLoadBalancing.__init__(self, local_proxy_variable, sync, staleness)
  31. AllReduce.__init__(self, chunk_size)
  32. # pylint: disable=attribute-defined-outside-init
  33. def build(self, graph_item, resource_spec):
  34. """Generate the strategy."""
  35. expr = Strategy()
  36. # For each variable, generate variable synchronizer config
  37. expr.graph_config.replicas.extend(
  38. [k for k, v in resource_spec.gpu_devices])
  39. reduction_device_names = [k for k, _ in resource_spec.cpu_devices]
  40. self.loads = {ps: 0.0 for ps in reduction_device_names}
  41. # Generate node config
  42. node_config = []
  43. for idx, var in enumerate(graph_item.trainable_var_op_to_var.values()):
  44. var_op_name = get_op_name(var.name)
  45. grad, _, _ = graph_item.var_op_name_to_grad_info[var_op_name]
  46. if isinstance(grad, ops.Tensor): # this is a dense variable
  47. group_id = idx // self.chunk_size
  48. config = self._gen_all_reduce_node_config(
  49. var.name, group=group_id)
  50. else: # sparse updates
  51. # For Parallax Strategy, all PS vars are sparse so we don't use a proxy.
  52. # Sparse variables are likely larger, so keeping copies would be costlier,
  53. # and usually each device only requires a small part of the overall variable.
  54. config = self._gen_ps_node_config(
  55. var,
  56. # For Parallax Strategy, all PS vars are sparse which does not need proxy.
  57. False,
  58. self._sync,
  59. self._staleness
  60. )
  61. node_config.append(config)
  62. expr.node_config.extend(node_config)
  63. return expr
  64. def getHitRatio(ranklist, gtItem):
  65. for item in ranklist:
  66. if item == gtItem:
  67. return 1
  68. return 0
  69. def getNDCG(ranklist, gtItem):
  70. for i in range(len(ranklist)):
  71. item = ranklist[i]
  72. if item == gtItem:
  73. return math.log(2) / math.log(i+2)
  74. return 0
  75. class Logging(object):
  76. def __init__(self, path='logs/tflog.txt'):
  77. with open(path, 'w') as fw:
  78. fw.write('')
  79. self.path = path
  80. def write(self, s):
  81. print(s)
  82. with open(self.path, 'a') as fw:
  83. fw.write(s + '\n')
  84. fw.flush()
  85. def main():
  86. resource_spec_file = os.path.join(os.path.dirname(
  87. __file__), '../ctr/settings', 'plx_local_spec.yml')
  88. autodist = AutoDist(resource_spec_file, Parallaxx())
  89. respec = ResourceSpec(resource_spec_file)
  90. def validate():
  91. # validate phase
  92. hits, ndcgs = [], []
  93. for idx in range(num_users):
  94. start_index = idx * 100
  95. my_feed_dict = {
  96. user_input: testUserInput[start_index:start_index+100],
  97. item_input: testItemInput[start_index:start_index+100],
  98. }
  99. predictions = sess.run([y], feed_dict=my_feed_dict)
  100. map_item_score = {
  101. testItemInput[start_index+i]: predictions[0][i] for i in range(100)}
  102. # Evaluate top rank list
  103. ranklist = heapq.nlargest(
  104. topK, map_item_score, key=map_item_score.get)
  105. hr = getHitRatio(ranklist, testItemInput[start_index])
  106. ndcg = getNDCG(ranklist, testItemInput[start_index])
  107. hits.append(hr)
  108. ndcgs.append(ndcg)
  109. hr, ndcg = np.array(hits).mean(), np.array(ndcgs).mean()
  110. return hr, ndcg
  111. from movielens import getdata
  112. trainData, testData = getdata('ml-25m', 'datasets')
  113. testUserInput = np.repeat(
  114. np.arange(testData.shape[0], dtype=np.int32), 100)
  115. testItemInput = testData.reshape((-1,))
  116. num_users, num_items = {
  117. 'ml-1m': (6040, 3706),
  118. 'ml-20m': (138493, 26744),
  119. 'ml-25m': (162541, 59047),
  120. }['ml-25m']
  121. batch_size = 1024
  122. num_negatives = 4
  123. topK = 10
  124. with tf.Graph().as_default() as g, autodist.scope():
  125. user_input = tf.compat.v1.placeholder(tf.int32, [None, ])
  126. item_input = tf.compat.v1.placeholder(tf.int32, [None, ])
  127. y_ = tf.compat.v1.placeholder(tf.float32, [None, ])
  128. loss, y, opt = neural_mf(
  129. user_input, item_input, y_, num_users, num_items)
  130. train_op = opt.minimize(loss)
  131. sess = autodist.create_distributed_session()
  132. log = Logging(path=os.path.join(
  133. os.path.dirname(__file__), 'logs', 'tfplx.txt'))
  134. epoch = 7
  135. iterations = trainData['user_input'].shape[0] // batch_size
  136. start = time.time()
  137. for ep in range(epoch):
  138. ep_st = time.time()
  139. log.write('epoch %d' % ep)
  140. train_loss = []
  141. for idx in range(iterations):
  142. start_index = idx * batch_size
  143. my_feed_dict = {
  144. user_input: trainData['user_input'][start_index:start_index+batch_size],
  145. item_input: trainData['item_input'][start_index:start_index+batch_size],
  146. y_: trainData['labels'][start_index:start_index+batch_size],
  147. }
  148. loss_val = sess.run([loss, train_op], feed_dict=my_feed_dict)
  149. train_loss.append(loss_val[0])
  150. tra_loss = np.mean(train_loss)
  151. ep_en = time.time()
  152. # validate phase
  153. hr, ndcg = validate()
  154. printstr = "train_loss: %.4f, HR: %.4f, NDCF: %.4f, train_time: %.4f" % (
  155. tra_loss, hr, ndcg, ep_en - ep_st)
  156. log.write(printstr)
  157. log.write('all time:', (time.time() - start))
  158. if __name__ == '__main__':
  159. main()