Browse Source

新的paddle测试数据集

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
6ef912675b
1 changed files with 4 additions and 11 deletions
  1. +4
    -11
      tests/helpers/datasets/paddle_data.py

+ 4
- 11
tests/helpers/datasets/paddle_data.py View File

@@ -16,19 +16,12 @@ class PaddleNormalDataset(Dataset):


class PaddleRandomDataset(Dataset):
def __init__(self, num_of_data=1000, features=64, labels=10):
self.num_of_data = num_of_data
self.x = [
paddle.rand((features,))
for i in range(num_of_data)
]
self.y = [
paddle.rand((labels,))
for i in range(num_of_data)
]
def __init__(self, num_samples, num_features):
self.x = paddle.randn((num_samples, num_features))
self.y = self.x.argmax(axis=-1)

def __len__(self):
return self.num_of_data
return len(self.x)

def __getitem__(self, item):
return {"x": self.x[item], "y": self.y[item]}


Loading…
Cancel
Save