You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_wdl.py 5.6 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import hetu as ht
  2. from hetu.launcher import launch
  3. import os
  4. import numpy as np
  5. import yaml
  6. import time
  7. import argparse
  8. from tqdm import tqdm
  9. from sklearn import metrics
  10. from models import load_data, wdl_adult
  11. def worker(args):
  12. def train(iterations, auc_enabled=True, tqdm_enabled=False):
  13. localiter = tqdm(range(iterations)
  14. ) if tqdm_enabled else range(iterations)
  15. train_loss = []
  16. train_acc = []
  17. if auc_enabled:
  18. train_auc = []
  19. for it in localiter:
  20. loss_val, predict_y, y_val, _ = executor.run(
  21. 'train', convert_to_numpy_ret_vals=True)
  22. acc_val = np.equal(
  23. np.argmax(y_val, 1),
  24. np.argmax(predict_y, 1)).astype(np.float32)
  25. train_loss.append(loss_val[0])
  26. train_acc.append(acc_val)
  27. if auc_enabled:
  28. train_auc.append(metrics.roc_auc_score(y_val, predict_y))
  29. if auc_enabled:
  30. return np.mean(train_loss), np.mean(train_acc), np.mean(train_auc)
  31. else:
  32. return np.mean(train_loss), np.mean(train_acc)
  33. def validate(iterations, tqdm_enabled=False):
  34. localiter = tqdm(range(iterations)
  35. ) if tqdm_enabled else range(iterations)
  36. test_loss = []
  37. test_acc = []
  38. test_auc = []
  39. for it in localiter:
  40. loss_val, test_y_predicted, y_test_val = executor.run(
  41. 'validate', convert_to_numpy_ret_vals=True)
  42. correct_prediction = np.equal(
  43. np.argmax(y_test_val, 1),
  44. np.argmax(test_y_predicted, 1)).astype(np.float32)
  45. test_loss.append(loss_val[0])
  46. test_acc.append(correct_prediction)
  47. test_auc.append(metrics.roc_auc_score(
  48. y_test_val, test_y_predicted))
  49. return np.mean(test_loss), np.mean(test_acc), np.mean(test_auc)
  50. batch_size = 128
  51. ctx = {
  52. 'local': 'gpu:0',
  53. 'lps': 'cpu:0,gpu:0,gpu:1,gpu:2,gpu:7',
  54. 'lhy': 'cpu:0,gpu:1,gpu:2,gpu:3,gpu:6',
  55. 'rps': 'cpu:0;daim118:gpu:0;daim118:gpu:2;daim118:gpu:4;daim118:gpu:6;daim117:gpu:1;daim117:gpu:3',
  56. 'rhy': 'cpu:0;daim118:gpu:0;daim118:gpu:2;daim118:gpu:4;daim118:gpu:6;daim117:gpu:1;daim117:gpu:3'
  57. }[args.config]
  58. dense_param_ctx = {'local': 'gpu:0', 'lps': 'cpu:0,gpu:0,gpu:1,gpu:2,gpu:7', 'lhy': 'gpu:1,gpu:2,gpu:3,gpu:6',
  59. 'rps': 'cpu:0;daim118:gpu:0;daim118:gpu:2;daim118:gpu:4;daim118:gpu:6;daim117:gpu:1;daim117:gpu:3',
  60. 'rhy': 'daim118:gpu:0;daim118:gpu:2;daim118:gpu:4;daim118:gpu:6;daim117:gpu:1;daim117:gpu:3'}[args.config]
  61. with ht.context(ctx):
  62. x_train_deep, x_train_wide, y_train, x_test_deep, x_test_wide, y_test = load_data.load_adult_data()
  63. dense_input = [
  64. ht.dataloader_op([
  65. [x_train_deep[:, i], batch_size, 'train'],
  66. [x_test_deep[:, i], batch_size, 'validate'],
  67. ]) for i in range(12)
  68. ]
  69. sparse_input = ht.dataloader_op([
  70. [x_train_wide, batch_size, 'train'],
  71. [x_test_wide, batch_size, 'validate'],
  72. ])
  73. y_ = ht.dataloader_op([
  74. [y_train, batch_size, 'train'],
  75. [y_test, batch_size, 'validate'],
  76. ])
  77. print("Data loaded.")
  78. loss, prediction, y_, train_op = wdl_adult.wdl_adult(
  79. dense_input, sparse_input, y_, dense_param_ctx)
  80. eval_nodes = {'train': [loss, prediction, y_, train_op]}
  81. if args.val:
  82. print('Validation enabled...')
  83. eval_nodes['validate'] = [loss, prediction, y_]
  84. executor = ht.Executor(eval_nodes,
  85. cstable_policy=args.cache, bsp=args.bsp, cache_bound=args.bound, seed=123)
  86. total_epoch = args.nepoch if args.nepoch > 0 else 50
  87. for ep in range(total_epoch):
  88. if ep == 5:
  89. start = time.time()
  90. print("epoch %d" % ep)
  91. ep_st = time.time()
  92. train_loss, train_acc = train(
  93. executor.get_batch_num('train'), auc_enabled=False)
  94. ep_en = time.time()
  95. if args.val:
  96. val_loss, val_acc, val_auc = validate(
  97. executor.get_batch_num('validate'))
  98. print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f, test_loss: %.4f, test_acc: %.4f, test_auc: %.4f"
  99. % (train_loss, train_acc, ep_en - ep_st, val_loss, val_acc, val_auc))
  100. else:
  101. print("train_loss: %.4f, train_acc: %.4f, train_time: %.4f"
  102. % (train_loss, train_acc, ep_en - ep_st))
  103. print('all time:', time.time() - start)
  104. if __name__ == '__main__':
  105. parser = argparse.ArgumentParser()
  106. parser.add_argument('--config', type=str, default='local',
  107. help='[local, lps(localps), lhy(localhybrid), rps(remoteps), rhy]')
  108. parser.add_argument("--val", action="store_true",
  109. help="whether to use validation")
  110. parser.add_argument("--all", action="store_true",
  111. help="whether to use all data")
  112. parser.add_argument("--bsp", action="store_true",
  113. help="whether to use bsp instead of asp")
  114. parser.add_argument("--cache", default=None, help="cache policy")
  115. parser.add_argument("--bound", default=100, help="cache bound")
  116. parser.add_argument("--nepoch", type=int, default=-1,
  117. help="num of epochs, each train 1/10 data")
  118. args = parser.parse_args()
  119. worker(args)