diff --git a/data_module.py b/data_module.py index 38ba325..831a695 100644 --- a/data_module.py +++ b/data_module.py @@ -34,10 +34,11 @@ class DataModule(pl.LightningDataModule): dataset[:, self.config['dim_in']] = torch.cos(1.5 * dataset[:, 0]) * (dataset[:, 1] ** 2.0) + torch.cos( torch.sin(dataset[:, 2] ** 3)) + torch.arctan(dataset[:, 4]) + noise assert (dataset[torch.isnan(dataset)].shape[0] == 0) + written = [' '.join([str(temp) for temp in dataset[cou, :].tolist()]) for cou in range(dataset.shape[0])] with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: - for line in range(self.config['dataset_len']): - f.write(' '.join([str(temp) for temp in dataset[line].tolist()]) + '\n') + for line in written: + f.write(line + '\n') print('已生成新的数据list') else: dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines() diff --git a/utils.py b/utils.py index f3b2ffa..41e8499 100644 --- a/utils.py +++ b/utils.py @@ -11,9 +11,12 @@ def get_dataset_list(dataset_path): if not os.path.exists(dataset_path + '/dataset_list.txt'): all_list = glob.glob(dataset_path + '/labels' + '/*.png') random.shuffle(all_list) + all_list = [os.path.basename(item.replace('\\', '/')) for item in all_list] + written = all_list + with open(dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: - for line in all_list: - f.write(os.path.basename(line.replace('\\', '/')) + '\n') + for line in written: + f.write(line + '\n') print('已生成新的数据list') return all_list else: