You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_tf.py 4.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import numpy as np
  2. import tensorflow as tf
  3. import time
  4. import argparse
  5. from tqdm import tqdm
  6. from tf_ncf import neural_mf
  7. import heapq # for retrieval topK
  8. import math
  9. def getHitRatio(ranklist, gtItem):
  10. for item in ranklist:
  11. if item == gtItem:
  12. return 1
  13. return 0
  14. def getNDCG(ranklist, gtItem):
  15. for i in range(len(ranklist)):
  16. item = ranklist[i]
  17. if item == gtItem:
  18. return math.log(2) / math.log(i+2)
  19. return 0
  20. class Logging(object):
  21. def __init__(self, path='logs/tflog.txt'):
  22. with open(path, 'w') as fw:
  23. fw.write('')
  24. self.path = path
  25. def write(self, s):
  26. print(s)
  27. with open(self.path, 'a') as fw:
  28. fw.write(s + '\n')
  29. fw.flush()
  30. def main():
  31. parser = argparse.ArgumentParser()
  32. parser.add_argument("--val", action="store_true",
  33. help="whether to perform validation")
  34. parser.add_argument("--all", action="store_true",
  35. help="whether to use all data")
  36. args = parser.parse_args()
  37. def validate():
  38. # validate phase
  39. hits, ndcgs = [], []
  40. for idx in range(num_users):
  41. start_index = idx * 100
  42. my_feed_dict = {
  43. user_input: testUserInput[start_index:start_index+100],
  44. item_input: testItemInput[start_index:start_index+100],
  45. }
  46. predictions = sess.run([y], feed_dict=my_feed_dict)
  47. map_item_score = {
  48. testItemInput[start_index+i]: predictions[0][i] for i in range(100)}
  49. # Evaluate top rank list
  50. ranklist = heapq.nlargest(
  51. topK, map_item_score, key=map_item_score.get)
  52. hr = getHitRatio(ranklist, testItemInput[start_index])
  53. ndcg = getNDCG(ranklist, testItemInput[start_index])
  54. hits.append(hr)
  55. ndcgs.append(ndcg)
  56. hr, ndcg = np.array(hits).mean(), np.array(ndcgs).mean()
  57. return hr, ndcg
  58. from movielens import getdata
  59. if args.all:
  60. trainData, testData = getdata('ml-25m', 'datasets')
  61. trainUsers = trainData['user_input']
  62. trainItems = trainData['item_input']
  63. trainLabels = trainData['labels']
  64. testData = testData
  65. testUserInput = np.repeat(
  66. np.arange(testData.shape[0], dtype=np.int32), 100)
  67. testItemInput = testData.reshape((-1,))
  68. else:
  69. trainData, testData = getdata('ml-25m', 'datasets')
  70. trainUsers = trainData['user_input'][:1024000]
  71. trainItems = trainData['item_input'][:1024000]
  72. trainLabels = trainData['labels'][:1024000]
  73. testData = testData[:1470]
  74. testUserInput = np.repeat(
  75. np.arange(testData.shape[0], dtype=np.int32), 100)
  76. testItemInput = testData.reshape((-1,))
  77. num_users, num_items = {
  78. 'ml-1m': (6040, 3706),
  79. 'ml-20m': (138493, 26744),
  80. 'ml-25m': (162541, 59047),
  81. }['ml-25m']
  82. batch_size = 1024
  83. num_negatives = 4
  84. topK = 10
  85. user_input = tf.compat.v1.placeholder(tf.int32, [None, ])
  86. item_input = tf.compat.v1.placeholder(tf.int32, [None, ])
  87. y_ = tf.compat.v1.placeholder(tf.float32, [None, ])
  88. loss, y, opt = neural_mf(user_input, item_input, y_, num_users, num_items)
  89. train_op = opt.minimize(loss)
  90. init = tf.compat.v1.global_variables_initializer()
  91. gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
  92. sess = tf.compat.v1.Session(
  93. config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
  94. sess.run(init)
  95. log = Logging()
  96. epoch = 7
  97. iterations = trainUsers.shape[0] // batch_size
  98. start = time.time()
  99. for ep in range(epoch):
  100. ep_st = time.time()
  101. log.write('epoch %d' % ep)
  102. train_loss = []
  103. for idx in range(iterations):
  104. start_index = idx * batch_size
  105. my_feed_dict = {
  106. user_input: trainUsers[start_index:start_index+batch_size],
  107. item_input: trainItems[start_index:start_index+batch_size],
  108. y_: trainLabels[start_index:start_index+batch_size],
  109. }
  110. loss_val = sess.run([loss, train_op], feed_dict=my_feed_dict)
  111. train_loss.append(loss_val[0])
  112. tra_loss = np.mean(train_loss)
  113. ep_en = time.time()
  114. # validate phase
  115. if args.val:
  116. hr, ndcg = validate()
  117. printstr = "train_loss: %.4f, HR: %.4f, NDCF: %.4f, train_time: %.4f" % (
  118. tra_loss, hr, ndcg, ep_en - ep_st)
  119. else:
  120. printstr = "train_loss: %.4f, train_time: %.4f" % (
  121. tra_loss, ep_en - ep_st)
  122. log.write(printstr)
  123. log.write('all time:%f' % (time.time() - start))
  124. if __name__ == '__main__':
  125. main()