|
|
@@ -1,25 +0,0 @@ |
|
|
|
import sys |
|
|
|
import os |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
os.environ["FASTNLP_BACKEND"] = "torch" |
|
|
|
sys.path.append("../../../../") |
|
|
|
|
|
|
|
import paddle |
|
|
|
from fastNLP.core.samplers import RandomSampler |
|
|
|
from fastNLP.core.drivers.paddle_driver.utils import replace_sampler, replace_batch_sampler |
|
|
|
from tests.helpers.datasets.paddle_data import PaddleNormalDataset |
|
|
|
|
|
|
|
dataset = PaddleNormalDataset(20) |
|
|
|
batch_sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=2) |
|
|
|
batch_sampler.sampler = RandomSampler(dataset, True) |
|
|
|
dataloader = paddle.io.DataLoader( |
|
|
|
dataset, |
|
|
|
batch_sampler=batch_sampler |
|
|
|
) |
|
|
|
|
|
|
|
forward_steps = 9 |
|
|
|
iter_dataloader = iter(dataloader) |
|
|
|
for _ in range(forward_steps): |
|
|
|
print(next(iter_dataloader)) |
|
|
|
print(dataloader.batch_sampler.sampler.during_iter) |