| @@ -40,6 +40,7 @@ class Dataset(object): | |||
| self._edge_attr_dim = None | |||
| self._class_number = None | |||
| self._ds_name = None | |||
| self._task_type = None | |||
| if inputs is None: | |||
| self._graphs = None | |||
| @@ -117,11 +118,16 @@ class Dataset(object): | |||
| ds_file = [os.path.join(path, fn) for fn in load_files[0]] | |||
| fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None | |||
| # Get extra_params. | |||
| if 'extra_params' in DATASET_META[ds_name]: | |||
| kwargs = DATASET_META[ds_name]['extra_params'] | |||
| else: | |||
| kwargs = {} | |||
| # Get the task type that is associated with the dataset. If it is classification, get the number of classes. | |||
| self._get_task_type(ds_name) | |||
| self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets, **kwargs).data | |||
| self._node_labels = label_names['node_labels'] | |||
| @@ -276,7 +282,8 @@ class Dataset(object): | |||
| 'edge_attr_dim', | |||
| 'class_number', | |||
| 'all_degree_entropy', | |||
| 'ave_degree_entropy' | |||
| 'ave_degree_entropy', | |||
| 'class_type' | |||
| ] | |||
| # dataset size | |||
| @@ -408,7 +415,7 @@ class Dataset(object): | |||
| if 'class_number' in keys: | |||
| if self._class_number is None: | |||
| self._class_number = self._get_class_number() | |||
| self._class_number = self._get_class_num() | |||
| infos['class_number'] = self._class_number | |||
| if 'node_attr_dim' in keys: | |||
| @@ -437,6 +444,11 @@ class Dataset(object): | |||
| base = None | |||
| infos['ave_degree_entropy'] = np.mean(self._compute_all_degree_entropy(base=base)) | |||
| if 'task_type' in keys: | |||
| if self._task_type is None: | |||
| self._task_type = self._get_task_type() | |||
| infos['task_type'] = self._task_type | |||
| return infos | |||
| @@ -790,6 +802,13 @@ class Dataset(object): | |||
| return degree_entropy | |||
| def _get_task_type(self, ds_name): | |||
| if 'task_type' in DATASET_META[ds_name]: | |||
| self._task_type = DATASET_META[ds_name]['task_type'] | |||
| if self._task_type == 'classification' and self._class_number is None and 'class_number' in DATASET_META[ds_name]: | |||
| self._class_number = DATASET_META[ds_name]['class_number'] | |||
| @property | |||
| def graphs(self): | |||
| return self._graphs | |||