diff --git a/tests/helpers/datasets/paddle_data.py b/tests/helpers/datasets/paddle_data.py index 1505e72d..f00c8d95 100644 --- a/tests/helpers/datasets/paddle_data.py +++ b/tests/helpers/datasets/paddle_data.py @@ -16,19 +16,12 @@ class PaddleNormalDataset(Dataset): class PaddleRandomDataset(Dataset): - def __init__(self, num_of_data=1000, features=64, labels=10): - self.num_of_data = num_of_data - self.x = [ - paddle.rand((features,)) - for i in range(num_of_data) - ] - self.y = [ - paddle.rand((labels,)) - for i in range(num_of_data) - ] + def __init__(self, num_samples, num_features): + self.x = paddle.randn((num_samples, num_features)) + self.y = self.x.argmax(axis=-1) def __len__(self): - return self.num_of_data + return len(self.x) def __getitem__(self, item): return {"x": self.x[item], "y": self.y[item]}