From 16cec4bd99d55aec8cceb06890c3f1ea5506dcce Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Fri, 15 Apr 2022 09:07:35 +0000 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E4=B8=8D=E5=BF=85=E8=A6=81?= =?UTF-8?q?=E7=9A=84=E6=B5=8B=E8=AF=95=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/drivers/paddle_driver/test.py | 25 ----------------------- tests/core/drivers/paddle_driver/test2.py | 21 ------------------- 2 files changed, 46 deletions(-) delete mode 100644 tests/core/drivers/paddle_driver/test.py delete mode 100644 tests/core/drivers/paddle_driver/test2.py diff --git a/tests/core/drivers/paddle_driver/test.py b/tests/core/drivers/paddle_driver/test.py deleted file mode 100644 index 5455a230..00000000 --- a/tests/core/drivers/paddle_driver/test.py +++ /dev/null @@ -1,25 +0,0 @@ -import sys -import os -import warnings -warnings.filterwarnings("ignore") -os.environ["FASTNLP_BACKEND"] = "torch" -sys.path.append("../../../../") - -import paddle -from fastNLP.core.samplers import RandomSampler -from fastNLP.core.drivers.paddle_driver.utils import replace_sampler, replace_batch_sampler -from tests.helpers.datasets.paddle_data import PaddleNormalDataset - -dataset = PaddleNormalDataset(20) -batch_sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=2) -batch_sampler.sampler = RandomSampler(dataset, True) -dataloader = paddle.io.DataLoader( - dataset, - batch_sampler=batch_sampler -) - -forward_steps = 9 -iter_dataloader = iter(dataloader) -for _ in range(forward_steps): - print(next(iter_dataloader)) -print(dataloader.batch_sampler.sampler.during_iter) diff --git a/tests/core/drivers/paddle_driver/test2.py b/tests/core/drivers/paddle_driver/test2.py deleted file mode 100644 index aaa3150e..00000000 --- a/tests/core/drivers/paddle_driver/test2.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -# from torch.utils.data import DataLoader, Dataset -import paddle -from paddle.io import Dataset, DataLoader -paddle.device.set_device("cpu") -class NormalDataset(Dataset): - def __init__(self, num_of_data=1000): - self.num_of_data = num_of_data - self._data = list(range(num_of_data)) - - def __len__(self): - return self.num_of_data - - def __getitem__(self, item): - return self._data[item] -dataset = NormalDataset(20) -dataloader = DataLoader(dataset, batch_size=2, use_buffer_reader=False) -for i, b in enumerate(dataloader): - print(b) - if i >= 2: - break