diff --git a/tests/core/drivers/paddle_driver/test.py b/tests/core/drivers/paddle_driver/test.py new file mode 100644 index 00000000..5455a230 --- /dev/null +++ b/tests/core/drivers/paddle_driver/test.py @@ -0,0 +1,25 @@ +import sys +import os +import warnings +warnings.filterwarnings("ignore") +os.environ["FASTNLP_BACKEND"] = "torch" +sys.path.append("../../../../") + +import paddle +from fastNLP.core.samplers import RandomSampler +from fastNLP.core.drivers.paddle_driver.utils import replace_sampler, replace_batch_sampler +from tests.helpers.datasets.paddle_data import PaddleNormalDataset + +dataset = PaddleNormalDataset(20) +batch_sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=2) +batch_sampler.sampler = RandomSampler(dataset, True) +dataloader = paddle.io.DataLoader( + dataset, + batch_sampler=batch_sampler +) + +forward_steps = 9 +iter_dataloader = iter(dataloader) +for _ in range(forward_steps): + print(next(iter_dataloader)) +print(dataloader.batch_sampler.sampler.during_iter) diff --git a/tests/core/drivers/paddle_driver/test2.py b/tests/core/drivers/paddle_driver/test2.py new file mode 100644 index 00000000..aaa3150e --- /dev/null +++ b/tests/core/drivers/paddle_driver/test2.py @@ -0,0 +1,21 @@ +import torch +# from torch.utils.data import DataLoader, Dataset +import paddle +from paddle.io import Dataset, DataLoader +paddle.device.set_device("cpu") +class NormalDataset(Dataset): + def __init__(self, num_of_data=1000): + self.num_of_data = num_of_data + self._data = list(range(num_of_data)) + + def __len__(self): + return self.num_of_data + + def __getitem__(self, item): + return self._data[item] +dataset = NormalDataset(20) +dataloader = DataLoader(dataset, batch_size=2, use_buffer_reader=False) +for i, b in enumerate(dataloader): + print(b) + if i >= 2: + break diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 434e9e5b..de98f9c5 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -117,12 +117,13 @@ class TestSetDistReproDataloader: """ @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 """ - dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) - batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) + batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) assert not (replaced_loader is dataloader) @@ -133,12 +134,13 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 """ - dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) - sampler = RandomSampler(self.dataset, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) + sampler = RandomSampler(self.dataset, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) assert not (replaced_loader is dataloader) @@ -171,14 +173,15 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler 时的表现 """ dataloader = DataLoader( self.dataset, - batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4), + batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle), ) dataloader.batch_sampler.set_distributed( num_replicas=self.driver.world_size, @@ -195,12 +198,13 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 """ batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) - batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler = RandomSampler(self.dataset, shuffle) batch_sampler.sampler.set_distributed( num_replicas=self.driver.world_size, rank=self.driver.global_rank @@ -222,11 +226,12 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 """ - dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) assert replaced_loader is dataloader @@ -238,14 +243,15 @@ class TestSetDistReproDataloader: """ @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler 的表现 """ dataloader = DataLoader( dataset=self.dataset, - batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) + batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) ) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) @@ -258,13 +264,14 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 的表现 """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) - batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle) + batch_sampler.sampler = RandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, batch_sampler=batch_sampler @@ -276,16 +283,17 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert replaced_loader.batch_sampler.batch_size == 2 - assert replaced_loader.batch_sampler.sampler.shuffle == True + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 """ - dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) @@ -293,7 +301,7 @@ class TestSetDistReproDataloader: assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size - assert replaced_loader.batch_sampler.sampler.shuffle == True + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle dist.barrier() """ @@ -302,13 +310,14 @@ class TestSetDistReproDataloader: """ @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler 的表现 """ batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) - batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler = RandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, batch_sampler=batch_sampler @@ -320,18 +329,19 @@ class TestSetDistReproDataloader: assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) assert replaced_loader.batch_sampler.batch_size == 2 - assert replaced_loader.batch_sampler.sampler.shuffle == True + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler 的表现 """ batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) - batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True) + batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, batch_sampler=batch_sampler @@ -349,11 +359,12 @@ class TestSetDistReproDataloader: dist.barrier() @magic_argv_env_context - def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self, shuffle): """ 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 """ - dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index fd947c73..ebd4721b 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -1,4 +1,5 @@ import os +from re import S os.environ["FASTNLP_BACKEND"] = "paddle" import pytest from pathlib import Path @@ -283,30 +284,32 @@ class TestSetDistReproDataloder: assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last - self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) - def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler """ - dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) - dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) + dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) assert replaced_loader.batch_sampler is dist - self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) - def test_set_dist_repro_dataloader_with_dist_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dist_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler """ - dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) - dist = RandomSampler(self.dataset, shuffle=True) + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) + dist = RandomSampler(self.dataset, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) assert not (replaced_loader is dataloader) @@ -316,16 +319,21 @@ class TestSetDistReproDataloder: assert replaced_loader.batch_sampler.sampler is dist assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size - self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) - def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dataloader_reproducible_batch_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 应该返回新的 dataloader,且其余各项设置和原来相同 """ dataloader = DataLoader( dataset=self.dataset, - batch_sampler=RandomBatchSampler(BatchSampler(self.dataset, batch_size=4), 4, False) + batch_sampler=RandomBatchSampler( + BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), + batch_size=4, + drop_last=False, + ) ) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) @@ -335,15 +343,16 @@ class TestSetDistReproDataloder: assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last - self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) - def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self): + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_set_dist_repro_dataloader_with_dataloader_reproducible_sampler(self, shuffle): """ 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 应该返回新的 dataloader,且其余各项设置和原来相同 """ - batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) - batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2, shuffle=shuffle) + batch_sampler.sampler = RandomSampler(self.dataset, shuffle) dataloader = DataLoader( self.dataset, batch_sampler=batch_sampler @@ -355,11 +364,11 @@ class TestSetDistReproDataloder: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) assert replaced_loader.batch_sampler.batch_size == 2 - assert replaced_loader.batch_sampler.sampler.shuffle == True + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle - self.check_set_dist_repro_dataloader(dataloader, replaced_loader) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) - def check_set_dist_repro_dataloader(self, dataloader, replaced_loader): + def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): """ 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 """ @@ -378,9 +387,6 @@ class TestSetDistReproDataloder: # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) - import time - time.sleep(5) - # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range left_idxes = set() if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): @@ -389,16 +395,29 @@ class TestSetDistReproDataloder: sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] else: sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size - replaced_loader.batch_sampler.load_state_dict(sampler_states) + # 重新改造 dataloader + new_loader = DataLoader( + dataset=replaced_loader.dataset, + batch_sampler=RandomBatchSampler( + BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size), + batch_size=batch_size, + drop_last=False, + ) + ) + new_loader.batch_sampler.load_state_dict(sampler_states) else: batch_size = replaced_loader.batch_sampler.batch_size + num_consumed_batches = num_consumed_batches * batch_size if num_consumed_samples_array is not None: sampler_states["num_consumed_samples"] = num_consumed_samples_array[num_consumed_batches] else: sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size - replaced_loader.batch_sampler.sampler.load_state_dict(sampler_states) - replaced_loader.batch_sampler.sampler.set_epoch(0) - for idx, batch in enumerate(replaced_loader): + # 重新构造 dataloader + batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) + batch_sampler.sampler = RandomSampler(replaced_loader.dataset, shuffle=shuffle) + new_loader = DataLoader(replaced_loader.dataset, batch_sampler=batch_sampler) + new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + for idx, batch in enumerate(new_loader): left_idxes.update(batch) assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)