Browse Source

set_dist_repro_dataloader测试例的完善

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
cf19062fb2
4 changed files with 129 additions and 53 deletions
  1. +25
    -0
      tests/core/drivers/paddle_driver/test.py
  2. +21
    -0
      tests/core/drivers/paddle_driver/test2.py
  3. +39
    -28
      tests/core/drivers/paddle_driver/test_fleet.py
  4. +44
    -25
      tests/core/drivers/paddle_driver/test_single_device.py

+ 25
- 0
tests/core/drivers/paddle_driver/test.py View File

@@ -0,0 +1,25 @@
import sys
import os
import warnings
warnings.filterwarnings("ignore")
os.environ["FASTNLP_BACKEND"] = "torch"
sys.path.append("../../../../")

import paddle
from fastNLP.core.samplers import RandomSampler
from fastNLP.core.drivers.paddle_driver.utils import replace_sampler, replace_batch_sampler
from tests.helpers.datasets.paddle_data import PaddleNormalDataset

dataset = PaddleNormalDataset(20)
batch_sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = paddle.io.DataLoader(
dataset,
batch_sampler=batch_sampler
)

forward_steps = 9
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
print(next(iter_dataloader))
print(dataloader.batch_sampler.sampler.during_iter)

+ 21
- 0
tests/core/drivers/paddle_driver/test2.py View File

@@ -0,0 +1,21 @@
import torch
# from torch.utils.data import DataLoader, Dataset
import paddle
from paddle.io import Dataset, DataLoader
paddle.device.set_device("cpu")
class NormalDataset(Dataset):
def __init__(self, num_of_data=1000):
self.num_of_data = num_of_data
self._data = list(range(num_of_data))

def __len__(self):
return self.num_of_data

def __getitem__(self, item):
return self._data[item]
dataset = NormalDataset(20)
dataloader = DataLoader(dataset, batch_size=2, use_buffer_reader=False)
for i, b in enumerate(dataloader):
print(b)
if i >= 2:
break

+ 39
- 28
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -117,12 +117,13 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)

assert not (replaced_loader is dataloader)
@@ -133,12 +134,13 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
sampler = RandomSampler(self.dataset, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)

assert not (replaced_loader is dataloader)
@@ -171,14 +173,15 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
时的表现
"""
dataloader = DataLoader(
self.dataset,
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4),
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle),
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
@@ -195,12 +198,13 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
@@ -222,11 +226,12 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)

assert replaced_loader is dataloader
@@ -238,14 +243,15 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
的表现
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

@@ -258,13 +264,14 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -276,16 +283,17 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
@@ -293,7 +301,7 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
dist.barrier()

"""
@@ -302,13 +310,14 @@ class TestSetDistReproDataloader:
"""

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -320,18 +329,19 @@ class TestSetDistReproDataloader:
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -349,11 +359,12 @@ class TestSetDistReproDataloader:
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)


+ 44
- 25
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -1,4 +1,5 @@
import os
from re import S
os.environ["FASTNLP_BACKEND"] = "paddle"
import pytest
from pathlib import Path
@@ -283,30 +284,32 @@ class TestSetDistReproDataloder:
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler is dist

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dist_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现
应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler
"""
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)
dist = RandomSampler(self.dataset, shuffle=True)
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)

assert not (replaced_loader is dataloader)
@@ -316,16 +319,21 @@ class TestSetDistReproDataloder:
assert replaced_loader.batch_sampler.sampler is dist
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False)
batch_sampler=RandomBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4,
drop_last=False,
)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)

@@ -335,15 +343,16 @@ class TestSetDistReproDataloder:
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self):
@pytest.mark.parametrize("shuffle", ([True, False]))
def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle):
"""
测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现
应该返回新的 dataloader,且其余各项设置和原来相同
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle)
batch_sampler.sampler = RandomSampler(self.dataset, shuffle)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
@@ -355,11 +364,11 @@ class TestSetDistReproDataloder:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle

self.check_set_dist_repro_dataloader(dataloader, replaced_loader)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)

def check_set_dist_repro_dataloader(self, dataloader, replaced_loader):
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
"""
测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确
"""
@@ -378,9 +387,6 @@ class TestSetDistReproDataloder:
# 加载 num_consumed_samples_array,设置正确取出的 batch 数目
num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None)

import time
time.sleep(5)

# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
@@ -389,16 +395,29 @@ class TestSetDistReproDataloder:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.load_state_dict(sampler_states)
# 重新改造 dataloader
new_loader = DataLoader(
dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size,
drop_last=False,
)
)
new_loader.batch_sampler.load_state_dict(sampler_states)
else:
batch_size = replaced_loader.batch_sampler.batch_size
num_consumed_batches = num_consumed_batches * batch_size
if num_consumed_samples_array is not None:
sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches]
else:
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states)
replaced_loader.batch_sampler.sampler.set_epoch(0)
for idx, batch in enumerate(replaced_loader):
# 重新构造 dataloader
batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size)
batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle)
new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler)
new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader):
left_idxes.update(batch)

assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)


Loading…
Cancel
Save