You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphfiles.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. """ Utilities function to manage graph files
  2. """
  3. def loadCT(filename):
  4. """load data from .ct file.
  5. Notes
  6. ------
  7. a typical example of data in .ct is like this:
  8. 3 2 <- number of nodes and edges
  9. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  10. 0.0000 0.0000 0.0000 C
  11. 0.0000 0.0000 0.0000 O
  12. 1 3 1 1 <- each line describes an edge : to, from,?, label
  13. 2 3 1 1
  14. """
  15. import networkx as nx
  16. from os.path import basename
  17. g = nx.Graph()
  18. with open(filename) as f:
  19. content = f.read().splitlines()
  20. g = nx.Graph(
  21. name = str(content[0]),
  22. filename = basename(filename)) # set name of the graph
  23. tmp = content[1].split(" ")
  24. if tmp[0] == '':
  25. nb_nodes = int(tmp[1]) # number of the nodes
  26. nb_edges = int(tmp[2]) # number of the edges
  27. else:
  28. nb_nodes = int(tmp[0])
  29. nb_edges = int(tmp[1])
  30. # patch for compatibility : label will be removed later
  31. for i in range(0, nb_nodes):
  32. tmp = content[i + 2].split(" ")
  33. tmp = [x for x in tmp if x != '']
  34. g.add_node(i, atom=tmp[3], label=tmp[3])
  35. for i in range(0, nb_edges):
  36. tmp = content[i + g.number_of_nodes() + 2].split(" ")
  37. tmp = [x for x in tmp if x != '']
  38. g.add_edge(
  39. int(tmp[0]) - 1,
  40. int(tmp[1]) - 1,
  41. bond_type=tmp[3].strip(),
  42. label=tmp[3].strip())
  43. # for i in range(0, nb_edges):
  44. # tmp = content[i + g.number_of_nodes() + 2]
  45. # tmp = [tmp[i:i+3] for i in range(0, len(tmp), 3)]
  46. # g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  47. # bond_type=tmp[3].strip(), label=tmp[3].strip())
  48. return g
  49. def loadGXL(filename):
  50. from os.path import basename
  51. import networkx as nx
  52. import xml.etree.ElementTree as ET
  53. tree = ET.parse(filename)
  54. root = tree.getroot()
  55. index = 0
  56. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  57. dic = {} # used to retrieve incident nodes of edges
  58. for node in root.iter('node'):
  59. dic[node.attrib['id']] = index
  60. labels = {}
  61. for attr in node.iter('attr'):
  62. labels[attr.attrib['name']] = attr[0].text
  63. if 'chem' in labels:
  64. labels['label'] = labels['chem']
  65. g.add_node(index, **labels)
  66. index += 1
  67. for edge in root.iter('edge'):
  68. labels = {}
  69. for attr in edge.iter('attr'):
  70. labels[attr.attrib['name']] = attr[0].text
  71. if 'valence' in labels:
  72. labels['label'] = labels['valence']
  73. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  74. return g
  75. def saveGXL(graph, filename, method='benoit'):
  76. if method == 'benoit':
  77. import xml.etree.ElementTree as ET
  78. root_node = ET.Element('gxl')
  79. attr = dict()
  80. attr['id'] = str(graph.graph['name'])
  81. attr['edgeids'] = 'true'
  82. attr['edgemode'] = 'undirected'
  83. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  84. for v in graph:
  85. current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
  86. for attr in graph.nodes[v].keys():
  87. cur_attr = ET.SubElement(
  88. current_node, 'attr', attrib={'name': attr})
  89. cur_value = ET.SubElement(cur_attr,
  90. graph.nodes[v][attr].__class__.__name__)
  91. cur_value.text = graph.nodes[v][attr]
  92. for v1 in graph:
  93. for v2 in graph[v1]:
  94. if (v1 < v2): # Non oriented graphs
  95. cur_edge = ET.SubElement(
  96. graph_node,
  97. 'edge',
  98. attrib={
  99. 'from': str(v1),
  100. 'to': str(v2)
  101. })
  102. for attr in graph[v1][v2].keys():
  103. cur_attr = ET.SubElement(
  104. cur_edge, 'attr', attrib={'name': attr})
  105. cur_value = ET.SubElement(
  106. cur_attr, graph[v1][v2][attr].__class__.__name__)
  107. cur_value.text = str(graph[v1][v2][attr])
  108. tree = ET.ElementTree(root_node)
  109. tree.write(filename)
  110. elif method == 'gedlib':
  111. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  112. pass
  113. # gxl_file = open(filename, 'w')
  114. # gxl_file.write("<?xml version=\"1.0\"?>\n")
  115. # gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  116. # gxl_file.write("<gxl>\n")
  117. # gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
  118. # for v in graph:
  119. # gxl_file.write("<node id=\"_" + str(v) + "\">\n")
  120. # gxl_file.write("<attr name=\"chem\"><int>" + str(self.node_labels[node]) + "</int></attr>\n")
  121. # gxl_file.write("</node>\n")
  122. # for edge in self.edge_list:
  123. # gxl_file.write("<edge from=\"_" + str(edge[0]) + "\" to=\"_" + str(edge[1]) + "\">\n")
  124. # gxl_file.write("<attr name=\"valence\"><int>1</int></attr>\n")
  125. # gxl_file.write("</edge>\n")
  126. # gxl_file.write("</graph>\n")
  127. # gxl_file.write("</gxl>\n")
  128. # gxl_file.close()
  129. def loadSDF(filename):
  130. """load data from structured data file (.sdf file).
  131. Notes
  132. ------
  133. A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  134. Check http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx, 2018 for detailed structure.
  135. """
  136. import networkx as nx
  137. from os.path import basename
  138. from tqdm import tqdm
  139. import sys
  140. data = []
  141. with open(filename) as f:
  142. content = f.read().splitlines()
  143. index = 0
  144. pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout)
  145. while index < len(content):
  146. index_old = index
  147. g = nx.Graph(name=content[index].strip()) # set name of the graph
  148. tmp = content[index + 3]
  149. nb_nodes = int(tmp[:3]) # number of the nodes
  150. nb_edges = int(tmp[3:6]) # number of the edges
  151. for i in range(0, nb_nodes):
  152. tmp = content[i + index + 4]
  153. g.add_node(i, atom=tmp[31:34].strip())
  154. for i in range(0, nb_edges):
  155. tmp = content[i + index + g.number_of_nodes() + 4]
  156. tmp = [tmp[i:i + 3] for i in range(0, len(tmp), 3)]
  157. g.add_edge(
  158. int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip())
  159. data.append(g)
  160. index += 4 + g.number_of_nodes() + g.number_of_edges()
  161. while content[index].strip() != '$$$$': # seperator
  162. index += 1
  163. index += 1
  164. pbar.update(index - index_old)
  165. pbar.update(1)
  166. pbar.close()
  167. return data
  168. def loadMAT(filename, extra_params):
  169. """Load graph data from a MATLAB (up to version 7.1) .mat file.
  170. Notes
  171. ------
  172. A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph.
  173. Check README in downloadable file in http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/, 2018 for detailed structure.
  174. """
  175. from scipy.io import loadmat
  176. import numpy as np
  177. import networkx as nx
  178. data = []
  179. content = loadmat(filename)
  180. order = extra_params['am_sp_al_nl_el']
  181. # print(content)
  182. # print('----')
  183. for key, value in content.items():
  184. if key[0] == 'l': # class label
  185. y = np.transpose(value)[0].tolist()
  186. # print(y)
  187. elif key[0] != '_':
  188. # print(value[0][0][0])
  189. # print()
  190. # print(value[0][0][1])
  191. # print()
  192. # print(value[0][0][2])
  193. # print()
  194. # if len(value[0][0]) > 3:
  195. # print(value[0][0][3])
  196. # print('----')
  197. # if adjacency matrix is not compressed / edge label exists
  198. if order[1] == 0:
  199. for i, item in enumerate(value[0]):
  200. # print(item)
  201. # print('------')
  202. g = nx.Graph(name=i) # set name of the graph
  203. nl = np.transpose(item[order[3]][0][0][0]) # node label
  204. # print(item[order[3]])
  205. # print()
  206. for index, label in enumerate(nl[0]):
  207. g.add_node(index, atom=str(label))
  208. el = item[order[4]][0][0][0] # edge label
  209. for edge in el:
  210. g.add_edge(
  211. edge[0] - 1, edge[1] - 1, bond_type=str(edge[2]))
  212. data.append(g)
  213. else:
  214. from scipy.sparse import csc_matrix
  215. for i, item in enumerate(value[0]):
  216. # print(item)
  217. # print('------')
  218. g = nx.Graph(name=i) # set name of the graph
  219. nl = np.transpose(item[order[3]][0][0][0]) # node label
  220. # print(nl)
  221. # print()
  222. for index, label in enumerate(nl[0]):
  223. g.add_node(index, atom=str(label))
  224. sam = item[order[0]] # sparse adjacency matrix
  225. index_no0 = sam.nonzero()
  226. for col, row in zip(index_no0[0], index_no0[1]):
  227. # print(col)
  228. # print(row)
  229. g.add_edge(col, row)
  230. data.append(g)
  231. # print(g.edges(data=True))
  232. return data, y
  233. def loadTXT(dirname_dataset):
  234. """Load graph data from a .txt file.
  235. Notes
  236. ------
  237. The graph data is loaded from separate files.
  238. Check README in downloadable file http://tiny.cc/PK_MLJ_data, 2018 for detailed structure.
  239. """
  240. import numpy as np
  241. import networkx as nx
  242. from os import listdir
  243. from os.path import dirname
  244. # load data file names
  245. for name in listdir(dirname_dataset):
  246. if '_A' in name:
  247. fam = dirname_dataset + '/' + name
  248. elif '_graph_indicator' in name:
  249. fgi = dirname_dataset + '/' + name
  250. elif '_graph_labels' in name:
  251. fgl = dirname_dataset + '/' + name
  252. elif '_node_labels' in name:
  253. fnl = dirname_dataset + '/' + name
  254. elif '_edge_labels' in name:
  255. fel = dirname_dataset + '/' + name
  256. elif '_edge_attributes' in name:
  257. fea = dirname_dataset + '/' + name
  258. elif '_node_attributes' in name:
  259. fna = dirname_dataset + '/' + name
  260. elif '_graph_attributes' in name:
  261. fga = dirname_dataset + '/' + name
  262. # this is supposed to be the node attrs, make sure to put this as the last 'elif'
  263. elif '_attributes' in name:
  264. fna = dirname_dataset + '/' + name
  265. content_gi = open(fgi).read().splitlines() # graph indicator
  266. content_am = open(fam).read().splitlines() # adjacency matrix
  267. content_gl = open(fgl).read().splitlines() # lass labels
  268. # create graphs and add nodes
  269. data = [nx.Graph(name=i) for i in range(0, len(content_gl))]
  270. if 'fnl' in locals():
  271. content_nl = open(fnl).read().splitlines() # node labels
  272. for i, line in enumerate(content_gi):
  273. # transfer to int first in case of unexpected blanks
  274. data[int(line) - 1].add_node(i, atom=str(int(content_nl[i])))
  275. else:
  276. for i, line in enumerate(content_gi):
  277. data[int(line) - 1].add_node(i)
  278. # add edges
  279. for line in content_am:
  280. tmp = line.split(',')
  281. n1 = int(tmp[0]) - 1
  282. n2 = int(tmp[1]) - 1
  283. # ignore edge weight here.
  284. g = int(content_gi[n1]) - 1
  285. data[g].add_edge(n1, n2)
  286. # add edge labels
  287. if 'fel' in locals():
  288. content_el = open(fel).read().splitlines()
  289. for index, line in enumerate(content_el):
  290. label = line.strip()
  291. n = [int(i) - 1 for i in content_am[index].split(',')]
  292. g = int(content_gi[n[0]]) - 1
  293. data[g].edges[n[0], n[1]]['bond_type'] = label
  294. # add node attributes
  295. if 'fna' in locals():
  296. content_na = open(fna).read().splitlines()
  297. for i, line in enumerate(content_na):
  298. attrs = [i.strip() for i in line.split(',')]
  299. g = int(content_gi[i]) - 1
  300. data[g].nodes[i]['attributes'] = attrs
  301. # add edge attributes
  302. if 'fea' in locals():
  303. content_ea = open(fea).read().splitlines()
  304. for index, line in enumerate(content_ea):
  305. attrs = [i.strip() for i in line.split(',')]
  306. n = [int(i) - 1 for i in content_am[index].split(',')]
  307. g = int(content_gi[n[0]]) - 1
  308. data[g].edges[n[0], n[1]]['attributes'] = attrs
  309. # load y
  310. y = [int(i) for i in content_gl]
  311. return data, y
  312. def loadDataset(filename, filename_y=None, extra_params=None):
  313. """load file list of the dataset.
  314. """
  315. from os.path import dirname, splitext
  316. dirname_dataset = dirname(filename)
  317. extension = splitext(filename)[1][1:]
  318. data = []
  319. y = []
  320. if extension == "ds":
  321. content = open(filename).read().splitlines()
  322. if filename_y is None or filename_y == '':
  323. for i in range(0, len(content)):
  324. tmp = content[i].split(' ')
  325. # remove the '#'s in file names
  326. data.append(
  327. loadCT(dirname_dataset + '/' + tmp[0].replace('#', '', 1)))
  328. y.append(float(tmp[1]))
  329. else: # y in a seperate file
  330. for i in range(0, len(content)):
  331. tmp = content[i]
  332. # remove the '#'s in file names
  333. data.append(
  334. loadCT(dirname_dataset + '/' + tmp.replace('#', '', 1)))
  335. content_y = open(filename_y).read().splitlines()
  336. # assume entries in filename and filename_y have the same order.
  337. for item in content_y:
  338. tmp = item.split(' ')
  339. # assume the 3rd entry in a line is y (for Alkane dataset)
  340. y.append(float(tmp[2]))
  341. elif extension == "cxl":
  342. import xml.etree.ElementTree as ET
  343. tree = ET.parse(filename)
  344. root = tree.getroot()
  345. data = []
  346. y = []
  347. for graph in root.iter('print'):
  348. mol_filename = graph.attrib['file']
  349. mol_class = graph.attrib['class']
  350. data.append(loadGXL(dirname_dataset + '/' + mol_filename))
  351. y.append(mol_class)
  352. elif extension == "sdf":
  353. import numpy as np
  354. from tqdm import tqdm
  355. import sys
  356. data = loadSDF(filename)
  357. y_raw = open(filename_y).read().splitlines()
  358. y_raw.pop(0)
  359. tmp0 = []
  360. tmp1 = []
  361. for i in range(0, len(y_raw)):
  362. tmp = y_raw[i].split(',')
  363. tmp0.append(tmp[0])
  364. tmp1.append(tmp[1].strip())
  365. y = []
  366. for i in tqdm(range(0, len(data)), desc='ajust data', file=sys.stdout):
  367. try:
  368. y.append(tmp1[tmp0.index(data[i].name)].strip())
  369. except ValueError: # if data[i].name not in tmp0
  370. data[i] = []
  371. data = list(filter(lambda a: a != [], data))
  372. elif extension == "mat":
  373. data, y = loadMAT(filename, extra_params)
  374. elif extension == 'txt':
  375. data, y = loadTXT(dirname_dataset)
  376. # print(len(y))
  377. # print(y)
  378. # print(data[0].nodes(data=True))
  379. # print('----')
  380. # print(data[0].edges(data=True))
  381. # for g in data:
  382. # print(g.nodes(data=True))
  383. # print('----')
  384. # print(g.edges(data=True))
  385. return data, y
  386. def saveDataset(Gn, y, gformat='gxl', group=None, filename='gfile'):
  387. """Save list of graphs.
  388. """
  389. import os
  390. dirname_ds = os.path.dirname(filename)
  391. if dirname_ds != '':
  392. dirname_ds += '/'
  393. if not os.path.exists(dirname_ds) :
  394. os.makedirs(dirname_ds)
  395. if group == 'xml' and gformat == 'gxl':
  396. with open(filename + '.xml', 'w') as fgroup:
  397. fgroup.write("<?xml version=\"1.0\"?>")
  398. fgroup.write("\n<!DOCTYPE GraphCollection SYSTEM \"https://dbblumenthal.github.io/gedlib/GraphCollection_8dtd_source.html\">")
  399. fgroup.write("\n<GraphCollection>")
  400. for idx, g in enumerate(Gn):
  401. fname_tmp = "graph" + str(idx) + ".gxl"
  402. saveGXL(g, dirname_ds + fname_tmp)
  403. fgroup.write("\n\t<graph file=\"" + fname_tmp + "\" class=\"" + str(y[idx]) + "\"/>")
  404. fgroup.write("\n</GraphCollection>")
  405. fgroup.close()
  406. if __name__ == '__main__':
  407. ds = {'name': 'MUTAG', 'dataset': '../../datasets/MUTAG/MUTAG.mat',
  408. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  409. Gn, y = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  410. saveDataset(Gn, y, group='xml', filename='temp/temp')

A Python package for graph kernels, graph edit distances and graph pre-image problem.