You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_fetcher.py 59 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Tue Oct 20 14:25:49 2020
  5. @author:
  6. Paul Zanoncelli, paul.zanoncelli@ecole.ensicaen.fr
  7. Luc Brun luc.brun@ensicaen.fr
  8. Sebastien Bougleux sebastien.bougleux@unicaen.fr
  9. Benoit Gaüzère benoit.gauzere@insa-rouen.fr
  10. Linlin Jia linlin.jia@insa-rouen.fr
  11. """
  12. import os
  13. import os.path as osp
  14. import urllib
  15. import tarfile
  16. from zipfile import ZipFile
  17. from gklearn.utils.graphfiles import loadDataset
  18. import torch.nn.functional as F
  19. import networkx as nx
  20. import torch
  21. import random
  22. import sys
  23. from lxml import etree
  24. import re
  25. from tqdm import tqdm
  26. from gklearn.dataset import DATABASES, DATASET_META
  27. class DataFetcher():
  28. def __init__(self, name=None, root='datasets', reload=False, verbose=False):
  29. self._name = name
  30. self._root = root
  31. if not osp.exists(self._root):
  32. os.makedirs(self._root)
  33. self._reload = reload
  34. self._verbose = verbose
  35. # self.has_train_valid_test = {
  36. # "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'),
  37. # "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'),
  38. # "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'),
  39. # # "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'),
  40. # "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'),
  41. # 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'),
  42. # 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl')
  43. # },
  44. # "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'),
  45. # # "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'],
  46. # "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'),
  47. # # "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl')
  48. # }
  49. if self._name is None:
  50. if self._verbose:
  51. print('No dataset name entered. All possible datasets will be loaded.')
  52. self._name, self._path = [], []
  53. for idx, ds_name in enumerate(DATASET_META):
  54. if self._verbose:
  55. print(str(idx + 1), '/', str(len(DATASET_META)), 'Fetching', ds_name, end='... ')
  56. self._name.append(ds_name)
  57. success = self.write_archive_file(ds_name)
  58. if success:
  59. self._path.append(self.open_files(ds_name))
  60. else:
  61. self._path.append(None)
  62. if self._verbose and self._path[-1] is not None and not self._reload:
  63. print('Fetched.')
  64. if self._verbose:
  65. print('Finished.', str(sum(v is not None for v in self._path)), 'of', str(len(self._path)), 'datasets are successfully fetched.')
  66. elif self._name not in DATASET_META:
  67. message = 'Invalid Dataset name "' + self._name + '".'
  68. message += '\nAvailable datasets are as follows: \n\n'
  69. message += '\n'.join(ds for ds in sorted(DATASET_META))
  70. raise ValueError(message)
  71. else:
  72. self.write_archive_file(self._name)
  73. self._path = self.open_files(self._name)
  74. # self.max_for_letter = 0
  75. # if mode == 'Pytorch':
  76. # if self._name in self.data_to_use_in_datasets :
  77. # Gs,y = self.dataset
  78. # inputs,adjs,y = self.from_networkx_to_pytorch(Gs,y)
  79. # #print(inputs,adjs)
  80. # self.pytorch_dataset = inputs,adjs,y
  81. # elif self._name == "Pah":
  82. # self.pytorch_dataset = []
  83. # test,train = self.dataset
  84. # Gs_test,y_test = test
  85. # Gs_train,y_train = train
  86. # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test))
  87. # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train))
  88. # elif self._name in self.has_train_valid_test:
  89. # self.pytorch_dataset = []
  90. # #[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])
  91. # test,train,valid = self.dataset
  92. # Gs_test,y_test = test
  93. #
  94. # Gs_train,y_train = train
  95. # Gs_valid,y_valid = valid
  96. # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test))
  97. # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train))
  98. # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_valid,y_valid))
  99. # #############
  100. # """
  101. # for G in Gs :
  102. # for e in G.edges():
  103. # print(G[e[0]])
  104. # """
  105. # ##############
  106. def download_file(self, url):
  107. try :
  108. response = urllib.request.urlopen(url)
  109. except urllib.error.HTTPError:
  110. print('"', url.split('/')[-1], '" is not available or incorrect http link.')
  111. return
  112. except urllib.error.URLError:
  113. print('Network is unreachable.')
  114. return
  115. return response
  116. def write_archive_file(self, ds_name):
  117. path = osp.join(self._root, ds_name)
  118. url = DATASET_META[ds_name]['url']
  119. # filename_dir = osp.join(path,filename)
  120. if not osp.exists(path) or self._reload:
  121. response = self.download_file(url)
  122. if response is None:
  123. return False
  124. os.makedirs(path, exist_ok=True)
  125. with open(os.path.join(path, url.split('/')[-1]), 'wb') as outfile:
  126. outfile.write(response.read())
  127. return True
  128. def open_files(self, ds_name=None):
  129. if ds_name is None:
  130. ds_name = (self._name if isinstance(self._name, str) else self._name[0])
  131. filename = DATASET_META[ds_name]['url'].split('/')[-1]
  132. path = osp.join(self._root, ds_name)
  133. filename_archive = osp.join(path, filename)
  134. if filename.endswith('gz'):
  135. if tarfile.is_tarfile(filename_archive):
  136. with tarfile.open(filename_archive, 'r:gz') as tar:
  137. if self._reload and self._verbose:
  138. print(filename + ' Downloaded.')
  139. tar.extractall(path = path)
  140. return os.path.join(path, tar.getnames()[0])
  141. elif filename.endswith('.tar'):
  142. if tarfile.is_tarfile(filename_archive):
  143. with tarfile.open(filename_archive, 'r:') as tar:
  144. if self._reload and self._verbose:
  145. print(filename + ' Downloaded.')
  146. tar.extractall(path = path)
  147. return os.path.join(path, tar.getnames()[0])
  148. elif filename.endswith('.zip'):
  149. with ZipFile(filename_archive, 'r') as zip_ref:
  150. if self._reload and self._verbose:
  151. print(filename + ' Downloaded.')
  152. zip_ref.extractall(path)
  153. return os.path.join(path, zip_ref.namelist()[0])
  154. else:
  155. raise ValueError(filename + ' Unsupported file.')
  156. def get_all_ds_infos(self, database):
  157. """Get information of all datasets from a database.
  158. Parameters
  159. ----------
  160. database : string
  161. DESCRIPTION.
  162. Returns
  163. -------
  164. None.
  165. """
  166. if database.lower() == 'tudataset':
  167. infos = self.get_all_tud_ds_infos()
  168. elif database.lower() == 'iam':
  169. pass
  170. else:
  171. msg = 'Invalid Database name "' + database + '"'
  172. msg += '\n Available databases are as follows: \n\n'
  173. msg += '\n'.join(db for db in sorted(DATABASES))
  174. msg += 'Check "gklearn.dataset.DATASET_META" for more details.'
  175. raise ValueError(msg)
  176. return infos
  177. def get_all_tud_ds_infos(self):
  178. """Get information of all datasets from database TUDataset.
  179. Returns
  180. -------
  181. None.
  182. """
  183. try:
  184. response = urllib.request.urlopen(DATABASES['tudataset'])
  185. except urllib.error.HTTPError:
  186. print('The URL of the database "TUDataset" is not available:\n' + DATABASES['tudataset'])
  187. infos = {}
  188. # Get tables.
  189. h_str = response.read()
  190. tree = etree.HTML(h_str)
  191. tables = tree.xpath('//table')
  192. for table in tables:
  193. # Get the domain of the datasets.
  194. h2_nodes = table.getprevious()
  195. if h2_nodes is not None and h2_nodes.tag == 'h2':
  196. domain = h2_nodes.text.strip().lower()
  197. else:
  198. domain = ''
  199. # Get each line in the table.
  200. tr_nodes = table.xpath('tbody/tr')
  201. for tr in tr_nodes[1:]:
  202. # Get each element in the line.
  203. td_node = tr.xpath('td')
  204. # task type.
  205. cls_txt = td_node[3].text.strip()
  206. if not cls_txt.startswith('R'):
  207. class_number = int(cls_txt)
  208. task_type = 'classification'
  209. else:
  210. class_number = None
  211. task_type = 'regression'
  212. # node attrs.
  213. na_text = td_node[8].text.strip()
  214. if not na_text.startswith('+'):
  215. node_attr_dim = 0
  216. else:
  217. node_attr_dim = int(re.findall('\((.*)\)', na_text)[0])
  218. # edge attrs.
  219. ea_text = td_node[10].text.strip()
  220. if ea_text == 'temporal':
  221. edge_attr_dim = ea_text
  222. elif not ea_text.startswith('+'):
  223. edge_attr_dim = 0
  224. else:
  225. edge_attr_dim = int(re.findall('\((.*)\)', ea_text)[0])
  226. # geometry.
  227. geo_txt = td_node[9].text.strip()
  228. if geo_txt == '–':
  229. geometry = None
  230. else:
  231. geometry = geo_txt
  232. infos[td_node[0].xpath('strong')[0].text.strip()] = {
  233. 'database': 'tudataset',
  234. 'reference': td_node[1].text.strip(),
  235. 'dataset_size': int(td_node[2].text.strip()),
  236. 'class_number': class_number,
  237. 'task_type': task_type,
  238. 'ave_node_num': float(td_node[4].text.strip()),
  239. 'ave_edge_num': float(td_node[5].text.strip()),
  240. 'node_labeled': True if td_node[6].text.strip() == '+' else False,
  241. 'edge_labeled': True if td_node[7].text.strip() == '+' else False,
  242. 'node_attr_dim': node_attr_dim,
  243. 'geometry': geometry,
  244. 'edge_attr_dim': edge_attr_dim,
  245. 'url': td_node[11].xpath('a')[0].attrib['href'].strip(),
  246. 'domain': domain
  247. }
  248. return infos
  249. def pretty_ds_infos(self, infos):
  250. """Get the string that pretty prints the information of datasets.
  251. Parameters
  252. ----------
  253. datasets : dict
  254. The datasets' information.
  255. Returns
  256. -------
  257. p_str : string
  258. The pretty print of the datasets' information.
  259. """
  260. p_str = '{\n'
  261. for key, val in infos.items():
  262. p_str += '\t\'' + str(key) + '\': {\n'
  263. for k, v in val.items():
  264. p_str += '\t\t\'' + str(k) + '\': '
  265. if isinstance(v, str):
  266. p_str += '\'' + str(v) + '\',\n'
  267. else:
  268. p_str += '' + str(v) + ',\n'
  269. p_str += '\t},\n'
  270. p_str += '}'
  271. return p_str
  272. @property
  273. def path(self):
  274. return self._path
  275. def dataset(self):
  276. if self.mode == "Tensorflow":
  277. return #something
  278. if self.mode == "Pytorch":
  279. return self.pytorch_dataset
  280. return self.dataset
  281. def info(self):
  282. print(self.info_dataset[self._name])
  283. def iter_load_dataset(self,data):
  284. results = []
  285. for datasets in data :
  286. results.append(loadDataset(osp.join(self._root,self._name,datasets)))
  287. return results
  288. def load_dataset(self,list_files):
  289. if self._name == "Ptc":
  290. if type(self.option) != str or self.option.upper() not in ['FR','FM','MM','MR']:
  291. raise ValueError('option for Ptc dataset needs to be one of : \n fr fm mm mr')
  292. results = []
  293. results.append(loadDataset(osp.join(self.root,self._name,'PTC/Test',self.gender + '.ds')))
  294. results.append(loadDataset(osp.join(self.root,self._name,'PTC/Train',self.gender + '.ds')))
  295. return results
  296. if self.name == "Pah":
  297. maximum_sets = 0
  298. for file in list_files:
  299. if file.endswith('ds'):
  300. maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0]))
  301. self.max_for_letter = maximum_sets
  302. if not type(self.option) == int or self.option > maximum_sets or self.option < 0:
  303. raise ValueError('option needs to be an integer between 0 and ' + str(maximum_sets))
  304. data = self.has_train_valid_test["Pah"]
  305. data[0] = self.has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(self.option) + '.ds'
  306. data[1] = self.has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(self.option) + '.ds'
  307. return self.iter_load_dataset(data)
  308. if self.name == "Letter":
  309. if type(self.option) == str and self.option.upper() in self.has_train_valid_test["Letter"]:
  310. data = self.has_train_valid_test["Letter"][self.option.upper()]
  311. else:
  312. message = "The parameter for letter is incorrect choose between : "
  313. message += "\nhigh med low"
  314. raise ValueError(message)
  315. return self.iter_load_dataset(data)
  316. if self.name in self.has_train_valid_test : #common IAM dataset with train, valid and test
  317. data = self.has_train_valid_test[self.name]
  318. return self.iter_load_dataset(data)
  319. else: #common dataset without train,valid and test, only dataset.ds file
  320. data = self.data_to_use_in_datasets[self.name]
  321. if len(data) > 1 and data[0] in list_files and data[1] in list_files: #case for Alkane
  322. return loadDataset(osp.join(self.root,self.name,data[0]),filename_y = osp.join(self.root,self.name,data[1]))
  323. if data in list_files:
  324. return loadDataset(osp.join(self.root,self.name,data))
  325. def build_dictionary(self,Gs):
  326. labels = set()
  327. #next line : from DeepGraphWithNNTorch
  328. #bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])))
  329. sizes = set()
  330. for G in Gs :
  331. for _,node in G.nodes(data = True): # or for node in nx.nodes(G)
  332. #print(_,node)
  333. labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0]) #what do we use for IAM datasets (they don't have bond_type or event label) ?
  334. sizes.add(G.order())
  335. label_dict = {}
  336. #print("labels : ", labels, bond_type_number_maxi)
  337. for i,label in enumerate(labels):
  338. label_dict[label] = [0.]*len(labels)
  339. label_dict[label][i] = 1.
  340. return label_dict
  341. def from_networkx_to_pytorch(self,Gs,y):
  342. #exemple for MAO: atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]}
  343. # code from https://github.com/bgauzere/pygnn/blob/master/utils.py
  344. atom_to_onehot = self.build_dictionary(Gs)
  345. max_size = 30
  346. adjs = []
  347. inputs = []
  348. for i, G in enumerate(Gs):
  349. I = torch.eye(G.order(), G.order())
  350. #A = torch.Tensor(nx.adjacency_matrix(G).todense())
  351. #A = torch.Tensor(nx.to_numpy_matrix(G))
  352. A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int) #what do we use for IAM datasets (they don't have bond_type or event label) ?
  353. adj = F.pad(A, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ? if yes : F.pad(A + I,pad = (...))
  354. adjs.append(adj)
  355. f_0 = []
  356. for _, label in G.nodes(data=True):
  357. #print(_,label)
  358. cur_label = atom_to_onehot[label['label'][0]].copy()
  359. f_0.append(cur_label)
  360. X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order()))
  361. inputs.append(X)
  362. return inputs,adjs,y
  363. def from_pytorch_to_tensorflow(self,batch_size):
  364. seed = random.randrange(sys.maxsize)
  365. random.seed(seed)
  366. tf_inputs = random.sample(self.pytorch_dataset[0],batch_size)
  367. random.seed(seed)
  368. tf_y = random.sample(self.pytorch_dataset[2],batch_size)
  369. def from_networkx_to_tensor(self,G,dict):
  370. A=nx.to_numpy_matrix(G)
  371. lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)]
  372. return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab))
  373. #dataset= selfopen_files()
  374. #print(build_dictionary(Gs))
  375. #dic={'C':0,'N':1,'O':2}
  376. #A,labels=from_networkx_to_tensor(Gs[13],dic)
  377. #print(nx.to_numpy_matrix(Gs[13]),labels)
  378. #print(A,labels)
  379. #@todo : from_networkx_to_tensorflow
  380. # dataloader = DataLoader('Acyclic',root = "database",option = 'high',mode = "Pytorch")
  381. # dataloader.info()
  382. # inputs,adjs,y = dataloader.pytorch_dataset
  383. # """
  384. # test,train,valid = dataloader.dataset
  385. # Gs,y = test
  386. # Gs2,y2 = train
  387. # Gs3,y3 = valid
  388. # """
  389. # #Gs,y = dataloader.
  390. # #print(Gs,y)
  391. # """
  392. # Gs,y = dataloader.dataset
  393. # for G in Gs :
  394. # for e in G.edges():
  395. # print(G[e[0]])
  396. # """
  397. # #for e in Gs[13].edges():
  398. # # print(Gs[13][e[0]])
  399. # #print(from_networkx_to_tensor(Gs[7],{'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]}))
  400. # #dataset.open_files()
  401. # import os
  402. # import os.path as osp
  403. # import urllib
  404. # import tarfile
  405. # from zipfile import ZipFile
  406. # from gklearn.utils.graphfiles import loadDataset
  407. # import torch
  408. # import torch.nn.functional as F
  409. # import networkx as nx
  410. # import matplotlib.pyplot as plt
  411. # import numpy as np
  412. #
  413. # def DataLoader(name,root = 'data',mode = "Networkx",downloadAll = False,reload = False,letter = "High",number = 0,gender = "MM"):
  414. # dir_name = "_".join(name.split("-"))
  415. # if not osp.exists(root) :
  416. # os.makedirs(root)
  417. # url = "https://brunl01.users.greyc.fr/CHEMISTRY/"
  418. # urliam = "https://iapr-tc15.greyc.fr/IAM/"
  419. # list_database = {
  420. # "Ace" : (url,"ACEDataset.tar"),
  421. # "Acyclic" : (url,"Acyclic.tar.gz"),
  422. # "Aids" : (urliam,"AIDS.zip"),
  423. # "Alkane" : (url,"alkane_dataset.tar.gz"),
  424. # "Chiral" : (url,"DatasetAcyclicChiral.tar"),
  425. # "Coil_Del" : (urliam,"COIL-DEL.zip"),
  426. # "Coil_Rag" : (urliam,"COIL-RAG.zip"),
  427. # "Fingerprint" : (urliam,"Fingerprint.zip"),
  428. # "Grec" : (urliam,"GREC.zip"),
  429. # "Letter" : (urliam,"Letter.zip"),
  430. # "Mao" : (url,"mao.tgz"),
  431. # "Monoterpenoides" : (url,"monoterpenoides.tar.gz"),
  432. # "Mutagenicity" : (urliam,"Mutagenicity.zip"),
  433. # "Pah" : (url,"PAH.tar.gz"),
  434. # "Protein" : (urliam,"Protein.zip"),
  435. # "Ptc" : (url,"ptc.tgz"),
  436. # "Steroid" : (url,"SteroidDataset.tar"),
  437. # "Vitamin" : (url,"DatasetVitamin.tar"),
  438. # "Web" : (urliam,"Web.zip")
  439. # }
  440. #
  441. # data_to_use_in_datasets = {
  442. # "Acyclic" : ("Acyclic/dataset_bps.ds"),
  443. # "Aids" : ("AIDS_A.txt"),
  444. # "Alkane" : ("Alkane/dataset.ds","Alkane/dataset_boiling_point_names.txt"),
  445. # "Mao" : ("MAO/dataset.ds"),
  446. # "Monoterpenoides" : ("monoterpenoides/dataset_10+.ds"), #('monoterpenoides/dataset.ds'),('monoterpenoides/dataset_9.ds'),('monoterpenoides/trainset_9.ds')
  447. #
  448. # }
  449. # has_train_valid_test = {
  450. # "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'),
  451. # "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'),
  452. # "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'),
  453. # "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'),
  454. # "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'),
  455. # 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'),
  456. # 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl')
  457. # },
  458. # "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'),
  459. # "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'],
  460. # "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'),
  461. # "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl')
  462. # }
  463. #
  464. # if not name :
  465. # raise ValueError("No dataset entered")
  466. # if name not in list_database:
  467. # message = "Invalid Dataset name " + name
  468. # message += '\n Available datasets are as follows : \n\n'
  469. # message += '\n'.join(database for database in list_database)
  470. # raise ValueError(message)
  471. #
  472. # def download_file(url,filename):
  473. # try :
  474. # response = urllib.request.urlopen(url + filename)
  475. # except urllib.error.HTTPError:
  476. # print(filename + " not available or incorrect http link")
  477. # return
  478. # return response
  479. #
  480. # def write_archive_file(root,database):
  481. # path = osp.join(root,database)
  482. # url,filename = list_database[database]
  483. # filename_dir = osp.join(path,filename)
  484. # if not osp.exists(filename_dir) or reload:
  485. # response = download_file(url,filename)
  486. # if response is None :
  487. # return
  488. # if not osp.exists(path) :
  489. # os.makedirs(path)
  490. # with open(filename_dir,'wb') as outfile :
  491. # outfile.write(response.read())
  492. #
  493. # if downloadAll :
  494. # print('Waiting...')
  495. # for database in list_database :
  496. # write_archive_file(root,database)
  497. # print('Downloading finished')
  498. # else:
  499. # write_archive_file(root,name)
  500. #
  501. # def iter_load_dataset(data):
  502. # results = []
  503. # for datasets in data :
  504. # results.append(loadDataset(osp.join(root,name,datasets)))
  505. # return results
  506. #
  507. # def load_dataset(list_files):
  508. # if name == "Ptc":
  509. # if gender.upper() not in ['FR','FM','MM','MR']:
  510. # raise ValueError('gender chosen needs to be one of \n fr fm mm mr')
  511. # results = []
  512. # results.append(loadDataset(osp.join(root,name,'PTC/Test',gender.upper() + '.ds')))
  513. # results.append(loadDataset(osp.join(root,name,'PTC/Train',gender.upper() + '.ds')))
  514. # return results
  515. # if name == "Pah":
  516. # maximum_sets = 0
  517. # for file in list_files:
  518. # if file.endswith('ds'):
  519. # maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0]))
  520. # if number > maximum_sets :
  521. # raise ValueError("Please select a dataset with number less than " + str(maximum_sets + 1))
  522. # data = has_train_valid_test["Pah"]
  523. # data[0] = has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(number) + '.ds'
  524. # data[1] = has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(number) + '.ds'
  525. # #print(data)
  526. # return iter_load_dataset(data)
  527. # if name == "Letter":
  528. # if letter.upper() in has_train_valid_test["Letter"]:
  529. # data = has_train_valid_test["Letter"][letter.upper()]
  530. # else:
  531. # message = "The parameter for letter is incorrect choose between : "
  532. # message += "\nhigh med low"
  533. # raise ValueError(message)
  534. # results = []
  535. # for datasets in data:
  536. # results.append(loadDataset(osp.join(root,name,datasets)))
  537. # return results
  538. # if name in has_train_valid_test : #common IAM dataset with train, valid and test
  539. # data = has_train_valid_test[name]
  540. # results = []
  541. # for datasets in data :
  542. # results.append(loadDataset(osp.join(root,name,datasets)))
  543. # return results
  544. # else: #common dataset without train,valid and test, only dataset.ds file
  545. # data = data_to_use_in_datasets[name]
  546. # if len(data) > 1 and data[0] in list_files and data[1] in list_files:
  547. # return loadDataset(osp.join(root,name,data[0]),filename_y = osp.join(root,name,data[1]))
  548. # if data in list_files:
  549. # return loadDataset(osp.join(root,name,data))
  550. # def open_files():
  551. # filename = list_database[name][1]
  552. # path = osp.join(root,name)
  553. # filename_archive = osp.join(root,name,filename)
  554. #
  555. # if filename.endswith('gz'):
  556. # if tarfile.is_tarfile(filename_archive):
  557. # with tarfile.open(filename_archive,"r:gz") as tar:
  558. # if reload:
  559. # print(filename + " Downloaded")
  560. # tar.extractall(path = path)
  561. # return load_dataset(tar.getnames())
  562. # #raise ValueError("dataset not available")
  563. #
  564. #
  565. # elif filename.endswith('.tar'):
  566. # if tarfile.is_tarfile(filename_archive):
  567. # with tarfile.open(filename_archive,"r:") as tar:
  568. # if reload :
  569. # print(filename + " Downloaded")
  570. # tar.extractall(path = path)
  571. # return load_dataset(tar.getnames())
  572. # elif filename.endswith('.zip'):
  573. # with ZipFile(filename_archive,"r") as zip_ref:
  574. # if reload :
  575. # print(filename + " Downloaded")
  576. # zip_ref.extractall(path)
  577. # return load_dataset(zip_ref.namelist())
  578. # else:
  579. # print(filename + " Unsupported file")
  580. # """
  581. # with tarfile.open(osp.join(root,name,list_database[name][1]),"r:gz") as files:
  582. # for file in files.getnames():
  583. # print(file)
  584. # """
  585. #
  586. # def build_dictionary(Gs):
  587. # labels = set()
  588. # bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])))
  589. # print(bond_type_number_maxi)
  590. # sizes = set()
  591. # for G in Gs :
  592. # for _,node in G.nodes(data = True): # or for node in nx.nodes(G)
  593. # #print(node)
  594. # labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0])
  595. # sizes.add(G.order())
  596. # if len(labels) >= bond_type_number_maxi:
  597. # break
  598. # label_dict = {}
  599. # for i,label in enumerate(labels):
  600. # label_dict[label] = [0.]*bond_type_number_maxi
  601. # label_dict[label][i] = 1.
  602. # return label_dict
  603. #
  604. # def from_networkx_to_pytorch(Gs):
  605. # #exemple : atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]}
  606. # # code from https://github.com/bgauzere/pygnn/blob/master/utils.py
  607. # atom_to_onehot = build_dictionary(Gs)
  608. # max_size = 30
  609. # adjs = []
  610. # inputs = []
  611. # for i, G in enumerate(Gs):
  612. # I = torch.eye(G.order(), G.order())
  613. # A = torch.Tensor(nx.adjacency_matrix(G).todense())
  614. # A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int)
  615. # adj = F.pad(A+I, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ?
  616. # adjs.append(adj)
  617. # f_0 = []
  618. # for _, label in G.nodes(data=True):
  619. # #print(_,label)
  620. # cur_label = atom_to_onehot[label['label'][0]].copy()
  621. # f_0.append(cur_label)
  622. # X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order()))
  623. # inputs.append(X)
  624. # return inputs,adjs,y
  625. #
  626. # def from_networkx_to_tensor(G,dict):
  627. # A=nx.to_numpy_matrix(G)
  628. # lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)]
  629. # return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab))
  630. #
  631. # dataset= open_files()
  632. # #print(build_dictionary(Gs))
  633. # #dic={'C':0,'N':1,'O':2}
  634. # #A,labels=from_networkx_to_tensor(Gs[13],dic)
  635. # #print(nx.to_numpy_matrix(Gs[13]),labels)
  636. # #print(A,labels)
  637. #
  638. # """
  639. # for G in Gs :
  640. # for node in nx.nodes(G):
  641. # print(G.nodes[node])
  642. # """
  643. # if mode == "pytorch":
  644. # Gs,y = dataset
  645. # inputs,adjs,y = from_networkx_to_pytorch(Gs)
  646. # print(inputs,adjs)
  647. # return inputs,adjs,y
  648. #
  649. #
  650. # """
  651. # dic = dict()
  652. # for i,l in enumerate(label):
  653. # dic[l] = i
  654. # dic = {'C': 0, 'N': 1, 'O': 2}
  655. # A,labels=from_networkx_to_tensor(Gs[0],dic)
  656. # #print(A,labels)
  657. # return A,labels
  658. # """
  659. #
  660. # return dataset
  661. #
  662. # #open_files()
  663. #
  664. # def label_to_color(label):
  665. # if label == 'C':
  666. # return 0.1
  667. # elif label == 'O':
  668. # return 0.8
  669. #
  670. # def nodes_to_color_sequence(G):
  671. # return [label_to_color(c[1]['label'][0]) for c in G.nodes(data=True)]
  672. # ##############
  673. # """
  674. # dataset = DataLoader('Mao',root = "database")
  675. # print(dataset)
  676. # Gs,y = dataset
  677. # """
  678. # """
  679. # dataset = DataLoader('Alkane',root = "database") # Gs is empty here whereas y isn't -> not working
  680. # Gs,y = dataset
  681. # """
  682. # """
  683. # dataset = DataLoader('Acyclic', root = "database")
  684. # Gs,y = dataset
  685. # """
  686. # """
  687. # dataset = DataLoader('Monoterpenoides', root = "database")
  688. # Gs,y = dataset
  689. # """
  690. # """
  691. # dataset = DataLoader('Pah',root = 'database', number = 8)
  692. # test_set,train_set = dataset
  693. # Gs,y = test_set
  694. # Gs2,y2 = train_set
  695. # """
  696. # """
  697. # dataset = DataLoader('Coil_Del',root = "database")
  698. # test,train,valid = dataset
  699. # Gs,y = test
  700. # Gs2,y2 = train
  701. # Gs3, y3 = valid
  702. # """
  703. # """
  704. # dataset = DataLoader('Coil_Rag',root = "database")
  705. # test,train,valid = dataset
  706. # Gs,y = test
  707. # Gs2,y2 = train
  708. # Gs3, y3 = valid
  709. # """
  710. # """
  711. # dataset = DataLoader('Fingerprint',root = "database")
  712. # test,train,valid = dataset
  713. # Gs,y = test
  714. # Gs2,y2 = train
  715. # Gs3, y3 = valid
  716. # """
  717. # """
  718. # dataset = DataLoader('Grec',root = "database")
  719. # test,train,valid = dataset
  720. # Gs,y = test
  721. # Gs2,y2 = train
  722. # Gs3, y3 = valid
  723. # """
  724. # """
  725. # dataset = DataLoader('Letter',root = "database",letter = 'low') #high low med
  726. # test,train,valid = dataset
  727. # Gs,y = test
  728. # Gs2,y2 = train
  729. # Gs3, y3 = valid
  730. # """
  731. # """
  732. # dataset = DataLoader('Mutagenicity',root = "database")
  733. # test,train,valid = dataset
  734. # Gs,y = test
  735. # Gs2,y2 = train
  736. # Gs3, y3 = valid
  737. # """
  738. # """
  739. # dataset = DataLoader('Protein',root = "database")
  740. # test,train,valid = dataset
  741. # Gs,y = test
  742. # Gs2,y2 = train
  743. # Gs3, y3 = valid
  744. # """
  745. # """
  746. # dataset = DataLoader('Ptc', root = "database",gender = 'fm') # not working, Gs and y are empty perhaps issue coming from loadDataset
  747. # valid,train = dataset
  748. # Gs,y = valid
  749. # Gs2,y2 = train
  750. # """
  751. # """
  752. # dataset = DataLoader('Web', root = "database")
  753. # test,train,valid = dataset
  754. # Gs,y = test
  755. # Gs2,y2 = train
  756. # Gs3,y3 = valid
  757. # """
  758. # print(Gs,y)
  759. # print(len(dataset))
  760. # ##############
  761. # #print('edge max label',max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])))
  762. # G1 = Gs[13]
  763. # G2 = Gs[23]
  764. # """
  765. # nx.draw_networkx(G1,with_labels=True,node_color = nodes_to_color_sequence(G1),cmap='autumn')
  766. # plt.figure()
  767. # nx.draw_networkx(G2,with_labels=True,node_color = nodes_to_color_sequence(G2),cmap='autumn')
  768. # """
  769. # from pathlib import Path
  770. # DATA_PATH = Path("data")
  771. # def import_datasets():
  772. #
  773. # import urllib
  774. # import tarfile
  775. # from zipfile import ZipFile
  776. # URL = "https://brunl01.users.greyc.fr/CHEMISTRY/"
  777. # URLIAM = "https://iapr-tc15.greyc.fr/IAM/"
  778. #
  779. # LIST_DATABASE = {
  780. # "Pah" : (URL,"PAH.tar.gz"),
  781. # "Mao" : (URL,"mao.tgz"),
  782. # "Ptc" : (URL,"ptc.tgz"),
  783. # "Aids" : (URLIAM,"AIDS.zip"),
  784. # "Acyclic" : (URL,"Acyclic.tar.gz"),
  785. # "Alkane" : (URL,"alkane_dataset.tar.gz"),
  786. # "Chiral" : (URL,"DatasetAcyclicChiral.tar"),
  787. # "Vitamin" : (URL,"DatasetVitamin.tar"),
  788. # "Ace" : (URL,"ACEDataset.tar"),
  789. # "Steroid" : (URL,"SteroidDataset.tar"),
  790. # "Monoterpenoides" : (URL,"monoterpenoides.tar.gz"),
  791. # "Letter" : (URLIAM,"Letter.zip"),
  792. # "Grec" : (URLIAM,"GREC.zip"),
  793. # "Fingerprint" : (URLIAM,"Fingerprint.zip"),
  794. # "Coil_Rag" : (URLIAM,"COIL-RAG.zip"),
  795. # "Coil_Del" : (URLIAM,"COIL-DEL.zip"),
  796. # "Web" : (URLIAM,"Web.zip"),
  797. # "Mutagenicity" : (URLIAM,"Mutagenicity.zip"),
  798. # "Protein" : (URLIAM,"Protein.zip")
  799. # }
  800. # print("Select databases in the list. Select multiple, split by white spaces .\nWrite All to select all of them.\n")
  801. # print(', '.join(database for database in LIST_DATABASE))
  802. # print("Choice : ",end = ' ')
  803. # selected_databases = input().split()
  804. #
  805. # def download_file(url,filename):
  806. # try :
  807. # response = urllib.request.urlopen(url + filename)
  808. # except urllib.error.HTTPError:
  809. # print(filename + " not available or incorrect http link")
  810. # return
  811. # return response
  812. #
  813. # def write_archive_file(database):
  814. #
  815. # PATH = DATA_PATH / database
  816. # url,filename = LIST_DATABASE[database]
  817. # if not (PATH / filename).exists():
  818. # response = download_file(url,filename)
  819. # if response is None :
  820. # return
  821. # if not PATH.exists() :
  822. # PATH.mkdir(parents=True, exist_ok=True)
  823. # with open(PATH/filename,'wb') as outfile :
  824. # outfile.write(response.read())
  825. #
  826. # if filename[-2:] == 'gz':
  827. # if tarfile.is_tarfile(PATH/filename):
  828. # with tarfile.open(PATH/filename,"r:gz") as tar:
  829. # tar.extractall(path = PATH)
  830. # print(filename + ' Downloaded')
  831. # elif filename[-3:] == 'tar':
  832. # if tarfile.is_tarfile(PATH/filename):
  833. # with tarfile.open(PATH/filename,"r:") as tar:
  834. # tar.extractall(path = PATH)
  835. # print(filename + ' Downloaded')
  836. # elif filename[-3:] == 'zip':
  837. # with ZipFile(PATH/filename,"r") as zip_ref:
  838. # zip_ref.extractall(PATH)
  839. # print(filename + ' Downloaded')
  840. # else:
  841. # print("Unsupported file")
  842. # if 'All' in selected_databases:
  843. # print('Waiting...')
  844. # for database in LIST_DATABASE :
  845. # write_archive_file(database)
  846. # print('Finished')
  847. # else:
  848. # print('Waiting...')
  849. # for database in selected_databases :
  850. # if database in LIST_DATABASE :
  851. # write_archive_file(database)
  852. # print('Finished')
  853. # import_datasets()
  854. # class GraphFetcher(object):
  855. #
  856. #
  857. # def __init__(self, filename=None, filename_targets=None, **kwargs):
  858. # if filename is None:
  859. # self._graphs = None
  860. # self._targets = None
  861. # self._node_labels = None
  862. # self._edge_labels = None
  863. # self._node_attrs = None
  864. # self._edge_attrs = None
  865. # else:
  866. # self.load_dataset(filename, filename_targets=filename_targets, **kwargs)
  867. #
  868. # self._substructures = None
  869. # self._node_label_dim = None
  870. # self._edge_label_dim = None
  871. # self._directed = None
  872. # self._dataset_size = None
  873. # self._total_node_num = None
  874. # self._ave_node_num = None
  875. # self._min_node_num = None
  876. # self._max_node_num = None
  877. # self._total_edge_num = None
  878. # self._ave_edge_num = None
  879. # self._min_edge_num = None
  880. # self._max_edge_num = None
  881. # self._ave_node_degree = None
  882. # self._min_node_degree = None
  883. # self._max_node_degree = None
  884. # self._ave_fill_factor = None
  885. # self._min_fill_factor = None
  886. # self._max_fill_factor = None
  887. # self._node_label_nums = None
  888. # self._edge_label_nums = None
  889. # self._node_attr_dim = None
  890. # self._edge_attr_dim = None
  891. # self._class_number = None
  892. #
  893. #
  894. # def load_dataset(self, filename, filename_targets=None, **kwargs):
  895. # self._graphs, self._targets, label_names = load_dataset(filename, filename_targets=filename_targets, **kwargs)
  896. # self._node_labels = label_names['node_labels']
  897. # self._node_attrs = label_names['node_attrs']
  898. # self._edge_labels = label_names['edge_labels']
  899. # self._edge_attrs = label_names['edge_attrs']
  900. # self.clean_labels()
  901. #
  902. #
  903. # def load_graphs(self, graphs, targets=None):
  904. # # this has to be followed by set_labels().
  905. # self._graphs = graphs
  906. # self._targets = targets
  907. # # self.set_labels_attrs() # @todo
  908. #
  909. #
  910. # def load_predefined_dataset(self, ds_name):
  911. # current_path = os.path.dirname(os.path.realpath(__file__)) + '/'
  912. # if ds_name == 'Acyclic':
  913. # ds_file = current_path + '../../datasets/Acyclic/dataset_bps.ds'
  914. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  915. # elif ds_name == 'AIDS':
  916. # ds_file = current_path + '../../datasets/AIDS/AIDS_A.txt'
  917. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  918. # elif ds_name == 'Alkane':
  919. # ds_file = current_path + '../../datasets/Alkane/dataset.ds'
  920. # fn_targets = current_path + '../../datasets/Alkane/dataset_boiling_point_names.txt'
  921. # self._graphs, self._targets, label_names = load_dataset(ds_file, filename_targets=fn_targets)
  922. # elif ds_name == 'COIL-DEL':
  923. # ds_file = current_path + '../../datasets/COIL-DEL/COIL-DEL_A.txt'
  924. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  925. # elif ds_name == 'COIL-RAG':
  926. # ds_file = current_path + '../../datasets/COIL-RAG/COIL-RAG_A.txt'
  927. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  928. # elif ds_name == 'COLORS-3':
  929. # ds_file = current_path + '../../datasets/COLORS-3/COLORS-3_A.txt'
  930. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  931. # elif ds_name == 'Cuneiform':
  932. # ds_file = current_path + '../../datasets/Cuneiform/Cuneiform_A.txt'
  933. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  934. # elif ds_name == 'DD':
  935. # ds_file = current_path + '../../datasets/DD/DD_A.txt'
  936. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  937. # elif ds_name == 'ENZYMES':
  938. # ds_file = current_path + '../../datasets/ENZYMES_txt/ENZYMES_A_sparse.txt'
  939. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  940. # elif ds_name == 'Fingerprint':
  941. # ds_file = current_path + '../../datasets/Fingerprint/Fingerprint_A.txt'
  942. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  943. # elif ds_name == 'FRANKENSTEIN':
  944. # ds_file = current_path + '../../datasets/FRANKENSTEIN/FRANKENSTEIN_A.txt'
  945. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  946. # elif ds_name == 'Letter-high': # node non-symb
  947. # ds_file = current_path + '../../datasets/Letter-high/Letter-high_A.txt'
  948. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  949. # elif ds_name == 'Letter-low': # node non-symb
  950. # ds_file = current_path + '../../datasets/Letter-low/Letter-low_A.txt'
  951. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  952. # elif ds_name == 'Letter-med': # node non-symb
  953. # ds_file = current_path + '../../datasets/Letter-med/Letter-med_A.txt'
  954. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  955. # elif ds_name == 'MAO':
  956. # ds_file = current_path + '../../datasets/MAO/dataset.ds'
  957. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  958. # elif ds_name == 'Monoterpenoides':
  959. # ds_file = current_path + '../../datasets/Monoterpenoides/dataset_10+.ds'
  960. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  961. # elif ds_name == 'MUTAG':
  962. # ds_file = current_path + '../../datasets/MUTAG/MUTAG_A.txt'
  963. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  964. # elif ds_name == 'NCI1':
  965. # ds_file = current_path + '../../datasets/NCI1/NCI1_A.txt'
  966. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  967. # elif ds_name == 'NCI109':
  968. # ds_file = current_path + '../../datasets/NCI109/NCI109_A.txt'
  969. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  970. # elif ds_name == 'PAH':
  971. # ds_file = current_path + '../../datasets/PAH/dataset.ds'
  972. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  973. # elif ds_name == 'SYNTHETIC':
  974. # pass
  975. # elif ds_name == 'SYNTHETICnew':
  976. # ds_file = current_path + '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt'
  977. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  978. # elif ds_name == 'Synthie':
  979. # pass
  980. # else:
  981. # raise Exception('The dataset name "', ds_name, '" is not pre-defined.')
  982. #
  983. # self._node_labels = label_names['node_labels']
  984. # self._node_attrs = label_names['node_attrs']
  985. # self._edge_labels = label_names['edge_labels']
  986. # self._edge_attrs = label_names['edge_attrs']
  987. # self.clean_labels()
  988. #
  989. # def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]):
  990. # self._node_labels = node_labels
  991. # self._node_attrs = node_attrs
  992. # self._edge_labels = edge_labels
  993. # self._edge_attrs = edge_attrs
  994. #
  995. # def set_labels_attrs(self, node_labels=None, node_attrs=None, edge_labels=None, edge_attrs=None):
  996. # # @todo: remove labels which have only one possible values.
  997. # if node_labels is None:
  998. # self._node_labels = self._graphs[0].graph['node_labels']
  999. # # # graphs are considered node unlabeled if all nodes have the same label.
  1000. # # infos.update({'node_labeled': is_nl if node_label_num > 1 else False})
  1001. # if node_attrs is None:
  1002. # self._node_attrs = self._graphs[0].graph['node_attrs']
  1003. # # for G in Gn:
  1004. # # for n in G.nodes(data=True):
  1005. # # if 'attributes' in n[1]:
  1006. # # return len(n[1]['attributes'])
  1007. # # return 0
  1008. # if edge_labels is None:
  1009. # self._edge_labels = self._graphs[0].graph['edge_labels']
  1010. # # # graphs are considered edge unlabeled if all edges have the same label.
  1011. # # infos.update({'edge_labeled': is_el if edge_label_num > 1 else False})
  1012. # if edge_attrs is None:
  1013. # self._edge_attrs = self._graphs[0].graph['edge_attrs']
  1014. # # for G in Gn:
  1015. # # if nx.number_of_edges(G) > 0:
  1016. # # for e in G.edges(data=True):
  1017. # # if 'attributes' in e[2]:
  1018. # # return len(e[2]['attributes'])
  1019. # # return 0
  1020. #
  1021. #
  1022. # def get_dataset_infos(self, keys=None, params=None):
  1023. # """Computes and returns the structure and property information of the graph dataset.
  1024. #
  1025. # Parameters
  1026. # ----------
  1027. # keys : list, optional
  1028. # A list of strings which indicate which informations will be returned. The
  1029. # possible choices includes:
  1030. #
  1031. # 'substructures': sub-structures graphs contains, including 'linear', 'non
  1032. # linear' and 'cyclic'.
  1033. #
  1034. # 'node_label_dim': whether vertices have symbolic labels.
  1035. #
  1036. # 'edge_label_dim': whether egdes have symbolic labels.
  1037. #
  1038. # 'directed': whether graphs in dataset are directed.
  1039. #
  1040. # 'dataset_size': number of graphs in dataset.
  1041. #
  1042. # 'total_node_num': total number of vertices of all graphs in dataset.
  1043. #
  1044. # 'ave_node_num': average number of vertices of graphs in dataset.
  1045. #
  1046. # 'min_node_num': minimum number of vertices of graphs in dataset.
  1047. #
  1048. # 'max_node_num': maximum number of vertices of graphs in dataset.
  1049. #
  1050. # 'total_edge_num': total number of edges of all graphs in dataset.
  1051. #
  1052. # 'ave_edge_num': average number of edges of graphs in dataset.
  1053. #
  1054. # 'min_edge_num': minimum number of edges of graphs in dataset.
  1055. #
  1056. # 'max_edge_num': maximum number of edges of graphs in dataset.
  1057. #
  1058. # 'ave_node_degree': average vertex degree of graphs in dataset.
  1059. #
  1060. # 'min_node_degree': minimum vertex degree of graphs in dataset.
  1061. #
  1062. # 'max_node_degree': maximum vertex degree of graphs in dataset.
  1063. #
  1064. # 'ave_fill_factor': average fill factor (number_of_edges /
  1065. # (number_of_nodes ** 2)) of graphs in dataset.
  1066. #
  1067. # 'min_fill_factor': minimum fill factor of graphs in dataset.
  1068. #
  1069. # 'max_fill_factor': maximum fill factor of graphs in dataset.
  1070. #
  1071. # 'node_label_nums': list of numbers of symbolic vertex labels of graphs in dataset.
  1072. #
  1073. # 'edge_label_nums': list number of symbolic edge labels of graphs in dataset.
  1074. #
  1075. # 'node_attr_dim': number of dimensions of non-symbolic vertex labels.
  1076. # Extracted from the 'attributes' attribute of graph nodes.
  1077. #
  1078. # 'edge_attr_dim': number of dimensions of non-symbolic edge labels.
  1079. # Extracted from the 'attributes' attribute of graph edges.
  1080. #
  1081. # 'class_number': number of classes. Only available for classification problems.
  1082. #
  1083. # 'all_degree_entropy': the entropy of degree distribution of each graph.
  1084. #
  1085. # 'ave_degree_entropy': the average entropy of degree distribution of all graphs.
  1086. #
  1087. # All informations above will be returned if `keys` is not given.
  1088. #
  1089. # params: dict of dict, optional
  1090. # A dictinary which contains extra parameters for each possible
  1091. # element in ``keys``.
  1092. #
  1093. # Return
  1094. # ------
  1095. # dict
  1096. # Information of the graph dataset keyed by `keys`.
  1097. # """
  1098. # infos = {}
  1099. #
  1100. # if keys == None:
  1101. # keys = [
  1102. # 'substructures',
  1103. # 'node_label_dim',
  1104. # 'edge_label_dim',
  1105. # 'directed',
  1106. # 'dataset_size',
  1107. # 'total_node_num',
  1108. # 'ave_node_num',
  1109. # 'min_node_num',
  1110. # 'max_node_num',
  1111. # 'total_edge_num',
  1112. # 'ave_edge_num',
  1113. # 'min_edge_num',
  1114. # 'max_edge_num',
  1115. # 'ave_node_degree',
  1116. # 'min_node_degree',
  1117. # 'max_node_degree',
  1118. # 'ave_fill_factor',
  1119. # 'min_fill_factor',
  1120. # 'max_fill_factor',
  1121. # 'node_label_nums',
  1122. # 'edge_label_nums',
  1123. # 'node_attr_dim',
  1124. # 'edge_attr_dim',
  1125. # 'class_number',
  1126. # 'all_degree_entropy',
  1127. # 'ave_degree_entropy'
  1128. # ]
  1129. #
  1130. # # dataset size
  1131. # if 'dataset_size' in keys:
  1132. # if self._dataset_size is None:
  1133. # self._dataset_size = self._get_dataset_size()
  1134. # infos['dataset_size'] = self._dataset_size
  1135. #
  1136. # # graph node number
  1137. # if any(i in keys for i in ['total_node_num', 'ave_node_num', 'min_node_num', 'max_node_num']):
  1138. # all_node_nums = self._get_all_node_nums()
  1139. # if 'total_node_num' in keys:
  1140. # if self._total_node_num is None:
  1141. # self._total_node_num = self._get_total_node_num(all_node_nums)
  1142. # infos['total_node_num'] = self._total_node_num
  1143. #
  1144. # if 'ave_node_num' in keys:
  1145. # if self._ave_node_num is None:
  1146. # self._ave_node_num = self._get_ave_node_num(all_node_nums)
  1147. # infos['ave_node_num'] = self._ave_node_num
  1148. #
  1149. # if 'min_node_num' in keys:
  1150. # if self._min_node_num is None:
  1151. # self._min_node_num = self._get_min_node_num(all_node_nums)
  1152. # infos['min_node_num'] = self._min_node_num
  1153. #
  1154. # if 'max_node_num' in keys:
  1155. # if self._max_node_num is None:
  1156. # self._max_node_num = self._get_max_node_num(all_node_nums)
  1157. # infos['max_node_num'] = self._max_node_num
  1158. #
  1159. # # graph edge number
  1160. # if any(i in keys for i in ['total_edge_num', 'ave_edge_num', 'min_edge_num', 'max_edge_num']):
  1161. # all_edge_nums = self._get_all_edge_nums()
  1162. # if 'total_edge_num' in keys:
  1163. # if self._total_edge_num is None:
  1164. # self._total_edge_num = self._get_total_edge_num(all_edge_nums)
  1165. # infos['total_edge_num'] = self._total_edge_num
  1166. #
  1167. # if 'ave_edge_num' in keys:
  1168. # if self._ave_edge_num is None:
  1169. # self._ave_edge_num = self._get_ave_edge_num(all_edge_nums)
  1170. # infos['ave_edge_num'] = self._ave_edge_num
  1171. #
  1172. # if 'max_edge_num' in keys:
  1173. # if self._max_edge_num is None:
  1174. # self._max_edge_num = self._get_max_edge_num(all_edge_nums)
  1175. # infos['max_edge_num'] = self._max_edge_num
  1176. # if 'min_edge_num' in keys:
  1177. # if self._min_edge_num is None:
  1178. # self._min_edge_num = self._get_min_edge_num(all_edge_nums)
  1179. # infos['min_edge_num'] = self._min_edge_num
  1180. #
  1181. # # label number
  1182. # if 'node_label_dim' in keys:
  1183. # if self._node_label_dim is None:
  1184. # self._node_label_dim = self._get_node_label_dim()
  1185. # infos['node_label_dim'] = self._node_label_dim
  1186. #
  1187. # if 'node_label_nums' in keys:
  1188. # if self._node_label_nums is None:
  1189. # self._node_label_nums = {}
  1190. # for node_label in self._node_labels:
  1191. # self._node_label_nums[node_label] = self._get_node_label_num(node_label)
  1192. # infos['node_label_nums'] = self._node_label_nums
  1193. #
  1194. # if 'edge_label_dim' in keys:
  1195. # if self._edge_label_dim is None:
  1196. # self._edge_label_dim = self._get_edge_label_dim()
  1197. # infos['edge_label_dim'] = self._edge_label_dim
  1198. #
  1199. # if 'edge_label_nums' in keys:
  1200. # if self._edge_label_nums is None:
  1201. # self._edge_label_nums = {}
  1202. # for edge_label in self._edge_labels:
  1203. # self._edge_label_nums[edge_label] = self._get_edge_label_num(edge_label)
  1204. # infos['edge_label_nums'] = self._edge_label_nums
  1205. #
  1206. # if 'directed' in keys or 'substructures' in keys:
  1207. # if self._directed is None:
  1208. # self._directed = self._is_directed()
  1209. # infos['directed'] = self._directed
  1210. #
  1211. # # node degree
  1212. # if any(i in keys for i in ['ave_node_degree', 'max_node_degree', 'min_node_degree']):
  1213. # all_node_degrees = self._get_all_node_degrees()
  1214. #
  1215. # if 'ave_node_degree' in keys:
  1216. # if self._ave_node_degree is None:
  1217. # self._ave_node_degree = self._get_ave_node_degree(all_node_degrees)
  1218. # infos['ave_node_degree'] = self._ave_node_degree
  1219. #
  1220. # if 'max_node_degree' in keys:
  1221. # if self._max_node_degree is None:
  1222. # self._max_node_degree = self._get_max_node_degree(all_node_degrees)
  1223. # infos['max_node_degree'] = self._max_node_degree
  1224. #
  1225. # if 'min_node_degree' in keys:
  1226. # if self._min_node_degree is None:
  1227. # self._min_node_degree = self._get_min_node_degree(all_node_degrees)
  1228. # infos['min_node_degree'] = self._min_node_degree
  1229. #
  1230. # # fill factor
  1231. # if any(i in keys for i in ['ave_fill_factor', 'max_fill_factor', 'min_fill_factor']):
  1232. # all_fill_factors = self._get_all_fill_factors()
  1233. #
  1234. # if 'ave_fill_factor' in keys:
  1235. # if self._ave_fill_factor is None:
  1236. # self._ave_fill_factor = self._get_ave_fill_factor(all_fill_factors)
  1237. # infos['ave_fill_factor'] = self._ave_fill_factor
  1238. #
  1239. # if 'max_fill_factor' in keys:
  1240. # if self._max_fill_factor is None:
  1241. # self._max_fill_factor = self._get_max_fill_factor(all_fill_factors)
  1242. # infos['max_fill_factor'] = self._max_fill_factor
  1243. #
  1244. # if 'min_fill_factor' in keys:
  1245. # if self._min_fill_factor is None:
  1246. # self._min_fill_factor = self._get_min_fill_factor(all_fill_factors)
  1247. # infos['min_fill_factor'] = self._min_fill_factor
  1248. #
  1249. # if 'substructures' in keys:
  1250. # if self._substructures is None:
  1251. # self._substructures = self._get_substructures()
  1252. # infos['substructures'] = self._substructures
  1253. #
  1254. # if 'class_number' in keys:
  1255. # if self._class_number is None:
  1256. # self._class_number = self._get_class_number()
  1257. # infos['class_number'] = self._class_number
  1258. #
  1259. # if 'node_attr_dim' in keys:
  1260. # if self._node_attr_dim is None:
  1261. # self._node_attr_dim = self._get_node_attr_dim()
  1262. # infos['node_attr_dim'] = self._node_attr_dim
  1263. #
  1264. # if 'edge_attr_dim' in keys:
  1265. # if self._edge_attr_dim is None:
  1266. # self._edge_attr_dim = self._get_edge_attr_dim()
  1267. # infos['edge_attr_dim'] = self._edge_attr_dim
  1268. #
  1269. # # entropy of degree distribution.
  1270. #
  1271. # if 'all_degree_entropy' in keys:
  1272. # if params is not None and ('all_degree_entropy' in params) and ('base' in params['all_degree_entropy']):
  1273. # base = params['all_degree_entropy']['base']
  1274. # else:
  1275. # base = None
  1276. # infos['all_degree_entropy'] = self._compute_all_degree_entropy(base=base)
  1277. #
  1278. # if 'ave_degree_entropy' in keys:
  1279. # if params is not None and ('ave_degree_entropy' in params) and ('base' in params['ave_degree_entropy']):
  1280. # base = params['ave_degree_entropy']['base']
  1281. # else:
  1282. # base = None
  1283. # infos['ave_degree_entropy'] = np.mean(self._compute_all_degree_entropy(base=base))
  1284. #
  1285. # return infos
  1286. #
  1287. #
  1288. # def print_graph_infos(self, infos):
  1289. # from collections import OrderedDict
  1290. # keys = list(infos.keys())
  1291. # print(OrderedDict(sorted(infos.items(), key=lambda i: keys.index(i[0]))))
  1292. #
  1293. #
  1294. # def remove_labels(self, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]):
  1295. # node_labels = [item for item in node_labels if item in self._node_labels]
  1296. # edge_labels = [item for item in edge_labels if item in self._edge_labels]
  1297. # node_attrs = [item for item in node_attrs if item in self._node_attrs]
  1298. # edge_attrs = [item for item in edge_attrs if item in self._edge_attrs]
  1299. # for g in self._graphs:
  1300. # for nd in g.nodes():
  1301. # for nl in node_labels:
  1302. # del g.nodes[nd][nl]
  1303. # for na in node_attrs:
  1304. # del g.nodes[nd][na]
  1305. # for ed in g.edges():
  1306. # for el in edge_labels:
  1307. # del g.edges[ed][el]
  1308. # for ea in edge_attrs:
  1309. # del g.edges[ed][ea]
  1310. # if len(node_labels) > 0:
  1311. # self._node_labels = [nl for nl in self._node_labels if nl not in node_labels]
  1312. # if len(edge_labels) > 0:
  1313. # self._edge_labels = [el for el in self._edge_labels if el not in edge_labels]
  1314. # if len(node_attrs) > 0:
  1315. # self._node_attrs = [na for na in self._node_attrs if na not in node_attrs]
  1316. # if len(edge_attrs) > 0:
  1317. # self._edge_attrs = [ea for ea in self._edge_attrs if ea not in edge_attrs]
  1318. #
  1319. #
  1320. # def clean_labels(self):
  1321. # labels = []
  1322. # for name in self._node_labels:
  1323. # label = set()
  1324. # for G in self._graphs:
  1325. # label = label | set(nx.get_node_attributes(G, name).values())
  1326. # if len(label) > 1:
  1327. # labels.append(name)
  1328. # break
  1329. # if len(label) < 2:
  1330. # for G in self._graphs:
  1331. # for nd in G.nodes():
  1332. # del G.nodes[nd][name]
  1333. # self._node_labels = labels
  1334. # labels = []
  1335. # for name in self._edge_labels:
  1336. # label = set()
  1337. # for G in self._graphs:
  1338. # label = label | set(nx.get_edge_attributes(G, name).values())
  1339. # if len(label) > 1:
  1340. # labels.append(name)
  1341. # break
  1342. # if len(label) < 2:
  1343. # for G in self._graphs:
  1344. # for ed in G.edges():
  1345. # del G.edges[ed][name]
  1346. # self._edge_labels = labels
  1347. # labels = []
  1348. # for name in self._node_attrs:
  1349. # label = set()
  1350. # for G in self._graphs:
  1351. # label = label | set(nx.get_node_attributes(G, name).values())
  1352. # if len(label) > 1:
  1353. # labels.append(name)
  1354. # break
  1355. # if len(label) < 2:
  1356. # for G in self._graphs:
  1357. # for nd in G.nodes():
  1358. # del G.nodes[nd][name]
  1359. # self._node_attrs = labels
  1360. # labels = []
  1361. # for name in self._edge_attrs:
  1362. # label = set()
  1363. # for G in self._graphs:
  1364. # label = label | set(nx.get_edge_attributes(G, name).values())
  1365. # if len(label) > 1:
  1366. # labels.append(name)
  1367. # break
  1368. # if len(label) < 2:
  1369. # for G in self._graphs:
  1370. # for ed in G.edges():
  1371. # del G.edges[ed][name]
  1372. # self._edge_attrs = labels
  1373. #
  1374. #
  1375. # def cut_graphs(self, range_):
  1376. # self._graphs = [self._graphs[i] for i in range_]
  1377. # if self._targets is not None:
  1378. # self._targets = [self._targets[i] for i in range_]
  1379. # self.clean_labels()
  1380. # def trim_dataset(self, edge_required=False):
  1381. # if edge_required:
  1382. # trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if (nx.number_of_nodes(g) != 0 and nx.number_of_edges(g) != 0)]
  1383. # else:
  1384. # trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if nx.number_of_nodes(g) != 0]
  1385. # idx = [p[0] for p in trimed_pairs]
  1386. # self._graphs = [p[1] for p in trimed_pairs]
  1387. # self._targets = [self._targets[i] for i in idx]
  1388. # self.clean_labels()
  1389. #
  1390. #
  1391. # def copy(self):
  1392. # dataset = Dataset()
  1393. # graphs = [g.copy() for g in self._graphs] if self._graphs is not None else None
  1394. # target = self._targets.copy() if self._targets is not None else None
  1395. # node_labels = self._node_labels.copy() if self._node_labels is not None else None
  1396. # node_attrs = self._node_attrs.copy() if self._node_attrs is not None else None
  1397. # edge_labels = self._edge_labels.copy() if self._edge_labels is not None else None
  1398. # edge_attrs = self._edge_attrs.copy() if self._edge_attrs is not None else None
  1399. # dataset.load_graphs(graphs, target)
  1400. # dataset.set_labels(node_labels=node_labels, node_attrs=node_attrs, edge_labels=edge_labels, edge_attrs=edge_attrs)
  1401. # # @todo: clean_labels and add other class members?
  1402. # return dataset
  1403. #
  1404. #
  1405. # def get_all_node_labels(self):
  1406. # node_labels = []
  1407. # for g in self._graphs:
  1408. # for n in g.nodes():
  1409. # nl = tuple(g.nodes[n].items())
  1410. # if nl not in node_labels:
  1411. # node_labels.append(nl)
  1412. # return node_labels
  1413. #
  1414. #
  1415. # def get_all_edge_labels(self):
  1416. # edge_labels = []
  1417. # for g in self._graphs:
  1418. # for e in g.edges():
  1419. # el = tuple(g.edges[e].items())
  1420. # if el not in edge_labels:
  1421. # edge_labels.append(el)
  1422. # return edge_labels
  1423. #
  1424. #
  1425. # def _get_dataset_size(self):
  1426. # return len(self._graphs)
  1427. #
  1428. #
  1429. # def _get_all_node_nums(self):
  1430. # return [nx.number_of_nodes(G) for G in self._graphs]
  1431. #
  1432. #
  1433. # def _get_total_node_nums(self, all_node_nums):
  1434. # return np.sum(all_node_nums)
  1435. #
  1436. #
  1437. # def _get_ave_node_num(self, all_node_nums):
  1438. # return np.mean(all_node_nums)
  1439. #
  1440. #
  1441. # def _get_min_node_num(self, all_node_nums):
  1442. # return np.amin(all_node_nums)
  1443. #
  1444. #
  1445. # def _get_max_node_num(self, all_node_nums):
  1446. # return np.amax(all_node_nums)
  1447. #
  1448. #
  1449. # def _get_all_edge_nums(self):
  1450. # return [nx.number_of_edges(G) for G in self._graphs]
  1451. #
  1452. #
  1453. # def _get_total_edge_nums(self, all_edge_nums):
  1454. # return np.sum(all_edge_nums)
  1455. #
  1456. #
  1457. # def _get_ave_edge_num(self, all_edge_nums):
  1458. # return np.mean(all_edge_nums)
  1459. #
  1460. #
  1461. # def _get_min_edge_num(self, all_edge_nums):
  1462. # return np.amin(all_edge_nums)
  1463. #
  1464. #
  1465. # def _get_max_edge_num(self, all_edge_nums):
  1466. # return np.amax(all_edge_nums)
  1467. #
  1468. #
  1469. # def _get_node_label_dim(self):
  1470. # return len(self._node_labels)
  1471. #
  1472. #
  1473. # def _get_node_label_num(self, node_label):
  1474. # nl = set()
  1475. # for G in self._graphs:
  1476. # nl = nl | set(nx.get_node_attributes(G, node_label).values())
  1477. # return len(nl)
  1478. #
  1479. #
  1480. # def _get_edge_label_dim(self):
  1481. # return len(self._edge_labels)
  1482. #
  1483. #
  1484. # def _get_edge_label_num(self, edge_label):
  1485. # el = set()
  1486. # for G in self._graphs:
  1487. # el = el | set(nx.get_edge_attributes(G, edge_label).values())
  1488. # return len(el)
  1489. #
  1490. #
  1491. # def _is_directed(self):
  1492. # return nx.is_directed(self._graphs[0])
  1493. #
  1494. #
  1495. # def _get_all_node_degrees(self):
  1496. # return [np.mean(list(dict(G.degree()).values())) for G in self._graphs]
  1497. #
  1498. #
  1499. # def _get_ave_node_degree(self, all_node_degrees):
  1500. # return np.mean(all_node_degrees)
  1501. #
  1502. #
  1503. # def _get_max_node_degree(self, all_node_degrees):
  1504. # return np.amax(all_node_degrees)
  1505. #
  1506. #
  1507. # def _get_min_node_degree(self, all_node_degrees):
  1508. # return np.amin(all_node_degrees)
  1509. #
  1510. #
  1511. # def _get_all_fill_factors(self):
  1512. # """Get fill factor, the number of non-zero entries in the adjacency matrix.
  1513. # Returns
  1514. # -------
  1515. # list[float]
  1516. # List of fill factors for all graphs.
  1517. # """
  1518. # return [nx.number_of_edges(G) / (nx.number_of_nodes(G) ** 2) for G in self._graphs]
  1519. #
  1520. # def _get_ave_fill_factor(self, all_fill_factors):
  1521. # return np.mean(all_fill_factors)
  1522. #
  1523. #
  1524. # def _get_max_fill_factor(self, all_fill_factors):
  1525. # return np.amax(all_fill_factors)
  1526. #
  1527. #
  1528. # def _get_min_fill_factor(self, all_fill_factors):
  1529. # return np.amin(all_fill_factors)
  1530. #
  1531. #
  1532. # def _get_substructures(self):
  1533. # subs = set()
  1534. # for G in self._graphs:
  1535. # degrees = list(dict(G.degree()).values())
  1536. # if any(i == 2 for i in degrees):
  1537. # subs.add('linear')
  1538. # if np.amax(degrees) >= 3:
  1539. # subs.add('non linear')
  1540. # if 'linear' in subs and 'non linear' in subs:
  1541. # break
  1542. # if self._directed:
  1543. # for G in self._graphs:
  1544. # if len(list(nx.find_cycle(G))) > 0:
  1545. # subs.add('cyclic')
  1546. # break
  1547. # # else:
  1548. # # # @todo: this method does not work for big graph with large amount of edges like D&D, try a better way.
  1549. # # upper = np.amin([nx.number_of_edges(G) for G in Gn]) * 2 + 10
  1550. # # for G in Gn:
  1551. # # if (nx.number_of_edges(G) < upper):
  1552. # # cyc = list(nx.simple_cycles(G.to_directed()))
  1553. # # if any(len(i) > 2 for i in cyc):
  1554. # # subs.add('cyclic')
  1555. # # break
  1556. # # if 'cyclic' not in subs:
  1557. # # for G in Gn:
  1558. # # cyc = list(nx.simple_cycles(G.to_directed()))
  1559. # # if any(len(i) > 2 for i in cyc):
  1560. # # subs.add('cyclic')
  1561. # # break
  1562. #
  1563. # return subs
  1564. #
  1565. #
  1566. # def _get_class_num(self):
  1567. # return len(set(self._targets))
  1568. #
  1569. #
  1570. # def _get_node_attr_dim(self):
  1571. # return len(self._node_attrs)
  1572. #
  1573. #
  1574. # def _get_edge_attr_dim(self):
  1575. # return len(self._edge_attrs)
  1576. #
  1577. # def _compute_all_degree_entropy(self, base=None):
  1578. # """Compute the entropy of degree distribution of each graph.
  1579. # Parameters
  1580. # ----------
  1581. # base : float, optional
  1582. # The logarithmic base to use. The default is ``e`` (natural logarithm).
  1583. # Returns
  1584. # -------
  1585. # degree_entropy : float
  1586. # The calculated entropy.
  1587. # """
  1588. # from gklearn.utils.stats import entropy
  1589. #
  1590. # degree_entropy = []
  1591. # for g in self._graphs:
  1592. # degrees = list(dict(g.degree()).values())
  1593. # en = entropy(degrees, base=base)
  1594. # degree_entropy.append(en)
  1595. # return degree_entropy
  1596. #
  1597. #
  1598. # @property
  1599. # def graphs(self):
  1600. # return self._graphs
  1601. # @property
  1602. # def targets(self):
  1603. # return self._targets
  1604. #
  1605. #
  1606. # @property
  1607. # def node_labels(self):
  1608. # return self._node_labels
  1609. # @property
  1610. # def edge_labels(self):
  1611. # return self._edge_labels
  1612. #
  1613. #
  1614. # @property
  1615. # def node_attrs(self):
  1616. # return self._node_attrs
  1617. #
  1618. #
  1619. # @property
  1620. # def edge_attrs(self):
  1621. # return self._edge_attrs
  1622. #
  1623. #
  1624. # def split_dataset_by_target(dataset):
  1625. # from gklearn.preimage.utils import get_same_item_indices
  1626. #
  1627. # graphs = dataset.graphs
  1628. # targets = dataset.targets
  1629. # datasets = []
  1630. # idx_targets = get_same_item_indices(targets)
  1631. # for key, val in idx_targets.items():
  1632. # sub_graphs = [graphs[i] for i in val]
  1633. # sub_dataset = Dataset()
  1634. # sub_dataset.load_graphs(sub_graphs, [key] * len(val))
  1635. # node_labels = dataset.node_labels.copy() if dataset.node_labels is not None else None
  1636. # node_attrs = dataset.node_attrs.copy() if dataset.node_attrs is not None else None
  1637. # edge_labels = dataset.edge_labels.copy() if dataset.edge_labels is not None else None
  1638. # edge_attrs = dataset.edge_attrs.copy() if dataset.edge_attrs is not None else None
  1639. # sub_dataset.set_labels(node_labels=node_labels, node_attrs=node_attrs, edge_labels=edge_labels, edge_attrs=edge_attrs)
  1640. # datasets.append(sub_dataset)
  1641. # # @todo: clean_labels?
  1642. # return datasets

A Python package for graph kernels, graph edit distances and graph pre-image problem.