You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_tfworker.py 6.5 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import os
  2. import json
  3. import numpy as np
  4. import tensorflow as tf
  5. import time
  6. import argparse
  7. from tqdm import tqdm
  8. from tf_ncf import neural_mf
  9. import heapq # for retrieval topK
  10. import math
  11. def pop_env():
  12. for k in ['https_proxy', 'http_proxy']:
  13. if k in os.environ:
  14. os.environ.pop(k)
  15. pop_env()
  16. def getHitRatio(ranklist, gtItem):
  17. for item in ranklist:
  18. if item == gtItem:
  19. return 1
  20. return 0
  21. def getNDCG(ranklist, gtItem):
  22. for i in range(len(ranklist)):
  23. item = ranklist[i]
  24. if item == gtItem:
  25. return math.log(2) / math.log(i+2)
  26. return 0
  27. class Logging(object):
  28. def __init__(self, path='logs/tflog.txt'):
  29. with open(path, 'w') as fw:
  30. fw.write('')
  31. self.path = path
  32. def write(self, s):
  33. print(s)
  34. with open(self.path, 'a') as fw:
  35. fw.write(s + '\n')
  36. fw.flush()
  37. def train_ncf(cluster, rank, nrank, args):
  38. def validate():
  39. # validate phase
  40. hits, ndcgs = [], []
  41. for idx in range(testData.shape[0]):
  42. start_index = idx * 100
  43. my_feed_dict = {
  44. user_input: testUserInput[start_index:start_index+100],
  45. item_input: testItemInput[start_index:start_index+100],
  46. }
  47. predictions = sess.run([y], feed_dict=my_feed_dict)
  48. map_item_score = {
  49. testItemInput[start_index+i]: predictions[0][i] for i in range(100)}
  50. # Evaluate top rank list
  51. ranklist = heapq.nlargest(
  52. topK, map_item_score, key=map_item_score.get)
  53. hr = getHitRatio(ranklist, testItemInput[start_index])
  54. ndcg = getNDCG(ranklist, testItemInput[start_index])
  55. hits.append(hr)
  56. ndcgs.append(ndcg)
  57. hr, ndcg = np.array(hits).mean(), np.array(ndcgs).mean()
  58. return hr, ndcg
  59. def get_current_shard(data):
  60. part_size = data.shape[0] // nrank
  61. start = part_size * rank
  62. end = start + part_size if rank != nrank - 1 else data.shape[0]
  63. return data[start:end]
  64. from movielens import getdata
  65. if args.all:
  66. trainData, testData = getdata('ml-25m', 'datasets')
  67. trainUsers = get_current_shard(trainData['user_input'])
  68. trainItems = get_current_shard(trainData['item_input'])
  69. trainLabels = get_current_shard(trainData['labels'])
  70. testData = get_current_shard(testData)
  71. testUserInput = np.repeat(
  72. np.arange(testData.shape[0], dtype=np.int32), 100)
  73. testItemInput = testData.reshape((-1,))
  74. else:
  75. trainData, testData = getdata('ml-25m', 'datasets')
  76. trainUsers = get_current_shard(trainData['user_input'][:1024000])
  77. trainItems = get_current_shard(trainData['item_input'][:1024000])
  78. trainLabels = get_current_shard(trainData['labels'][:1024000])
  79. testData = get_current_shard(testData[:1470])
  80. testUserInput = np.repeat(
  81. np.arange(testData.shape[0], dtype=np.int32), 100)
  82. testItemInput = testData.reshape((-1,))
  83. num_users, num_items = {
  84. 'ml-1m': (6040, 3706),
  85. 'ml-20m': (138493, 26744),
  86. 'ml-25m': (162541, 59047),
  87. }['ml-25m']
  88. batch_size = 1024
  89. num_negatives = 4
  90. topK = 10
  91. worker_device = "/job:worker/task:%d/gpu:0" % (rank)
  92. with tf.device(worker_device):
  93. user_input = tf.compat.v1.placeholder(tf.int32, [None, ])
  94. item_input = tf.compat.v1.placeholder(tf.int32, [None, ])
  95. y_ = tf.compat.v1.placeholder(tf.float32, [None, ])
  96. with tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster)):
  97. server_num = len(cluster.as_dict()['ps'])
  98. embed_partitioner = tf.fixed_size_partitioner(
  99. server_num, 0) if server_num > 1 else None
  100. loss, y, opt = neural_mf(
  101. user_input, item_input, y_, num_users, num_items, embed_partitioner)
  102. train_op = opt.minimize(loss)
  103. server = tf.train.Server(
  104. cluster, job_name="worker", task_index=rank)
  105. init = tf.compat.v1.global_variables_initializer()
  106. sv = tf.train.Supervisor(
  107. is_chief=(rank == 0),
  108. init_op=init,
  109. recovery_wait_secs=1)
  110. sess_config = tf.compat.v1.ConfigProto(
  111. allow_soft_placement=True,
  112. log_device_placement=False,
  113. device_filters=["/job:ps",
  114. "/job:worker/task:%d" % rank])
  115. sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
  116. log = Logging(path='logs/tflog%d.txt' % rank)
  117. epoch = 7
  118. iterations = trainUsers.shape[0] // batch_size
  119. start = time.time()
  120. for ep in range(epoch):
  121. ep_st = time.time()
  122. log.write('epoch %d' % ep)
  123. train_loss = []
  124. for idx in tqdm(range(iterations)):
  125. start_index = idx * batch_size
  126. my_feed_dict = {
  127. user_input: trainUsers[start_index:start_index+batch_size],
  128. item_input: trainItems[start_index:start_index+batch_size],
  129. y_: trainLabels[start_index:start_index+batch_size],
  130. }
  131. loss_val = sess.run([loss, train_op], feed_dict=my_feed_dict)
  132. train_loss.append(loss_val[0])
  133. tra_loss = np.mean(train_loss)
  134. ep_en = time.time()
  135. # validate phase
  136. if args.val:
  137. hr, ndcg = validate()
  138. printstr = "train_loss: %.4f, HR: %.4f, NDCF: %.4f, train_time: %.4f" % (
  139. tra_loss, hr, ndcg, ep_en - ep_st)
  140. else:
  141. printstr = "train_loss: %.4f, train_time: %.4f" % (
  142. tra_loss, ep_en - ep_st)
  143. log.write(printstr)
  144. log.write('all time: %f' % (time.time() - start))
  145. def main():
  146. parser = argparse.ArgumentParser()
  147. parser.add_argument("--val", action="store_true",
  148. help="whether to perform validation")
  149. parser.add_argument("--rank", type=int, required=True,
  150. help="rank of process")
  151. parser.add_argument(
  152. "--config", type=str, default='../ctr/settings/tf_local_s1_w2.json', help="config file path")
  153. parser.add_argument("--all", action="store_true",
  154. help="whether to use all data")
  155. args = parser.parse_args()
  156. task_id = int(args.rank)
  157. raw_config = args.config
  158. config = json.load(open(raw_config))
  159. cluster = tf.train.ClusterSpec(config)
  160. train_ncf(cluster, task_id, len(config['worker']), args)
  161. if __name__ == '__main__':
  162. main()