| @@ -18,6 +18,7 @@ from gklearn.ged.median import MedianGraphEstimator | |||
| from gklearn.ged.median import constant_node_costs,mge_options_to_string | |||
| from gklearn.gedlib import librariesImport, gedlibpy | |||
| from gklearn.utils import Timer | |||
| from gklearn.utils.utils import get_graph_kernel_by_name | |||
| # from gklearn.utils.dataset import Dataset | |||
| class MedianPreimageGenerator(PreimageGenerator): | |||
| @@ -81,7 +82,13 @@ class MedianPreimageGenerator(PreimageGenerator): | |||
| def run(self): | |||
| self.__set_graph_kernel_by_name() | |||
| self._graph_kernel = get_graph_kernel_by_name(self._kernel_options['name'], | |||
| node_labels=self._dataset.node_labels, | |||
| edge_labels=self._dataset.edge_labels, | |||
| node_attrs=self._dataset.node_attrs, | |||
| edge_attrs=self._dataset.edge_attrs, | |||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||
| **self._kernel_options) | |||
| # record start time. | |||
| start = time.time() | |||
| @@ -722,43 +729,6 @@ class MedianPreimageGenerator(PreimageGenerator): | |||
| print('distance in kernel space for generalized median:', self.__k_dis_gen_median) | |||
| print('minimum distance in kernel space for each graph in median set:', self.__k_dis_dataset) | |||
| print('distance in kernel space for each graph in median set:', k_dis_median_set) | |||
| def __set_graph_kernel_by_name(self): | |||
| if self._kernel_options['name'] == 'ShortestPath': | |||
| from gklearn.kernels import ShortestPath | |||
| self._graph_kernel = ShortestPath(node_labels=self._dataset.node_labels, | |||
| node_attrs=self._dataset.node_attrs, | |||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||
| **self._kernel_options) | |||
| elif self._kernel_options['name'] == 'StructuralSP': | |||
| from gklearn.kernels import StructuralSP | |||
| self._graph_kernel = StructuralSP(node_labels=self._dataset.node_labels, | |||
| edge_labels=self._dataset.edge_labels, | |||
| node_attrs=self._dataset.node_attrs, | |||
| edge_attrs=self._dataset.edge_attrs, | |||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||
| **self._kernel_options) | |||
| elif self._kernel_options['name'] == 'PathUpToH': | |||
| from gklearn.kernels import PathUpToH | |||
| self._graph_kernel = PathUpToH(node_labels=self._dataset.node_labels, | |||
| edge_labels=self._dataset.edge_labels, | |||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||
| **self._kernel_options) | |||
| elif self._kernel_options['name'] == 'Treelet': | |||
| from gklearn.kernels import Treelet | |||
| self._graph_kernel = Treelet(node_labels=self._dataset.node_labels, | |||
| edge_labels=self._dataset.edge_labels, | |||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||
| **self._kernel_options) | |||
| elif self._kernel_options['name'] == 'WeisfeilerLehman': | |||
| from gklearn.kernels import WeisfeilerLehman | |||
| self._graph_kernel = WeisfeilerLehman(node_labels=self._dataset.node_labels, | |||
| edge_labels=self._dataset.edge_labels, | |||
| ds_infos=self._dataset.get_dataset_infos(keys=['directed']), | |||
| **self._kernel_options) | |||
| else: | |||
| raise Exception('The graph kernel given is not defined. Possible choices include: "StructuralSP", "ShortestPath", "PathUpToH", "Treelet", "WeisfeilerLehman".') | |||
| # def __clean_graph(self, G, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]): | |||