You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

commonWalkKernel.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. """
  2. @author: linlin
  3. @references:
  4. [1] Thomas Gärtner, Peter Flach, and Stefan Wrobel. On graph kernels:
  5. Hardness results and efficient alternatives. Learning Theory and Kernel
  6. Machines, pages 129–143, 2003.
  7. """
  8. import sys
  9. import time
  10. from tqdm import tqdm
  11. from collections import Counter
  12. from itertools import combinations_with_replacement
  13. from functools import partial
  14. from multiprocessing import Pool
  15. #import traceback
  16. import networkx as nx
  17. import numpy as np
  18. sys.path.insert(0, "../")
  19. from pygraph.utils.utils import direct_product
  20. from pygraph.utils.graphdataset import get_dataset_attributes
  21. def commonwalkkernel(*args,
  22. node_label='atom',
  23. edge_label='bond_type',
  24. n=None,
  25. weight=1,
  26. compute_method=None,
  27. n_jobs=None):
  28. """Calculate common walk graph kernels between graphs.
  29. Parameters
  30. ----------
  31. Gn : List of NetworkX graph
  32. List of graphs between which the kernels are calculated.
  33. /
  34. G1, G2 : NetworkX graphs
  35. 2 graphs between which the kernel is calculated.
  36. node_label : string
  37. node attribute used as label. The default node label is atom.
  38. edge_label : string
  39. edge attribute used as label. The default edge label is bond_type.
  40. n : integer
  41. Longest length of walks. Only useful when applying the 'brute' method.
  42. weight: integer
  43. Weight coefficient of different lengths of walks, which represents beta
  44. in 'exp' method and gamma in 'geo'.
  45. compute_method : string
  46. Method used to compute walk kernel. The Following choices are
  47. available:
  48. 'exp' : exponential serial method applied on the direct product graph,
  49. as shown in reference [1]. The time complexity is O(n^6) for graphs
  50. with n vertices.
  51. 'geo' : geometric serial method applied on the direct product graph, as
  52. shown in reference [1]. The time complexity is O(n^6) for graphs with n
  53. vertices.
  54. 'brute' : brute force, simply search for all walks and compare them.
  55. Return
  56. ------
  57. Kmatrix : Numpy matrix
  58. Kernel matrix, each element of which is a common walk kernel between 2
  59. graphs.
  60. """
  61. compute_method = compute_method.lower()
  62. # arrange all graphs in a list
  63. Gn = args[0] if len(args) == 1 else [args[0], args[1]]
  64. Kmatrix = np.zeros((len(Gn), len(Gn)))
  65. ds_attrs = get_dataset_attributes(
  66. Gn,
  67. attr_names=['node_labeled', 'edge_labeled', 'is_directed'],
  68. node_label=node_label, edge_label=edge_label)
  69. if not ds_attrs['node_labeled']:
  70. for G in Gn:
  71. nx.set_node_attributes(G, '0', 'atom')
  72. if not ds_attrs['edge_labeled']:
  73. for G in Gn:
  74. nx.set_edge_attributes(G, '0', 'bond_type')
  75. if not ds_attrs['is_directed']: # convert
  76. Gn = [G.to_directed() for G in Gn]
  77. start_time = time.time()
  78. # ---- use pool.imap_unordered to parallel and track progress. ----
  79. pool = Pool(n_jobs)
  80. itr = zip(combinations_with_replacement(Gn, 2),
  81. combinations_with_replacement(range(0, len(Gn)), 2))
  82. len_itr = int(len(Gn) * (len(Gn) + 1) / 2)
  83. if len_itr < 1000 * n_jobs:
  84. chunksize = int(len_itr / n_jobs) + 1
  85. else:
  86. chunksize = 1000
  87. # direct product graph method - exponential
  88. if compute_method == 'exp':
  89. do_partial = partial(wrapper_cw_exp, node_label, edge_label, weight)
  90. # direct product graph method - geometric
  91. elif compute_method == 'geo':
  92. do_partial = partial(wrapper_cw_geo, node_label, edge_label, weight)
  93. for i, j, kernel in tqdm(
  94. pool.imap_unordered(do_partial, itr, chunksize),
  95. desc='calculating kernels',
  96. file=sys.stdout):
  97. Kmatrix[i][j] = kernel
  98. Kmatrix[j][i] = kernel
  99. pool.close()
  100. pool.join()
  101. # # ---- direct running, normally use single CPU core. ----
  102. # # direct product graph method - exponential
  103. # itr = combinations_with_replacement(range(0, len(Gn)), 2)
  104. # if compute_method == 'exp':
  105. # for gs in tqdm(itr, desc='calculating kernels', file=sys.stdout):
  106. # i, j, Kmatrix[i][j] = _commonwalkkernel_exp(Gn, node_label,
  107. # edge_label, weight, gs)
  108. # Kmatrix[j][i] = Kmatrix[i][j]
  109. #
  110. # # direct product graph method - geometric
  111. # elif compute_method == 'geo':
  112. # for gs in tqdm(itr, desc='calculating kernels', file=sys.stdout):
  113. # i, j, Kmatrix[i][j] = _commonwalkkernel_geo(Gn, node_label,
  114. # edge_label, weight, gs)
  115. # Kmatrix[j][i] = Kmatrix[i][j]
  116. #
  117. # # search all paths use brute force.
  118. # elif compute_method == 'brute':
  119. # n = int(n)
  120. # # get all paths of all graphs before calculating kernels to save time, but this may cost a lot of memory for large dataset.
  121. # all_walks = [
  122. # find_all_walks_until_length(Gn[i], n, node_label, edge_label)
  123. # for i in range(0, len(Gn))
  124. # ]
  125. #
  126. # for i in range(0, len(Gn)):
  127. # for j in range(i, len(Gn)):
  128. # Kmatrix[i][j] = _commonwalkkernel_brute(
  129. # all_walks[i],
  130. # all_walks[j],
  131. # node_label=node_label,
  132. # edge_label=edge_label)
  133. # Kmatrix[j][i] = Kmatrix[i][j]
  134. run_time = time.time() - start_time
  135. print(
  136. "\n --- kernel matrix of common walk kernel of size %d built in %s seconds ---"
  137. % (len(Gn), run_time))
  138. return Kmatrix, run_time
  139. def _commonwalkkernel_exp(g1, g2, node_label, edge_label, beta):
  140. """Calculate walk graph kernels up to n between 2 graphs using exponential
  141. series.
  142. Parameters
  143. ----------
  144. Gn : List of NetworkX graph
  145. List of graphs between which the kernels are calculated.
  146. node_label : string
  147. Node attribute used as label.
  148. edge_label : string
  149. Edge attribute used as label.
  150. beta : integer
  151. Weight.
  152. ij : tuple of integer
  153. Index of graphs between which the kernel is computed.
  154. Return
  155. ------
  156. kernel : float
  157. The common walk Kernel between 2 graphs.
  158. """
  159. # get tensor product / direct product
  160. gp = direct_product(g1, g2, node_label, edge_label)
  161. A = nx.adjacency_matrix(gp).todense()
  162. # print(A)
  163. # from matplotlib import pyplot as plt
  164. # nx.draw_networkx(G1)
  165. # plt.show()
  166. # nx.draw_networkx(G2)
  167. # plt.show()
  168. # nx.draw_networkx(gp)
  169. # plt.show()
  170. # print(G1.nodes(data=True))
  171. # print(G2.nodes(data=True))
  172. # print(gp.nodes(data=True))
  173. # print(gp.edges(data=True))
  174. ew, ev = np.linalg.eig(A)
  175. # print('ew: ', ew)
  176. # print(ev)
  177. # T = np.matrix(ev)
  178. # print('T: ', T)
  179. # T = ev.I
  180. D = np.zeros((len(ew), len(ew)))
  181. for i in range(len(ew)):
  182. D[i][i] = np.exp(beta * ew[i])
  183. # print('D: ', D)
  184. # print('hshs: ', T.I * D * T)
  185. # print(np.exp(-2))
  186. # print(D)
  187. # print(np.exp(weight * D))
  188. # print(ev)
  189. # print(np.linalg.inv(ev))
  190. exp_D = ev * D * ev.T
  191. # print(exp_D)
  192. # print(np.exp(weight * A))
  193. # print('-------')
  194. return exp_D.sum()
  195. def wrapper_cw_exp(node_label, edge_label, beta, itr_item):
  196. g1 = itr_item[0][0]
  197. g2 = itr_item[0][1]
  198. i = itr_item[1][0]
  199. j = itr_item[1][1]
  200. return i, j, _commonwalkkernel_exp(g1, g2, node_label, edge_label, beta)
  201. def _commonwalkkernel_geo(g1, g2, node_label, edge_label, gamma):
  202. """Calculate common walk graph kernels up to n between 2 graphs using
  203. geometric series.
  204. Parameters
  205. ----------
  206. Gn : List of NetworkX graph
  207. List of graphs between which the kernels are calculated.
  208. node_label : string
  209. Node attribute used as label.
  210. edge_label : string
  211. Edge attribute used as label.
  212. gamma: integer
  213. Weight.
  214. ij : tuple of integer
  215. Index of graphs between which the kernel is computed.
  216. Return
  217. ------
  218. kernel : float
  219. The common walk Kernel between 2 graphs.
  220. """
  221. # get tensor product / direct product
  222. gp = direct_product(g1, g2, node_label, edge_label)
  223. A = nx.adjacency_matrix(gp).todense()
  224. mat = np.identity(len(A)) - gamma * A
  225. try:
  226. return mat.I.sum()
  227. except np.linalg.LinAlgError:
  228. return np.nan
  229. def wrapper_cw_geo(node_label, edge_label, gama, itr_item):
  230. g1 = itr_item[0][0]
  231. g2 = itr_item[0][1]
  232. i = itr_item[1][0]
  233. j = itr_item[1][1]
  234. return i, j, _commonwalkkernel_geo(g1, g2, node_label, edge_label, gama)
  235. def _commonwalkkernel_brute(walks1,
  236. walks2,
  237. node_label='atom',
  238. edge_label='bond_type',
  239. labeled=True):
  240. """Calculate walk graph kernels up to n between 2 graphs.
  241. Parameters
  242. ----------
  243. walks1, walks2 : list
  244. List of walks in 2 graphs, where for unlabeled graphs, each walk is
  245. represented by a list of nodes; while for labeled graphs, each walk is
  246. represented by a string consists of labels of nodes and edges on that
  247. walk.
  248. node_label : string
  249. node attribute used as label. The default node label is atom.
  250. edge_label : string
  251. edge attribute used as label. The default edge label is bond_type.
  252. labeled : boolean
  253. Whether the graphs are labeled. The default is True.
  254. Return
  255. ------
  256. kernel : float
  257. Treelet Kernel between 2 graphs.
  258. """
  259. counts_walks1 = dict(Counter(walks1))
  260. counts_walks2 = dict(Counter(walks2))
  261. all_walks = list(set(walks1 + walks2))
  262. vector1 = [(counts_walks1[walk] if walk in walks1 else 0)
  263. for walk in all_walks]
  264. vector2 = [(counts_walks2[walk] if walk in walks2 else 0)
  265. for walk in all_walks]
  266. kernel = np.dot(vector1, vector2)
  267. return kernel
  268. # this method find walks repetively, it could be faster.
  269. def find_all_walks_until_length(G,
  270. length,
  271. node_label='atom',
  272. edge_label='bond_type',
  273. labeled=True):
  274. """Find all walks with a certain maximum length in a graph.
  275. A recursive depth first search is applied.
  276. Parameters
  277. ----------
  278. G : NetworkX graphs
  279. The graph in which walks are searched.
  280. length : integer
  281. The maximum length of walks.
  282. node_label : string
  283. node attribute used as label. The default node label is atom.
  284. edge_label : string
  285. edge attribute used as label. The default edge label is bond_type.
  286. labeled : boolean
  287. Whether the graphs are labeled. The default is True.
  288. Return
  289. ------
  290. walk : list
  291. List of walks retrieved, where for unlabeled graphs, each walk is
  292. represented by a list of nodes; while for labeled graphs, each walk
  293. is represented by a string consists of labels of nodes and edges on
  294. that walk.
  295. """
  296. all_walks = []
  297. # @todo: in this way, the time complexity is close to N(d^n+d^(n+1)+...+1), which could be optimized to O(Nd^n)
  298. for i in range(0, length + 1):
  299. new_walks = find_all_walks(G, i)
  300. if new_walks == []:
  301. break
  302. all_walks.extend(new_walks)
  303. if labeled == True: # convert paths to strings
  304. walk_strs = []
  305. for walk in all_walks:
  306. strlist = [
  307. G.node[node][node_label] +
  308. G[node][walk[walk.index(node) + 1]][edge_label]
  309. for node in walk[:-1]
  310. ]
  311. walk_strs.append(''.join(strlist) + G.node[walk[-1]][node_label])
  312. return walk_strs
  313. return all_walks
  314. def find_walks(G, source_node, length):
  315. """Find all walks with a certain length those start from a source node. A
  316. recursive depth first search is applied.
  317. Parameters
  318. ----------
  319. G : NetworkX graphs
  320. The graph in which walks are searched.
  321. source_node : integer
  322. The number of the node from where all walks start.
  323. length : integer
  324. The length of walks.
  325. Return
  326. ------
  327. walk : list of list
  328. List of walks retrieved, where each walk is represented by a list of
  329. nodes.
  330. """
  331. return [[source_node]] if length == 0 else \
  332. [[source_node] + walk for neighbor in G[source_node]
  333. for walk in find_walks(G, neighbor, length - 1)]
  334. def find_all_walks(G, length):
  335. """Find all walks with a certain length in a graph. A recursive depth first
  336. search is applied.
  337. Parameters
  338. ----------
  339. G : NetworkX graphs
  340. The graph in which walks are searched.
  341. length : integer
  342. The length of walks.
  343. Return
  344. ------
  345. walk : list of list
  346. List of walks retrieved, where each walk is represented by a list of
  347. nodes.
  348. """
  349. all_walks = []
  350. for node in G:
  351. all_walks.extend(find_walks(G, node, length))
  352. # The following process is not carried out according to the original article
  353. # all_paths_r = [ path[::-1] for path in all_paths ]
  354. # # For each path, two presentation are retrieved from its two extremities. Remove one of them.
  355. # for idx, path in enumerate(all_paths[:-1]):
  356. # for path2 in all_paths_r[idx+1::]:
  357. # if path == path2:
  358. # all_paths[idx] = []
  359. # break
  360. # return list(filter(lambda a: a != [], all_paths))
  361. return all_walks

A Python package for graph kernels, graph edit distances and graph pre-image problem.