| @@ -40,6 +40,7 @@ class Dataset(object): | |||||
| self._edge_attr_dim = None | self._edge_attr_dim = None | ||||
| self._class_number = None | self._class_number = None | ||||
| self._ds_name = None | self._ds_name = None | ||||
| self._task_type = None | |||||
| if inputs is None: | if inputs is None: | ||||
| self._graphs = None | self._graphs = None | ||||
| @@ -117,11 +118,16 @@ class Dataset(object): | |||||
| ds_file = [os.path.join(path, fn) for fn in load_files[0]] | ds_file = [os.path.join(path, fn) for fn in load_files[0]] | ||||
| fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None | fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None | ||||
| # Get extra_params. | |||||
| if 'extra_params' in DATASET_META[ds_name]: | if 'extra_params' in DATASET_META[ds_name]: | ||||
| kwargs = DATASET_META[ds_name]['extra_params'] | kwargs = DATASET_META[ds_name]['extra_params'] | ||||
| else: | else: | ||||
| kwargs = {} | kwargs = {} | ||||
| # Get the task type that is associated with the dataset. If it is classification, get the number of classes. | |||||
| self._get_task_type(ds_name) | |||||
| self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets, **kwargs).data | self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets, **kwargs).data | ||||
| self._node_labels = label_names['node_labels'] | self._node_labels = label_names['node_labels'] | ||||
| @@ -276,7 +282,8 @@ class Dataset(object): | |||||
| 'edge_attr_dim', | 'edge_attr_dim', | ||||
| 'class_number', | 'class_number', | ||||
| 'all_degree_entropy', | 'all_degree_entropy', | ||||
| 'ave_degree_entropy' | |||||
| 'ave_degree_entropy', | |||||
| 'class_type' | |||||
| ] | ] | ||||
| # dataset size | # dataset size | ||||
| @@ -408,7 +415,7 @@ class Dataset(object): | |||||
| if 'class_number' in keys: | if 'class_number' in keys: | ||||
| if self._class_number is None: | if self._class_number is None: | ||||
| self._class_number = self._get_class_number() | |||||
| self._class_number = self._get_class_num() | |||||
| infos['class_number'] = self._class_number | infos['class_number'] = self._class_number | ||||
| if 'node_attr_dim' in keys: | if 'node_attr_dim' in keys: | ||||
| @@ -437,6 +444,11 @@ class Dataset(object): | |||||
| base = None | base = None | ||||
| infos['ave_degree_entropy'] = np.mean(self._compute_all_degree_entropy(base=base)) | infos['ave_degree_entropy'] = np.mean(self._compute_all_degree_entropy(base=base)) | ||||
| if 'task_type' in keys: | |||||
| if self._task_type is None: | |||||
| self._task_type = self._get_task_type() | |||||
| infos['task_type'] = self._task_type | |||||
| return infos | return infos | ||||
| @@ -790,6 +802,13 @@ class Dataset(object): | |||||
| return degree_entropy | return degree_entropy | ||||
| def _get_task_type(self, ds_name): | |||||
| if 'task_type' in DATASET_META[ds_name]: | |||||
| self._task_type = DATASET_META[ds_name]['task_type'] | |||||
| if self._task_type == 'classification' and self._class_number is None and 'class_number' in DATASET_META[ds_name]: | |||||
| self._class_number = DATASET_META[ds_name]['class_number'] | |||||
| @property | @property | ||||
| def graphs(self): | def graphs(self): | ||||
| return self._graphs | return self._graphs | ||||