You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layer.py 2.6 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import hetu as ht
  2. from hetu import init
  3. class GCN(object):
  4. def __init__(self, in_features, out_features, norm_adj, activation=None, dropout=0,
  5. name="GCN", custom_init=None):
  6. if custom_init is not None:
  7. self.weight = ht.Variable(
  8. value=custom_init[0], name=name+"_Weight")
  9. self.bias = ht.Variable(value=custom_init[1], name=name+"_Bias")
  10. else:
  11. self.weight = init.xavier_uniform(
  12. shape=(in_features, out_features), name=name+"_Weight")
  13. self.bias = init.zeros(shape=(out_features,), name=name+"_Bias")
  14. # self.mp is a sparse matrix and should appear in feed_dict later
  15. self.mp = norm_adj
  16. self.activation = activation
  17. self.dropout = dropout
  18. self.output_width = out_features
  19. def __call__(self, x):
  20. """
  21. Build the computation graph, return the output node
  22. """
  23. if self.dropout > 0:
  24. x = ht.dropout_op(x, 1 - self.dropout)
  25. x = ht.matmul_op(x, self.weight)
  26. msg = x + ht.broadcastto_op(self.bias, x)
  27. x = ht.csrmm_op(self.mp, msg)
  28. if self.activation == "relu":
  29. x = ht.relu_op(x)
  30. elif self.activation is not None:
  31. raise NotImplementedError
  32. return x
  33. class SageConv(object):
  34. def __init__(self, in_features, out_features, norm_adj, activation=None, dropout=0,
  35. name="GCN", custom_init=None, mp_val=None):
  36. self.weight = init.xavier_uniform(
  37. shape=(in_features, out_features), name=name+"_Weight")
  38. self.bias = init.zeros(shape=(out_features,), name=name+"_Bias")
  39. self.weight2 = init.xavier_uniform(
  40. shape=(in_features, out_features), name=name+"_Weight")
  41. # self.mp is a sparse matrix and should appear in feed_dict later
  42. self.mp = norm_adj
  43. self.activation = activation
  44. self.dropout = dropout
  45. self.output_width = 2 * out_features
  46. def __call__(self, x):
  47. """
  48. Build the computation graph, return the output node
  49. """
  50. feat = x
  51. if self.dropout > 0:
  52. x = ht.dropout_op(x, 1 - self.dropout)
  53. x = ht.csrmm_op(self.mp, x)
  54. x = ht.matmul_op(x, self.weight)
  55. x = x + ht.broadcastto_op(self.bias, x)
  56. if self.activation == "relu":
  57. x = ht.relu_op(x)
  58. elif self.activation is not None:
  59. raise NotImplementedError
  60. return ht.concat_op(x, ht.matmul_op(feat, self.weight2), axis=1)