|
|
@@ -68,4 +68,26 @@ class TestPaddle: |
|
|
|
collate_fn = Collator(backend='auto') |
|
|
|
paddle_dl = DataLoader(ds, collate_fn=collate_fn) |
|
|
|
for batch in paddle_dl: |
|
|
|
print(batch) |
|
|
|
|
|
|
|
def test_v4(self): |
|
|
|
from paddle.io import DataLoader |
|
|
|
from fastNLP import Collator |
|
|
|
from paddle.io import Dataset |
|
|
|
import paddle |
|
|
|
|
|
|
|
class PaddleRandomMaxDataset(Dataset): |
|
|
|
def __init__(self, num_samples, num_features): |
|
|
|
self.x = paddle.randn((num_samples, num_features)) |
|
|
|
self.y = self.x.argmax(axis=-1) |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.x) |
|
|
|
|
|
|
|
def __getitem__(self, item): |
|
|
|
return {"x": self.x[item], "y": self.y[item]} |
|
|
|
|
|
|
|
ds = PaddleRandomMaxDataset(100, 2) |
|
|
|
dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) |
|
|
|
for batch in dl: |
|
|
|
print(batch) |