| @@ -18,6 +18,7 @@ from gklearn.ged.median import MedianGraphEstimator | |||||
| from gklearn.ged.median import constant_node_costs,mge_options_to_string | from gklearn.ged.median import constant_node_costs,mge_options_to_string | ||||
| from gklearn.gedlib import librariesImport, gedlibpy | from gklearn.gedlib import librariesImport, gedlibpy | ||||
| from gklearn.utils import Timer | from gklearn.utils import Timer | ||||
| from gklearn.utils.utils import get_graph_kernel_by_name | |||||
| # from gklearn.utils.dataset import Dataset | # from gklearn.utils.dataset import Dataset | ||||
| class MedianPreimageGenerator(PreimageGenerator): | class MedianPreimageGenerator(PreimageGenerator): | ||||
| @@ -81,7 +82,13 @@ class MedianPreimageGenerator(PreimageGenerator): | |||||
| def run(self): | def run(self): | ||||
| self.__set_graph_kernel_by_name() | |||||
| self._graph_kernel = get_graph_kernel_by_name(self._kernel_options['name'], | |||||
| node_labels=self._dataset.node_labels, | |||||
| edge_labels=self._dataset.edge_labels, | |||||
| node_attrs=self._dataset.node_attrs, | |||||
| edge_attrs=self._dataset.edge_attrs, | |||||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||||
| **self._kernel_options) | |||||
| # record start time. | # record start time. | ||||
| start = time.time() | start = time.time() | ||||
| @@ -722,43 +729,6 @@ class MedianPreimageGenerator(PreimageGenerator): | |||||
| print('distance in kernel space for generalized median:', self.__k_dis_gen_median) | print('distance in kernel space for generalized median:', self.__k_dis_gen_median) | ||||
| print('minimum distance in kernel space for each graph in median set:', self.__k_dis_dataset) | print('minimum distance in kernel space for each graph in median set:', self.__k_dis_dataset) | ||||
| print('distance in kernel space for each graph in median set:', k_dis_median_set) | print('distance in kernel space for each graph in median set:', k_dis_median_set) | ||||
| def __set_graph_kernel_by_name(self): | |||||
| if self._kernel_options['name'] == 'ShortestPath': | |||||
| from gklearn.kernels import ShortestPath | |||||
| self._graph_kernel = ShortestPath(node_labels=self._dataset.node_labels, | |||||
| node_attrs=self._dataset.node_attrs, | |||||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||||
| **self._kernel_options) | |||||
| elif self._kernel_options['name'] == 'StructuralSP': | |||||
| from gklearn.kernels import StructuralSP | |||||
| self._graph_kernel = StructuralSP(node_labels=self._dataset.node_labels, | |||||
| edge_labels=self._dataset.edge_labels, | |||||
| node_attrs=self._dataset.node_attrs, | |||||
| edge_attrs=self._dataset.edge_attrs, | |||||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||||
| **self._kernel_options) | |||||
| elif self._kernel_options['name'] == 'PathUpToH': | |||||
| from gklearn.kernels import PathUpToH | |||||
| self._graph_kernel = PathUpToH(node_labels=self._dataset.node_labels, | |||||
| edge_labels=self._dataset.edge_labels, | |||||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||||
| **self._kernel_options) | |||||
| elif self._kernel_options['name'] == 'Treelet': | |||||
| from gklearn.kernels import Treelet | |||||
| self._graph_kernel = Treelet(node_labels=self._dataset.node_labels, | |||||
| edge_labels=self._dataset.edge_labels, | |||||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||||
| **self._kernel_options) | |||||
| elif self._kernel_options['name'] == 'WeisfeilerLehman': | |||||
| from gklearn.kernels import WeisfeilerLehman | |||||
| self._graph_kernel = WeisfeilerLehman(node_labels=self._dataset.node_labels, | |||||
| edge_labels=self._dataset.edge_labels, | |||||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||||
| **self._kernel_options) | |||||
| else: | |||||
| raise Exception('The graph kernel given is not defined. Possible choices include: "StructuralSP", "ShortestPath", "PathUpToH", "Treelet", "WeisfeilerLehman".') | |||||
| # def __clean_graph(self, G, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]): | # def __clean_graph(self, G, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]): | ||||