You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.4 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import hetu
  2. import graphmix
  3. import numpy as np
  4. from tqdm import tqdm
  5. def padding(graph, target_num_nodes):
  6. assert graph.num_nodes <= target_num_nodes
  7. graph.convert2coo()
  8. new_graph = graphmix.Graph(graph.edge_index, target_num_nodes)
  9. new_graph.tag = graph.tag
  10. new_graph.type = graph.type
  11. extra = target_num_nodes - graph.num_nodes
  12. new_graph.i_feat = np.concatenate(
  13. [graph.i_feat, np.tile(graph.i_feat[0], [extra, 1])])
  14. new_graph.f_feat = np.concatenate(
  15. [graph.f_feat, np.tile(graph.f_feat[0], [extra, 1])])
  16. if graph.extra.size:
  17. new_graph.extra = np.concatenate([graph.extra, np.zeros([extra, 1])])
  18. return new_graph
  19. def prepare_data(ngraph):
  20. cli = graphmix.Client()
  21. graphs = []
  22. for i in tqdm(range(ngraph)):
  23. query = cli.pull_graph()
  24. graph = cli.wait(query)
  25. graphs.append(graph)
  26. max_num_nodes = 0
  27. for i in range(ngraph):
  28. max_num_nodes = max(max_num_nodes, graphs[i].num_nodes)
  29. for i in range(ngraph):
  30. graphs[i] = padding(graphs[i], max_num_nodes)
  31. return graphs
  32. def get_norm_adj(graph, device, use_original_gcn_norm=False):
  33. norm = graph.gcn_norm(use_original_gcn_norm)
  34. mp_mat = hetu.ndarray.sparse_array(
  35. values=norm,
  36. indices=(graph.edge_index[1], graph.edge_index[0]),
  37. shape=(graph.num_nodes, graph.num_nodes),
  38. ctx=device
  39. )
  40. return mp_mat