You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_hetu.py 6.6 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import hetu as ht
  2. from hetu.launcher import launch
  3. import os
  4. import numpy as np
  5. import yaml
  6. import time
  7. import math
  8. import argparse
  9. from tqdm import tqdm
  10. from hetu_ncf import neural_mf
  11. import heapq # for retrieval topK
  12. def getHitRatio(ranklist, gtItem):
  13. for item in ranklist:
  14. if item == gtItem:
  15. return 1
  16. return 0
  17. def getNDCG(ranklist, gtItem):
  18. for i in range(len(ranklist)):
  19. item = ranklist[i]
  20. if item == gtItem:
  21. return math.log(2) / math.log(i+2)
  22. return 0
  23. class Logging(object):
  24. def __init__(self, path='logs/hetulog.txt'):
  25. with open(path, 'w') as fw:
  26. fw.write('')
  27. self.path = path
  28. def write(self, s):
  29. print(s)
  30. with open(self.path, 'a') as fw:
  31. fw.write(s + '\n')
  32. fw.flush()
  33. def worker(args):
  34. def validate():
  35. hits, ndcgs = [], []
  36. for idx in range(testData.shape[0]):
  37. start_index = idx * 100
  38. predictions = executor.run(
  39. 'validate', convert_to_numpy_ret_vals=True)
  40. map_item_score = {
  41. testItemInput[start_index + i]: predictions[0][i] for i in range(100)}
  42. gtItem = testItemInput[start_index]
  43. # Evaluate top rank list
  44. ranklist = heapq.nlargest(
  45. topK, map_item_score, key=map_item_score.get)
  46. hr = getHitRatio(ranklist, gtItem)
  47. ndcg = getNDCG(ranklist, gtItem)
  48. hits.append(hr)
  49. ndcgs.append(ndcg)
  50. hr, ndcg = np.array(hits).mean(), np.array(ndcgs).mean()
  51. return hr, ndcg
  52. def get_current_shard(data):
  53. if args.comm is not None:
  54. part_size = data.shape[0] // nrank
  55. start = part_size * rank
  56. end = start + part_size if rank != nrank - 1 else data.shape[0]
  57. return data[start:end]
  58. else:
  59. return data
  60. device_id = 0
  61. if args.comm == 'PS':
  62. rank = ht.get_worker_communicate().rank()
  63. nrank = int(os.environ['DMLC_NUM_WORKER'])
  64. device_id = rank % 8
  65. elif args.comm == 'Hybrid':
  66. comm = ht.wrapped_mpi_nccl_init()
  67. device_id = comm.dev_id
  68. rank = comm.rank
  69. nrank = int(os.environ['DMLC_NUM_WORKER'])
  70. from movielens import getdata
  71. if args.all:
  72. trainData, testData = getdata('ml-25m', 'datasets')
  73. trainUsers = get_current_shard(trainData['user_input'])
  74. trainItems = get_current_shard(trainData['item_input'])
  75. trainLabels = get_current_shard(trainData['labels'])
  76. testData = get_current_shard(testData)
  77. testUserInput = np.repeat(
  78. np.arange(testData.shape[0], dtype=np.int32), 100)
  79. testItemInput = testData.reshape((-1,))
  80. else:
  81. trainData, testData = getdata('ml-25m', 'datasets')
  82. trainUsers = get_current_shard(trainData['user_input'][:1024000])
  83. trainItems = get_current_shard(trainData['item_input'][:1024000])
  84. trainLabels = get_current_shard(trainData['labels'][:1024000])
  85. testData = get_current_shard(testData[:1470])
  86. testUserInput = np.repeat(
  87. np.arange(testData.shape[0], dtype=np.int32), 100)
  88. testItemInput = testData.reshape((-1,))
  89. num_users, num_items = {
  90. 'ml-1m': (6040, 3706),
  91. 'ml-20m': (138493, 26744),
  92. 'ml-25m': (162541, 59047),
  93. }['ml-25m']
  94. # assert not args.all or num_users == testData.shape[0]
  95. batch_size = 1024
  96. num_negatives = 4
  97. topK = 10
  98. user_input = ht.dataloader_op([
  99. ht.Dataloader(trainUsers, batch_size, 'train'),
  100. ht.Dataloader(testUserInput, 100, 'validate'),
  101. ])
  102. item_input = ht.dataloader_op([
  103. ht.Dataloader(trainItems, batch_size, 'train'),
  104. ht.Dataloader(testItemInput, 100, 'validate'),
  105. ])
  106. y_ = ht.dataloader_op([
  107. ht.Dataloader(trainLabels.reshape((-1, 1)), batch_size, 'train'),
  108. ])
  109. loss, y, train_op = neural_mf(
  110. user_input, item_input, y_, num_users, num_items)
  111. executor = ht.Executor({'train': [loss, train_op], 'validate': [y]}, ctx=ht.gpu(device_id),
  112. comm_mode=args.comm, cstable_policy=args.cache, bsp=args.bsp, cache_bound=args.bound, seed=123)
  113. path = 'logs/hetulog_%s' % ({None: 'local',
  114. 'PS': 'ps', 'Hybrid': 'hybrid'}[args.comm])
  115. path += '_%d.txt' % rank if args.comm else '.txt'
  116. log = Logging(path=path)
  117. epoch = 7
  118. start = time.time()
  119. for ep in range(epoch):
  120. ep_st = time.time()
  121. log.write('epoch %d' % ep)
  122. train_loss = []
  123. for idx in tqdm(range(executor.get_batch_num('train'))):
  124. loss_val = executor.run('train', convert_to_numpy_ret_vals=True)
  125. train_loss.append(loss_val[0])
  126. tra_loss = np.mean(train_loss)
  127. ep_en = time.time()
  128. # validate phase
  129. if args.val:
  130. hr, ndcg = validate()
  131. printstr = "train_loss: %.4f, HR: %.4f, NDCF: %.4f, train_time: %.4f" % (
  132. tra_loss, hr, ndcg, ep_en - ep_st)
  133. else:
  134. printstr = "train_loss: %.4f, train_time: %.4f" % (
  135. tra_loss, ep_en - ep_st)
  136. log.write(printstr)
  137. log.write('all time: %f' % (time.time() - start))
  138. if __name__ == '__main__':
  139. parser = argparse.ArgumentParser()
  140. parser.add_argument("--val", action="store_true",
  141. help="whether to perform validation")
  142. parser.add_argument("--all", action="store_true",
  143. help="whether to use all data, default to use 1024000 training data")
  144. parser.add_argument("--comm", default=None,
  145. help="whether to use distributed setting, can be None, AllReduce, PS, Hybrid")
  146. parser.add_argument("--bsp", action="store_true",
  147. help="whether to use bsp instead of asp")
  148. parser.add_argument("--cache", default=None, help="cache policy")
  149. parser.add_argument("--bound", default=100, help="cache bound")
  150. parser.add_argument(
  151. "--config", type=str, default="./settings/local_s1_w4.yml", help="configuration for ps")
  152. args = parser.parse_args()
  153. if args.comm is None:
  154. worker(args)
  155. elif args.comm == 'Hybrid':
  156. settings = yaml.load(open(args.config).read(), Loader=yaml.FullLoader)
  157. value = settings['shared']
  158. os.environ['DMLC_ROLE'] = 'worker'
  159. for k, v in value.items():
  160. os.environ[k] = str(v)
  161. worker(args)
  162. elif args.comm == 'PS':
  163. launch(worker, args)
  164. else:
  165. raise NotImplementedError