You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

spKernel.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. """
  2. @author: linlin
  3. @references: Borgwardt KM, Kriegel HP. Shortest-path kernels on graphs. InData Mining, Fifth IEEE International Conference on 2005 Nov 27 (pp. 8-pp). IEEE.
  4. """
  5. import sys
  6. import pathlib
  7. sys.path.insert(0, "../")
  8. from tqdm import tqdm
  9. import time
  10. import networkx as nx
  11. import numpy as np
  12. from pygraph.utils.utils import getSPGraph
  13. from pygraph.utils.graphdataset import get_dataset_attributes
  14. def spkernel(*args, node_label='atom', edge_weight=None, node_kernels=None):
  15. """Calculate shortest-path kernels between graphs.
  16. Parameters
  17. ----------
  18. Gn : List of NetworkX graph
  19. List of graphs between which the kernels are calculated.
  20. /
  21. G1, G2 : NetworkX graphs
  22. 2 graphs between which the kernel is calculated.
  23. edge_weight : string
  24. Edge attribute name corresponding to the edge weight.
  25. node_kernels: dict
  26. A dictionary of kernel functions for nodes, including 3 items: 'symb' for symbolic node labels, 'nsymb' for non-symbolic node labels, 'mix' for both labels. The first 2 functions take two node labels as parameters, and the 'mix' function takes 4 parameters, a symbolic and a non-symbolic label for each the two nodes. Each label is in form of 2-D dimension array (n_samples, n_features). Each function returns an number as the kernel value. Ignored when nodes are unlabeled.
  27. Return
  28. ------
  29. Kmatrix : Numpy matrix
  30. Kernel matrix, each element of which is the sp kernel between 2 praphs.
  31. """
  32. # pre-process
  33. Gn = args[0] if len(args) == 1 else [args[0], args[1]]
  34. Gn = [nx.to_directed(G) for G in Gn]
  35. weight = None
  36. if edge_weight == None:
  37. print('\n None edge weight specified. Set all weight to 1.\n')
  38. else:
  39. try:
  40. some_weight = list(
  41. nx.get_edge_attributes(Gn[0], edge_weight).values())[0]
  42. if isinstance(some_weight, float) or isinstance(some_weight, int):
  43. weight = edge_weight
  44. else:
  45. print(
  46. '\n Edge weight with name %s is not float or integer. Set all weight to 1.\n'
  47. % edge_weight)
  48. except:
  49. print(
  50. '\n Edge weight with name "%s" is not found in the edge attributes. Set all weight to 1.\n'
  51. % edge_weight)
  52. ds_attrs = get_dataset_attributes(
  53. Gn,
  54. attr_names=['node_labeled', 'node_attr_dim', 'is_directed'],
  55. node_label=node_label)
  56. # remove graphs with no edges, as no sp can be found in their structures, so the kernel between such a graph and itself will be zero.
  57. len_gn = len(Gn)
  58. Gn = [(idx, G) for idx, G in enumerate(Gn) if nx.number_of_edges(G) != 0]
  59. idx = [G[0] for G in Gn]
  60. Gn = [G[1] for G in Gn]
  61. if len(Gn) != len_gn:
  62. print('\n %d graphs are removed as they don\'t contain edges.\n' %
  63. (len_gn - len(Gn)))
  64. start_time = time.time()
  65. # get shortest path graphs of Gn
  66. Gn = [
  67. getSPGraph(G, edge_weight=edge_weight)
  68. for G in tqdm(Gn, desc='getting sp graphs', file=sys.stdout)
  69. ]
  70. Kmatrix = np.zeros((len(Gn), len(Gn)))
  71. pbar = tqdm(
  72. total=((len(Gn) + 1) * len(Gn) / 2),
  73. desc='calculating kernels',
  74. file=sys.stdout)
  75. if ds_attrs['node_labeled']:
  76. # node symb and non-synb labeled
  77. if ds_attrs['node_attr_dim'] > 0:
  78. if ds_attrs['is_directed']:
  79. for i in range(0, len(Gn)):
  80. for j in range(i, len(Gn)):
  81. for e1 in Gn[i].edges(data=True):
  82. for e2 in Gn[j].edges(data=True):
  83. if e1[2]['cost'] == e2[2]['cost']:
  84. kn = node_kernels['mix']
  85. try:
  86. n11, n12, n21, n22 = Gn[i].nodes[e1[
  87. 0]], Gn[i].nodes[e1[1]], Gn[
  88. j].nodes[e2[0]], Gn[j].nodes[
  89. e2[1]]
  90. kn1 = kn(n11[node_label], n21[
  91. node_label], [n11['attributes']],
  92. [n21['attributes']]) * kn(
  93. n12[node_label],
  94. n22[node_label],
  95. [n12['attributes']],
  96. [n22['attributes']])
  97. Kmatrix[i][j] += kn1
  98. except KeyError: # missing labels or attributes
  99. pass
  100. Kmatrix[j][i] = Kmatrix[i][j]
  101. pbar.update(1)
  102. else:
  103. for i in range(0, len(Gn)):
  104. for j in range(i, len(Gn)):
  105. for e1 in Gn[i].edges(data=True):
  106. for e2 in Gn[j].edges(data=True):
  107. if e1[2]['cost'] == e2[2]['cost']:
  108. kn = node_kernels['mix']
  109. try:
  110. # each edge walk is counted twice, starting from both its extreme nodes.
  111. n11, n12, n21, n22 = Gn[i].nodes[e1[
  112. 0]], Gn[i].nodes[e1[1]], Gn[
  113. j].nodes[e2[0]], Gn[j].nodes[
  114. e2[1]]
  115. kn1 = kn(n11[node_label], n21[
  116. node_label], [n11['attributes']],
  117. [n21['attributes']]) * kn(
  118. n12[node_label],
  119. n22[node_label],
  120. [n12['attributes']],
  121. [n22['attributes']])
  122. kn2 = kn(n11[node_label], n22[
  123. node_label], [n11['attributes']],
  124. [n22['attributes']]) * kn(
  125. n12[node_label],
  126. n21[node_label],
  127. [n12['attributes']],
  128. [n21['attributes']])
  129. Kmatrix[i][j] += kn1 + kn2
  130. except KeyError: # missing labels or attributes
  131. pass
  132. Kmatrix[j][i] = Kmatrix[i][j]
  133. pbar.update(1)
  134. # node symb labeled
  135. else:
  136. if ds_attrs['is_directed']:
  137. for i in range(0, len(Gn)):
  138. for j in range(i, len(Gn)):
  139. for e1 in Gn[i].edges(data=True):
  140. for e2 in Gn[j].edges(data=True):
  141. if e1[2]['cost'] == e2[2]['cost']:
  142. kn = node_kernels['symb']
  143. try:
  144. n11, n12, n21, n22 = Gn[i].nodes[e1[
  145. 0]], Gn[i].nodes[e1[1]], Gn[
  146. j].nodes[e2[0]], Gn[j].nodes[
  147. e2[1]]
  148. kn1 = kn(n11[node_label],
  149. n21[node_label]) * kn(
  150. n12[node_label],
  151. n22[node_label])
  152. Kmatrix[i][j] += kn1
  153. except KeyError: # missing labels
  154. pass
  155. Kmatrix[j][i] = Kmatrix[i][j]
  156. pbar.update(1)
  157. else:
  158. for i in range(0, len(Gn)):
  159. for j in range(i, len(Gn)):
  160. for e1 in Gn[i].edges(data=True):
  161. for e2 in Gn[j].edges(data=True):
  162. if e1[2]['cost'] == e2[2]['cost']:
  163. kn = node_kernels['symb']
  164. try:
  165. # each edge walk is counted twice, starting from both its extreme nodes.
  166. n11, n12, n21, n22 = Gn[i].nodes[e1[
  167. 0]], Gn[i].nodes[e1[1]], Gn[
  168. j].nodes[e2[0]], Gn[j].nodes[
  169. e2[1]]
  170. kn1 = kn(n11[node_label],
  171. n21[node_label]) * kn(
  172. n12[node_label],
  173. n22[node_label])
  174. kn2 = kn(n11[node_label],
  175. n22[node_label]) * kn(
  176. n12[node_label],
  177. n21[node_label])
  178. Kmatrix[i][j] += kn1 + kn2
  179. except KeyError: # missing labels
  180. pass
  181. Kmatrix[j][i] = Kmatrix[i][j]
  182. pbar.update(1)
  183. else:
  184. # node non-synb labeled
  185. if ds_attrs['node_attr_dim'] > 0:
  186. if ds_attrs['is_directed']:
  187. for i in range(0, len(Gn)):
  188. for j in range(i, len(Gn)):
  189. for e1 in Gn[i].edges(data=True):
  190. for e2 in Gn[j].edges(data=True):
  191. if e1[2]['cost'] == e2[2]['cost']:
  192. kn = node_kernels['nsymb']
  193. try:
  194. # each edge walk is counted twice, starting from both its extreme nodes.
  195. n11, n12, n21, n22 = Gn[i].nodes[e1[
  196. 0]], Gn[i].nodes[e1[1]], Gn[
  197. j].nodes[e2[0]], Gn[j].nodes[
  198. e2[1]]
  199. kn1 = kn([n11['attributes']],
  200. [n21['attributes']]) * kn(
  201. [n12['attributes']],
  202. [n22['attributes']])
  203. Kmatrix[i][j] += kn1
  204. except KeyError: # missing attributes
  205. pass
  206. Kmatrix[j][i] = Kmatrix[i][j]
  207. pbar.update(1)
  208. else:
  209. for i in range(0, len(Gn)):
  210. for j in range(i, len(Gn)):
  211. for e1 in Gn[i].edges(data=True):
  212. for e2 in Gn[j].edges(data=True):
  213. if e1[2]['cost'] == e2[2]['cost']:
  214. kn = node_kernels['nsymb']
  215. try:
  216. # each edge walk is counted twice, starting from both its extreme nodes.
  217. n11, n12, n21, n22 = Gn[i].nodes[e1[
  218. 0]], Gn[i].nodes[e1[1]], Gn[
  219. j].nodes[e2[0]], Gn[j].nodes[
  220. e2[1]]
  221. kn1 = kn([n11['attributes']],
  222. [n21['attributes']]) * kn(
  223. [n12['attributes']],
  224. [n22['attributes']])
  225. kn2 = kn([n11['attributes']],
  226. [n22['attributes']]) * kn(
  227. [n12['attributes']],
  228. [n21['attributes']])
  229. Kmatrix[i][j] += kn1 + kn2
  230. except KeyError: # missing attributes
  231. pass
  232. Kmatrix[j][i] = Kmatrix[i][j]
  233. pbar.update(1)
  234. # node unlabeled
  235. else:
  236. for i in range(0, len(Gn)):
  237. for j in range(i, len(Gn)):
  238. for e1 in Gn[i].edges(data=True):
  239. for e2 in Gn[j].edges(data=True):
  240. if e1[2]['cost'] == e2[2]['cost']:
  241. Kmatrix[i][j] += 1
  242. Kmatrix[j][i] = Kmatrix[i][j]
  243. pbar.update(1)
  244. run_time = time.time() - start_time
  245. print(
  246. "\n --- shortest path kernel matrix of size %d built in %s seconds ---"
  247. % (len(Gn), run_time))
  248. return Kmatrix, run_time, idx

A Python package for graph kernels, graph edit distances and graph pre-image problem.