You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

treeletKernel.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. """
  2. @author: linlin
  3. @references: Gaüzère B, Brun L, Villemin D. Two new graphs kernels in chemoinformatics. Pattern Recognition Letters. 2012 Nov 1;33(15):2038-47.
  4. """
  5. import sys
  6. sys.path.insert(0, "../")
  7. import time
  8. from collections import Counter
  9. from itertools import chain
  10. from functools import partial
  11. from multiprocessing import Pool
  12. from tqdm import tqdm
  13. import networkx as nx
  14. import numpy as np
  15. from pygraph.utils.graphdataset import get_dataset_attributes
  16. from pygraph.utils.parallel import parallel_gm
  17. def treeletkernel(*args,
  18. sub_kernel,
  19. node_label='atom',
  20. edge_label='bond_type',
  21. n_jobs=None,
  22. verbose=True):
  23. """Calculate treelet graph kernels between graphs.
  24. Parameters
  25. ----------
  26. Gn : List of NetworkX graph
  27. List of graphs between which the kernels are calculated.
  28. /
  29. G1, G2 : NetworkX graphs
  30. Two graphs between which the kernel is calculated.
  31. sub_kernel : function
  32. The sub-kernel between 2 real number vectors. Each vector counts the
  33. numbers of isomorphic treelets in a graph.
  34. node_label : string
  35. Node attribute used as label. The default node label is atom.
  36. edge_label : string
  37. Edge attribute used as label. The default edge label is bond_type.
  38. labeled : boolean
  39. Whether the graphs are labeled. The default is True.
  40. Return
  41. ------
  42. Kmatrix : Numpy matrix
  43. Kernel matrix, each element of which is the treelet kernel between 2 praphs.
  44. """
  45. # pre-process
  46. Gn = args[0] if len(args) == 1 else [args[0], args[1]]
  47. Kmatrix = np.zeros((len(Gn), len(Gn)))
  48. ds_attrs = get_dataset_attributes(Gn,
  49. attr_names=['node_labeled', 'edge_labeled', 'is_directed'],
  50. node_label=node_label, edge_label=edge_label)
  51. labeled = False
  52. if ds_attrs['node_labeled'] or ds_attrs['edge_labeled']:
  53. labeled = True
  54. if not ds_attrs['node_labeled']:
  55. for G in Gn:
  56. nx.set_node_attributes(G, '0', 'atom')
  57. if not ds_attrs['edge_labeled']:
  58. for G in Gn:
  59. nx.set_edge_attributes(G, '0', 'bond_type')
  60. start_time = time.time()
  61. # ---- use pool.imap_unordered to parallel and track progress. ----
  62. # get all canonical keys of all graphs before calculating kernels to save
  63. # time, but this may cost a lot of memory for large dataset.
  64. pool = Pool(n_jobs)
  65. itr = zip(Gn, range(0, len(Gn)))
  66. if len(Gn) < 100 * n_jobs:
  67. chunksize = int(len(Gn) / n_jobs) + 1
  68. else:
  69. chunksize = 100
  70. canonkeys = [[] for _ in range(len(Gn))]
  71. getps_partial = partial(wrapper_get_canonkeys, node_label, edge_label,
  72. labeled, ds_attrs['is_directed'])
  73. if verbose:
  74. iterator = tqdm(pool.imap_unordered(getps_partial, itr, chunksize),
  75. desc='getting canonkeys', file=sys.stdout)
  76. else:
  77. iterator = pool.imap_unordered(getps_partial, itr, chunksize)
  78. for i, ck in iterator:
  79. canonkeys[i] = ck
  80. pool.close()
  81. pool.join()
  82. # compute kernels.
  83. def init_worker(canonkeys_toshare):
  84. global G_canonkeys
  85. G_canonkeys = canonkeys_toshare
  86. do_partial = partial(wrapper_treeletkernel_do, sub_kernel)
  87. parallel_gm(do_partial, Kmatrix, Gn, init_worker=init_worker,
  88. glbv=(canonkeys,), n_jobs=n_jobs, verbose=verbose)
  89. run_time = time.time() - start_time
  90. if verbose:
  91. print("\n --- treelet kernel matrix of size %d built in %s seconds ---"
  92. % (len(Gn), run_time))
  93. return Kmatrix, run_time
  94. def _treeletkernel_do(canonkey1, canonkey2, sub_kernel):
  95. """Calculate treelet graph kernel between 2 graphs.
  96. Parameters
  97. ----------
  98. canonkey1, canonkey2 : list
  99. List of canonical keys in 2 graphs, where each key is represented by a string.
  100. Return
  101. ------
  102. kernel : float
  103. Treelet Kernel between 2 graphs.
  104. """
  105. keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs
  106. vector1 = np.array([(canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys])
  107. vector2 = np.array([(canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys])
  108. kernel = np.sum(np.exp(-np.square(vector1 - vector2) / 2))
  109. # kernel = sub_kernel(vector1, vector2)
  110. return kernel
  111. def wrapper_treeletkernel_do(sub_kernel, itr):
  112. i = itr[0]
  113. j = itr[1]
  114. return i, j, _treeletkernel_do(G_canonkeys[i], G_canonkeys[j], sub_kernel)
  115. def get_canonkeys(G, node_label, edge_label, labeled, is_directed):
  116. """Generate canonical keys of all treelets in a graph.
  117. Parameters
  118. ----------
  119. G : NetworkX graphs
  120. The graph in which keys are generated.
  121. node_label : string
  122. node attribute used as label. The default node label is atom.
  123. edge_label : string
  124. edge attribute used as label. The default edge label is bond_type.
  125. labeled : boolean
  126. Whether the graphs are labeled. The default is True.
  127. Return
  128. ------
  129. canonkey/canonkey_l : dict
  130. For unlabeled graphs, canonkey is a dictionary which records amount of
  131. every tree pattern. For labeled graphs, canonkey_l is one which keeps
  132. track of amount of every treelet.
  133. """
  134. patterns = {} # a dictionary which consists of lists of patterns for all graphlet.
  135. canonkey = {} # canonical key, a dictionary which records amount of every tree pattern.
  136. ### structural analysis ###
  137. ### In this section, a list of patterns is generated for each graphlet,
  138. ### where every pattern is represented by nodes ordered by Morgan's
  139. ### extended labeling.
  140. # linear patterns
  141. patterns['0'] = G.nodes()
  142. canonkey['0'] = nx.number_of_nodes(G)
  143. for i in range(1, 6): # for i in range(1, 6):
  144. patterns[str(i)] = find_all_paths(G, i, is_directed)
  145. canonkey[str(i)] = len(patterns[str(i)])
  146. # n-star patterns
  147. patterns['3star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 3]
  148. patterns['4star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 4]
  149. patterns['5star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 5]
  150. # n-star patterns
  151. canonkey['6'] = len(patterns['3star'])
  152. canonkey['8'] = len(patterns['4star'])
  153. canonkey['d'] = len(patterns['5star'])
  154. # pattern 7
  155. patterns['7'] = [] # the 1st line of Table 1 in Ref [1]
  156. for pattern in patterns['3star']:
  157. for i in range(1, len(pattern)): # for each neighbor of node 0
  158. if G.degree(pattern[i]) >= 2:
  159. pattern_t = pattern[:]
  160. # set the node with degree >= 2 as the 4th node
  161. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  162. for neighborx in G[pattern[i]]:
  163. if neighborx != pattern[0]:
  164. new_pattern = pattern_t + [neighborx]
  165. patterns['7'].append(new_pattern)
  166. canonkey['7'] = len(patterns['7'])
  167. # pattern 11
  168. patterns['11'] = [] # the 4th line of Table 1 in Ref [1]
  169. for pattern in patterns['4star']:
  170. for i in range(1, len(pattern)):
  171. if G.degree(pattern[i]) >= 2:
  172. pattern_t = pattern[:]
  173. pattern_t[i], pattern_t[4] = pattern_t[4], pattern_t[i]
  174. for neighborx in G[pattern[i]]:
  175. if neighborx != pattern[0]:
  176. new_pattern = pattern_t + [ neighborx ]
  177. patterns['11'].append(new_pattern)
  178. canonkey['b'] = len(patterns['11'])
  179. # pattern 12
  180. patterns['12'] = [] # the 5th line of Table 1 in Ref [1]
  181. rootlist = [] # a list of root nodes, whose extended labels are 3
  182. for pattern in patterns['3star']:
  183. if pattern[0] not in rootlist: # prevent to count the same pattern twice from each of the two root nodes
  184. rootlist.append(pattern[0])
  185. for i in range(1, len(pattern)):
  186. if G.degree(pattern[i]) >= 3:
  187. rootlist.append(pattern[i])
  188. pattern_t = pattern[:]
  189. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  190. for neighborx1 in G[pattern[i]]:
  191. if neighborx1 != pattern[0]:
  192. for neighborx2 in G[pattern[i]]:
  193. if neighborx1 > neighborx2 and neighborx2 != pattern[0]:
  194. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  195. # new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pattern[i]] if neighborx1 != pattern[0] for neighborx2 in G[pattern[i]] if (neighborx1 > neighborx2 and neighborx2 != pattern[0]) ]
  196. patterns['12'].append(new_pattern)
  197. canonkey['c'] = int(len(patterns['12']) / 2)
  198. # pattern 9
  199. patterns['9'] = [] # the 2nd line of Table 1 in Ref [1]
  200. for pattern in patterns['3star']:
  201. for pairs in [ [neighbor1, neighbor2] for neighbor1 in G[pattern[0]] if G.degree(neighbor1) >= 2 \
  202. for neighbor2 in G[pattern[0]] if G.degree(neighbor2) >= 2 if neighbor1 > neighbor2 ]:
  203. pattern_t = pattern[:]
  204. # move nodes with extended labels 4 to specific position to correspond to their children
  205. pattern_t[pattern_t.index(pairs[0])], pattern_t[2] = pattern_t[2], pattern_t[pattern_t.index(pairs[0])]
  206. pattern_t[pattern_t.index(pairs[1])], pattern_t[3] = pattern_t[3], pattern_t[pattern_t.index(pairs[1])]
  207. for neighborx1 in G[pairs[0]]:
  208. if neighborx1 != pattern[0]:
  209. for neighborx2 in G[pairs[1]]:
  210. if neighborx2 != pattern[0]:
  211. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  212. patterns['9'].append(new_pattern)
  213. canonkey['9'] = len(patterns['9'])
  214. # pattern 10
  215. patterns['10'] = [] # the 3rd line of Table 1 in Ref [1]
  216. for pattern in patterns['3star']:
  217. for i in range(1, len(pattern)):
  218. if G.degree(pattern[i]) >= 2:
  219. for neighborx in G[pattern[i]]:
  220. if neighborx != pattern[0] and G.degree(neighborx) >= 2:
  221. pattern_t = pattern[:]
  222. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  223. new_patterns = [ pattern_t + [neighborx] + [neighborxx] for neighborxx in G[neighborx] if neighborxx != pattern[i] ]
  224. patterns['10'].extend(new_patterns)
  225. canonkey['a'] = len(patterns['10'])
  226. ### labeling information ###
  227. ### In this section, a list of canonical keys is generated for every
  228. ### pattern obtained in the structural analysis section above, which is a
  229. ### string corresponding to a unique treelet. A dictionary is built to keep
  230. ### track of the amount of every treelet.
  231. if labeled == True:
  232. canonkey_l = {} # canonical key, a dictionary which keeps track of amount of every treelet.
  233. # linear patterns
  234. canonkey_t = Counter(list(nx.get_node_attributes(G, node_label).values()))
  235. for key in canonkey_t:
  236. canonkey_l['0' + key] = canonkey_t[key]
  237. for i in range(1, 6): # for i in range(1, 6):
  238. treelet = []
  239. for pattern in patterns[str(i)]:
  240. canonlist = list(chain.from_iterable((G.node[node][node_label], \
  241. G[node][pattern[idx+1]][edge_label]) for idx, node in enumerate(pattern[:-1])))
  242. canonlist.append(G.node[pattern[-1]][node_label])
  243. canonkey_t = ''.join(canonlist)
  244. canonkey_t = canonkey_t if canonkey_t < canonkey_t[::-1] else canonkey_t[::-1]
  245. treelet.append(str(i) + canonkey_t)
  246. canonkey_l.update(Counter(treelet))
  247. # n-star patterns
  248. for i in range(3, 6):
  249. treelet = []
  250. for pattern in patterns[str(i) + 'star']:
  251. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:] ]
  252. canonlist.sort()
  253. canonkey_t = ('d' if i == 5 else str(i * 2)) + G.node[pattern[0]][node_label] + ''.join(canonlist)
  254. treelet.append(canonkey_t)
  255. canonkey_l.update(Counter(treelet))
  256. # pattern 7
  257. treelet = []
  258. for pattern in patterns['7']:
  259. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]
  260. canonlist.sort()
  261. canonkey_t = '7' + G.node[pattern[0]][node_label] + ''.join(canonlist) \
  262. + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \
  263. + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label]
  264. treelet.append(canonkey_t)
  265. canonkey_l.update(Counter(treelet))
  266. # pattern 11
  267. treelet = []
  268. for pattern in patterns['11']:
  269. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:4] ]
  270. canonlist.sort()
  271. canonkey_t = 'b' + G.node[pattern[0]][node_label] + ''.join(canonlist) \
  272. + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[0]][edge_label] \
  273. + G.node[pattern[5]][node_label] + G[pattern[5]][pattern[4]][edge_label]
  274. treelet.append(canonkey_t)
  275. canonkey_l.update(Counter(treelet))
  276. # pattern 10
  277. treelet = []
  278. for pattern in patterns['10']:
  279. canonkey4 = G.node[pattern[5]][node_label] + G[pattern[5]][pattern[4]][edge_label]
  280. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]
  281. canonlist.sort()
  282. canonkey0 = ''.join(canonlist)
  283. canonkey_t = 'a' + G.node[pattern[3]][node_label] \
  284. + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label] \
  285. + G.node[pattern[0]][node_label] + G[pattern[0]][pattern[3]][edge_label] \
  286. + canonkey4 + canonkey0
  287. treelet.append(canonkey_t)
  288. canonkey_l.update(Counter(treelet))
  289. # pattern 12
  290. treelet = []
  291. for pattern in patterns['12']:
  292. canonlist0 = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]
  293. canonlist0.sort()
  294. canonlist3 = [ G.node[leaf][node_label] + G[leaf][pattern[3]][edge_label] for leaf in pattern[4:6] ]
  295. canonlist3.sort()
  296. # 2 possible key can be generated from 2 nodes with extended label 3, select the one with lower lexicographic order.
  297. canonkey_t1 = 'c' + G.node[pattern[0]][node_label] \
  298. + ''.join(canonlist0) \
  299. + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \
  300. + ''.join(canonlist3)
  301. canonkey_t2 = 'c' + G.node[pattern[3]][node_label] \
  302. + ''.join(canonlist3) \
  303. + G.node[pattern[0]][node_label] + G[pattern[0]][pattern[3]][edge_label] \
  304. + ''.join(canonlist0)
  305. treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2)
  306. canonkey_l.update(Counter(treelet))
  307. # pattern 9
  308. treelet = []
  309. for pattern in patterns['9']:
  310. canonkey2 = G.node[pattern[4]][node_label] + G[pattern[4]][pattern[2]][edge_label]
  311. canonkey3 = G.node[pattern[5]][node_label] + G[pattern[5]][pattern[3]][edge_label]
  312. prekey2 = G.node[pattern[2]][node_label] + G[pattern[2]][pattern[0]][edge_label]
  313. prekey3 = G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label]
  314. if prekey2 + canonkey2 < prekey3 + canonkey3:
  315. canonkey_t = G.node[pattern[1]][node_label] + G[pattern[1]][pattern[0]][edge_label] \
  316. + prekey2 + prekey3 + canonkey2 + canonkey3
  317. else:
  318. canonkey_t = G.node[pattern[1]][node_label] + G[pattern[1]][pattern[0]][edge_label] \
  319. + prekey3 + prekey2 + canonkey3 + canonkey2
  320. treelet.append('9' + G.node[pattern[0]][node_label] + canonkey_t)
  321. canonkey_l.update(Counter(treelet))
  322. return canonkey_l
  323. return canonkey
  324. def wrapper_get_canonkeys(node_label, edge_label, labeled, is_directed, itr_item):
  325. g = itr_item[0]
  326. i = itr_item[1]
  327. return i, get_canonkeys(g, node_label, edge_label, labeled, is_directed)
  328. def find_paths(G, source_node, length):
  329. """Find all paths with a certain length those start from a source node.
  330. A recursive depth first search is applied.
  331. Parameters
  332. ----------
  333. G : NetworkX graphs
  334. The graph in which paths are searched.
  335. source_node : integer
  336. The number of the node from where all paths start.
  337. length : integer
  338. The length of paths.
  339. Return
  340. ------
  341. path : list of list
  342. List of paths retrieved, where each path is represented by a list of nodes.
  343. """
  344. if length == 0:
  345. return [[source_node]]
  346. path = [[source_node] + path for neighbor in G[source_node] \
  347. for path in find_paths(G, neighbor, length - 1) if source_node not in path]
  348. return path
  349. def find_all_paths(G, length, is_directed):
  350. """Find all paths with a certain length in a graph. A recursive depth first
  351. search is applied.
  352. Parameters
  353. ----------
  354. G : NetworkX graphs
  355. The graph in which paths are searched.
  356. length : integer
  357. The length of paths.
  358. Return
  359. ------
  360. path : list of list
  361. List of paths retrieved, where each path is represented by a list of nodes.
  362. """
  363. all_paths = []
  364. for node in G:
  365. all_paths.extend(find_paths(G, node, length))
  366. if not is_directed:
  367. # For each path, two presentations are retrieved from its two extremities.
  368. # Remove one of them.
  369. all_paths_r = [path[::-1] for path in all_paths]
  370. for idx, path in enumerate(all_paths[:-1]):
  371. for path2 in all_paths_r[idx+1::]:
  372. if path == path2:
  373. all_paths[idx] = []
  374. break
  375. all_paths = list(filter(lambda a: a != [], all_paths))
  376. return all_paths

A Python package for graph kernels, graph edit distances and graph pre-image problem.