You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_tf_local.py 7.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import numpy as np
  2. import tensorflow as tf
  3. import time
  4. import argparse
  5. from tqdm import tqdm
  6. from sklearn import metrics
  7. def train_criteo(model, args):
  8. if args.all:
  9. from models.load_data import process_all_criteo_data
  10. dense, sparse, all_labels = process_all_criteo_data()
  11. dense_feature, val_dense = dense
  12. sparse_feature, val_sparse = sparse
  13. labels, val_labels = all_labels
  14. else:
  15. from models.load_data import process_sampled_criteo_data
  16. dense_feature, sparse_feature, labels = process_sampled_criteo_data()
  17. batch_size = 128
  18. dense_input = tf.compat.v1.placeholder(tf.float32, [batch_size, 13])
  19. sparse_input = tf.compat.v1.placeholder(tf.int32, [batch_size, 26])
  20. y_ = y_ = tf.compat.v1.placeholder(tf.float32, [batch_size, 1])
  21. loss, y, opt = model(dense_input, sparse_input, y_)
  22. train_op = opt.minimize(loss)
  23. init = tf.compat.v1.global_variables_initializer()
  24. gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
  25. sess = tf.compat.v1.Session(
  26. config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
  27. sess.run(init)
  28. my_feed_dict = {
  29. dense_input: np.empty(shape=(batch_size, 13)),
  30. sparse_input: np.empty(shape=(batch_size, 26)),
  31. y_: np.empty(shape=(batch_size, 1)),
  32. }
  33. if args.all:
  34. raw_log_file = './logs/tf_local_%s.log' % (args.model)
  35. print('Processing all data, log to', raw_log_file)
  36. log_file = open(raw_log_file, 'w')
  37. iterations = dense_feature.shape[0] // batch_size
  38. total_epoch = 11
  39. start_index = 0
  40. for ep in range(total_epoch):
  41. print("epoch %d" % ep)
  42. st_time = time.time()
  43. train_loss, train_acc, train_auc = [], [], []
  44. for it in tqdm(range(iterations // 10 + (ep % 10 == 9) * (iterations % 10))):
  45. my_feed_dict[dense_input][:] = dense_feature[start_index: start_index + batch_size]
  46. my_feed_dict[sparse_input][:] = sparse_feature[start_index: start_index + batch_size]
  47. my_feed_dict[y_][:] = labels[start_index: start_index+batch_size]
  48. start_index += batch_size
  49. if start_index + batch_size > dense_feature.shape[0]:
  50. start_index = 0
  51. loss_val = sess.run([loss, y, y_, train_op],
  52. feed_dict=my_feed_dict)
  53. pred_val = loss_val[1]
  54. true_val = loss_val[2]
  55. acc_val = np.equal(
  56. true_val,
  57. pred_val > 0.5)
  58. train_loss.append(loss_val[0])
  59. train_acc.append(acc_val)
  60. train_auc.append(metrics.roc_auc_score(true_val, pred_val))
  61. tra_accuracy = np.mean(train_acc)
  62. tra_loss = np.mean(train_loss)
  63. tra_auc = np.mean(train_auc)
  64. en_time = time.time()
  65. train_time = en_time - st_time
  66. printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, train_time: %.4f"\
  67. % (tra_loss, tra_accuracy, tra_auc, train_time)
  68. print(printstr)
  69. log_file.write(printstr + '\n')
  70. log_file.flush()
  71. else:
  72. iteration = dense_feature.shape[0] // batch_size
  73. epoch = 50
  74. for ep in range(epoch):
  75. print('epoch', ep)
  76. if ep == 5:
  77. start = time.time()
  78. ep_st = time.time()
  79. train_loss = []
  80. train_acc = []
  81. for idx in range(iteration):
  82. start_index = idx * batch_size
  83. my_feed_dict[dense_input][:] = dense_feature[start_index: start_index + batch_size]
  84. my_feed_dict[sparse_input][:] = sparse_feature[start_index: start_index + batch_size]
  85. my_feed_dict[y_][:] = labels[start_index: start_index+batch_size]
  86. loss_val = sess.run([loss, y, y_, train_op],
  87. feed_dict=my_feed_dict)
  88. pred_val = loss_val[1]
  89. true_val = loss_val[2]
  90. if pred_val.shape[1] == 1: # for criteo case
  91. acc_val = np.equal(
  92. true_val,
  93. pred_val > 0.5)
  94. else:
  95. acc_val = np.equal(
  96. np.argmax(pred_val, 1),
  97. np.argmax(true_val, 1)).astype(np.float32)
  98. train_loss.append(loss_val[0])
  99. train_acc.append(acc_val)
  100. tra_accuracy = np.mean(train_acc)
  101. tra_loss = np.mean(train_loss)
  102. ep_en = time.time()
  103. print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f"
  104. % (tra_loss, tra_accuracy, ep_en - ep_st))
  105. print('all time:', (time.time() - start))
  106. def train_adult(model):
  107. batch_size = 128
  108. total_epoch = 50
  109. dim_wide = 809
  110. X_deep = []
  111. for i in range(8):
  112. X_deep.append(tf.compat.v1.placeholder(tf.int32, [batch_size, 1]))
  113. for i in range(4):
  114. X_deep.append(tf.compat.v1.placeholder(tf.float32, [batch_size, 1]))
  115. X_wide = tf.compat.v1.placeholder(tf.float32, [batch_size, dim_wide])
  116. y_ = tf.compat.v1.placeholder(tf.float32, [batch_size, 2])
  117. loss, y, train_op = model(X_deep, X_wide, y_)
  118. init = tf.global_variables_initializer()
  119. gpu_options = tf.GPUOptions(allow_growth=True)
  120. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  121. sess.run(init)
  122. from models.load_data import load_adult_data
  123. x_train_deep, x_train_wide, y_train = load_adult_data(return_val=False)
  124. iterations = x_train_deep.shape[0] // batch_size
  125. for ep in range(total_epoch):
  126. print('epoch', ep)
  127. if ep == 5:
  128. start = time.time()
  129. ep_st = time.time()
  130. train_loss = []
  131. train_acc = []
  132. pre_index = 0
  133. for it in range(iterations):
  134. batch_x_deep = x_train_deep[pre_index:pre_index + batch_size]
  135. batch_x_wide = x_train_wide[pre_index:pre_index + batch_size]
  136. batch_y = y_train[pre_index:pre_index + batch_size]
  137. pre_index += batch_size
  138. my_feed_dict = dict()
  139. for i in range(12):
  140. my_feed_dict[X_deep[i]] = np.array(
  141. batch_x_deep[:, 1]).reshape(-1, 1)
  142. my_feed_dict[X_wide] = np.array(batch_x_wide)
  143. my_feed_dict[y_] = batch_y
  144. loss_val = sess.run([loss, y, y_, train_op],
  145. feed_dict=my_feed_dict)
  146. acc_val = np.equal(
  147. np.argmax(loss_val[1], 1),
  148. np.argmax(loss_val[2], 1)).astype(np.float32)
  149. train_loss.append(loss_val[0])
  150. train_acc.append(acc_val)
  151. tra_accuracy = np.mean(train_acc)
  152. tra_loss = np.mean(train_loss)
  153. ep_en = time.time()
  154. print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f"
  155. % (tra_loss, tra_accuracy, ep_en - ep_st))
  156. print('all time:', (time.time() - start))
  157. def main():
  158. parser = argparse.ArgumentParser()
  159. parser.add_argument("--model", type=str, required=True,
  160. help="model to be tested")
  161. parser.add_argument("--all", action="store_true",
  162. help="whether to use all data")
  163. args = parser.parse_args()
  164. raw_model = args.model
  165. import tf_models
  166. model = eval('tf_models.' + raw_model)
  167. dataset = raw_model.split('_')[-1]
  168. print('Model:', raw_model)
  169. if dataset == 'criteo':
  170. train_criteo(model, args)
  171. elif dataset == 'adult':
  172. train_adult(model)
  173. else:
  174. raise NotImplementedError
  175. if __name__ == '__main__':
  176. main()

分布式深度学习系统

Contributors (1)