| @@ -7,15 +7,9 @@ Created on Tue Oct 20 14:25:49 2020 | |||
| Paul Zanoncelli, paul.zanoncelli@ecole.ensicaen.fr | |||
| Luc Brun luc.brun@ensicaen.fr | |||
| Sebastien Bougleux sebastien.bougleux@unicaen.fr | |||
| benoit gaüzère benoit.gauzere@insa-rouen.fr | |||
| Benoit Gaüzère benoit.gauzere@insa-rouen.fr | |||
| Linlin Jia linlin.jia@insa-rouen.fr | |||
| """ | |||
| import numpy as np | |||
| import networkx as nx | |||
| from gklearn.utils.graph_files import load_dataset | |||
| import os | |||
| import os | |||
| import os.path as osp | |||
| import urllib | |||
| @@ -29,299 +23,154 @@ import random | |||
| import sys | |||
| from lxml import etree | |||
| import re | |||
| from gklearn.dataset import DATABASES | |||
| from tqdm import tqdm | |||
| from gklearn.dataset import DATABASES, DATASET_META | |||
| class DataFetcher(): | |||
| def __init__(self,name='Ace',root = 'data',downloadAll = False,reload = False,mode = 'Networkx', option = None): # option : number, gender, letter | |||
| self.name = name | |||
| self.dir_name = "_".join(name.split("-")) | |||
| self.root = root | |||
| self.option = option | |||
| self.mode = mode | |||
| if not osp.exists(self.root) : | |||
| os.makedirs(self.root) | |||
| self.url = "https://brunl01.users.greyc.fr/CHEMISTRY/" | |||
| self.urliam = "https://iapr-tc15.greyc.fr/IAM/" | |||
| self.downloadAll = downloadAll | |||
| self.reload = reload | |||
| self.list_database = { | |||
| # "Ace" : (self.url,"ACEDataset.tar"), | |||
| # "Acyclic" : (self.url,"Acyclic.tar.gz"), | |||
| # "Aids" : (self.urliam,"AIDS.zip"), | |||
| # "Alkane" : (self.url,"alkane_dataset.tar.gz"), | |||
| # "Chiral" : (self.url,"DatasetAcyclicChiral.tar"), | |||
| # "Coil_Del" : (self.urliam,"COIL-DEL.zip"), | |||
| # "Coil_Rag" : (self.urliam,"COIL-RAG.zip"), | |||
| # "Fingerprint" : (self.urliam,"Fingerprint.zip"), | |||
| # "Grec" : (self.urliam,"GREC.zip"), | |||
| # "Letter" : (self.urliam,"Letter.zip"), | |||
| # "Mao" : (self.url,"mao.tgz"), | |||
| # "Monoterpenoides" : (self.url,"monoterpenoides.tar.gz"), | |||
| # "Mutagenicity" : (self.urliam,"Mutagenicity.zip"), | |||
| # "Pah" : (self.url,"PAH.tar.gz"), | |||
| # "Protein" : (self.urliam,"Protein.zip"), | |||
| # "Ptc" : (self.url,"ptc.tgz"), | |||
| # "Steroid" : (self.url,"SteroidDataset.tar"), | |||
| # "Vitamin" : (self.url,"DatasetVitamin.tar"), | |||
| # "Web" : (self.urliam,"Web.zip") | |||
| } | |||
| self.data_to_use_in_datasets = { | |||
| # "Acyclic" : ("Acyclic/dataset_bps.ds"), | |||
| # "Aids" : ("AIDS_A.txt"), | |||
| # "Alkane" : ("Alkane/dataset.ds","Alkane/dataset_boiling_point_names.txt"), | |||
| # "Mao" : ("MAO/dataset.ds"), | |||
| # "Monoterpenoides" : ("monoterpenoides/dataset_10+.ds"), #('monoterpenoides/dataset.ds'),('monoterpenoides/dataset_9.ds'),('monoterpenoides/trainset_9.ds') | |||
| } | |||
| self.has_train_valid_test = { | |||
| "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'), | |||
| "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'), | |||
| "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'), | |||
| # "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'), | |||
| "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'), | |||
| 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'), | |||
| 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl') | |||
| }, | |||
| "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'), | |||
| # "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'], | |||
| "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'), | |||
| # "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl') | |||
| } | |||
| def __init__(self, name=None, root='datasets', reload=False, verbose=False): | |||
| self._name = name | |||
| self._root = root | |||
| if not osp.exists(self._root): | |||
| os.makedirs(self._root) | |||
| self._reload = reload | |||
| self._verbose = verbose | |||
| # self.has_train_valid_test = { | |||
| # "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'), | |||
| # "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'), | |||
| # "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'), | |||
| # # "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'), | |||
| # "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'), | |||
| # 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'), | |||
| # 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl') | |||
| # }, | |||
| # "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'), | |||
| # # "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'], | |||
| # "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'), | |||
| # # "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl') | |||
| # } | |||
| # if not self.name : | |||
| # raise ValueError("No dataset entered" ) | |||
| # if self.name not in self.list_database: | |||
| # message = "Invalid Dataset name " + self.name | |||
| # message += '\n Available datasets are as follows : \n\n' | |||
| # | |||
| # message += '\n'.join(database for database in self.list_database) | |||
| # raise ValueError(message) | |||
| # if self.downloadAll : | |||
| # print('Waiting...') | |||
| # for database in self.list_database : | |||
| # self.write_archive_file(database) | |||
| # print('Finished') | |||
| # else: | |||
| # self.write_archive_file(self.name) | |||
| # self.max_for_letter = 0 | |||
| # self.dataset = self.open_files() | |||
| self.info_dataset = { | |||
| # 'Ace' : "This dataset is not available yet", | |||
| # 'Acyclic' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Acyclic,root = ...') \nGs,y = dataloader.dataset ", | |||
| # 'Aids' : "This dataset is not available yet", | |||
| # 'Alkane' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Acyclic',root = ...) \nGs,y = dataloader.dataset ", | |||
| # 'Chiral' : "This dataset is not available yet", | |||
| # "Coil-Del" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Coil-Deg', root = ...). \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", | |||
| # "Coil-Rag" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Coil-Rag', root = ...). \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train\n Gs_valid,y_valid = valid", | |||
| # "Fingerprint" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Fingerprint', root = ...). \ntest,train,valid = dataloader.dataset. \nGs_test,y_test = test \nGs_train,y_train = train\n Gs_valid,y_valid = valid", | |||
| # "Grec" : "This dataset has test,train,valid datasets. Write dataloader = DataLoader('Grec', root = ...). \ntest,train,valid = dataloader.dataset. \nGs_test,y_test = test\n Gs_train,y_train = train\n Gs_valid,y_valid = valid", | |||
| # "Letter" : "This dataset has test,train,valid datasets. Choose between high,low,med dataset. \ndataloader = DataLoader('Letter', root = ..., option = 'high') \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", | |||
| # 'Mao' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Mao',root= ...) \nGs,y = dataloader.dataset ", | |||
| # 'Monoterpenoides': "This dataset isn't composed of valid, test, train dataset but one whole dataset\n Write dataloader = DataLoader('Monoterpenoides',root= ...) \nGs,y = dataloader.dataset ", | |||
| # 'Mutagenicity' : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Mutagenicity', root = ...) \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test\n Gs_train,y_train = train \nGs_valid,y_valid = valid", | |||
| # 'Pah' : 'This dataset is composed of test and train datasets. '+ str(self.max_for_letter + 1) + ' datasets are available. \nChoose number between 0 and ' + str(self.max_for_letter) + "\ndataloader = DataLoader('Pah', root = ...,option = 0) \ntest,train = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train\n ", | |||
| # "Protein" : "This dataset has test,train,valid dataset. \ndataloader = DataLoader('Protein', root = ...) \n test,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", | |||
| # "Ptc" : "This dataset has test and train datasets. Select gender between mm, fm, mr, fr. \ndataloader = DataLoader('Ptc',root = ...,option = 'mm') \ntest,train = dataloader.dataset \nGs_test,y_test = test \nGs_train_,y_train = train", | |||
| # "Steroid" : "This dataset is not available yet", | |||
| # 'Vitamin' : "This dataset is not available yet", | |||
| # 'Web' : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Web', root = ...) \n test,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", | |||
| } | |||
| if mode == "Pytorch": | |||
| if self.name in self.data_to_use_in_datasets : | |||
| Gs,y = self.dataset | |||
| inputs,adjs,y = self.from_networkx_to_pytorch(Gs,y) | |||
| #print(inputs,adjs) | |||
| self.pytorch_dataset = inputs,adjs,y | |||
| elif self.name == "Pah": | |||
| self.pytorch_dataset = [] | |||
| test,train = self.dataset | |||
| Gs_test,y_test = test | |||
| Gs_train,y_train = train | |||
| self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) | |||
| self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) | |||
| elif self.name in self.has_train_valid_test: | |||
| self.pytorch_dataset = [] | |||
| #[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]) | |||
| test,train,valid = self.dataset | |||
| Gs_test,y_test = test | |||
| Gs_train,y_train = train | |||
| Gs_valid,y_valid = valid | |||
| self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) | |||
| self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) | |||
| self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_valid,y_valid)) | |||
| ############# | |||
| """ | |||
| for G in Gs : | |||
| for e in G.edges(): | |||
| print(G[e[0]]) | |||
| """ | |||
| ############## | |||
| if self._name is None: | |||
| if self._verbose: | |||
| print('No dataset name entered. All possible datasets will be loaded.') | |||
| self._name, self._path = [], [] | |||
| for idx, ds_name in enumerate(DATASET_META): | |||
| if self._verbose: | |||
| print(str(idx + 1), '/', str(len(DATASET_META)), 'Fetching', ds_name, end='... ') | |||
| self._name.append(ds_name) | |||
| success = self.write_archive_file(ds_name) | |||
| if success: | |||
| self._path.append(self.open_files(ds_name)) | |||
| else: | |||
| self._path.append(None) | |||
| if self._verbose and self._path[-1] is not None and not self._reload: | |||
| print('Fetched.') | |||
| if self._verbose: | |||
| print('Finished.', str(sum(v is not None for v in self._path)), 'of', str(len(self._path)), 'datasets are successfully fetched.') | |||
| elif self._name not in DATASET_META: | |||
| message = 'Invalid Dataset name "' + self._name + '".' | |||
| message += '\nAvailable datasets are as follows: \n\n' | |||
| message += '\n'.join(ds for ds in sorted(DATASET_META)) | |||
| raise ValueError(message) | |||
| else: | |||
| self.write_archive_file(self._name) | |||
| self._path = self.open_files(self._name) | |||
| # self.max_for_letter = 0 | |||
| # if mode == 'Pytorch': | |||
| # if self._name in self.data_to_use_in_datasets : | |||
| # Gs,y = self.dataset | |||
| # inputs,adjs,y = self.from_networkx_to_pytorch(Gs,y) | |||
| # #print(inputs,adjs) | |||
| # self.pytorch_dataset = inputs,adjs,y | |||
| # elif self._name == "Pah": | |||
| # self.pytorch_dataset = [] | |||
| # test,train = self.dataset | |||
| # Gs_test,y_test = test | |||
| # Gs_train,y_train = train | |||
| # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) | |||
| # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) | |||
| # elif self._name in self.has_train_valid_test: | |||
| # self.pytorch_dataset = [] | |||
| # #[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]) | |||
| # test,train,valid = self.dataset | |||
| # Gs_test,y_test = test | |||
| # | |||
| # Gs_train,y_train = train | |||
| # Gs_valid,y_valid = valid | |||
| # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) | |||
| # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) | |||
| # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_valid,y_valid)) | |||
| # ############# | |||
| # """ | |||
| # for G in Gs : | |||
| # for e in G.edges(): | |||
| # print(G[e[0]]) | |||
| # """ | |||
| # ############## | |||
| def download_file(self,url,filename): | |||
| def download_file(self, url): | |||
| try : | |||
| response = urllib.request.urlopen(url + filename) | |||
| response = urllib.request.urlopen(url) | |||
| except urllib.error.HTTPError: | |||
| print(filename + " not available or incorrect http link") | |||
| print('"', url.split('/')[-1], '" is not available or incorrect http link.') | |||
| return | |||
| except urllib.error.URLError: | |||
| print('Network is unreachable.') | |||
| return | |||
| return response | |||
| def write_archive_file(self,database): | |||
| path = osp.join(self.root,database) | |||
| url,filename = self.list_database[database] | |||
| filename_dir = osp.join(path,filename) | |||
| if not osp.exists(filename_dir) or self.reload: | |||
| response = self.download_file(url,filename) | |||
| if response is None : | |||
| return | |||
| if not osp.exists(path) : | |||
| os.makedirs(path) | |||
| with open(filename_dir,'wb') as outfile : | |||
| outfile.write(response.read()) | |||
| def dataset(self): | |||
| if self.mode == "Tensorflow": | |||
| return #something | |||
| if self.mode == "Pytorch": | |||
| return self.pytorch_dataset | |||
| return self.dataset | |||
| def info(self): | |||
| print(self.info_dataset[self.name]) | |||
| def iter_load_dataset(self,data): | |||
| results = [] | |||
| for datasets in data : | |||
| results.append(loadDataset(osp.join(self.root,self.name,datasets))) | |||
| return results | |||
| def load_dataset(self,list_files): | |||
| if self.name == "Ptc": | |||
| if type(self.option) != str or self.option.upper() not in ['FR','FM','MM','MR']: | |||
| raise ValueError('option for Ptc dataset needs to be one of : \n fr fm mm mr') | |||
| results = [] | |||
| results.append(loadDataset(osp.join(self.root,self.name,'PTC/Test',self.gender + '.ds'))) | |||
| results.append(loadDataset(osp.join(self.root,self.name,'PTC/Train',self.gender + '.ds'))) | |||
| return results | |||
| if self.name == "Pah": | |||
| maximum_sets = 0 | |||
| for file in list_files: | |||
| if file.endswith('ds'): | |||
| maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0])) | |||
| self.max_for_letter = maximum_sets | |||
| if not type(self.option) == int or self.option > maximum_sets or self.option < 0: | |||
| raise ValueError('option needs to be an integer between 0 and ' + str(maximum_sets)) | |||
| data = self.has_train_valid_test["Pah"] | |||
| data[0] = self.has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(self.option) + '.ds' | |||
| data[1] = self.has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(self.option) + '.ds' | |||
| return self.iter_load_dataset(data) | |||
| if self.name == "Letter": | |||
| if type(self.option) == str and self.option.upper() in self.has_train_valid_test["Letter"]: | |||
| data = self.has_train_valid_test["Letter"][self.option.upper()] | |||
| else: | |||
| message = "The parameter for letter is incorrect choose between : " | |||
| message += "\nhigh med low" | |||
| raise ValueError(message) | |||
| return self.iter_load_dataset(data) | |||
| if self.name in self.has_train_valid_test : #common IAM dataset with train, valid and test | |||
| data = self.has_train_valid_test[self.name] | |||
| return self.iter_load_dataset(data) | |||
| else: #common dataset without train,valid and test, only dataset.ds file | |||
| data = self.data_to_use_in_datasets[self.name] | |||
| if len(data) > 1 and data[0] in list_files and data[1] in list_files: #case for Alkane | |||
| return loadDataset(osp.join(self.root,self.name,data[0]),filename_y = osp.join(self.root,self.name,data[1])) | |||
| if data in list_files: | |||
| return loadDataset(osp.join(self.root,self.name,data)) | |||
| def open_files(self): | |||
| filename = self.list_database[self.name][1] | |||
| path = osp.join(self.root,self.name) | |||
| filename_archive = osp.join(path,filename) | |||
| def write_archive_file(self, ds_name): | |||
| path = osp.join(self._root, ds_name) | |||
| url = DATASET_META[ds_name]['url'] | |||
| # filename_dir = osp.join(path,filename) | |||
| if not osp.exists(path) or self._reload: | |||
| response = self.download_file(url) | |||
| if response is None: | |||
| return False | |||
| os.makedirs(path, exist_ok=True) | |||
| with open(os.path.join(path, url.split('/')[-1]), 'wb') as outfile: | |||
| outfile.write(response.read()) | |||
| return True | |||
| def open_files(self, ds_name=None): | |||
| if ds_name is None: | |||
| ds_name = (self._name if isinstance(self._name, str) else self._name[0]) | |||
| filename = DATASET_META[ds_name]['url'].split('/')[-1] | |||
| path = osp.join(self._root, ds_name) | |||
| filename_archive = osp.join(path, filename) | |||
| if filename.endswith('gz'): | |||
| if tarfile.is_tarfile(filename_archive): | |||
| with tarfile.open(filename_archive,"r:gz") as tar: | |||
| if self.reload: | |||
| print(filename + " Downloaded") | |||
| tar.extractall(path = path) | |||
| return self.load_dataset(tar.getnames()) | |||
| with tarfile.open(filename_archive, 'r:gz') as tar: | |||
| if self._reload and self._verbose: | |||
| print(filename + ' Downloaded.') | |||
| tar.extractall(path = path) | |||
| return os.path.join(path, tar.getnames()[0]) | |||
| elif filename.endswith('.tar'): | |||
| if tarfile.is_tarfile(filename_archive): | |||
| with tarfile.open(filename_archive,"r:") as tar: | |||
| if self.reload : | |||
| print(filename + " Downloaded") | |||
| tar.extractall(path = path) | |||
| return self.load_dataset(tar.getnames()) | |||
| with tarfile.open(filename_archive, 'r:') as tar: | |||
| if self._reload and self._verbose: | |||
| print(filename + ' Downloaded.') | |||
| tar.extractall(path = path) | |||
| return os.path.join(path, tar.getnames()[0]) | |||
| elif filename.endswith('.zip'): | |||
| with ZipFile(filename_archive,"r") as zip_ref: | |||
| if self.reload : | |||
| print(filename + " Downloaded") | |||
| zip_ref.extractall(path) | |||
| return self.load_dataset(zip_ref.namelist()) | |||
| with ZipFile(filename_archive, 'r') as zip_ref: | |||
| if self._reload and self._verbose: | |||
| print(filename + ' Downloaded.') | |||
| zip_ref.extractall(path) | |||
| return os.path.join(path, zip_ref.namelist()[0]) | |||
| else: | |||
| print(filename + " Unsupported file") | |||
| def build_dictionary(self,Gs): | |||
| labels = set() | |||
| #next line : from DeepGraphWithNNTorch | |||
| #bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]))) | |||
| sizes = set() | |||
| for G in Gs : | |||
| for _,node in G.nodes(data = True): # or for node in nx.nodes(G) | |||
| #print(_,node) | |||
| labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0]) #what do we use for IAM datasets (they don't have bond_type or event label) ? | |||
| sizes.add(G.order()) | |||
| label_dict = {} | |||
| #print("labels : ", labels, bond_type_number_maxi) | |||
| for i,label in enumerate(labels): | |||
| label_dict[label] = [0.]*len(labels) | |||
| label_dict[label][i] = 1. | |||
| return label_dict | |||
| def from_networkx_to_pytorch(self,Gs,y): | |||
| #exemple for MAO: atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]} | |||
| # code from https://github.com/bgauzere/pygnn/blob/master/utils.py | |||
| atom_to_onehot = self.build_dictionary(Gs) | |||
| max_size = 30 | |||
| adjs = [] | |||
| inputs = [] | |||
| for i, G in enumerate(Gs): | |||
| I = torch.eye(G.order(), G.order()) | |||
| #A = torch.Tensor(nx.adjacency_matrix(G).todense()) | |||
| #A = torch.Tensor(nx.to_numpy_matrix(G)) | |||
| A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int) #what do we use for IAM datasets (they don't have bond_type or event label) ? | |||
| adj = F.pad(A, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ? if yes : F.pad(A + I,pad = (...)) | |||
| adjs.append(adj) | |||
| f_0 = [] | |||
| for _, label in G.nodes(data=True): | |||
| #print(_,label) | |||
| cur_label = atom_to_onehot[label['label'][0]].copy() | |||
| f_0.append(cur_label) | |||
| X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order())) | |||
| inputs.append(X) | |||
| return inputs,adjs,y | |||
| def from_pytorch_to_tensorflow(self,batch_size): | |||
| seed = random.randrange(sys.maxsize) | |||
| random.seed(seed) | |||
| tf_inputs = random.sample(self.pytorch_dataset[0],batch_size) | |||
| random.seed(seed) | |||
| tf_y = random.sample(self.pytorch_dataset[2],batch_size) | |||
| def from_networkx_to_tensor(self,G,dict): | |||
| A=nx.to_numpy_matrix(G) | |||
| lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)] | |||
| return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab)) | |||
| raise ValueError(filename + ' Unsupported file.') | |||
| def get_all_ds_infos(self, database): | |||
| """Get information of all datasets from a database. | |||
| @@ -342,6 +191,7 @@ class DataFetcher(): | |||
| msg = 'Invalid Database name "' + database + '"' | |||
| msg += '\n Available databases are as follows: \n\n' | |||
| msg += '\n'.join(db for db in sorted(DATABASES)) | |||
| msg += 'Check "gklearn.dataset.DATASET_META" for more details.' | |||
| raise ValueError(msg) | |||
| return infos | |||
| @@ -457,6 +307,146 @@ class DataFetcher(): | |||
| p_str += '}' | |||
| return p_str | |||
| @property | |||
| def path(self): | |||
| return self._path | |||
| def dataset(self): | |||
| if self.mode == "Tensorflow": | |||
| return #something | |||
| if self.mode == "Pytorch": | |||
| return self.pytorch_dataset | |||
| return self.dataset | |||
| def info(self): | |||
| print(self.info_dataset[self._name]) | |||
| def iter_load_dataset(self,data): | |||
| results = [] | |||
| for datasets in data : | |||
| results.append(loadDataset(osp.join(self._root,self._name,datasets))) | |||
| return results | |||
| def load_dataset(self,list_files): | |||
| if self._name == "Ptc": | |||
| if type(self.option) != str or self.option.upper() not in ['FR','FM','MM','MR']: | |||
| raise ValueError('option for Ptc dataset needs to be one of : \n fr fm mm mr') | |||
| results = [] | |||
| results.append(loadDataset(osp.join(self.root,self._name,'PTC/Test',self.gender + '.ds'))) | |||
| results.append(loadDataset(osp.join(self.root,self._name,'PTC/Train',self.gender + '.ds'))) | |||
| return results | |||
| if self.name == "Pah": | |||
| maximum_sets = 0 | |||
| for file in list_files: | |||
| if file.endswith('ds'): | |||
| maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0])) | |||
| self.max_for_letter = maximum_sets | |||
| if not type(self.option) == int or self.option > maximum_sets or self.option < 0: | |||
| raise ValueError('option needs to be an integer between 0 and ' + str(maximum_sets)) | |||
| data = self.has_train_valid_test["Pah"] | |||
| data[0] = self.has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(self.option) + '.ds' | |||
| data[1] = self.has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(self.option) + '.ds' | |||
| return self.iter_load_dataset(data) | |||
| if self.name == "Letter": | |||
| if type(self.option) == str and self.option.upper() in self.has_train_valid_test["Letter"]: | |||
| data = self.has_train_valid_test["Letter"][self.option.upper()] | |||
| else: | |||
| message = "The parameter for letter is incorrect choose between : " | |||
| message += "\nhigh med low" | |||
| raise ValueError(message) | |||
| return self.iter_load_dataset(data) | |||
| if self.name in self.has_train_valid_test : #common IAM dataset with train, valid and test | |||
| data = self.has_train_valid_test[self.name] | |||
| return self.iter_load_dataset(data) | |||
| else: #common dataset without train,valid and test, only dataset.ds file | |||
| data = self.data_to_use_in_datasets[self.name] | |||
| if len(data) > 1 and data[0] in list_files and data[1] in list_files: #case for Alkane | |||
| return loadDataset(osp.join(self.root,self.name,data[0]),filename_y = osp.join(self.root,self.name,data[1])) | |||
| if data in list_files: | |||
| return loadDataset(osp.join(self.root,self.name,data)) | |||
| def build_dictionary(self,Gs): | |||
| labels = set() | |||
| #next line : from DeepGraphWithNNTorch | |||
| #bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]))) | |||
| sizes = set() | |||
| for G in Gs : | |||
| for _,node in G.nodes(data = True): # or for node in nx.nodes(G) | |||
| #print(_,node) | |||
| labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0]) #what do we use for IAM datasets (they don't have bond_type or event label) ? | |||
| sizes.add(G.order()) | |||
| label_dict = {} | |||
| #print("labels : ", labels, bond_type_number_maxi) | |||
| for i,label in enumerate(labels): | |||
| label_dict[label] = [0.]*len(labels) | |||
| label_dict[label][i] = 1. | |||
| return label_dict | |||
| def from_networkx_to_pytorch(self,Gs,y): | |||
| #exemple for MAO: atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]} | |||
| # code from https://github.com/bgauzere/pygnn/blob/master/utils.py | |||
| atom_to_onehot = self.build_dictionary(Gs) | |||
| max_size = 30 | |||
| adjs = [] | |||
| inputs = [] | |||
| for i, G in enumerate(Gs): | |||
| I = torch.eye(G.order(), G.order()) | |||
| #A = torch.Tensor(nx.adjacency_matrix(G).todense()) | |||
| #A = torch.Tensor(nx.to_numpy_matrix(G)) | |||
| A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int) #what do we use for IAM datasets (they don't have bond_type or event label) ? | |||
| adj = F.pad(A, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ? if yes : F.pad(A + I,pad = (...)) | |||
| adjs.append(adj) | |||
| f_0 = [] | |||
| for _, label in G.nodes(data=True): | |||
| #print(_,label) | |||
| cur_label = atom_to_onehot[label['label'][0]].copy() | |||
| f_0.append(cur_label) | |||
| X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order())) | |||
| inputs.append(X) | |||
| return inputs,adjs,y | |||
| def from_pytorch_to_tensorflow(self,batch_size): | |||
| seed = random.randrange(sys.maxsize) | |||
| random.seed(seed) | |||
| tf_inputs = random.sample(self.pytorch_dataset[0],batch_size) | |||
| random.seed(seed) | |||
| tf_y = random.sample(self.pytorch_dataset[2],batch_size) | |||
| def from_networkx_to_tensor(self,G,dict): | |||
| A=nx.to_numpy_matrix(G) | |||
| lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)] | |||
| return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab)) | |||