You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

spKernel.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. """
  2. @author: linlin
  3. @references: Borgwardt KM, Kriegel HP. Shortest-path kernels on graphs. InData Mining, Fifth IEEE International Conference on 2005 Nov 27 (pp. 8-pp). IEEE.
  4. """
  5. import sys
  6. import pathlib
  7. sys.path.insert(0, "../")
  8. from tqdm import tqdm
  9. import time
  10. from itertools import combinations_with_replacement, product
  11. import networkx as nx
  12. import numpy as np
  13. from pygraph.utils.utils import getSPGraph
  14. from pygraph.utils.graphdataset import get_dataset_attributes
  15. def spkernel(*args, node_label='atom', edge_weight=None, node_kernels=None):
  16. """Calculate shortest-path kernels between graphs.
  17. Parameters
  18. ----------
  19. Gn : List of NetworkX graph
  20. List of graphs between which the kernels are calculated.
  21. /
  22. G1, G2 : NetworkX graphs
  23. 2 graphs between which the kernel is calculated.
  24. edge_weight : string
  25. Edge attribute name corresponding to the edge weight.
  26. node_kernels: dict
  27. A dictionary of kernel functions for nodes, including 3 items: 'symb' for symbolic node labels, 'nsymb' for non-symbolic node labels, 'mix' for both labels. The first 2 functions take two node labels as parameters, and the 'mix' function takes 4 parameters, a symbolic and a non-symbolic label for each the two nodes. Each label is in form of 2-D dimension array (n_samples, n_features). Each function returns an number as the kernel value. Ignored when nodes are unlabeled.
  28. Return
  29. ------
  30. Kmatrix : Numpy matrix
  31. Kernel matrix, each element of which is the sp kernel between 2 praphs.
  32. """
  33. # pre-process
  34. Gn = args[0] if len(args) == 1 else [args[0], args[1]]
  35. weight = None
  36. if edge_weight == None:
  37. print('\n None edge weight specified. Set all weight to 1.\n')
  38. else:
  39. try:
  40. some_weight = list(
  41. nx.get_edge_attributes(Gn[0], edge_weight).values())[0]
  42. if isinstance(some_weight, float) or isinstance(some_weight, int):
  43. weight = edge_weight
  44. else:
  45. print(
  46. '\n Edge weight with name %s is not float or integer. Set all weight to 1.\n'
  47. % edge_weight)
  48. except:
  49. print(
  50. '\n Edge weight with name "%s" is not found in the edge attributes. Set all weight to 1.\n'
  51. % edge_weight)
  52. ds_attrs = get_dataset_attributes(
  53. Gn,
  54. attr_names=['node_labeled', 'node_attr_dim', 'is_directed'],
  55. node_label=node_label)
  56. # remove graphs with no edges, as no sp can be found in their structures, so the kernel between such a graph and itself will be zero.
  57. len_gn = len(Gn)
  58. Gn = [(idx, G) for idx, G in enumerate(Gn) if nx.number_of_edges(G) != 0]
  59. idx = [G[0] for G in Gn]
  60. Gn = [G[1] for G in Gn]
  61. if len(Gn) != len_gn:
  62. print('\n %d graphs are removed as they don\'t contain edges.\n' %
  63. (len_gn - len(Gn)))
  64. start_time = time.time()
  65. # get shortest path graphs of Gn
  66. Gn = [
  67. getSPGraph(G, edge_weight=edge_weight)
  68. for G in tqdm(Gn, desc='getting sp graphs', file=sys.stdout)
  69. ]
  70. Kmatrix = np.zeros((len(Gn), len(Gn)))
  71. pbar = tqdm(
  72. total=((len(Gn) + 1) * len(Gn) / 2),
  73. desc='calculating kernels',
  74. file=sys.stdout)
  75. if ds_attrs['node_labeled']:
  76. # node symb and non-synb labeled
  77. if ds_attrs['node_attr_dim'] > 0:
  78. if ds_attrs['is_directed']:
  79. for i, j in combinations_with_replacement(
  80. range(0, len(Gn)), 2):
  81. for e1, e2 in product(
  82. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  83. if e1[2]['cost'] == e2[2]['cost']:
  84. kn = node_kernels['mix']
  85. try:
  86. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  87. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  88. j].nodes[e2[1]]
  89. kn1 = kn(n11[node_label], n21[node_label], [
  90. n11['attributes']
  91. ], [n21['attributes']]) * kn(
  92. n12[node_label], n22[node_label],
  93. [n12['attributes']], [n22['attributes']])
  94. Kmatrix[i][j] += kn1
  95. except KeyError: # missing labels or attributes
  96. pass
  97. Kmatrix[j][i] = Kmatrix[i][j]
  98. pbar.update(1)
  99. else:
  100. for i, j in combinations_with_replacement(
  101. range(0, len(Gn)), 2):
  102. for e1, e2 in product(
  103. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  104. if e1[2]['cost'] == e2[2]['cost']:
  105. kn = node_kernels['mix']
  106. try:
  107. # each edge walk is counted twice, starting from both its extreme nodes.
  108. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  109. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  110. j].nodes[e2[1]]
  111. kn1 = kn(n11[node_label], n21[node_label], [
  112. n11['attributes']
  113. ], [n21['attributes']]) * kn(
  114. n12[node_label], n22[node_label],
  115. [n12['attributes']], [n22['attributes']])
  116. kn2 = kn(n11[node_label], n22[node_label], [
  117. n11['attributes']
  118. ], [n22['attributes']]) * kn(
  119. n12[node_label], n21[node_label],
  120. [n12['attributes']], [n21['attributes']])
  121. Kmatrix[i][j] += kn1 + kn2
  122. except KeyError: # missing labels or attributes
  123. pass
  124. Kmatrix[j][i] = Kmatrix[i][j]
  125. pbar.update(1)
  126. # node symb labeled
  127. else:
  128. if ds_attrs['is_directed']:
  129. for i, j in combinations_with_replacement(
  130. range(0, len(Gn)), 2):
  131. for e1, e2 in product(
  132. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  133. if e1[2]['cost'] == e2[2]['cost']:
  134. kn = node_kernels['symb']
  135. try:
  136. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  137. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  138. j].nodes[e2[1]]
  139. kn1 = kn(n11[node_label],
  140. n21[node_label]) * kn(
  141. n12[node_label], n22[node_label])
  142. Kmatrix[i][j] += kn1
  143. except KeyError: # missing labels
  144. pass
  145. Kmatrix[j][i] = Kmatrix[i][j]
  146. pbar.update(1)
  147. else:
  148. for i, j in combinations_with_replacement(
  149. range(0, len(Gn)), 2):
  150. for e1, e2 in product(
  151. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  152. if e1[2]['cost'] == e2[2]['cost']:
  153. kn = node_kernels['symb']
  154. try:
  155. # each edge walk is counted twice, starting from both its extreme nodes.
  156. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  157. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  158. j].nodes[e2[1]]
  159. kn1 = kn(n11[node_label],
  160. n21[node_label]) * kn(
  161. n12[node_label], n22[node_label])
  162. kn2 = kn(n11[node_label],
  163. n22[node_label]) * kn(
  164. n12[node_label], n21[node_label])
  165. Kmatrix[i][j] += kn1 + kn2
  166. except KeyError: # missing labels
  167. pass
  168. Kmatrix[j][i] = Kmatrix[i][j]
  169. pbar.update(1)
  170. else:
  171. # node non-synb labeled
  172. if ds_attrs['node_attr_dim'] > 0:
  173. if ds_attrs['is_directed']:
  174. for i, j in combinations_with_replacement(
  175. range(0, len(Gn)), 2):
  176. for e1, e2 in product(
  177. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  178. if e1[2]['cost'] == e2[2]['cost']:
  179. kn = node_kernels['nsymb']
  180. try:
  181. # each edge walk is counted twice, starting from both its extreme nodes.
  182. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  183. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  184. j].nodes[e2[1]]
  185. kn1 = kn([n11['attributes']],
  186. [n21['attributes']]) * kn(
  187. [n12['attributes']],
  188. [n22['attributes']])
  189. Kmatrix[i][j] += kn1
  190. except KeyError: # missing attributes
  191. pass
  192. Kmatrix[j][i] = Kmatrix[i][j]
  193. pbar.update(1)
  194. else:
  195. for i, j in combinations_with_replacement(
  196. range(0, len(Gn)), 2):
  197. for e1, e2 in product(
  198. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  199. if e1[2]['cost'] == e2[2]['cost']:
  200. kn = node_kernels['nsymb']
  201. try:
  202. # each edge walk is counted twice, starting from both its extreme nodes.
  203. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  204. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  205. j].nodes[e2[1]]
  206. kn1 = kn([n11['attributes']],
  207. [n21['attributes']]) * kn(
  208. [n12['attributes']],
  209. [n22['attributes']])
  210. kn2 = kn([n11['attributes']],
  211. [n22['attributes']]) * kn(
  212. [n12['attributes']],
  213. [n21['attributes']])
  214. Kmatrix[i][j] += kn1 + kn2
  215. except KeyError: # missing attributes
  216. pass
  217. Kmatrix[j][i] = Kmatrix[i][j]
  218. pbar.update(1)
  219. # node unlabeled
  220. else:
  221. for i, j in combinations_with_replacement(range(0, len(Gn)), 2):
  222. for e1, e2 in product(
  223. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  224. if e1[2]['cost'] == e2[2]['cost']:
  225. Kmatrix[i][j] += 1
  226. Kmatrix[j][i] = Kmatrix[i][j]
  227. pbar.update(1)
  228. run_time = time.time() - start_time
  229. print(
  230. "\n --- shortest path kernel matrix of size %d built in %s seconds ---"
  231. % (len(Gn), run_time))
  232. return Kmatrix, run_time, idx

A Python package for graph kernels, graph edit distances and graph pre-image problem.