You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_dist_hybrid.py 6.0 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from gnn_tools.launcher import launch_graphmix_and_hetu_ps
  2. from gnn_model.utils import get_norm_adj, prepare_data
  3. from gnn_model.model import sparse_model
  4. import graphmix
  5. import hetu as ht
  6. from hetu.communicator.mpi_nccl_comm import ncclDataType_t, ncclRedOp_t
  7. import numpy as np
  8. import time
  9. import os
  10. import sys
  11. import multiprocessing
  12. import argparse
  13. # usage :
  14. # mpirun -np 4 --allow-run-as-root python3 run_dist_hybrid.py [configfile] [-p data_path]
  15. # python3 run_dist_hybrid.py [configfile] [-p data_path] --server
  16. class TrainStat():
  17. def __init__(self, comm):
  18. self.file = open("log.txt", "w")
  19. self.train_stat = np.zeros(4)
  20. self.test_stat = np.zeros(4)
  21. self.count = 0
  22. self.time = []
  23. self.comm = comm
  24. def update_test(self, cnt, total, loss):
  25. self.test_stat += [1, cnt, total, loss]
  26. def update_train(self, cnt, total, loss):
  27. self.train_stat += [1, cnt, total, loss]
  28. def sync_and_clear(self):
  29. self.count += 1
  30. train_stat = ht.array(self.train_stat, ht.cpu())
  31. test_stat = ht.array(self.test_stat, ht.cpu())
  32. self.comm.dlarrayNcclAllReduce(
  33. train_stat, train_stat, ncclDataType_t.ncclFloat32, ncclRedOp_t.ncclSum, self.comm.stream)
  34. self.comm.dlarrayNcclAllReduce(
  35. test_stat, test_stat, ncclDataType_t.ncclFloat32, ncclRedOp_t.ncclSum, self.comm.stream)
  36. self.comm.stream.sync()
  37. train_stat, test_stat = train_stat.asnumpy(), test_stat.asnumpy()
  38. printstr = "epoch {}: test loss: {:.3f} test acc: {:.3f} train loss: {:.3f} train acc: {:.3f}".format(
  39. self.count,
  40. test_stat[3] / test_stat[0],
  41. test_stat[1] / test_stat[2],
  42. train_stat[3] / train_stat[0],
  43. train_stat[1] / train_stat[2],
  44. )
  45. logstr = "{} {} {} {}".format(
  46. test_stat[3] / test_stat[0],
  47. test_stat[1] / test_stat[2],
  48. train_stat[3] / train_stat[0],
  49. train_stat[1] / train_stat[2],
  50. )
  51. self.time.append(time.time())
  52. if self.comm.device_id.value == 0:
  53. print(printstr, flush=True)
  54. print(logstr, file=self.file, flush=True)
  55. if len(self.time) > 3:
  56. epoch_time = np.array(self.time[1:])-np.array(self.time[:-1])
  57. print(
  58. "epoch time: {:.3f}+-{:.3f}".format(np.mean(epoch_time), np.var(epoch_time)))
  59. self.train_stat[:] = 0
  60. self.test_stat[:] = 0
  61. def train_main(args):
  62. comm = ht.wrapped_mpi_nccl_init()
  63. device_id = comm.dev_id
  64. cli = graphmix.Client()
  65. meta = cli.meta
  66. hidden_layer_size = args.hidden_size
  67. num_epoch = args.num_epoch
  68. rank = cli.rank()
  69. nrank = cli.num_worker()
  70. ctx = ht.gpu(device_id)
  71. embedding_width = args.hidden_size
  72. # the last two is train label and other train mask
  73. num_int_feature = meta["int_feature"] - 2
  74. # sample some graphs
  75. ngraph = 10 * meta["train_node"] // (args.batch_size * nrank)
  76. graphs = prepare_data(ngraph)
  77. # build model
  78. [loss, y, train_op], [mask_, norm_adj_] = sparse_model(
  79. num_int_feature, args.hidden_size, meta["idx_max"], args.hidden_size, meta["class"], args.learning_rate)
  80. idx = 0
  81. graph = graphs[idx]
  82. idx = (idx + 1) % ngraph
  83. ht.GNNDataLoaderOp.step(graph)
  84. ht.GNNDataLoaderOp.step(graph)
  85. executor = ht.Executor([loss, y, train_op], ctx=ctx, comm_mode='Hybrid',
  86. use_sparse_pull=False, cstable_policy=args.cache)
  87. nbatches = meta["train_node"] // (args.batch_size * nrank)
  88. train_state = TrainStat(comm)
  89. for epoch in range(num_epoch):
  90. for _ in range(nbatches):
  91. graph_nxt = graphs[idx]
  92. idx = (idx + 1) % ngraph
  93. ht.GNNDataLoaderOp.step(graph_nxt)
  94. train_mask = np.bitwise_and(
  95. graph.extra[:, 0], graph.i_feat[:, -1] == 1)
  96. eval_mask = np.bitwise_and(
  97. graph.extra[:, 0], graph.i_feat[:, -1] != 1)
  98. feed_dict = {
  99. norm_adj_: get_norm_adj(graph, ht.gpu(device_id)),
  100. mask_: train_mask
  101. }
  102. loss_val, y_predicted, _ = executor.run(feed_dict=feed_dict)
  103. y_predicted = y_predicted.asnumpy().argmax(axis=1)
  104. acc = np.sum((y_predicted == graph.i_feat[:, -2]) * eval_mask)
  105. train_acc = np.sum(
  106. (y_predicted == graph.i_feat[:, -2]) * train_mask)
  107. train_state.update_test(acc, eval_mask.sum(), np.sum(
  108. loss_val.asnumpy()*eval_mask)/eval_mask.sum())
  109. train_state.update_train(train_acc, train_mask.sum(), np.sum(
  110. loss_val.asnumpy()*train_mask)/train_mask.sum())
  111. ht.get_worker_communicate().BarrierWorker()
  112. graph = graph_nxt
  113. train_state.sync_and_clear()
  114. def server_init(server):
  115. batch_size = args.batch_size
  116. server.init_cache(0.1, graphmix.cache.LFUOpt)
  117. worker_per_server = server.num_worker() // server.num_server()
  118. server.add_sampler(graphmix.sampler.GraphSage, batch_size=batch_size,
  119. depth=2, width=2, thread=4 * worker_per_server, subgraph=True)
  120. server.is_ready()
  121. if __name__ == '__main__':
  122. parser = argparse.ArgumentParser()
  123. parser.add_argument("config")
  124. parser.add_argument("--path", "-p", required=True)
  125. parser.add_argument("--num_epoch", default=300, type=int)
  126. parser.add_argument("--hidden_size", default=128, type=int)
  127. parser.add_argument("--learning_rate", default=1, type=float)
  128. parser.add_argument("--batch_size", default=128, type=int)
  129. parser.add_argument("--cache", default="LFUOpt", type=str)
  130. parser.add_argument("--server", action="store_true")
  131. args = parser.parse_args()
  132. if args.server:
  133. launch_graphmix_and_hetu_ps(
  134. train_main, args, server_init, hybrid_config="server")
  135. else:
  136. launch_graphmix_and_hetu_ps(
  137. train_main, args, server_init, hybrid_config="worker")