|
- import numpy as np
- import pytest
-
- from functools import partial
- from itertools import chain
-
- from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler
- from tests.helpers.datasets.torch_data import TorchNormalDataset
-
-
- class TestRandomSamplerYh:
- def test_init(self):
- # 测试能否正确初始化
- dataset = TorchNormalDataset(num_of_data=100)
- sampler = RandomSampler(dataset)
- for i in sampler:
- pass
-
- def test_during_iter(self):
- dataset = TorchNormalDataset(num_of_data=100)
- sampler = RandomSampler(dataset)
- for i in sampler:
- with pytest.raises(AssertionError):
- sampler.set_distributed(1, 0)
- break
-
- # should not raise
- for i in sampler:
- pass
- sampler.set_distributed(1, 0)
-
- def test_set_distributed(self):
- dataset = TorchNormalDataset(num_of_data=100)
- sampler = RandomSampler(dataset, shuffle=False)
- sampler.set_distributed(num_replicas=2, rank=0, pad=False)
- assert len(sampler)==50
- count = 0
- for i in sampler:
- assert i%2==0
- count += 1
- assert count == 50
-
- sampler.set_distributed(num_replicas=2, rank=1, pad=False)
- assert len(sampler)==50
- count = 0
- for i in sampler:
- assert i%2==1
- count += 1
- assert count==50
-
- dataset = TorchNormalDataset(num_of_data=101)
- sampler = RandomSampler(dataset, shuffle=False)
- sampler.set_distributed(num_replicas=2, rank=0, pad=True)
- assert len(sampler)==51
- count = 0
- for i in sampler:
- assert i%2==0
- count += 1
- assert count == 51
-
- sampler.set_distributed(num_replicas=2, rank=1, pad=True)
- assert len(sampler) == 51
- count = 0
- for i in sampler:
- if i!=0:
- assert i%2==1
- count += 1
- assert count == 51
-
- def test_state_dict_check_length(self):
- dataset = TorchNormalDataset(num_of_data=100)
- sampler = RandomSampler(dataset, shuffle=False)
- states = sampler.state_dict()
-
- new_ds = TorchNormalDataset(num_of_data=10)
- with pytest.raises(AssertionError):
- new_sampler = RandomSampler(new_ds)
- new_sampler.load_state_dict(states)
-
- new_ds = TorchNormalDataset(num_of_data=100)
- new_sampler = RandomSampler(new_ds)
- new_sampler.load_state_dict(states)
-
- @pytest.mark.parametrize('pad', [True, False])
- @pytest.mark.parametrize('pre_shuffle', [True, False])
- @pytest.mark.parametrize('post_shuffle', [True, False])
- @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
- def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
- num_samples = 100
- dataset = TorchNormalDataset(num_of_data=num_samples)
- # 测试使用 前后shuffle不一致的load操作
- sampler = RandomSampler(dataset, shuffle=pre_shuffle)
- sampler.set_epoch(0)
- already_numbers = set()
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- assert len(already_numbers) == num_consumed_samples
-
- states = sampler.state_dict()
-
- new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- for i in new_sampler:
- assert i not in already_numbers
-
- # 测试切换成多卡也没有问题
- other_rank_number = set()
- for rank in range(3):
- new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
- new_sampler.load_state_dict(states)
- new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
- new_sampler.set_epoch(0)
- count = 0
- seen = 0
- seen_in_other_rank = 0
- for i in new_sampler:
- seen_in_other_rank += int(i in other_rank_number)
- other_rank_number.add(i)
- seen += int(i in already_numbers)
- count += 1
- assert seen <= 1 if pad else seen == 0
- assert seen_in_other_rank<=1 # 因为pad可能重复
-
- @pytest.mark.parametrize('pad', [True, False])
- @pytest.mark.parametrize('pre_shuffle', [True, False])
- @pytest.mark.parametrize('post_shuffle', [True, False])
- @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
- def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
- # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
- num_samples = 100
- dataset = TorchNormalDataset(num_of_data=num_samples)
- # 测试使用 前后shuffle不一致的load操作
- # lst = [30]
- already_numbers = set()
- sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
- sampler.set_distributed(num_replicas=2, rank=0)
- sampler.set_epoch(0)
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
- sampler.set_epoch(0)
- sampler.set_distributed(num_replicas=2, rank=1)
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- assert len(already_numbers) == num_consumed_samples*2
-
- states = sampler.state_dict()
-
- new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- for i in new_sampler:
- assert i not in already_numbers
-
- # 测试切换成多卡也没有问题
- other_rank_number = set()
- for rank in range(3):
- new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
- count = 0
- seen = 0
- seen_in_other_rank = 0
- for i in new_sampler:
- seen_in_other_rank += int(i in other_rank_number)
- other_rank_number.add(i)
- seen += int(i in already_numbers)
- count += 1
- assert seen <= 1 if pad else seen == 0
- assert seen_in_other_rank<=1 # 因为pad可能重复
-
-
- class TestRandomSampler:
- # 测试单卡;
- def test_seed_work_when_shuffle_is_true(self):
- data_length = 100
-
- torch_normal_data = TorchNormalDataset(num_of_data=data_length)
- for shuffle in [True, False]:
- iterable = RandomSampler(dataset=torch_normal_data, shuffle=shuffle)
- # 迭代一些数据,但是不迭代完;
- iterable.set_epoch(1)
- iterator = iter(iterable)
- pre_data = []
- forward_steps = 30
- for _ in range(forward_steps):
- pre_data.append(next(iterator))
-
- # 看重新生成迭代器是否能够完全重置状态;
- iterator = iter(iterable)
- res = []
- for _ in range(forward_steps):
- res.append(next(iterator))
- assert pre_data == res
-
- # 测试断点重训;
- # 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的;
- def test_2(self):
- data_length = 100
- torch_normal_data = TorchNormalDataset(num_of_data=data_length)
- random_sampler_1 = RandomSampler(dataset=torch_normal_data, shuffle=True)
-
- iterator = iter(random_sampler_1)
- # 第一轮
- random_sampler_1.set_epoch(0)
- first_epoch = []
- forward_steps = 30
- for _ in range(forward_steps):
- first_epoch.append(next(iterator))
-
- # 先提前保存断点重训的结果;
- state = random_sampler_1.state_dict()
-
- # 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确;
- first_left_data = []
- while True:
- try:
- first_left_data.append(next(iterator))
- except StopIteration:
- break
-
- # 第二轮
- random_sampler_1.set_epoch(1)
- iterator = iter(random_sampler_1)
- second_epoch = []
- for _ in range(forward_steps):
- second_epoch.append(next(iterator))
-
- assert first_epoch != second_epoch
-
- # 重新加载第一轮的状态,查看断点重训是否正确;
- random_sampler_2 = RandomSampler(dataset=torch_normal_data, shuffle=True)
- random_sampler_2.load_state_dict(state)
- random_sampler_2.set_epoch(0)
- iterator = iter(random_sampler_2)
- re_first_epoch = []
- while True:
- try:
- re_first_epoch.append(next(iterator))
- except StopIteration:
- break
- assert re_first_epoch == first_left_data
-
- # 查看第二轮的结果是否也是和第一次的第二轮完全一致;
- random_sampler_2.set_epoch(1)
- iterator = iter(random_sampler_2)
- re_second_epoch = []
- for _ in range(forward_steps):
- re_second_epoch.append(next(iterator))
- assert re_second_epoch == second_epoch
-
- # 多卡;
- # 如果一个 sampler 还没有迭代完,我们又直接 iter(sampler) 那么是否正确(应当生成一个全新的 sampler)?
- def test_3(self):
- data_length = 100
-
- torch_normal_data = TorchNormalDataset(num_of_data=data_length)
- random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=False)
- random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
- iterable_items = [random_sampler_1, random_sampler_2]
-
- world_size = 3
- for pad in {True, False}:
- for iterable in iterable_items:
- for rank in range(world_size):
- each_rank_iterable = iterable()
- each_rank_iterable.set_epoch(0)
- each_rank_iterable.set_distributed(num_replicas=world_size, rank=rank, pad=pad)
- # 迭代一些数据,但是不迭代完;
- iterator = iter(each_rank_iterable)
- pre_data = []
- forward_steps = 10
- for _ in range(forward_steps):
- pre_data.append(next(iterator))
-
- # 看重新生成迭代器是否能够完全重置状态;
- iterator = iter(each_rank_iterable)
- res = []
- for _ in range(forward_steps):
- res.append(next(iterator))
- assert res == pre_data
-
- # 测试断点重训;
- # 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的;
- def test_4(self):
- data_length = 100
- torch_normal_data = TorchNormalDataset(num_of_data=data_length)
- random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
-
- world_size_1 = 2
- forward_steps = 10
-
- for pad in {True, False}:
- all_rank_state = {}
- all_rank_first_left_data = {}
- all_rank_second_epoch = {}
- for rank in range(world_size_1):
- each_rank_iterable = random_sampler_1()
- each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad)
- iterator = iter(each_rank_iterable)
- # 第一轮
- each_rank_iterable.set_epoch(0)
- first_epoch = []
- for _ in range(forward_steps):
- first_epoch.append(next(iterator))
-
- # 先提前保存断点重训的结果;
- all_rank_state[rank] = each_rank_iterable.state_dict()
-
- # 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确;
- first_left_data = []
- while True:
- try:
- first_left_data.append(next(iterator))
- except StopIteration:
- break
- all_rank_first_left_data[rank] = first_left_data
- # 第二轮
- each_rank_iterable.set_epoch(1)
- iterator = iter(each_rank_iterable)
- second_epoch = []
- for _ in range(forward_steps):
- second_epoch.append(next(iterator))
- all_rank_second_epoch[rank] = second_epoch
- assert first_epoch != second_epoch
-
- # 重新加载第一轮的状态,查看断点重训是否正确;
- random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
- for rank in range(world_size_1):
- each_rank_iterable = random_sampler_2()
- each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad)
- each_rank_iterable.load_state_dict(all_rank_state[rank])
- each_rank_iterable.set_epoch(0)
- iterator = iter(each_rank_iterable)
- re_first_epoch = []
- while True:
- try:
- re_first_epoch.append(next(iterator))
- except StopIteration:
- break
- assert re_first_epoch == all_rank_first_left_data[rank]
-
- # 查看第二轮的结果是否也是和第一次的第二轮完全一致;
- each_rank_iterable.set_epoch(1)
- iterator = iter(each_rank_iterable)
- re_second_epoch = []
- for _ in range(forward_steps):
- re_second_epoch.append(next(iterator))
- assert re_second_epoch == all_rank_second_epoch[rank]
-
- # todo 测试 ddp 时 world_size 改变的断点重训;
- def test_5(self):
- ...
-
-
- class DatasetWithVaryLength:
- def __init__(self, num_of_data=100, reverse=False):
- self.data = np.arange(num_of_data)
- if reverse:
- self.data = self.data[::-1]
-
- def __getitem__(self, item):
- return self.data[item]
-
- def __len__(self):
- return len(self.data)
-
-
- class TestSortedSampler:
- def test_single(self):
- num_of_data = 100
- data = DatasetWithVaryLength(num_of_data)
- sampler = SortedSampler(data, length=data.data)
- indexes = list(sampler)
- assert indexes==list(range(num_of_data-1, -1, -1))
-
- @pytest.mark.parametrize('pad', [True, False])
- @pytest.mark.parametrize('num_replica', [2, 3])
- @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
- def test_multi(self, pad, num_replica, num_of_data):
- data = DatasetWithVaryLength(num_of_data=num_of_data)
- samplers = []
- for i in range(num_replica):
- sampler = SortedSampler(dataset=data, length=data.data)
- sampler.set_distributed(num_replica, rank=i, pad=pad)
- samplers.append(sampler)
-
- # 保证顺序是没乱的
- already_seen_index = set()
- for sampler in samplers:
- larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。
- prev_index = float('inf')
- cur_set = set()
- seen_in_other_rank = 0
- for index in sampler:
- seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
- cur_set.add(index)
- larger_count += int(index <= prev_index)
- prev_index = index
- assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
- assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0
- already_seen_index.update(cur_set)
-
- indexes = list(chain(*samplers))
- indexes = set(indexes)
- if pad:
- assert indexes == set(range(num_of_data))
- else:
- assert len(indexes) <= num_of_data
-
- @pytest.mark.parametrize('pad', [True, False])
- @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
- def test_state_dict(self, pad, num_consumed_samples):
- num_samples = 100
- dataset = DatasetWithVaryLength(num_of_data=num_samples)
- # 测试使用 前后shuffle不一致的load操作
- sampler = SortedSampler(dataset, length=dataset.data)
- sampler.set_epoch(0)
- already_numbers = set()
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- if already_numbers:
- assert j<max(already_numbers)
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- assert len(already_numbers) == num_consumed_samples
-
- states = sampler.state_dict()
-
- new_sampler = SortedSampler(dataset, length=dataset.data)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- for i in new_sampler:
- if already_numbers:
- assert i < max(already_numbers)
- assert i not in already_numbers
-
- # 测试切换成多卡也没有问题
- other_rank_number = set()
- for rank in range(3):
- new_sampler = SortedSampler(dataset, length=dataset.data)
- new_sampler.load_state_dict(states)
- new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
- new_sampler.set_epoch(0)
- count = 0
- seen = 0
- seen_in_other_rank = 0
- smaller = 0
- for i in new_sampler:
- if already_numbers:
- smaller += int(i >= max(already_numbers))
- seen_in_other_rank += int(i in other_rank_number)
- other_rank_number.add(i)
- seen += int(i in already_numbers)
- count += 1
- assert seen <= 1 if pad else seen == 0
- assert seen_in_other_rank<=1 # 因为pad可能重复
- assert smaller<=1 if pad else smaller==0
-
- @pytest.mark.parametrize('pad', [True, False])
- @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
- def test_state_dict_2(self, pad, num_consumed_samples):
- # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
- num_samples = 100
- dataset = DatasetWithVaryLength(num_of_data=num_samples)
- # 测试使用 前后shuffle不一致的load操作
- # lst = [30]
- already_numbers = set()
- sampler = SortedSampler(dataset, length=dataset.data)
- sampler.set_distributed(num_replicas=2, rank=0)
- sampler.set_epoch(0)
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- if already_numbers:
- assert j<=max(already_numbers)
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- sampler = SortedSampler(dataset, length=dataset.data)
- sampler.set_epoch(0)
- sampler.set_distributed(num_replicas=2, rank=1)
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- assert len(already_numbers) == num_consumed_samples*2
-
- states = sampler.state_dict()
-
- new_sampler = SortedSampler(dataset, length=dataset.data)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- for i in new_sampler:
- if already_numbers:
- assert i < max(already_numbers)
- assert i not in already_numbers
-
- # 测试切换成多卡也没有问题
- other_rank_number = set()
- for rank in range(3):
- new_sampler = SortedSampler(dataset, length=dataset.data)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
- count = 0
- seen = 0
- seen_in_other_rank = 0
- smaller = 0
- for i in new_sampler:
- if already_numbers:
- smaller += int(i>=max(already_numbers))
- seen_in_other_rank += int(i in other_rank_number)
- other_rank_number.add(i)
- seen += int(i in already_numbers)
- count += 1
- assert seen <= 1 if pad else seen == 0
- assert seen_in_other_rank<=1 # 因为pad可能重复
- assert smaller <= 1 if pad else smaller == 0
-
-
- class TestSequentialSampler:
- def test_single(self):
- num_of_data = 100
- data = DatasetWithVaryLength(num_of_data)
- sampler = SequentialSampler(data)
- indexes = list(sampler)
- assert indexes==list(range(num_of_data))
-
- @pytest.mark.parametrize('pad', [True, False])
- @pytest.mark.parametrize('num_replica', [2, 3])
- @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
- def test_multi(self, pad, num_replica, num_of_data):
- data = DatasetWithVaryLength(num_of_data=num_of_data)
- samplers = []
- for i in range(num_replica):
- sampler = SequentialSampler(dataset=data)
- sampler.set_distributed(num_replica, rank=i, pad=pad)
- samplers.append(sampler)
-
- # 保证顺序是没乱的
- already_seen_index = set()
- for idx, sampler in enumerate(samplers):
- larger_count = 1
- prev_index = float('inf')
- cur_set = set()
- seen_in_other_rank = 0
- for index in sampler:
- seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
- cur_set.add(index)
- larger_count += int(index >= prev_index)
- prev_index = index
- assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
- assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0
- already_seen_index.update(cur_set)
-
- indexes = list(chain(*samplers))
- indexes = set(indexes)
- if pad:
- assert indexes == set(range(num_of_data))
- else:
- assert len(indexes) <= num_of_data
-
- @pytest.mark.parametrize('pad', [True, False])
- @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
- def test_state_dict(self, pad, num_consumed_samples):
- num_samples = 100
- dataset = DatasetWithVaryLength(num_of_data=num_samples)
- # 测试使用 前后shuffle不一致的load操作
- sampler = SequentialSampler(dataset=dataset)
- sampler.set_epoch(0)
- already_numbers = set()
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- if already_numbers:
- assert j>max(already_numbers)
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- assert len(already_numbers) == num_consumed_samples
-
- states = sampler.state_dict()
-
- new_sampler = SequentialSampler(dataset=dataset)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- for i in new_sampler:
- if already_numbers:
- assert i > max(already_numbers)
- assert i not in already_numbers
-
- # 测试切换成多卡也没有问题
- other_rank_number = set()
- for rank in range(3):
- new_sampler = SequentialSampler(dataset=dataset)
- new_sampler.load_state_dict(states)
- new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
- new_sampler.set_epoch(0)
- count = 0
- seen = 0
- seen_in_other_rank = 0
- smaller = 0
- for i in new_sampler:
- if already_numbers:
- smaller += int(i <= max(already_numbers))
- seen_in_other_rank += int(i in other_rank_number)
- other_rank_number.add(i)
- seen += int(i in already_numbers)
- count += 1
- assert seen <= 1 if pad else seen == 0
- assert seen_in_other_rank<=rank # 因为pad可能重复
- assert smaller<=1 if pad else smaller==0
-
- @pytest.mark.parametrize('pad', [True, False])
- @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
- def test_state_dict_2(self, pad, num_consumed_samples):
- # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
- num_samples = 100
- dataset = DatasetWithVaryLength(num_of_data=num_samples)
- # 测试使用 前后shuffle不一致的load操作
- # lst = [30]
- already_numbers = set()
- sampler = SequentialSampler(dataset=dataset)
- sampler.set_distributed(num_replicas=2, rank=0)
- sampler.set_epoch(0)
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- if already_numbers:
- assert j>max(already_numbers)
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- sampler = SequentialSampler(dataset=dataset)
- sampler.set_epoch(0)
- sampler.set_distributed(num_replicas=2, rank=1)
- if num_consumed_samples>0:
- for i, j in enumerate(sampler, start=1):
- already_numbers.add(j)
- if i == num_consumed_samples:
- break
- assert len(already_numbers) == num_consumed_samples*2
-
- states = sampler.state_dict()
-
- new_sampler = SequentialSampler(dataset=dataset)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- for i in new_sampler:
- if already_numbers:
- assert i > max(already_numbers)
- assert i not in already_numbers
-
- # 测试切换成多卡也没有问题
- other_rank_number = set()
- for rank in range(3):
- new_sampler = SequentialSampler(dataset=dataset)
- new_sampler.load_state_dict(states)
- new_sampler.set_epoch(0)
- new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
- count = 0
- seen = 0
- seen_in_other_rank = 0
- smaller = 0
- for i in new_sampler:
- if already_numbers:
- smaller += int(i<max(already_numbers))
- seen_in_other_rank += int(i in other_rank_number)
- other_rank_number.add(i)
- seen += int(i in already_numbers)
- count += 1
- assert seen <= 1 if pad else seen == 0
- assert seen_in_other_rank<=1 # 因为pad可能重复
- assert smaller <= rank if pad else smaller == 0
-
|