|
- import unittest
-
- from itertools import product
- import numpy as np
-
- from functools import partial
- from array import array
-
- from fastNLP.core.samplers.reproducible_sampler import RandomSampler
- from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
- from tests.helpers.datasets.torch_data import TorchNormalDataset
-
-
-
- class TestRandomSamplerYh(unittest.TestCase):
- def test_init(self):
- # 测试能否正确初始化
- dataset = TorchNormalDataset(num_of_data=100)
- sampler = RandomSampler(dataset)
- for i in sampler:
- pass
-
- def test_during_iter(self):
- dataset = TorchNormalDataset(num_of_data=100)
- sampler = RandomSampler(dataset)
- for i in sampler:
- with self.assertRaises(AssertionError):
- sampler.set_distributed(1, 0)
- break
-
- # should not raise
- for i in sampler:
- pass
- sampler.set_distributed(1, 0)
-
- def test_set_distributed(self):
- dataset = TorchNormalDataset(num_of_data=100)
- sampler = RandomSampler(dataset, shuffle=False)
- sampler.set_distributed(num_replicas=2, rank=0, pad=False)
- self.assertEqual(len(sampler), 50)
- count = 0
- for i in sampler:
- self.assertEqual(i%2, 0)
- count += 1
- self.assertEqual(count, 50)
-
- sampler.set_distributed(num_replicas=2, rank=1, pad=False)
- self.assertEqual(len(sampler), 50)
- count = 0
- for i in sampler:
- self.assertEqual(i%2, 1)
- count += 1
- self.assertEqual(count, 50)
-
- dataset = TorchNormalDataset(num_of_data=101)
- sampler = RandomSampler(dataset, shuffle=False)
- sampler.set_distributed(num_replicas=2, rank=0, pad=True)
- self.assertEqual(len(sampler), 51)
- count = 0
- for i in sampler:
- self.assertEqual(i%2, 0)
- count += 1
- self.assertEqual(count, 51)
-
- sampler.set_distributed(num_replicas=2, rank=1, pad=True)
- self.assertEqual(len(sampler), 51)
- count = 0
- for i in sampler:
- if i!=0:
- self.assertEqual(i%2, 1)
- count += 1
- self.assertEqual(count, 51)
-
- def test_state_dict_check_length(self):
- dataset = TorchNormalDataset(num_of_data=100)
- sampler = RandomSampler(dataset, shuffle=False)
- states = sampler.state_dict()
-
- new_ds = TorchNormalDataset(num_of_data=10)
- with self.assertRaises(AssertionError):
- new_sampler = RandomSampler(new_ds)
- new_sampler.load_state_dict(states)
-
- new_ds = TorchNormalDataset(num_of_data=100)
- new_sampler = RandomSampler(new_ds)
- new_sampler.load_state_dict(states)
-
- def test_state_dict(self):
- num_samples = 100
- dataset = TorchNormalDataset(num_of_data=num_samples)
- # 测试使用 前后shuffle不一致的load操作
- lst = [0]+np.random.randint(1, num_samples, size=3).tolist()
- for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False],
- lst):
- with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples):
- sampler = RandomSampler(dataset, shuffle=pre_shuffle)
- sampler.set_epoch(0)
- already_numbers = set()
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- self.assertEqual(len(already_numbers), num_consumed_samples)
-
- states = sampler.state_dict()
-
- new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- for i in new_sampler:
- self.assertNotIn(i, already_numbers)
-
- # 测试切换成多卡也没有问题
- other_rank_number = set()
- for rank in range(3):
- new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
- new_sampler.load_state_dict(states)
- new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False)
- new_sampler.set_epoch(0)
- count = 0
- for i in new_sampler:
- self.assertNotIn(i, other_rank_number)
- other_rank_number.add(i)
- self.assertNotIn(i, already_numbers)
- count += 1
-
- def test_state_dict_2(self):
- # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
- num_samples = 100
- dataset = TorchNormalDataset(num_of_data=num_samples)
- # 测试使用 前后shuffle不一致的load操作
- lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist()
- # lst = [30]
- for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False],
- lst):
- with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples):
- already_numbers = set()
- sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
- sampler.set_distributed(num_replicas=2, rank=0)
- sampler.set_epoch(0)
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
- sampler.set_epoch(0)
- sampler.set_distributed(num_replicas=2, rank=1)
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- self.assertEqual(len(already_numbers), num_consumed_samples*2)
-
- states = sampler.state_dict()
-
- new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- for i in new_sampler:
- self.assertNotIn(i, already_numbers)
-
- # 测试切换成多卡也没有问题
- other_rank_number = set()
- for rank in range(3):
- new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False)
- count = 0
- for i in new_sampler:
- self.assertNotIn(i, other_rank_number)
- other_rank_number.add(i)
- self.assertNotIn(i, already_numbers)
- count += 1
-
-
- class TestRandomSampler(unittest.TestCase):
- # 测试单卡;
- def test_seed_work_when_shuffle_is_true(self):
- data_length = 100
-
- torch_normal_data = TorchNormalDataset(num_of_data=data_length)
- for shuffle in [True, False]:
- iterable = RandomSampler(dataset=torch_normal_data, shuffle=shuffle)
- # 迭代一些数据,但是不迭代完;
- iterable.set_epoch(1)
- iterator = iter(iterable)
- pre_data = []
- forward_steps = 30
- for _ in range(forward_steps):
- pre_data.append(next(iterator))
-
- # 看重新生成迭代器是否能够完全重置状态;
- iterator = iter(iterable)
- res = []
- for _ in range(forward_steps):
- res.append(next(iterator))
- assert pre_data == res
-
- # 测试断点重训;
- # 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的;
- def test_2(self):
- data_length = 100
- torch_normal_data = TorchNormalDataset(num_of_data=data_length)
- random_sampler_1 = RandomSampler(dataset=torch_normal_data, shuffle=True)
-
- iterator = iter(random_sampler_1)
- # 第一轮
- random_sampler_1.set_epoch(0)
- first_epoch = []
- forward_steps = 30
- for _ in range(forward_steps):
- first_epoch.append(next(iterator))
-
- # 先提前保存断点重训的结果;
- state = random_sampler_1.state_dict()
-
- # 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确;
- first_left_data = []
- while True:
- try:
- first_left_data.append(next(iterator))
- except StopIteration:
- break
-
- # 第二轮
- random_sampler_1.set_epoch(1)
- iterator = iter(random_sampler_1)
- second_epoch = []
- for _ in range(forward_steps):
- second_epoch.append(next(iterator))
-
- assert first_epoch != second_epoch
-
- # 重新加载第一轮的状态,查看断点重训是否正确;
- random_sampler_2 = RandomSampler(dataset=torch_normal_data, shuffle=True)
- random_sampler_2.load_state_dict(state)
- random_sampler_2.set_epoch(0)
- iterator = iter(random_sampler_2)
- re_first_epoch = []
- while True:
- try:
- re_first_epoch.append(next(iterator))
- except StopIteration:
- break
- assert re_first_epoch == first_left_data
-
- # 查看第二轮的结果是否也是和第一次的第二轮完全一致;
- random_sampler_2.set_epoch(1)
- iterator = iter(random_sampler_2)
- re_second_epoch = []
- for _ in range(forward_steps):
- re_second_epoch.append(next(iterator))
- assert re_second_epoch == second_epoch
-
- # 多卡;
- # 如果一个 sampler 还没有迭代完,我们又直接 iter(sampler) 那么是否正确(应当生成一个全新的 sampler)?
- def test_3(self):
- data_length = 100
-
- torch_normal_data = TorchNormalDataset(num_of_data=data_length)
- random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=False)
- random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
- iterable_items = [random_sampler_1, random_sampler_2]
-
- world_size = 3
- for pad in {True, False}:
- for iterable in iterable_items:
- for rank in range(world_size):
- each_rank_iterable = iterable()
- each_rank_iterable.set_epoch(0)
- each_rank_iterable.set_distributed(num_replicas=world_size, rank=rank, pad=pad)
- # 迭代一些数据,但是不迭代完;
- iterator = iter(each_rank_iterable)
- pre_data = []
- forward_steps = 10
- for _ in range(forward_steps):
- pre_data.append(next(iterator))
-
- # 看重新生成迭代器是否能够完全重置状态;
- iterator = iter(each_rank_iterable)
- res = []
- for _ in range(forward_steps):
- res.append(next(iterator))
- assert res == pre_data
-
- # 测试断点重训;
- # 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的;
- def test_4(self):
- data_length = 100
- torch_normal_data = TorchNormalDataset(num_of_data=data_length)
- random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
-
- world_size_1 = 2
- forward_steps = 10
-
- for pad in {True, False}:
- all_rank_state = {}
- all_rank_first_left_data = {}
- all_rank_second_epoch = {}
- for rank in range(world_size_1):
- each_rank_iterable = random_sampler_1()
- each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad)
- iterator = iter(each_rank_iterable)
- # 第一轮
- each_rank_iterable.set_epoch(0)
- first_epoch = []
- for _ in range(forward_steps):
- first_epoch.append(next(iterator))
-
- # 先提前保存断点重训的结果;
- all_rank_state[rank] = each_rank_iterable.state_dict()
-
- # 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确;
- first_left_data = []
- while True:
- try:
- first_left_data.append(next(iterator))
- except StopIteration:
- break
- all_rank_first_left_data[rank] = first_left_data
- # 第二轮
- each_rank_iterable.set_epoch(1)
- iterator = iter(each_rank_iterable)
- second_epoch = []
- for _ in range(forward_steps):
- second_epoch.append(next(iterator))
- all_rank_second_epoch[rank] = second_epoch
- assert first_epoch != second_epoch
-
- # 重新加载第一轮的状态,查看断点重训是否正确;
- random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
- for rank in range(world_size_1):
- each_rank_iterable = random_sampler_2()
- each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad)
- each_rank_iterable.load_state_dict(all_rank_state[rank])
- each_rank_iterable.set_epoch(0)
- iterator = iter(each_rank_iterable)
- re_first_epoch = []
- while True:
- try:
- re_first_epoch.append(next(iterator))
- except StopIteration:
- break
- assert re_first_epoch == all_rank_first_left_data[rank]
-
- # 查看第二轮的结果是否也是和第一次的第二轮完全一致;
- each_rank_iterable.set_epoch(1)
- iterator = iter(each_rank_iterable)
- re_second_epoch = []
- for _ in range(forward_steps):
- re_second_epoch.append(next(iterator))
- assert re_second_epoch == all_rank_second_epoch[rank]
-
- # todo 测试 ddp 时 world_size 改变的断点重训;
- def test_5(self):
- ...
-
-
|