You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_dist.py 3.6 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from gnn_tools.launcher import launch_graphmix_and_hetu_ps
  2. from gnn_model.utils import get_norm_adj, prepare_data
  3. from gnn_model.model import sparse_model
  4. from gnn_tools.log import SharedTrainingStat
  5. import graphmix
  6. import hetu as ht
  7. import numpy as np
  8. import argparse
  9. # usage : on each machine
  10. # python3 run_dist.py [configfile] [-p data_path]
  11. def train_main(args):
  12. cli = graphmix.Client()
  13. meta = cli.meta
  14. hidden_layer_size = args.hidden_size
  15. num_epoch = args.num_epoch
  16. rank = cli.rank()
  17. nrank = cli.num_worker()
  18. ctx = ht.gpu(rank % args.num_local_worker)
  19. embedding_width = args.hidden_size
  20. # the last two is train label and other train mask
  21. num_int_feature = meta["int_feature"] - 2
  22. # sample some graphs
  23. ngraph = meta["train_node"] // (args.batch_size * nrank)
  24. graphs = prepare_data(ngraph)
  25. # build model
  26. [loss, y, train_op], [mask_, norm_adj_] = sparse_model(
  27. num_int_feature, args.hidden_size, meta["idx_max"], args.hidden_size, meta["class"], args.learning_rate)
  28. idx = 0
  29. graph = graphs[idx]
  30. idx = (idx + 1) % ngraph
  31. ht.GNNDataLoaderOp.step(graph)
  32. ht.GNNDataLoaderOp.step(graph)
  33. executor = ht.Executor([loss, y, train_op], ctx=ctx, comm_mode='PS',
  34. use_sparse_pull=False, cstable_policy=args.cache)
  35. nbatches = meta["train_node"] // (args.batch_size * nrank)
  36. for epoch in range(num_epoch):
  37. for _ in range(nbatches):
  38. graph_nxt = graphs[idx]
  39. idx = (idx + 1) % ngraph
  40. ht.GNNDataLoaderOp.step(graph_nxt)
  41. train_mask = np.bitwise_and(
  42. graph.extra[:, 0], graph.i_feat[:, -1] == 1)
  43. eval_mask = np.bitwise_and(
  44. graph.extra[:, 0], graph.i_feat[:, -1] != 1)
  45. feed_dict = {
  46. norm_adj_: get_norm_adj(graph, ht.gpu(rank % args.num_local_worker)),
  47. mask_: train_mask
  48. }
  49. loss_val, y_predicted, _ = executor.run(feed_dict=feed_dict)
  50. y_predicted = y_predicted.asnumpy().argmax(axis=1)
  51. acc = np.sum((y_predicted == graph.i_feat[:, -2]) * eval_mask)
  52. train_acc = np.sum(
  53. (y_predicted == graph.i_feat[:, -2]) * train_mask)
  54. stat.update(acc, eval_mask.sum(), np.sum(
  55. loss_val.asnumpy()*eval_mask)/eval_mask.sum())
  56. stat.update_train(train_acc, train_mask.sum(), np.sum(
  57. loss_val.asnumpy()*train_mask)/train_mask.sum())
  58. ht.get_worker_communicate().BarrierWorker()
  59. graph = graph_nxt
  60. if rank == 0:
  61. stat.print(epoch)
  62. def server_init(server):
  63. batch_size = args.batch_size
  64. server.init_cache(0.1, graphmix.cache.LFUOpt)
  65. worker_per_server = server.num_worker() // server.num_server()
  66. server.add_sampler(graphmix.sampler.GraphSage, batch_size=batch_size,
  67. depth=2, width=2, thread=4 * worker_per_server, subgraph=True)
  68. server.is_ready()
  69. if __name__ == '__main__':
  70. parser = argparse.ArgumentParser()
  71. parser.add_argument("config")
  72. parser.add_argument("--path", "-p", required=True)
  73. parser.add_argument("--num_epoch", default=300, type=int)
  74. parser.add_argument("--hidden_size", default=128, type=int)
  75. parser.add_argument("--learning_rate", default=1, type=float)
  76. parser.add_argument("--batch_size", default=128, type=int)
  77. parser.add_argument("--cache", default="LFUOpt", type=str)
  78. args = parser.parse_args()
  79. stat = SharedTrainingStat()
  80. launch_graphmix_and_hetu_ps(train_main, args, server_init=server_init)