| @@ -7,15 +7,9 @@ Created on Tue Oct 20 14:25:49 2020 | |||||
| Paul Zanoncelli, paul.zanoncelli@ecole.ensicaen.fr | Paul Zanoncelli, paul.zanoncelli@ecole.ensicaen.fr | ||||
| Luc Brun luc.brun@ensicaen.fr | Luc Brun luc.brun@ensicaen.fr | ||||
| Sebastien Bougleux sebastien.bougleux@unicaen.fr | Sebastien Bougleux sebastien.bougleux@unicaen.fr | ||||
| benoit gaüzère benoit.gauzere@insa-rouen.fr | |||||
| Benoit Gaüzère benoit.gauzere@insa-rouen.fr | |||||
| Linlin Jia linlin.jia@insa-rouen.fr | Linlin Jia linlin.jia@insa-rouen.fr | ||||
| """ | """ | ||||
| import numpy as np | |||||
| import networkx as nx | |||||
| from gklearn.utils.graph_files import load_dataset | |||||
| import os | |||||
| import os | import os | ||||
| import os.path as osp | import os.path as osp | ||||
| import urllib | import urllib | ||||
| @@ -29,299 +23,154 @@ import random | |||||
| import sys | import sys | ||||
| from lxml import etree | from lxml import etree | ||||
| import re | import re | ||||
| from gklearn.dataset import DATABASES | |||||
| from tqdm import tqdm | |||||
| from gklearn.dataset import DATABASES, DATASET_META | |||||
| class DataFetcher(): | class DataFetcher(): | ||||
| def __init__(self,name='Ace',root = 'data',downloadAll = False,reload = False,mode = 'Networkx', option = None): # option : number, gender, letter | |||||
| self.name = name | |||||
| self.dir_name = "_".join(name.split("-")) | |||||
| self.root = root | |||||
| self.option = option | |||||
| self.mode = mode | |||||
| if not osp.exists(self.root) : | |||||
| os.makedirs(self.root) | |||||
| self.url = "https://brunl01.users.greyc.fr/CHEMISTRY/" | |||||
| self.urliam = "https://iapr-tc15.greyc.fr/IAM/" | |||||
| self.downloadAll = downloadAll | |||||
| self.reload = reload | |||||
| self.list_database = { | |||||
| # "Ace" : (self.url,"ACEDataset.tar"), | |||||
| # "Acyclic" : (self.url,"Acyclic.tar.gz"), | |||||
| # "Aids" : (self.urliam,"AIDS.zip"), | |||||
| # "Alkane" : (self.url,"alkane_dataset.tar.gz"), | |||||
| # "Chiral" : (self.url,"DatasetAcyclicChiral.tar"), | |||||
| # "Coil_Del" : (self.urliam,"COIL-DEL.zip"), | |||||
| # "Coil_Rag" : (self.urliam,"COIL-RAG.zip"), | |||||
| # "Fingerprint" : (self.urliam,"Fingerprint.zip"), | |||||
| # "Grec" : (self.urliam,"GREC.zip"), | |||||
| # "Letter" : (self.urliam,"Letter.zip"), | |||||
| # "Mao" : (self.url,"mao.tgz"), | |||||
| # "Monoterpenoides" : (self.url,"monoterpenoides.tar.gz"), | |||||
| # "Mutagenicity" : (self.urliam,"Mutagenicity.zip"), | |||||
| # "Pah" : (self.url,"PAH.tar.gz"), | |||||
| # "Protein" : (self.urliam,"Protein.zip"), | |||||
| # "Ptc" : (self.url,"ptc.tgz"), | |||||
| # "Steroid" : (self.url,"SteroidDataset.tar"), | |||||
| # "Vitamin" : (self.url,"DatasetVitamin.tar"), | |||||
| # "Web" : (self.urliam,"Web.zip") | |||||
| } | |||||
| self.data_to_use_in_datasets = { | |||||
| # "Acyclic" : ("Acyclic/dataset_bps.ds"), | |||||
| # "Aids" : ("AIDS_A.txt"), | |||||
| # "Alkane" : ("Alkane/dataset.ds","Alkane/dataset_boiling_point_names.txt"), | |||||
| # "Mao" : ("MAO/dataset.ds"), | |||||
| # "Monoterpenoides" : ("monoterpenoides/dataset_10+.ds"), #('monoterpenoides/dataset.ds'),('monoterpenoides/dataset_9.ds'),('monoterpenoides/trainset_9.ds') | |||||
| } | |||||
| self.has_train_valid_test = { | |||||
| "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'), | |||||
| "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'), | |||||
| "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'), | |||||
| # "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'), | |||||
| "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'), | |||||
| 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'), | |||||
| 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl') | |||||
| }, | |||||
| "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'), | |||||
| # "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'], | |||||
| "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'), | |||||
| # "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl') | |||||
| } | |||||
| def __init__(self, name=None, root='datasets', reload=False, verbose=False): | |||||
| self._name = name | |||||
| self._root = root | |||||
| if not osp.exists(self._root): | |||||
| os.makedirs(self._root) | |||||
| self._reload = reload | |||||
| self._verbose = verbose | |||||
| # self.has_train_valid_test = { | |||||
| # "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'), | |||||
| # "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'), | |||||
| # "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'), | |||||
| # # "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'), | |||||
| # "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'), | |||||
| # 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'), | |||||
| # 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl') | |||||
| # }, | |||||
| # "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'), | |||||
| # # "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'], | |||||
| # "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'), | |||||
| # # "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl') | |||||
| # } | |||||
| # if not self.name : | |||||
| # raise ValueError("No dataset entered" ) | |||||
| # if self.name not in self.list_database: | |||||
| # message = "Invalid Dataset name " + self.name | |||||
| # message += '\n Available datasets are as follows : \n\n' | |||||
| # | |||||
| # message += '\n'.join(database for database in self.list_database) | |||||
| # raise ValueError(message) | |||||
| # if self.downloadAll : | |||||
| # print('Waiting...') | |||||
| # for database in self.list_database : | |||||
| # self.write_archive_file(database) | |||||
| # print('Finished') | |||||
| # else: | |||||
| # self.write_archive_file(self.name) | |||||
| # self.max_for_letter = 0 | |||||
| # self.dataset = self.open_files() | |||||
| self.info_dataset = { | |||||
| # 'Ace' : "This dataset is not available yet", | |||||
| # 'Acyclic' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Acyclic,root = ...') \nGs,y = dataloader.dataset ", | |||||
| # 'Aids' : "This dataset is not available yet", | |||||
| # 'Alkane' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Acyclic',root = ...) \nGs,y = dataloader.dataset ", | |||||
| # 'Chiral' : "This dataset is not available yet", | |||||
| # "Coil-Del" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Coil-Deg', root = ...). \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", | |||||
| # "Coil-Rag" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Coil-Rag', root = ...). \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train\n Gs_valid,y_valid = valid", | |||||
| # "Fingerprint" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Fingerprint', root = ...). \ntest,train,valid = dataloader.dataset. \nGs_test,y_test = test \nGs_train,y_train = train\n Gs_valid,y_valid = valid", | |||||
| # "Grec" : "This dataset has test,train,valid datasets. Write dataloader = DataLoader('Grec', root = ...). \ntest,train,valid = dataloader.dataset. \nGs_test,y_test = test\n Gs_train,y_train = train\n Gs_valid,y_valid = valid", | |||||
| # "Letter" : "This dataset has test,train,valid datasets. Choose between high,low,med dataset. \ndataloader = DataLoader('Letter', root = ..., option = 'high') \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", | |||||
| # 'Mao' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Mao',root= ...) \nGs,y = dataloader.dataset ", | |||||
| # 'Monoterpenoides': "This dataset isn't composed of valid, test, train dataset but one whole dataset\n Write dataloader = DataLoader('Monoterpenoides',root= ...) \nGs,y = dataloader.dataset ", | |||||
| # 'Mutagenicity' : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Mutagenicity', root = ...) \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test\n Gs_train,y_train = train \nGs_valid,y_valid = valid", | |||||
| # 'Pah' : 'This dataset is composed of test and train datasets. '+ str(self.max_for_letter + 1) + ' datasets are available. \nChoose number between 0 and ' + str(self.max_for_letter) + "\ndataloader = DataLoader('Pah', root = ...,option = 0) \ntest,train = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train\n ", | |||||
| # "Protein" : "This dataset has test,train,valid dataset. \ndataloader = DataLoader('Protein', root = ...) \n test,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", | |||||
| # "Ptc" : "This dataset has test and train datasets. Select gender between mm, fm, mr, fr. \ndataloader = DataLoader('Ptc',root = ...,option = 'mm') \ntest,train = dataloader.dataset \nGs_test,y_test = test \nGs_train_,y_train = train", | |||||
| # "Steroid" : "This dataset is not available yet", | |||||
| # 'Vitamin' : "This dataset is not available yet", | |||||
| # 'Web' : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Web', root = ...) \n test,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", | |||||
| } | |||||
| if mode == "Pytorch": | |||||
| if self.name in self.data_to_use_in_datasets : | |||||
| Gs,y = self.dataset | |||||
| inputs,adjs,y = self.from_networkx_to_pytorch(Gs,y) | |||||
| #print(inputs,adjs) | |||||
| self.pytorch_dataset = inputs,adjs,y | |||||
| elif self.name == "Pah": | |||||
| self.pytorch_dataset = [] | |||||
| test,train = self.dataset | |||||
| Gs_test,y_test = test | |||||
| Gs_train,y_train = train | |||||
| self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) | |||||
| self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) | |||||
| elif self.name in self.has_train_valid_test: | |||||
| self.pytorch_dataset = [] | |||||
| #[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]) | |||||
| test,train,valid = self.dataset | |||||
| Gs_test,y_test = test | |||||
| Gs_train,y_train = train | |||||
| Gs_valid,y_valid = valid | |||||
| self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) | |||||
| self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) | |||||
| self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_valid,y_valid)) | |||||
| ############# | |||||
| """ | |||||
| for G in Gs : | |||||
| for e in G.edges(): | |||||
| print(G[e[0]]) | |||||
| """ | |||||
| ############## | |||||
| if self._name is None: | |||||
| if self._verbose: | |||||
| print('No dataset name entered. All possible datasets will be loaded.') | |||||
| self._name, self._path = [], [] | |||||
| for idx, ds_name in enumerate(DATASET_META): | |||||
| if self._verbose: | |||||
| print(str(idx + 1), '/', str(len(DATASET_META)), 'Fetching', ds_name, end='... ') | |||||
| self._name.append(ds_name) | |||||
| success = self.write_archive_file(ds_name) | |||||
| if success: | |||||
| self._path.append(self.open_files(ds_name)) | |||||
| else: | |||||
| self._path.append(None) | |||||
| if self._verbose and self._path[-1] is not None and not self._reload: | |||||
| print('Fetched.') | |||||
| if self._verbose: | |||||
| print('Finished.', str(sum(v is not None for v in self._path)), 'of', str(len(self._path)), 'datasets are successfully fetched.') | |||||
| elif self._name not in DATASET_META: | |||||
| message = 'Invalid Dataset name "' + self._name + '".' | |||||
| message += '\nAvailable datasets are as follows: \n\n' | |||||
| message += '\n'.join(ds for ds in sorted(DATASET_META)) | |||||
| raise ValueError(message) | |||||
| else: | |||||
| self.write_archive_file(self._name) | |||||
| self._path = self.open_files(self._name) | |||||
| # self.max_for_letter = 0 | |||||
| # if mode == 'Pytorch': | |||||
| # if self._name in self.data_to_use_in_datasets : | |||||
| # Gs,y = self.dataset | |||||
| # inputs,adjs,y = self.from_networkx_to_pytorch(Gs,y) | |||||
| # #print(inputs,adjs) | |||||
| # self.pytorch_dataset = inputs,adjs,y | |||||
| # elif self._name == "Pah": | |||||
| # self.pytorch_dataset = [] | |||||
| # test,train = self.dataset | |||||
| # Gs_test,y_test = test | |||||
| # Gs_train,y_train = train | |||||
| # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) | |||||
| # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) | |||||
| # elif self._name in self.has_train_valid_test: | |||||
| # self.pytorch_dataset = [] | |||||
| # #[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]) | |||||
| # test,train,valid = self.dataset | |||||
| # Gs_test,y_test = test | |||||
| # | |||||
| # Gs_train,y_train = train | |||||
| # Gs_valid,y_valid = valid | |||||
| # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) | |||||
| # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) | |||||
| # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_valid,y_valid)) | |||||
| # ############# | |||||
| # """ | |||||
| # for G in Gs : | |||||
| # for e in G.edges(): | |||||
| # print(G[e[0]]) | |||||
| # """ | |||||
| # ############## | |||||
| def download_file(self,url,filename): | |||||
| def download_file(self, url): | |||||
| try : | try : | ||||
| response = urllib.request.urlopen(url + filename) | |||||
| response = urllib.request.urlopen(url) | |||||
| except urllib.error.HTTPError: | except urllib.error.HTTPError: | ||||
| print(filename + " not available or incorrect http link") | |||||
| print('"', url.split('/')[-1], '" is not available or incorrect http link.') | |||||
| return | |||||
| except urllib.error.URLError: | |||||
| print('Network is unreachable.') | |||||
| return | return | ||||
| return response | return response | ||||
| def write_archive_file(self,database): | |||||
| path = osp.join(self.root,database) | |||||
| url,filename = self.list_database[database] | |||||
| filename_dir = osp.join(path,filename) | |||||
| if not osp.exists(filename_dir) or self.reload: | |||||
| response = self.download_file(url,filename) | |||||
| if response is None : | |||||
| return | |||||
| if not osp.exists(path) : | |||||
| os.makedirs(path) | |||||
| with open(filename_dir,'wb') as outfile : | |||||
| outfile.write(response.read()) | |||||
| def dataset(self): | |||||
| if self.mode == "Tensorflow": | |||||
| return #something | |||||
| if self.mode == "Pytorch": | |||||
| return self.pytorch_dataset | |||||
| return self.dataset | |||||
| def info(self): | |||||
| print(self.info_dataset[self.name]) | |||||
| def iter_load_dataset(self,data): | |||||
| results = [] | |||||
| for datasets in data : | |||||
| results.append(loadDataset(osp.join(self.root,self.name,datasets))) | |||||
| return results | |||||
| def load_dataset(self,list_files): | |||||
| if self.name == "Ptc": | |||||
| if type(self.option) != str or self.option.upper() not in ['FR','FM','MM','MR']: | |||||
| raise ValueError('option for Ptc dataset needs to be one of : \n fr fm mm mr') | |||||
| results = [] | |||||
| results.append(loadDataset(osp.join(self.root,self.name,'PTC/Test',self.gender + '.ds'))) | |||||
| results.append(loadDataset(osp.join(self.root,self.name,'PTC/Train',self.gender + '.ds'))) | |||||
| return results | |||||
| if self.name == "Pah": | |||||
| maximum_sets = 0 | |||||
| for file in list_files: | |||||
| if file.endswith('ds'): | |||||
| maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0])) | |||||
| self.max_for_letter = maximum_sets | |||||
| if not type(self.option) == int or self.option > maximum_sets or self.option < 0: | |||||
| raise ValueError('option needs to be an integer between 0 and ' + str(maximum_sets)) | |||||
| data = self.has_train_valid_test["Pah"] | |||||
| data[0] = self.has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(self.option) + '.ds' | |||||
| data[1] = self.has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(self.option) + '.ds' | |||||
| return self.iter_load_dataset(data) | |||||
| if self.name == "Letter": | |||||
| if type(self.option) == str and self.option.upper() in self.has_train_valid_test["Letter"]: | |||||
| data = self.has_train_valid_test["Letter"][self.option.upper()] | |||||
| else: | |||||
| message = "The parameter for letter is incorrect choose between : " | |||||
| message += "\nhigh med low" | |||||
| raise ValueError(message) | |||||
| return self.iter_load_dataset(data) | |||||
| if self.name in self.has_train_valid_test : #common IAM dataset with train, valid and test | |||||
| data = self.has_train_valid_test[self.name] | |||||
| return self.iter_load_dataset(data) | |||||
| else: #common dataset without train,valid and test, only dataset.ds file | |||||
| data = self.data_to_use_in_datasets[self.name] | |||||
| if len(data) > 1 and data[0] in list_files and data[1] in list_files: #case for Alkane | |||||
| return loadDataset(osp.join(self.root,self.name,data[0]),filename_y = osp.join(self.root,self.name,data[1])) | |||||
| if data in list_files: | |||||
| return loadDataset(osp.join(self.root,self.name,data)) | |||||
| def open_files(self): | |||||
| filename = self.list_database[self.name][1] | |||||
| path = osp.join(self.root,self.name) | |||||
| filename_archive = osp.join(path,filename) | |||||
| def write_archive_file(self, ds_name): | |||||
| path = osp.join(self._root, ds_name) | |||||
| url = DATASET_META[ds_name]['url'] | |||||
| # filename_dir = osp.join(path,filename) | |||||
| if not osp.exists(path) or self._reload: | |||||
| response = self.download_file(url) | |||||
| if response is None: | |||||
| return False | |||||
| os.makedirs(path, exist_ok=True) | |||||
| with open(os.path.join(path, url.split('/')[-1]), 'wb') as outfile: | |||||
| outfile.write(response.read()) | |||||
| return True | |||||
| def open_files(self, ds_name=None): | |||||
| if ds_name is None: | |||||
| ds_name = (self._name if isinstance(self._name, str) else self._name[0]) | |||||
| filename = DATASET_META[ds_name]['url'].split('/')[-1] | |||||
| path = osp.join(self._root, ds_name) | |||||
| filename_archive = osp.join(path, filename) | |||||
| if filename.endswith('gz'): | if filename.endswith('gz'): | ||||
| if tarfile.is_tarfile(filename_archive): | if tarfile.is_tarfile(filename_archive): | ||||
| with tarfile.open(filename_archive,"r:gz") as tar: | |||||
| if self.reload: | |||||
| print(filename + " Downloaded") | |||||
| tar.extractall(path = path) | |||||
| return self.load_dataset(tar.getnames()) | |||||
| with tarfile.open(filename_archive, 'r:gz') as tar: | |||||
| if self._reload and self._verbose: | |||||
| print(filename + ' Downloaded.') | |||||
| tar.extractall(path = path) | |||||
| return os.path.join(path, tar.getnames()[0]) | |||||
| elif filename.endswith('.tar'): | elif filename.endswith('.tar'): | ||||
| if tarfile.is_tarfile(filename_archive): | if tarfile.is_tarfile(filename_archive): | ||||
| with tarfile.open(filename_archive,"r:") as tar: | |||||
| if self.reload : | |||||
| print(filename + " Downloaded") | |||||
| tar.extractall(path = path) | |||||
| return self.load_dataset(tar.getnames()) | |||||
| with tarfile.open(filename_archive, 'r:') as tar: | |||||
| if self._reload and self._verbose: | |||||
| print(filename + ' Downloaded.') | |||||
| tar.extractall(path = path) | |||||
| return os.path.join(path, tar.getnames()[0]) | |||||
| elif filename.endswith('.zip'): | elif filename.endswith('.zip'): | ||||
| with ZipFile(filename_archive,"r") as zip_ref: | |||||
| if self.reload : | |||||
| print(filename + " Downloaded") | |||||
| zip_ref.extractall(path) | |||||
| return self.load_dataset(zip_ref.namelist()) | |||||
| with ZipFile(filename_archive, 'r') as zip_ref: | |||||
| if self._reload and self._verbose: | |||||
| print(filename + ' Downloaded.') | |||||
| zip_ref.extractall(path) | |||||
| return os.path.join(path, zip_ref.namelist()[0]) | |||||
| else: | else: | ||||
| print(filename + " Unsupported file") | |||||
| def build_dictionary(self,Gs): | |||||
| labels = set() | |||||
| #next line : from DeepGraphWithNNTorch | |||||
| #bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]))) | |||||
| sizes = set() | |||||
| for G in Gs : | |||||
| for _,node in G.nodes(data = True): # or for node in nx.nodes(G) | |||||
| #print(_,node) | |||||
| labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0]) #what do we use for IAM datasets (they don't have bond_type or event label) ? | |||||
| sizes.add(G.order()) | |||||
| label_dict = {} | |||||
| #print("labels : ", labels, bond_type_number_maxi) | |||||
| for i,label in enumerate(labels): | |||||
| label_dict[label] = [0.]*len(labels) | |||||
| label_dict[label][i] = 1. | |||||
| return label_dict | |||||
| def from_networkx_to_pytorch(self,Gs,y): | |||||
| #exemple for MAO: atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]} | |||||
| # code from https://github.com/bgauzere/pygnn/blob/master/utils.py | |||||
| atom_to_onehot = self.build_dictionary(Gs) | |||||
| max_size = 30 | |||||
| adjs = [] | |||||
| inputs = [] | |||||
| for i, G in enumerate(Gs): | |||||
| I = torch.eye(G.order(), G.order()) | |||||
| #A = torch.Tensor(nx.adjacency_matrix(G).todense()) | |||||
| #A = torch.Tensor(nx.to_numpy_matrix(G)) | |||||
| A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int) #what do we use for IAM datasets (they don't have bond_type or event label) ? | |||||
| adj = F.pad(A, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ? if yes : F.pad(A + I,pad = (...)) | |||||
| adjs.append(adj) | |||||
| f_0 = [] | |||||
| for _, label in G.nodes(data=True): | |||||
| #print(_,label) | |||||
| cur_label = atom_to_onehot[label['label'][0]].copy() | |||||
| f_0.append(cur_label) | |||||
| X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order())) | |||||
| inputs.append(X) | |||||
| return inputs,adjs,y | |||||
| def from_pytorch_to_tensorflow(self,batch_size): | |||||
| seed = random.randrange(sys.maxsize) | |||||
| random.seed(seed) | |||||
| tf_inputs = random.sample(self.pytorch_dataset[0],batch_size) | |||||
| random.seed(seed) | |||||
| tf_y = random.sample(self.pytorch_dataset[2],batch_size) | |||||
| def from_networkx_to_tensor(self,G,dict): | |||||
| A=nx.to_numpy_matrix(G) | |||||
| lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)] | |||||
| return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab)) | |||||
| raise ValueError(filename + ' Unsupported file.') | |||||
| def get_all_ds_infos(self, database): | def get_all_ds_infos(self, database): | ||||
| """Get information of all datasets from a database. | """Get information of all datasets from a database. | ||||
| @@ -342,6 +191,7 @@ class DataFetcher(): | |||||
| msg = 'Invalid Database name "' + database + '"' | msg = 'Invalid Database name "' + database + '"' | ||||
| msg += '\n Available databases are as follows: \n\n' | msg += '\n Available databases are as follows: \n\n' | ||||
| msg += '\n'.join(db for db in sorted(DATABASES)) | msg += '\n'.join(db for db in sorted(DATABASES)) | ||||
| msg += 'Check "gklearn.dataset.DATASET_META" for more details.' | |||||
| raise ValueError(msg) | raise ValueError(msg) | ||||
| return infos | return infos | ||||
| @@ -457,6 +307,146 @@ class DataFetcher(): | |||||
| p_str += '}' | p_str += '}' | ||||
| return p_str | return p_str | ||||
| @property | |||||
| def path(self): | |||||
| return self._path | |||||
| def dataset(self): | |||||
| if self.mode == "Tensorflow": | |||||
| return #something | |||||
| if self.mode == "Pytorch": | |||||
| return self.pytorch_dataset | |||||
| return self.dataset | |||||
| def info(self): | |||||
| print(self.info_dataset[self._name]) | |||||
| def iter_load_dataset(self,data): | |||||
| results = [] | |||||
| for datasets in data : | |||||
| results.append(loadDataset(osp.join(self._root,self._name,datasets))) | |||||
| return results | |||||
| def load_dataset(self,list_files): | |||||
| if self._name == "Ptc": | |||||
| if type(self.option) != str or self.option.upper() not in ['FR','FM','MM','MR']: | |||||
| raise ValueError('option for Ptc dataset needs to be one of : \n fr fm mm mr') | |||||
| results = [] | |||||
| results.append(loadDataset(osp.join(self.root,self._name,'PTC/Test',self.gender + '.ds'))) | |||||
| results.append(loadDataset(osp.join(self.root,self._name,'PTC/Train',self.gender + '.ds'))) | |||||
| return results | |||||
| if self.name == "Pah": | |||||
| maximum_sets = 0 | |||||
| for file in list_files: | |||||
| if file.endswith('ds'): | |||||
| maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0])) | |||||
| self.max_for_letter = maximum_sets | |||||
| if not type(self.option) == int or self.option > maximum_sets or self.option < 0: | |||||
| raise ValueError('option needs to be an integer between 0 and ' + str(maximum_sets)) | |||||
| data = self.has_train_valid_test["Pah"] | |||||
| data[0] = self.has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(self.option) + '.ds' | |||||
| data[1] = self.has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(self.option) + '.ds' | |||||
| return self.iter_load_dataset(data) | |||||
| if self.name == "Letter": | |||||
| if type(self.option) == str and self.option.upper() in self.has_train_valid_test["Letter"]: | |||||
| data = self.has_train_valid_test["Letter"][self.option.upper()] | |||||
| else: | |||||
| message = "The parameter for letter is incorrect choose between : " | |||||
| message += "\nhigh med low" | |||||
| raise ValueError(message) | |||||
| return self.iter_load_dataset(data) | |||||
| if self.name in self.has_train_valid_test : #common IAM dataset with train, valid and test | |||||
| data = self.has_train_valid_test[self.name] | |||||
| return self.iter_load_dataset(data) | |||||
| else: #common dataset without train,valid and test, only dataset.ds file | |||||
| data = self.data_to_use_in_datasets[self.name] | |||||
| if len(data) > 1 and data[0] in list_files and data[1] in list_files: #case for Alkane | |||||
| return loadDataset(osp.join(self.root,self.name,data[0]),filename_y = osp.join(self.root,self.name,data[1])) | |||||
| if data in list_files: | |||||
| return loadDataset(osp.join(self.root,self.name,data)) | |||||
| def build_dictionary(self,Gs): | |||||
| labels = set() | |||||
| #next line : from DeepGraphWithNNTorch | |||||
| #bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]))) | |||||
| sizes = set() | |||||
| for G in Gs : | |||||
| for _,node in G.nodes(data = True): # or for node in nx.nodes(G) | |||||
| #print(_,node) | |||||
| labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0]) #what do we use for IAM datasets (they don't have bond_type or event label) ? | |||||
| sizes.add(G.order()) | |||||
| label_dict = {} | |||||
| #print("labels : ", labels, bond_type_number_maxi) | |||||
| for i,label in enumerate(labels): | |||||
| label_dict[label] = [0.]*len(labels) | |||||
| label_dict[label][i] = 1. | |||||
| return label_dict | |||||
| def from_networkx_to_pytorch(self,Gs,y): | |||||
| #exemple for MAO: atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]} | |||||
| # code from https://github.com/bgauzere/pygnn/blob/master/utils.py | |||||
| atom_to_onehot = self.build_dictionary(Gs) | |||||
| max_size = 30 | |||||
| adjs = [] | |||||
| inputs = [] | |||||
| for i, G in enumerate(Gs): | |||||
| I = torch.eye(G.order(), G.order()) | |||||
| #A = torch.Tensor(nx.adjacency_matrix(G).todense()) | |||||
| #A = torch.Tensor(nx.to_numpy_matrix(G)) | |||||
| A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int) #what do we use for IAM datasets (they don't have bond_type or event label) ? | |||||
| adj = F.pad(A, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ? if yes : F.pad(A + I,pad = (...)) | |||||
| adjs.append(adj) | |||||
| f_0 = [] | |||||
| for _, label in G.nodes(data=True): | |||||
| #print(_,label) | |||||
| cur_label = atom_to_onehot[label['label'][0]].copy() | |||||
| f_0.append(cur_label) | |||||
| X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order())) | |||||
| inputs.append(X) | |||||
| return inputs,adjs,y | |||||
| def from_pytorch_to_tensorflow(self,batch_size): | |||||
| seed = random.randrange(sys.maxsize) | |||||
| random.seed(seed) | |||||
| tf_inputs = random.sample(self.pytorch_dataset[0],batch_size) | |||||
| random.seed(seed) | |||||
| tf_y = random.sample(self.pytorch_dataset[2],batch_size) | |||||
| def from_networkx_to_tensor(self,G,dict): | |||||
| A=nx.to_numpy_matrix(G) | |||||
| lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)] | |||||
| return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab)) | |||||