diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 0fd74795..d2d548f5 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -19,7 +19,7 @@ from fastNLP.core.utils import ( paddle_move_data_to_device, is_in_paddle_dist, ) -from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler +from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES from fastNLP.core.log import logger @@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver): return dataloader # evaluator elif dist == "unrepeatdist": - sampler = UnrepeatedDistributedSampler( + sampler = UnrepeatedSampler( dataset=dataloader.dataset, shuffle=shuffle, seed=int(os.environ.get("FASTNLP_SEED", 0)) diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 9b3325d8..7fe0bcee 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -23,11 +23,12 @@ from fastNLP.core.drivers.torch_driver.utils import ( ForwardState, _MODE_PARAMETER, reset_seed, - replace_sampler + replace_sampler, + replace_batch_sampler ) from fastNLP.core.drivers.utils import distributed_open_proc from fastNLP.core.utils import auto_param_call, check_user_specific_params -from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler +from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED from fastNLP.core.log import logger from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object @@ -445,11 +446,25 @@ class TorchDDPDriver(TorchDriver): # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], - reproducible: bool = False, sampler_or_batch_sampler=None): + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, + reproducible: bool = False): + if isinstance(dist, ReproducibleBatchSampler): + dist = re_instantiate_sampler(dist) + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_batch_sampler(dataloader, dist) + if isinstance(dist, ReproducibleIterator): # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; dist = re_instantiate_sampler(dist) + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) return replace_sampler(dataloader, dist) # trainer, evaluator @@ -463,7 +478,15 @@ class TorchDDPDriver(TorchDriver): elif dist == "dist": args = self.get_dataloader_args(dataloader) # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; - if isinstance(args.sampler, ReproducibleIterator): + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + batch_sampler.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleIterator): sampler = re_instantiate_sampler(args.sampler) sampler.set_distributed( num_replicas=self.world_size, @@ -477,7 +500,6 @@ class TorchDDPDriver(TorchDriver): shuffle=args.shuffle, seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) ) - # todo 这个你写个todo吧,有两个角度;第一个是dataloader即使检测到sampler是我们reproducible,也不能直接set_distributeds; 第二个如果是单卡的,也需要替换sampler乃至切换sampler的状态,方式之前多卡,现在切换成单卡运行 sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank, @@ -487,8 +509,11 @@ class TorchDDPDriver(TorchDriver): # evaluator elif dist == "unrepeatdist": + # todo @yh,补充 unrepeatdist 相关内容; args = self.get_dataloader_args(dataloader) - sampler = UnrepeatedDistributedSampler( + + # todo 判断 batch_sampler; + sampler = UnrepeatedSampler( dataset=args.dataset, shuffle=args.shuffle, ) diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 952712be..3375d557 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -133,8 +133,10 @@ class TorchSingleDriver(TorchDriver): def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, reproducible: bool = False): if isinstance(dist, ReproducibleBatchSampler): + dist = re_instantiate_sampler(dist) return replace_batch_sampler(dataloader, dist) elif isinstance(dist, ReproducibleIterator): + dist = re_instantiate_sampler(dist) return replace_sampler(dataloader, dist) if reproducible: diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 685e1f63..369e4432 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -244,8 +244,34 @@ class TorchDriver(Driver): logger.debug("Load model.") # 3. 恢复 sampler 的状态; + """ + 使用场景: + + 现在sampler/batch_sampler的替换情况: + 1. 单卡多卡; + 2. 是否断点重训; + + 3. 用户通过 dist 传入; + 4. 用户自己直接在外面替换dataloader的sampler或者 batchsampler; + + 应当确定的规则: + batchsampler 优先级高于 sampler; + + 单卡: + 不是断点重训: + 用户自己 + + + 用户不自己在外面直接替换 sampler 或者 batchsampler + 1. 单卡: + + """ dataloader_args = self.get_dataloader_args(dataloader) + # todo 先捋一下; + # batch_sampler = dataloader_args.batch_sampler + # if not (hasattr(batch_sampler, 'load_state_dict') and callable(batch_sampler.load_state_dict)): + sampler = dataloader_args.sampler if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): # 说明这里需要使用 ReproduceSampler 来弄一下了 diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index 68928b66..bb2ee661 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -3,19 +3,24 @@ __all__ = [ 'SortedSampler', 'ConstTokenNumSampler', 'ConstantTokenNumSampler', - 'UnrepeatedDistributedSampler', + 'MixSampler', - 'InnerSampler', 'DopedSampler', 'MixSequentialSampler', 'PollingSampler', + 'ReproducibleIterator', 'RandomSampler', - 're_instantiate_sampler' + + 're_instantiate_sampler', + + 'UnrepeatedSampler', + "UnrepeatedSortedSampler" ] -from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler -from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler +from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler +from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler +from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler diff --git a/fastNLP/core/samplers/mix_sampler.py b/fastNLP/core/samplers/mix_sampler.py index e219b6e2..f53c06a5 100644 --- a/fastNLP/core/samplers/mix_sampler.py +++ b/fastNLP/core/samplers/mix_sampler.py @@ -4,7 +4,6 @@ from typing import Union, List, Iterable, Dict __all__ = [ 'MixSampler', - 'InnerSampler', 'DopedSampler', 'MixSequentialSampler', 'PollingSampler' diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 0a4ac7bf..6d2c8246 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -16,7 +16,6 @@ def re_instantiate_sampler(sampler): return type(sampler)(**all_attributes) - class ReproducibleIterator: """ 注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler diff --git a/fastNLP/core/samplers/sampler.py b/fastNLP/core/samplers/sampler.py index e41472bf..89751884 100644 --- a/fastNLP/core/samplers/sampler.py +++ b/fastNLP/core/samplers/sampler.py @@ -7,7 +7,6 @@ __all__ = [ "SortedSampler", 'ConstTokenNumSampler', "ConstantTokenNumSampler", - "UnrepeatedDistributedSampler", ] from itertools import chain @@ -18,7 +17,7 @@ import numpy as np from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: - from torch.utils.data import SequentialSampler, Sampler, RandomSampler + from torch.utils.data import Sampler else: from fastNLP.core.utils.dummy_class import DummyClass as Sampler @@ -727,87 +726,3 @@ def k_means_bucketing(lengths, buckets): if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: bucket_data[bucket_id].append(idx) return bucket_data - - -class UnrepeatedDistributedSampler: - def __init__(self, dataset, shuffle: bool = False, seed: int = 0): - """ - 考虑在多卡evaluate的场景下,不能重复sample。 - - :param dataset: - :param shuffle: - :param seed: - """ - self.dataset = dataset - self.shuffle = shuffle - self.seed = seed - - # 多卡的相关的参数 - self.num_replicas = 1 - self.rank = 0 - self.epoch = -1 - - def __len__(self): - """ - 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; - :return: - """ - num_common = len(self.dataset)//self.num_replicas - self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) - return self.num_samples - - def __iter__(self): - r""" - 当前使用num_consumed_samples做法会在交替使用的时候遇到问题; - Example: - >>> sampler = RandomSampler() - >>> iter1 = iter(sampler) - >>> iter2 = iter(sampler) - >>> next(iter1) - >>> next(iter2) # 当前num_consumed_samples的数量会发生变化 - """ - - indices = self.generate_indices() - - # subsample - indices = indices[self.rank:len(indices):self.num_replicas] - assert len(indices) == len(self) - - for index in indices: - yield index - - def generate_indices(self) -> List[int]: - """ - 生成随机序列 - - :return: - """ - if self.shuffle: - indices = list(range(len(self.dataset))) - seed = self.seed + self.epoch - rng = np.random.default_rng(abs(seed)) - rng.shuffle(indices) - if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 - self.epoch -= 1 - else: - indices = list(range(len(self.dataset))) - return indices - - def set_epoch(self, epoch: int) -> None: - self.epoch = epoch - - def set_distributed(self, num_replicas, rank): - """ - 该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用; - - :param num_replicas: - :param rank: - :return: - """ - assert num_replicas>0 and isinstance(num_replicas, int) - assert isinstance(rank, int) and 0<=rank List[int]: + """ + 生成随机序列 + + :return: + """ + if self.shuffle: + indices = list(range(len(self.dataset))) + seed = self.seed + self.epoch + rng = np.random.default_rng(abs(seed)) + rng.shuffle(indices) + if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 + self.epoch -= 1 + else: + indices = list(range(len(self.dataset))) + return indices + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + def set_distributed(self, num_replicas, rank): + """ + 该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用; + + :param num_replicas: + :param rank: + :return: + """ + assert num_replicas>0 and isinstance(num_replicas, int) + assert isinstance(rank, int) and 0<=rank List[int]: + return self.sorted_indices diff --git a/tests/core/samplers/test_unrepeated_sampler.py b/tests/core/samplers/test_unrepeated_sampler.py new file mode 100644 index 00000000..3e2f79ed --- /dev/null +++ b/tests/core/samplers/test_unrepeated_sampler.py @@ -0,0 +1,64 @@ +from itertools import chain + +import pytest + +from fastNLP.core.samplers import UnrepeatedSampler, UnrepeatedSortedSampler + + +class DatasetWithVaryLength: + def __init__(self, num_of_data=100): + self.data = list(range(num_of_data)) + + def __getitem__(self, item): + return self.data[item] + + def __len__(self): + return len(self.data) + + +class TestUnrepeatedSampler: + @pytest.mark.parametrize('shuffle', [True, False]) + def test_single(self, shuffle): + num_of_data = 100 + data = DatasetWithVaryLength(num_of_data) + sampler = UnrepeatedSampler(data, shuffle) + indexes = set(sampler) + assert indexes==set(range(num_of_data)) + + @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) + @pytest.mark.parametrize('shuffle', [False, True]) + def test_multi(self, num_replica, num_of_data, shuffle): + data = DatasetWithVaryLength(num_of_data=num_of_data) + samplers = [] + for i in range(num_replica): + sampler = UnrepeatedSampler(dataset=data, shuffle=shuffle) + sampler.set_distributed(num_replica, rank=i) + samplers.append(sampler) + + indexes = set(chain(*samplers)) + assert indexes==set(range(num_of_data)) + + +class TestUnrepeatedSortedSampler: + @pytest.mark.parametrize('shuffle', [True, False]) + def test_single(self, shuffle): + num_of_data = 100 + data = DatasetWithVaryLength(num_of_data) + sampler = UnrepeatedSortedSampler(data, length=data.data) + indexes = list(sampler) + assert indexes==list(range(num_of_data-1, -1, -1)) + + @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) + @pytest.mark.parametrize('shuffle', [False, True]) + def test_multi(self, num_replica, num_of_data, shuffle): + data = DatasetWithVaryLength(num_of_data=num_of_data) + samplers = [] + for i in range(num_replica): + sampler = UnrepeatedSortedSampler(dataset=data, length=data.data) + sampler.set_distributed(num_replica, rank=i) + samplers.append(sampler) + + indexes = set(chain(*samplers)) + assert indexes==set(range(num_of_data))