|
- import hetu
- import graphmix
- import numpy as np
- from tqdm import tqdm
-
-
- def padding(graph, target_num_nodes):
- assert graph.num_nodes <= target_num_nodes
- graph.convert2coo()
- new_graph = graphmix.Graph(graph.edge_index, target_num_nodes)
- new_graph.tag = graph.tag
- new_graph.type = graph.type
- extra = target_num_nodes - graph.num_nodes
- new_graph.i_feat = np.concatenate(
- [graph.i_feat, np.tile(graph.i_feat[0], [extra, 1])])
- new_graph.f_feat = np.concatenate(
- [graph.f_feat, np.tile(graph.f_feat[0], [extra, 1])])
- if graph.extra.size:
- new_graph.extra = np.concatenate([graph.extra, np.zeros([extra, 1])])
- return new_graph
-
-
- def prepare_data(ngraph):
- cli = graphmix.Client()
- graphs = []
- for i in tqdm(range(ngraph)):
- query = cli.pull_graph()
- graph = cli.wait(query)
- graphs.append(graph)
- max_num_nodes = 0
- for i in range(ngraph):
- max_num_nodes = max(max_num_nodes, graphs[i].num_nodes)
- for i in range(ngraph):
- graphs[i] = padding(graphs[i], max_num_nodes)
- return graphs
-
-
- def get_norm_adj(graph, device, use_original_gcn_norm=False):
- norm = graph.gcn_norm(use_original_gcn_norm)
- mp_mat = hetu.ndarray.sparse_array(
- values=norm,
- indices=(graph.edge_index[1], graph.edge_index[0]),
- shape=(graph.num_nodes, graph.num_nodes),
- ctx=device
- )
- return mp_mat
|