|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- import hetu as ht
- from hetu import init
-
-
- class GCN(object):
- def __init__(self, in_features, out_features, norm_adj, activation=None, dropout=0,
- name="GCN", custom_init=None):
- if custom_init is not None:
- self.weight = ht.Variable(
- value=custom_init[0], name=name+"_Weight")
- self.bias = ht.Variable(value=custom_init[1], name=name+"_Bias")
- else:
- self.weight = init.xavier_uniform(
- shape=(in_features, out_features), name=name+"_Weight")
- self.bias = init.zeros(shape=(out_features,), name=name+"_Bias")
- # self.mp is a sparse matrix and should appear in feed_dict later
- self.mp = norm_adj
- self.activation = activation
- self.dropout = dropout
- self.output_width = out_features
-
- def __call__(self, x):
- """
- Build the computation graph, return the output node
- """
- if self.dropout > 0:
- x = ht.dropout_op(x, 1 - self.dropout)
- x = ht.matmul_op(x, self.weight)
- msg = x + ht.broadcastto_op(self.bias, x)
- x = ht.csrmm_op(self.mp, msg)
- if self.activation == "relu":
- x = ht.relu_op(x)
- elif self.activation is not None:
- raise NotImplementedError
- return x
-
-
- class SageConv(object):
- def __init__(self, in_features, out_features, norm_adj, activation=None, dropout=0,
- name="GCN", custom_init=None, mp_val=None):
-
- self.weight = init.xavier_uniform(
- shape=(in_features, out_features), name=name+"_Weight")
- self.bias = init.zeros(shape=(out_features,), name=name+"_Bias")
- self.weight2 = init.xavier_uniform(
- shape=(in_features, out_features), name=name+"_Weight")
- # self.mp is a sparse matrix and should appear in feed_dict later
- self.mp = norm_adj
- self.activation = activation
- self.dropout = dropout
- self.output_width = 2 * out_features
-
- def __call__(self, x):
- """
- Build the computation graph, return the output node
- """
- feat = x
- if self.dropout > 0:
- x = ht.dropout_op(x, 1 - self.dropout)
-
- x = ht.csrmm_op(self.mp, x)
- x = ht.matmul_op(x, self.weight)
- x = x + ht.broadcastto_op(self.bias, x)
- if self.activation == "relu":
- x = ht.relu_op(x)
- elif self.activation is not None:
- raise NotImplementedError
- return ht.concat_op(x, ht.matmul_op(feat, self.weight2), axis=1)
|