You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_launch_worker.py 14 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import tensorflow as tf
  2. import numpy as np
  3. import argparse
  4. import os
  5. import time
  6. import json
  7. from sklearn import metrics
  8. from tqdm import tqdm
  9. def pop_env():
  10. for k in ['https_proxy', 'http_proxy']:
  11. if k in os.environ:
  12. os.environ.pop(k)
  13. pop_env()
  14. def train_criteo(model, cluster, task_id, nrank, args):
  15. def get_current_shard(data):
  16. part_size = data.shape[0] // nrank
  17. start = part_size * task_id
  18. end = start + part_size if task_id != nrank - 1 else data.shape[0]
  19. return data[start:end]
  20. if args.all:
  21. from models.load_data import process_all_criteo_data
  22. dense, sparse, all_labels = process_all_criteo_data()
  23. dense_feature = get_current_shard(dense[0])
  24. sparse_feature = get_current_shard(sparse[0])
  25. labels = get_current_shard(all_labels[0])
  26. val_dense = get_current_shard(dense[1])
  27. val_sparse = get_current_shard(sparse[1])
  28. val_labels = get_current_shard(all_labels[1])
  29. else:
  30. from models.load_data import process_sampled_criteo_data
  31. dense_feature, sparse_feature, labels = process_sampled_criteo_data()
  32. dense_feature = get_current_shard(dense_feature)
  33. sparse_feature = get_current_shard(sparse_feature)
  34. labels = get_current_shard(labels)
  35. batch_size = 128
  36. worker_device = "/job:worker/task:%d/gpu:0" % (task_id)
  37. with tf.device(worker_device):
  38. dense_input = tf.compat.v1.placeholder(tf.float32, [batch_size, 13])
  39. sparse_input = tf.compat.v1.placeholder(tf.int32, [batch_size, 26])
  40. y_ = y_ = tf.compat.v1.placeholder(tf.float32, [batch_size, 1])
  41. with tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster)):
  42. server_num = len(cluster.as_dict()['ps'])
  43. embed_partitioner = tf.fixed_size_partitioner(
  44. server_num, 0) if server_num > 1 else None
  45. loss, y, opt = model(dense_input, sparse_input, y_,
  46. embed_partitioner, param_on_gpu=False)
  47. train_op = opt.minimize(loss)
  48. server = tf.train.Server(
  49. cluster, job_name="worker", task_index=task_id)
  50. init = tf.compat.v1.global_variables_initializer()
  51. sv = tf.train.Supervisor(
  52. is_chief=(task_id == 0),
  53. init_op=init,
  54. recovery_wait_secs=1)
  55. sess_config = tf.compat.v1.ConfigProto(
  56. allow_soft_placement=True,
  57. log_device_placement=False,
  58. device_filters=["/job:ps",
  59. "/job:worker/task:%d" % task_id])
  60. sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
  61. # sess.run(init)
  62. if task_id == 0:
  63. writer = tf.compat.v1.summary.FileWriter('logs/board', sess.graph)
  64. my_feed_dict = {
  65. dense_input: np.empty(shape=(batch_size, 13)),
  66. sparse_input: np.empty(shape=(batch_size, 26)),
  67. y_: np.empty(shape=(batch_size, 1)),
  68. }
  69. if args.all:
  70. raw_log_file = './logs/tf_dist_%s_%d.log' % (args.model, task_id)
  71. print('Processing all data, log to', raw_log_file)
  72. log_file = open(raw_log_file, 'w')
  73. iterations = dense_feature.shape[0] // batch_size
  74. total_epoch = 21
  75. start_index = 0
  76. for ep in range(total_epoch):
  77. print("epoch %d" % ep)
  78. st_time = time.time()
  79. train_loss, train_acc, train_auc = [], [], []
  80. for it in range(iterations // 10 + (ep % 10 == 9) * (iterations % 10)):
  81. my_feed_dict[dense_input][:] = dense_feature[start_index: start_index + batch_size]
  82. my_feed_dict[sparse_input][:] = sparse_feature[start_index: start_index + batch_size]
  83. my_feed_dict[y_][:] = labels[start_index: start_index+batch_size]
  84. start_index += batch_size
  85. if start_index + batch_size > dense_feature.shape[0]:
  86. start_index = 0
  87. loss_val = sess.run([loss, y, y_, train_op],
  88. feed_dict=my_feed_dict)
  89. pred_val = loss_val[1]
  90. true_val = loss_val[2]
  91. acc_val = np.equal(
  92. true_val,
  93. pred_val > 0.5)
  94. train_loss.append(loss_val[0])
  95. train_acc.append(acc_val)
  96. train_auc.append(metrics.roc_auc_score(true_val, pred_val))
  97. tra_accuracy = np.mean(train_acc)
  98. tra_loss = np.mean(train_loss)
  99. tra_auc = np.mean(train_auc)
  100. en_time = time.time()
  101. train_time = en_time - st_time
  102. if args.val:
  103. val_loss, val_acc, val_auc = [], [], []
  104. for it in range(val_dense.shape[0] // batch_size):
  105. local_st = it * batch_size
  106. my_feed_dict[dense_input][:] = val_dense[local_st: local_st + batch_size]
  107. my_feed_dict[sparse_input][:] = val_sparse[local_st: local_st + batch_size]
  108. my_feed_dict[y_][:] = val_labels[local_st: local_st+batch_size]
  109. loss_val = sess.run([loss, y, y_], feed_dict=my_feed_dict)
  110. pred_val = loss_val[1]
  111. true_val = loss_val[2]
  112. acc_val = np.equal(
  113. true_val,
  114. pred_val > 0.5)
  115. val_loss.append(loss_val[0])
  116. val_acc.append(acc_val)
  117. val_auc.append(metrics.roc_auc_score(true_val, pred_val))
  118. v_accuracy = np.mean(val_acc)
  119. v_loss = np.mean(val_loss)
  120. v_auc = np.mean(val_auc)
  121. printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f, train_time: %.4f"\
  122. % (tra_loss, tra_accuracy, tra_auc, v_loss, v_accuracy, v_auc, train_time)
  123. else:
  124. printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, train_time: %.4f"\
  125. % (tra_loss, tra_accuracy, tra_auc, train_time)
  126. print(printstr)
  127. log_file.write(printstr + '\n')
  128. log_file.flush()
  129. else:
  130. # here no val
  131. iteration = dense_feature.shape[0] // batch_size
  132. epoch = 10
  133. for ep in range(epoch):
  134. print('epoch', ep)
  135. if ep == 5:
  136. start = time.time()
  137. ep_st = time.time()
  138. train_loss = []
  139. train_acc = []
  140. for idx in range(iteration):
  141. start_index = idx * batch_size
  142. my_feed_dict[dense_input][:] = dense_feature[start_index: start_index + batch_size]
  143. my_feed_dict[sparse_input][:] = sparse_feature[start_index: start_index + batch_size]
  144. my_feed_dict[y_][:] = labels[start_index: start_index+batch_size]
  145. loss_val = sess.run([loss, y, y_, train_op],
  146. feed_dict=my_feed_dict)
  147. pred_val = loss_val[1]
  148. true_val = loss_val[2]
  149. if pred_val.shape[1] == 1: # for criteo case
  150. acc_val = np.equal(
  151. true_val,
  152. pred_val > 0.5)
  153. else:
  154. acc_val = np.equal(
  155. np.argmax(pred_val, 1),
  156. np.argmax(true_val, 1)).astype(np.float32)
  157. train_loss.append(loss_val[0])
  158. train_acc.append(acc_val)
  159. tra_accuracy = np.mean(train_acc)
  160. tra_loss = np.mean(train_loss)
  161. ep_en = time.time()
  162. print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f"
  163. % (tra_loss, tra_accuracy, ep_en - ep_st))
  164. print("tensorflow: ", (time.time() - start))
  165. def train_adult(model, cluster, task_id, nrank):
  166. from models.load_data import load_adult_data
  167. x_train_deep, x_train_wide, y_train = load_adult_data(return_val=False)
  168. part_size = len(x_train_deep) // nrank
  169. start = part_size * task_id
  170. end = start + part_size if task_id != nrank - 1 else len(x_train_deep)
  171. x_train_deep = x_train_deep[start:end]
  172. x_train_wide = x_train_wide[start:end]
  173. y_train = y_train[start:end]
  174. batch_size = 128
  175. total_epoch = 50
  176. dim_wide = 809
  177. worker_device = "/job:worker/task:%d/gpu:0" % (task_id)
  178. with tf.device(worker_device):
  179. X_deep = []
  180. for i in range(8):
  181. X_deep.append(tf.compat.v1.placeholder(tf.int32, [batch_size, 1]))
  182. for i in range(4):
  183. X_deep.append(tf.compat.v1.placeholder(
  184. tf.float32, [batch_size, 1]))
  185. X_wide = tf.compat.v1.placeholder(tf.float32, [batch_size, dim_wide])
  186. y_ = tf.compat.v1.placeholder(tf.float32, [batch_size, 2])
  187. loss, y, train_op, global_step = model(
  188. X_deep, X_wide, y_, cluster, task_id)
  189. with tf.device(
  190. tf.compat.v1.train.replica_device_setter(
  191. worker_device=worker_device,
  192. cluster=cluster)):
  193. server = tf.train.Server(
  194. cluster, job_name="worker", task_index=task_id)
  195. init = tf.global_variables_initializer()
  196. sv = tf.train.Supervisor(
  197. is_chief=(task_id == 0),
  198. init_op=init,
  199. recovery_wait_secs=1,
  200. global_step=global_step)
  201. sess_config = tf.ConfigProto(
  202. # allow_soft_placement=True,
  203. log_device_placement=False,
  204. device_filters=["/job:ps",
  205. "/job:worker/task:%d" % task_id])
  206. sess = sv.prepare_or_wait_for_session(
  207. server.target, config=sess_config)
  208. sess.run(init)
  209. iterations = x_train_deep.shape[0] // batch_size
  210. for ep in range(total_epoch):
  211. print('epoch', ep)
  212. if ep == 5:
  213. start = time.time()
  214. ep_st = time.time()
  215. train_loss = []
  216. train_acc = []
  217. pre_index = 0
  218. for it in range(iterations):
  219. batch_x_deep = x_train_deep[pre_index:pre_index + batch_size]
  220. batch_x_wide = x_train_wide[pre_index:pre_index + batch_size]
  221. batch_y = y_train[pre_index:pre_index + batch_size]
  222. pre_index += batch_size
  223. my_feed_dict = dict()
  224. for i in range(12):
  225. my_feed_dict[X_deep[i]] = np.array(
  226. batch_x_deep[:, 1]).reshape(-1, 1)
  227. my_feed_dict[X_wide] = np.array(batch_x_wide)
  228. my_feed_dict[y_] = batch_y
  229. loss_val = sess.run([loss, y, y_, train_op],
  230. feed_dict=my_feed_dict)
  231. acc_val = np.equal(
  232. np.argmax(loss_val[1], 1),
  233. np.argmax(loss_val[2], 1)).astype(np.float32)
  234. train_loss.append(loss_val[0])
  235. train_acc.append(acc_val)
  236. tra_accuracy = np.mean(train_acc)
  237. tra_loss = np.mean(train_loss)
  238. ep_en = time.time()
  239. print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f"
  240. % (tra_loss, tra_accuracy, ep_en - ep_st))
  241. print("tensorflow: ", (time.time() - start))
  242. def test_bandwidth(cluster, task_id):
  243. print('test bandwidth')
  244. iters = 1000
  245. params_size = 128 * 9
  246. ps_device = "/job:ps/task:0/cpu:0"
  247. worker_device = "/job:worker/task:%d/cpu:0" % (task_id)
  248. with tf.device(ps_device):
  249. dtype = tf.int32
  250. params = tf.get_variable("params", shape=[params_size], dtype=dtype,
  251. initializer=tf.zeros_initializer())
  252. with tf.device(tf.compat.v1.train.replica_device_setter(
  253. worker_device=worker_device,
  254. cluster=cluster)):
  255. update = tf.get_variable("update", shape=[params_size], dtype=dtype,
  256. initializer=tf.ones_initializer())
  257. add_op = params.assign(update)
  258. server = tf.train.Server(
  259. cluster, job_name="worker", task_index=task_id)
  260. init = tf.global_variables_initializer()
  261. sv = tf.train.Supervisor(
  262. is_chief=(task_id == 0),
  263. init_op=init,
  264. recovery_wait_secs=1)
  265. sess_config = tf.ConfigProto(
  266. allow_soft_placement=True,
  267. log_device_placement=False,
  268. device_filters=["/job:ps",
  269. "/job:worker/task:%d" % task_id])
  270. sess = sv.prepare_or_wait_for_session(
  271. server.target, config=sess_config)
  272. sess.run(init)
  273. # warm up
  274. for i in range(5):
  275. sess.run(add_op.op)
  276. start_time = time.time()
  277. for i in range(iters):
  278. sess.run(add_op.op)
  279. elapsed_time = time.time() - start_time
  280. ans = float(iters)*(params_size / 1024 / 1024)/elapsed_time
  281. print("transfer rate: %f MB/s" % (ans))
  282. def main():
  283. parser = argparse.ArgumentParser()
  284. parser.add_argument("--model", type=str, required=True,
  285. help="model to be tested")
  286. parser.add_argument("--rank", type=int, required=True,
  287. help="rank of process")
  288. parser.add_argument(
  289. "--config", type=str, default='./settings/tf_dist_s1_w2.json', help="config file path")
  290. parser.add_argument("--val", action="store_true",
  291. help="whether to use validation")
  292. parser.add_argument("--all", action="store_true",
  293. help="whether to use all data")
  294. args = parser.parse_args()
  295. raw_model = args.model
  296. task_id = int(args.rank)
  297. raw_config = args.config
  298. config = json.load(open(raw_config))
  299. cluster = tf.train.ClusterSpec(config)
  300. if raw_model != 'band':
  301. import tf_models
  302. model = eval('tf_models.' + raw_model)
  303. dataset = raw_model.split('_')[-1]
  304. print('Model:', raw_model)
  305. if dataset == 'criteo':
  306. train_criteo(model, cluster, task_id, len(config['worker']), args)
  307. elif dataset == 'adult':
  308. # not support val or all
  309. train_adult(model, cluster, task_id, len(config['worker']))
  310. else:
  311. raise NotImplementedError
  312. else:
  313. test_bandwidth(cluster, task_id)
  314. if __name__ == '__main__':
  315. main()