You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_hetu.py 10 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import hetu as ht
  2. from hetu.launcher import launch
  3. import os
  4. import os.path as osp
  5. import numpy as np
  6. import yaml
  7. import time
  8. import argparse
  9. from tqdm import tqdm
  10. from sklearn import metrics
  11. def worker(args):
  12. def train(iterations, auc_enabled=True, tqdm_enabled=False):
  13. localiter = tqdm(range(iterations)
  14. ) if tqdm_enabled else range(iterations)
  15. train_loss = []
  16. train_acc = []
  17. if auc_enabled:
  18. train_auc = []
  19. for it in localiter:
  20. loss_val, predict_y, y_val, _ = executor.run(
  21. 'train', convert_to_numpy_ret_vals=True)
  22. if y_val.shape[1] == 1: # for criteo case
  23. acc_val = np.equal(
  24. y_val,
  25. predict_y > 0.5).astype(np.float32)
  26. else:
  27. acc_val = np.equal(
  28. np.argmax(y_val, 1),
  29. np.argmax(predict_y, 1)).astype(np.float32)
  30. train_loss.append(loss_val[0])
  31. train_acc.append(acc_val)
  32. if auc_enabled:
  33. train_auc.append(metrics.roc_auc_score(y_val, predict_y))
  34. if auc_enabled:
  35. return np.mean(train_loss), np.mean(train_acc), np.mean(train_auc)
  36. else:
  37. return np.mean(train_loss), np.mean(train_acc)
  38. def validate(iterations, tqdm_enabled=False):
  39. localiter = tqdm(range(iterations)
  40. ) if tqdm_enabled else range(iterations)
  41. test_loss = []
  42. test_acc = []
  43. test_auc = []
  44. for it in localiter:
  45. loss_val, test_y_predicted, y_test_val = executor.run(
  46. 'validate', convert_to_numpy_ret_vals=True)
  47. if y_test_val.shape[1] == 1: # for criteo case
  48. correct_prediction = np.equal(
  49. y_test_val,
  50. test_y_predicted > 0.5).astype(np.float32)
  51. else:
  52. correct_prediction = np.equal(
  53. np.argmax(y_test_val, 1),
  54. np.argmax(test_y_predicted, 1)).astype(np.float32)
  55. test_loss.append(loss_val[0])
  56. test_acc.append(correct_prediction)
  57. test_auc.append(metrics.roc_auc_score(
  58. y_test_val, test_y_predicted))
  59. return np.mean(test_loss), np.mean(test_acc), np.mean(test_auc)
  60. def get_current_shard(data):
  61. if args.comm is not None:
  62. part_size = data.shape[0] // nrank
  63. start = part_size * rank
  64. end = start + part_size if rank != nrank - 1 else data.shape[0]
  65. return data[start:end]
  66. else:
  67. return data
  68. batch_size = 128
  69. dataset = args.dataset
  70. model = args.model
  71. device_id = 0
  72. if args.comm == 'PS':
  73. rank = ht.get_worker_communicate().rank()
  74. nrank = int(os.environ['DMLC_NUM_WORKER'])
  75. device_id = rank % 8
  76. elif args.comm == 'Hybrid':
  77. comm = ht.wrapped_mpi_nccl_init()
  78. device_id = comm.dev_id
  79. rank = comm.rank
  80. nrank = int(os.environ['DMLC_NUM_WORKER'])
  81. if dataset == 'criteo':
  82. # define models for criteo
  83. if args.all:
  84. from models.load_data import process_all_criteo_data
  85. dense, sparse, labels = process_all_criteo_data(
  86. return_val=args.val)
  87. elif args.val:
  88. from models.load_data import process_head_criteo_data
  89. dense, sparse, labels = process_head_criteo_data(return_val=True)
  90. else:
  91. from models.load_data import process_sampled_criteo_data
  92. dense, sparse, labels = process_sampled_criteo_data()
  93. if isinstance(dense, tuple):
  94. dense_input = ht.dataloader_op([[get_current_shard(dense[0]), batch_size, 'train'], [
  95. get_current_shard(dense[1]), batch_size, 'validate']])
  96. sparse_input = ht.dataloader_op([[get_current_shard(sparse[0]), batch_size, 'train'], [
  97. get_current_shard(sparse[1]), batch_size, 'validate']])
  98. y_ = ht.dataloader_op([[get_current_shard(labels[0]), batch_size, 'train'], [
  99. get_current_shard(labels[1]), batch_size, 'validate']])
  100. else:
  101. dense_input = ht.dataloader_op(
  102. [[get_current_shard(dense), batch_size, 'train']])
  103. sparse_input = ht.dataloader_op(
  104. [[get_current_shard(sparse), batch_size, 'train']])
  105. y_ = ht.dataloader_op(
  106. [[get_current_shard(labels), batch_size, 'train']])
  107. elif dataset == 'adult':
  108. from models.load_data import load_adult_data
  109. x_train_deep, x_train_wide, y_train, x_test_deep, x_test_wide, y_test = load_adult_data()
  110. dense_input = [
  111. ht.dataloader_op([
  112. [get_current_shard(x_train_deep[:, i]), batch_size, 'train'],
  113. [get_current_shard(x_test_deep[:, i]), batch_size, 'validate'],
  114. ]) for i in range(12)
  115. ]
  116. sparse_input = ht.dataloader_op([
  117. [get_current_shard(x_train_wide), batch_size, 'train'],
  118. [get_current_shard(x_test_wide), batch_size, 'validate'],
  119. ])
  120. y_ = ht.dataloader_op([
  121. [get_current_shard(y_train), batch_size, 'train'],
  122. [get_current_shard(y_test), batch_size, 'validate'],
  123. ])
  124. else:
  125. raise NotImplementedError
  126. print("Data loaded.")
  127. loss, prediction, y_, train_op = model(dense_input, sparse_input, y_)
  128. eval_nodes = {'train': [loss, prediction, y_, train_op]}
  129. if args.val:
  130. print('Validation enabled...')
  131. eval_nodes['validate'] = [loss, prediction, y_]
  132. executor_log_path = osp.join(osp.dirname(osp.abspath(__file__)), 'logs')
  133. executor = ht.Executor(eval_nodes, ctx=ht.gpu(device_id),
  134. comm_mode=args.comm, cstable_policy=args.cache, bsp=args.bsp, cache_bound=args.bound, seed=123, log_path=executor_log_path)
  135. if args.all and dataset == 'criteo':
  136. print('Processing all data...')
  137. file_path = '%s_%s' % ({None: 'local', 'PS': 'ps', 'Hybrid': 'hybrid'}[
  138. args.comm], args.raw_model)
  139. file_path += '%d.log' % rank if args.comm else '.log'
  140. file_path = osp.join(osp.dirname(
  141. osp.abspath(__file__)), 'logs', file_path)
  142. log_file = open(file_path, 'w')
  143. total_epoch = args.nepoch if args.nepoch > 0 else 11
  144. for ep in range(total_epoch):
  145. print("ep: %d" % ep)
  146. ep_st = time.time()
  147. train_loss, train_acc, train_auc = train(executor.get_batch_num(
  148. 'train') // 10 + (ep % 10 == 9) * (executor.get_batch_num('train') % 10), tqdm_enabled=True)
  149. ep_en = time.time()
  150. if args.val:
  151. val_loss, val_acc, val_auc = validate(
  152. executor.get_batch_num('validate'))
  153. printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f, train_time: %.4f"\
  154. % (train_loss, train_acc, train_auc, val_loss, val_acc, val_auc, ep_en - ep_st)
  155. else:
  156. printstr = "train_loss: %.4f, train_acc: %.4f, train_auc: %.4f, train_time: %.4f"\
  157. % (train_loss, train_acc, train_auc, ep_en - ep_st)
  158. print(printstr)
  159. log_file.write(printstr + '\n')
  160. log_file.flush()
  161. else:
  162. total_epoch = args.nepoch if args.nepoch > 0 else 50
  163. for ep in range(total_epoch):
  164. if ep == 5:
  165. start = time.time()
  166. print("epoch %d" % ep)
  167. ep_st = time.time()
  168. train_loss, train_acc = train(
  169. executor.get_batch_num('train'), auc_enabled=False)
  170. ep_en = time.time()
  171. if args.val:
  172. val_loss, val_acc, val_auc = validate(
  173. executor.get_batch_num('validate'))
  174. print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f"
  175. % (train_loss, train_acc, ep_en - ep_st, val_loss, val_acc, val_auc))
  176. else:
  177. print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f"
  178. % (train_loss, train_acc, ep_en - ep_st))
  179. print('all time:', time.time() - start)
  180. if __name__ == '__main__':
  181. parser = argparse.ArgumentParser()
  182. parser.add_argument("--model", type=str, required=True,
  183. help="model to be tested")
  184. parser.add_argument("--val", action="store_true",
  185. help="whether to use validation")
  186. parser.add_argument("--all", action="store_true",
  187. help="whether to use all data")
  188. parser.add_argument("--comm", default=None,
  189. help="whether to use distributed setting, can be None, AllReduce, PS, Hybrid")
  190. parser.add_argument("--bsp", action="store_true",
  191. help="whether to use bsp instead of asp")
  192. parser.add_argument("--cache", default=None, help="cache policy")
  193. parser.add_argument("--bound", default=100, help="cache bound")
  194. parser.add_argument("--config", type=str, default=osp.join(osp.dirname(
  195. osp.abspath(__file__)), "./settings/local_s1_w4.yml"), help="configuration for ps")
  196. parser.add_argument("--nepoch", type=int, default=-1,
  197. help="num of epochs, each train 1/10 data")
  198. args = parser.parse_args()
  199. import models
  200. print('Model:', args.model)
  201. model = eval('models.' + args.model)
  202. args.dataset = args.model.split('_')[-1]
  203. args.raw_model = args.model
  204. args.model = model
  205. if args.comm is None:
  206. worker(args)
  207. elif args.comm == 'Hybrid':
  208. settings = yaml.load(open(args.config).read(), Loader=yaml.FullLoader)
  209. value = settings['shared']
  210. os.environ['DMLC_ROLE'] = 'worker'
  211. for k, v in value.items():
  212. os.environ[k] = str(v)
  213. worker(args)
  214. elif args.comm == 'PS':
  215. launch(worker, args)
  216. else:
  217. raise NotImplementedError