Browse Source

修改paddle fdl测试用例

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
81ef1ec4b7
1 changed files with 22 additions and 0 deletions
  1. +22
    -0
      tests/core/dataloaders/paddle_dataloader/test_fdl.py

+ 22
- 0
tests/core/dataloaders/paddle_dataloader/test_fdl.py View File

@@ -68,4 +68,26 @@ class TestPaddle:
collate_fn = Collator(backend='auto')
paddle_dl = DataLoader(ds, collate_fn=collate_fn)
for batch in paddle_dl:
print(batch)

def test_v4(self):
from paddle.io import DataLoader
from fastNLP import Collator
from paddle.io import Dataset
import paddle

class PaddleRandomMaxDataset(Dataset):
def __init__(self, num_samples, num_features):
self.x = paddle.randn((num_samples, num_features))
self.y = self.x.argmax(axis=-1)

def __len__(self):
return len(self.x)

def __getitem__(self, item):
return {"x": self.x[item], "y": self.y[item]}

ds = PaddleRandomMaxDataset(100, 2)
dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4)
for batch in dl:
print(batch)

Loading…
Cancel
Save