| @@ -14,7 +14,7 @@ from gklearn.dataset import DATASET_META, DataFetcher, DataLoader | |||||
| class Dataset(object): | class Dataset(object): | ||||
| def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs): | |||||
| def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', remove_null_graphs=True, clean_labels=True, reload=False, verbose=False, **kwargs): | |||||
| self._substructures = None | self._substructures = None | ||||
| self._node_label_dim = None | self._node_label_dim = None | ||||
| self._edge_label_dim = None | self._edge_label_dim = None | ||||
| @@ -82,6 +82,8 @@ class Dataset(object): | |||||
| else: | else: | ||||
| raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.') | raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.') | ||||
| if remove_null_graphs: | |||||
| self.trim_dataset(edge_required=False) | |||||
| def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs): | def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs): | ||||
| @@ -537,7 +539,7 @@ class Dataset(object): | |||||
| def trim_dataset(self, edge_required=False): | def trim_dataset(self, edge_required=False): | ||||
| if edge_required: | |||||
| if edge_required: # @todo: there is a possibility that some node labels will be removed. | |||||
| trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if (nx.number_of_nodes(g) != 0 and nx.number_of_edges(g) != 0)] | trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if (nx.number_of_nodes(g) != 0 and nx.number_of_edges(g) != 0)] | ||||
| else: | else: | ||||
| trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if nx.number_of_nodes(g) != 0] | trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if nx.number_of_nodes(g) != 0] | ||||