|
|
@@ -34,10 +34,11 @@ class DataModule(pl.LightningDataModule): |
|
|
dataset[:, self.config['dim_in']] = torch.cos(1.5 * dataset[:, 0]) * (dataset[:, 1] ** 2.0) + torch.cos( |
|
|
dataset[:, self.config['dim_in']] = torch.cos(1.5 * dataset[:, 0]) * (dataset[:, 1] ** 2.0) + torch.cos( |
|
|
torch.sin(dataset[:, 2] ** 3)) + torch.arctan(dataset[:, 4]) + noise |
|
|
torch.sin(dataset[:, 2] ** 3)) + torch.arctan(dataset[:, 4]) + noise |
|
|
assert (dataset[torch.isnan(dataset)].shape[0] == 0) |
|
|
assert (dataset[torch.isnan(dataset)].shape[0] == 0) |
|
|
|
|
|
written = [' '.join([str(temp) for temp in dataset[cou, :].tolist()]) for cou in range(dataset.shape[0])] |
|
|
|
|
|
|
|
|
with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: |
|
|
with open(self.dataset_path + '/dataset_list.txt', 'w', encoding='utf-8') as f: |
|
|
for line in range(self.config['dataset_len']): |
|
|
|
|
|
f.write(' '.join([str(temp) for temp in dataset[line].tolist()]) + '\n') |
|
|
|
|
|
|
|
|
for line in written: |
|
|
|
|
|
f.write(line + '\n') |
|
|
print('已生成新的数据list') |
|
|
print('已生成新的数据list') |
|
|
else: |
|
|
else: |
|
|
dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines() |
|
|
dataset_list = open(self.dataset_path + '/dataset_list.txt').readlines() |
|
|
|