from gnn_tools.launcher import launch_graphmix_and_hetu_ps from gnn_tools.log import SharedTrainingStat from gnn_model.utils import get_norm_adj, prepare_data from gnn_model.model import sparse_model, dense_model import graphmix import hetu as ht import numpy as np import time import os import sys import argparse # usage # python3 run_single.py [-p data_path] def train_main(args): cli = graphmix.Client() meta = cli.meta hidden_layer_size = args.hidden_size num_epoch = args.num_epoch rank = cli.rank() nrank = cli.num_worker() ctx = ht.gpu(rank % args.num_local_worker) embedding_width = args.hidden_size # the last two is train label and other train mask num_int_feature = meta["int_feature"] - 2 # sample some graphs ngraph = meta["train_node"] // (args.batch_size * nrank) graphs = prepare_data(ngraph) # build model if args.dense: [loss, y, train_op], [mask_, norm_adj_] = dense_model( meta["float_feature"], args.hidden_size, meta["class"], args.learning_rate) else: [loss, y, train_op], [mask_, norm_adj_] = sparse_model( num_int_feature, args.hidden_size, meta["idx_max"], args.hidden_size, meta["class"], args.learning_rate) idx = 0 graph = graphs[idx] idx = (idx + 1) % ngraph ht.GNNDataLoaderOp.step(graph) ht.GNNDataLoaderOp.step(graph) executor = ht.Executor([loss, y, train_op], ctx=ctx) nbatches = meta["train_node"] // (args.batch_size * nrank) for epoch in range(num_epoch): for _ in range(nbatches): graph_nxt = graphs[idx] idx = (idx + 1) % ngraph ht.GNNDataLoaderOp.step(graph_nxt) train_mask = np.bitwise_and( graph.extra[:, 0], graph.i_feat[:, -1] == 1) eval_mask = np.bitwise_and( graph.extra[:, 0], graph.i_feat[:, -1] != 1) feed_dict = { norm_adj_: get_norm_adj(graph, ht.gpu(rank % args.num_local_worker)), mask_: train_mask } loss_val, y_predicted, _ = executor.run(feed_dict=feed_dict) y_predicted = y_predicted.asnumpy().argmax(axis=1) acc = np.sum((y_predicted == graph.i_feat[:, -2]) * eval_mask) train_acc = np.sum( (y_predicted == graph.i_feat[:, -2]) * train_mask) stat.update(acc, eval_mask.sum(), np.sum( loss_val.asnumpy()*eval_mask)/eval_mask.sum()) stat.update_train(train_acc, train_mask.sum(), np.sum( loss_val.asnumpy()*train_mask)/train_mask.sum()) graph = graph_nxt stat.print(epoch) def server_init(server): batch_size = args.batch_size server.init_cache(0.1, graphmix.cache.LFUOpt) worker_per_server = server.num_worker() // server.num_server() server.add_sampler(graphmix.sampler.GraphSage, batch_size=batch_size, depth=2, width=2, thread=4 * worker_per_server, subgraph=True) server.is_ready() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--config", default="config/single.yml") parser.add_argument("--path", "-p", required=True) parser.add_argument("--num_epoch", default=300, type=int) parser.add_argument("--hidden_size", default=128, type=int) parser.add_argument("--learning_rate", default=1, type=float) parser.add_argument("--batch_size", default=128, type=int) parser.add_argument("--dense", action="store_true") args = parser.parse_args() stat = SharedTrainingStat() launch_graphmix_and_hetu_ps(train_main, args, server_init=server_init)