Browse Source

修改得到数据list的方式

master
shenyan 4 years ago
parent
commit
f23c6ce854
2 changed files with 8 additions and 4 deletions
  1. +3
    -2
      data_module.py
  2. +5
    -2
      utils.py

+ 3
- 2
data_module.py View File

@@ -34,10 +34,11 @@ class DataModule(pl.LightningDataModule):
dataset[:, self.config['dim_in']] = torch.cos(1.5 * dataset[:, 0]) * (dataset[:, 1] ** 2.0) + torch.cos( dataset[:, self.config['dim_in']] = torch.cos(1.5 * dataset[:, 0]) * (dataset[:, 1] ** 2.0) + torch.cos(
torch.sin(dataset[:, 2] ** 3)) + torch.arctan(dataset[:, 4]) + noise torch.sin(dataset[:, 2] ** 3)) + torch.arctan(dataset[:, 4]) + noise
assert (dataset[torch.isnan(dataset)].shape[0] == 0) assert (dataset[torch.isnan(dataset)].shape[0] == 0)
written = [' '.join([str(temp) for temp in dataset[cou, :].tolist()]) for cou in range(dataset.shape[0])]


with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
for line in range(self.config['dataset_len']):
f.write(' '.join([str(temp) for temp in dataset[line].tolist()]) + '\n')
for line in written:
f.write(line + '\n')
print('已生成新的数据list') print('已生成新的数据list')
else: else:
dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines() dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines()


+ 5
- 2
utils.py View File

@@ -11,9 +11,12 @@ def get_dataset_list(dataset_path):
if not os.path.exists(dataset_path + '/dataset_list.txt'): if not os.path.exists(dataset_path + '/dataset_list.txt'):
all_list = glob.glob(dataset_path + '/labels' + '/*.png') all_list = glob.glob(dataset_path + '/labels' + '/*.png')
random.shuffle(all_list) random.shuffle(all_list)
all_list = [os.path.basename(item.replace('\\', '/')) for item in all_list]
written = all_list

with open(dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: with open(dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f:
for line in all_list:
f.write(os.path.basename(line.replace('\\', '/')) + '\n')
for line in written:
f.write(line + '\n')
print('已生成新的数据list') print('已生成新的数据list')
return all_list return all_list
else: else:


Loading…
Cancel
Save