You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_single.py 3.7 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from gnn_tools.launcher import launch_graphmix_and_hetu_ps
  2. from gnn_tools.log import SharedTrainingStat
  3. from gnn_model.utils import get_norm_adj, prepare_data
  4. from gnn_model.model import sparse_model, dense_model
  5. import graphmix
  6. import hetu as ht
  7. import numpy as np
  8. import time
  9. import os
  10. import sys
  11. import argparse
  12. # usage
  13. # python3 run_single.py [-p data_path]
  14. def train_main(args):
  15. cli = graphmix.Client()
  16. meta = cli.meta
  17. hidden_layer_size = args.hidden_size
  18. num_epoch = args.num_epoch
  19. rank = cli.rank()
  20. nrank = cli.num_worker()
  21. ctx = ht.gpu(rank % args.num_local_worker)
  22. embedding_width = args.hidden_size
  23. # the last two is train label and other train mask
  24. num_int_feature = meta["int_feature"] - 2
  25. # sample some graphs
  26. ngraph = meta["train_node"] // (args.batch_size * nrank)
  27. graphs = prepare_data(ngraph)
  28. # build model
  29. if args.dense:
  30. [loss, y, train_op], [mask_, norm_adj_] = dense_model(
  31. meta["float_feature"], args.hidden_size, meta["class"], args.learning_rate)
  32. else:
  33. [loss, y, train_op], [mask_, norm_adj_] = sparse_model(
  34. num_int_feature, args.hidden_size, meta["idx_max"], args.hidden_size, meta["class"], args.learning_rate)
  35. idx = 0
  36. graph = graphs[idx]
  37. idx = (idx + 1) % ngraph
  38. ht.GNNDataLoaderOp.step(graph)
  39. ht.GNNDataLoaderOp.step(graph)
  40. executor = ht.Executor([loss, y, train_op], ctx=ctx)
  41. nbatches = meta["train_node"] // (args.batch_size * nrank)
  42. for epoch in range(num_epoch):
  43. for _ in range(nbatches):
  44. graph_nxt = graphs[idx]
  45. idx = (idx + 1) % ngraph
  46. ht.GNNDataLoaderOp.step(graph_nxt)
  47. train_mask = np.bitwise_and(
  48. graph.extra[:, 0], graph.i_feat[:, -1] == 1)
  49. eval_mask = np.bitwise_and(
  50. graph.extra[:, 0], graph.i_feat[:, -1] != 1)
  51. feed_dict = {
  52. norm_adj_: get_norm_adj(graph, ht.gpu(rank % args.num_local_worker)),
  53. mask_: train_mask
  54. }
  55. loss_val, y_predicted, _ = executor.run(feed_dict=feed_dict)
  56. y_predicted = y_predicted.asnumpy().argmax(axis=1)
  57. acc = np.sum((y_predicted == graph.i_feat[:, -2]) * eval_mask)
  58. train_acc = np.sum(
  59. (y_predicted == graph.i_feat[:, -2]) * train_mask)
  60. stat.update(acc, eval_mask.sum(), np.sum(
  61. loss_val.asnumpy()*eval_mask)/eval_mask.sum())
  62. stat.update_train(train_acc, train_mask.sum(), np.sum(
  63. loss_val.asnumpy()*train_mask)/train_mask.sum())
  64. graph = graph_nxt
  65. stat.print(epoch)
  66. def server_init(server):
  67. batch_size = args.batch_size
  68. server.init_cache(0.1, graphmix.cache.LFUOpt)
  69. worker_per_server = server.num_worker() // server.num_server()
  70. server.add_sampler(graphmix.sampler.GraphSage, batch_size=batch_size,
  71. depth=2, width=2, thread=4 * worker_per_server, subgraph=True)
  72. server.is_ready()
  73. if __name__ == '__main__':
  74. parser = argparse.ArgumentParser()
  75. parser.add_argument("--config", default="config/single.yml")
  76. parser.add_argument("--path", "-p", required=True)
  77. parser.add_argument("--num_epoch", default=300, type=int)
  78. parser.add_argument("--hidden_size", default=128, type=int)
  79. parser.add_argument("--learning_rate", default=1, type=float)
  80. parser.add_argument("--batch_size", default=128, type=int)
  81. parser.add_argument("--dense", action="store_true")
  82. args = parser.parse_args()
  83. stat = SharedTrainingStat()
  84. launch_graphmix_and_hetu_ps(train_main, args, server_init=server_init)