From 14fffcb36cc5622d48ae719c635815b3dacebb6a Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 11 Apr 2022 21:44:53 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9ESortedSampler=E5=92=8CSequent?= =?UTF-8?q?ialSampler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 4 +- fastNLP/core/drivers/driver.py | 4 +- fastNLP/core/drivers/jittor_driver/mpi.py | 4 +- .../drivers/jittor_driver/single_device.py | 14 +- fastNLP/core/drivers/paddle_driver/fleet.py | 10 +- .../drivers/paddle_driver/single_device.py | 14 +- fastNLP/core/drivers/torch_driver/ddp.py | 45 +- .../drivers/torch_driver/single_device.py | 17 +- .../core/drivers/torch_driver/torch_driver.py | 18 +- fastNLP/core/samplers/__init__.py | 20 +- .../samplers/reproducible_batch_sampler.py | 10 +- fastNLP/core/samplers/reproducible_sampler.py | 146 ++++- fastNLP/core/samplers/unrepeated_sampler.py | 55 +- fastNLP/core/samplers/utils.py | 42 ++ .../paddle_driver/test_single_device.py | 4 +- .../test_torch_replace_sampler.py | 2 +- .../test_reproducible_batch_sampler.py | 18 +- .../samplers/test_reproducible_sampler.py | 538 ++++++++++++++---- .../core/samplers/test_unrepeated_sampler.py | 58 +- 19 files changed, 797 insertions(+), 226 deletions(-) create mode 100644 fastNLP/core/samplers/utils.py diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index b7456b61..11697bdc 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -23,7 +23,7 @@ from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.envs import rank_zero_call -from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler +from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME @@ -610,7 +610,7 @@ class Trainer(TrainerEventTrigger): r""" 用于断点重训的加载函数; 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 - 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; + 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; 注意我们目前不支持单卡到多卡的断点重训; diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index d9d66970..84c5f373 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -49,13 +49,13 @@ class Driver(ABC): 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; - 注意当 dist 为 ReproducibleIterator, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; + 注意当 dist 为 ReproducibleIterator, RandomBatchSampler 时,是断点重训加载时 driver.load 函数在调用; 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 可以可以加载。 :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, - 如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 + 如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的 dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 """ if dist is None and reproducible is False: diff --git a/fastNLP/core/drivers/jittor_driver/mpi.py b/fastNLP/core/drivers/jittor_driver/mpi.py index 596148bc..c467b868 100644 --- a/fastNLP/core/drivers/jittor_driver/mpi.py +++ b/fastNLP/core/drivers/jittor_driver/mpi.py @@ -3,7 +3,7 @@ from typing import Optional, Union from .jittor_driver import JittorDriver from fastNLP.envs.imports import _NEED_IMPORT_JITTOR -from fastNLP.core.samplers import ReproducibleIterator +from fastNLP.core.samplers import ReproducibleSampler if _NEED_IMPORT_JITTOR: import jittor @@ -70,7 +70,7 @@ class JittorMPIDriver(JittorDriver): def test_step(self, batch): return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): pass diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index f39053d3..0bd7188d 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -3,7 +3,7 @@ from typing import Dict, Union from .jittor_driver import JittorDriver from fastNLP.core.utils import auto_param_call from fastNLP.envs.imports import _NEED_IMPORT_JITTOR -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler if _NEED_IMPORT_JITTOR: import jittor @@ -99,25 +99,25 @@ class JittorSingleDriver(JittorDriver): def is_distributed(self): return False - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], reproducible: bool = False, sampler_or_batch_sampler=None): # reproducible 的相关功能暂时没有实现 - if isinstance(dist, ReproducibleBatchSampler): + if isinstance(dist, RandomBatchSampler): raise NotImplementedError dataloader.batch_sampler = dist_sample - if isinstance(dist, ReproducibleIterator): + if isinstance(dist, ReproducibleSampler): raise NotImplementedError dataloader.batch_sampler.sampler = dist if reproducible: raise NotImplementedError - if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): + if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): return dataloader - elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): + elif isinstance(dataloader.batch_sampler, RandomBatchSampler): return dataloader else: # TODO - batch_sampler = ReproducibleBatchSampler( + batch_sampler = RandomBatchSampler( batch_sampler=dataloader.batch_sampler, batch_size=dataloader.batch_sampler.batch_size, drop_last=dataloader.drop_last diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index d2d548f5..65af48a1 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -19,7 +19,7 @@ from fastNLP.core.utils import ( paddle_move_data_to_device, is_in_paddle_dist, ) -from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler +from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES from fastNLP.core.log import logger @@ -312,13 +312,13 @@ class PaddleFleetDriver(PaddleDriver): def test_step(self, batch): return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], reproducible: bool = False, sampler_or_batch_sampler=None): # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." - if isinstance(dist, ReproducibleIterator): + if isinstance(dist, ReproducibleSampler): dataloader.batch_sampler.sampler = dist return dataloader @@ -340,7 +340,7 @@ class PaddleFleetDriver(PaddleDriver): # trainer elif dist == "dist": # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; - if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): + if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): dataloader.batch_sampler.sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank, @@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver): return dataloader # evaluator elif dist == "unrepeatdist": - sampler = UnrepeatedSampler( + sampler = UnrepeatedRandomSampler( dataset=dataloader.dataset, shuffle=shuffle, seed=int(os.environ.get("FASTNLP_SEED", 0)) diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 97f14bb6..a124b9be 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -10,7 +10,7 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -139,26 +139,26 @@ class PaddleSingleDriver(PaddleDriver): """ return paddle_move_data_to_device(batch, "gpu:0") - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], reproducible: bool = False, sampler_or_batch_sampler=None): # 暂时不支持IteratorDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." - if isinstance(dist, ReproducibleBatchSampler): + if isinstance(dist, RandomBatchSampler): dataloader.batch_sampler = dist return dataloader - if isinstance(dist, ReproducibleIterator): + if isinstance(dist, ReproducibleSampler): dataloader.batch_sampler.sampler = dist return dataloader if reproducible: - if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): + if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): return dataloader - elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): + elif isinstance(dataloader.batch_sampler, RandomBatchSampler): return dataloader else: # TODO - batch_sampler = ReproducibleBatchSampler( + batch_sampler = RandomBatchSampler( batch_sampler=dataloader.batch_sampler, batch_size=dataloader.batch_sampler.batch_size, drop_last=dataloader.drop_last diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 9e5e16fd..e1408df1 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -28,11 +28,11 @@ from fastNLP.core.drivers.torch_driver.utils import ( ) from fastNLP.core.drivers.utils import distributed_open_proc from fastNLP.core.utils import auto_param_call, check_user_specific_params -from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler +from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \ + re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED from fastNLP.core.log import logger from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object -from fastNLP.core.samplers import re_instantiate_sampler class TorchDDPDriver(TorchDriver): @@ -446,13 +446,23 @@ class TorchDDPDriver(TorchDriver): # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None, reproducible: bool = False): - # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; + # 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; - if isinstance(dist, ReproducibleBatchSampler): + if isinstance(dist, RandomBatchSampler): + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) return replace_batch_sampler(dataloader, dist) - if isinstance(dist, ReproducibleIterator): + if isinstance(dist, ReproducibleSampler): + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) return replace_sampler(dataloader, dist) # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; @@ -462,10 +472,10 @@ class TorchDDPDriver(TorchDriver): raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " "control.") else: - if isinstance(dist, ReproducibleBatchSampler): + if isinstance(dist, RandomBatchSampler): dist = re_instantiate_sampler(dist) return replace_batch_sampler(dataloader, dist) - if isinstance(dist, ReproducibleIterator): + if isinstance(dist, ReproducibleSampler): dist = re_instantiate_sampler(dist) return replace_sampler(dataloader, dist) return dataloader @@ -473,7 +483,7 @@ class TorchDDPDriver(TorchDriver): elif dist == "dist": args = self.get_dataloader_args(dataloader) # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; - if isinstance(args.batch_sampler, ReproducibleBatchSampler): + if isinstance(args.batch_sampler, RandomBatchSampler): batch_sampler = re_instantiate_sampler(args.batch_sampler) batch_sampler.set_distributed( num_replicas=self.world_size, @@ -481,7 +491,7 @@ class TorchDDPDriver(TorchDriver): pad=True ) return replace_batch_sampler(dataloader, batch_sampler) - elif isinstance(args.sampler, ReproducibleIterator): + elif isinstance(args.sampler, ReproducibleSampler): sampler = re_instantiate_sampler(args.sampler) sampler.set_distributed( num_replicas=self.world_size, @@ -503,14 +513,15 @@ class TorchDDPDriver(TorchDriver): return replace_sampler(dataloader, sampler) # evaluator elif dist == "unrepeatdist": - # todo @yh,补充 unrepeatdist 相关内容; args = self.get_dataloader_args(dataloader) - - # todo 判断 batch_sampler; - sampler = UnrepeatedSampler( - dataset=args.dataset, - shuffle=args.shuffle, - ) + if isinstance(args.sampler, ReproducibleSampler): + sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) + elif not isinstance(args.sampler, UnrepeatedSampler): + sampler = UnrepeatedSequentialSampler( + dataset=args.dataset + ) + else: + sampler = re_instantiate_sampler(args.sampler) sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 14a135ee..cf8c19a8 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -13,9 +13,8 @@ __all__ = [ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.utils import auto_param_call -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler from fastNLP.core.log import logger -from fastNLP.core.samplers import re_instantiate_sampler class TorchSingleDriver(TorchDriver): @@ -130,26 +129,26 @@ class TorchSingleDriver(TorchDriver): else: return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None, reproducible: bool = False): - # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; - if isinstance(dist, ReproducibleBatchSampler): + # 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; + if isinstance(dist, RandomBatchSampler): return replace_batch_sampler(dataloader, dist) - elif isinstance(dist, ReproducibleIterator): + elif isinstance(dist, ReproducibleSampler): return replace_sampler(dataloader, dist) # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; args = self.get_dataloader_args(dataloader) - if isinstance(args.batch_sampler, ReproducibleBatchSampler): + if isinstance(args.batch_sampler, RandomBatchSampler): batch_sampler = re_instantiate_sampler(args.batch_sampler) return replace_batch_sampler(dataloader, batch_sampler) - elif isinstance(args.sampler, ReproducibleIterator): + elif isinstance(args.sampler, ReproducibleSampler): sampler = re_instantiate_sampler(args.sampler) return replace_sampler(dataloader, sampler) if reproducible: - batch_sampler = ReproducibleBatchSampler( + batch_sampler = RandomBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, drop_last=args.drop_last diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index ce1bff14..b3386f5a 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator class TorchDriver(Driver): @@ -182,10 +182,10 @@ class TorchDriver(Driver): # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; - # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 - # sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; + # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 + # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`; dataloader_args = self.get_dataloader_args(dataloader) - if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): sampler = dataloader_args.batch_sampler elif dataloader_args.sampler: sampler = dataloader_args.sampler @@ -245,15 +245,15 @@ class TorchDriver(Driver): # 3. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) - if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): sampler = dataloader_args.batch_sampler elif isinstance(dataloader_args.sampler, ReproducibleIterator): sampler = dataloader_args.sampler elif self.is_distributed(): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " - "`ReproducibleBatchSampler` or `ReproducibleIterator`.") + "`RandomBatchSampler` or `ReproducibleIterator`.") else: - sampler = ReproducibleBatchSampler( + sampler = RandomBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last @@ -263,7 +263,7 @@ class TorchDriver(Driver): # 4. 修改 trainer_state.batch_idx_in_epoch # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; - if not isinstance(sampler, ReproducibleBatchSampler): + if not isinstance(sampler, RandomBatchSampler): if dataloader_args.drop_last: batch_idx_in_epoch = len( sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size @@ -291,7 +291,7 @@ class TorchDriver(Driver): @staticmethod def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover - """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed + """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with ``seed_everything(seed, workers=True)``. See also the PyTorch documentation on diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index bb2ee661..3d6813f7 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -9,18 +9,24 @@ __all__ = [ 'MixSequentialSampler', 'PollingSampler', - 'ReproducibleIterator', + 'ReproducibleSampler', 'RandomSampler', - - 're_instantiate_sampler', + "SequentialSampler", + "SortedSampler", 'UnrepeatedSampler', - "UnrepeatedSortedSampler" + 'UnrepeatedRandomSampler', + "UnrepeatedSortedSampler", + "UnrepeatedSequentialSampler", + + "re_instantiate_sampler", + "conversion_between_reproducible_and_unrepeated_sampler" ] from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler -from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler +from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler -from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler -from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler +from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler +from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler +from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 3e39aca5..5a25110b 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -1,6 +1,6 @@ __all__ = [ 'BucketedBatchSampler', - "ReproducibleBatchSampler" + "RandomBatchSampler" ] import math @@ -16,7 +16,7 @@ from fastNLP.core.log import logger from abc import abstractmethod -class ReproducibleBatchIterator: +class ReproducibleBatchSampler: @abstractmethod def set_distributed(self, num_replicas, rank, pad=True): raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") @@ -42,13 +42,13 @@ class ReproducibleBatchIterator: pass -class ReproducibleBatchSampler(ReproducibleBatchIterator): +class RandomBatchSampler(ReproducibleBatchSampler): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): """ 可以使得 batch_sampler 对象状态恢复的 wrapper 。 - :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 + :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 :param batch_size: 每个 batch 的大小是多少。 :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 @@ -138,7 +138,7 @@ class ReproducibleBatchSampler(ReproducibleBatchIterator): (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size -class BucketedBatchSampler(ReproducibleBatchIterator): +class BucketedBatchSampler(ReproducibleBatchSampler): def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): """ diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 6d2c8246..1dc226a5 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -1,24 +1,21 @@ -from typing import Dict, List +from typing import Dict, List, Union import math import numpy as np from fastNLP.core.log import logger +from fastNLP.core.dataset import DataSet __all__ = [ - 'ReproducibleIterator', + 'ReproducibleSampler', 'RandomSampler', - 're_instantiate_sampler' + "SortedSampler", + "SequentialSampler" ] -def re_instantiate_sampler(sampler): - all_attributes = vars(sampler) - return type(sampler)(**all_attributes) - - -class ReproducibleIterator: +class ReproducibleSampler: """ - 注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler + 注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 """ @@ -46,7 +43,7 @@ class ReproducibleIterator: pass -class RandomSampler(ReproducibleIterator): +class RandomSampler(ReproducibleSampler): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): """ @@ -156,8 +153,8 @@ class RandomSampler(ReproducibleIterator): f"we cannot use {self.__class__.__name__} to load it." length = states['length'] - assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ - "and current dataset." + assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ + f"and current dataset({len(self.dataset)})." self.seed = states['seed'] self.epoch = states['epoch'] self.num_consumed_samples = states['num_consumed_samples'] @@ -214,9 +211,132 @@ class RandomSampler(ReproducibleIterator): self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) +class SequentialSampler(RandomSampler): + def __init__(self, dataset, dist_mode:str='interval', **kwargs): + """ + 按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param kwargs: + """ + super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) + + def __iter__(self): + if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 + self.num_consumed_samples = 0 + self.during_iter = True + indices = self.generate_indices() + + if self.pad: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + + assert len(indices) == self.total_size + + # subsample + indices = indices[self.num_consumed_samples:] + indices = indices[self.rank:len(indices):self.num_replicas] + assert len(indices) == self.num_left_samples + for index in indices: + self.num_consumed_samples += self.num_replicas + yield index + self.during_iter = False + self.num_consumed_samples = 0 + def generate_indices(self) -> List[int]: + """ + 生成随机序列 + :return: + """ + return list(range(len(self.dataset))) + def state_dict(self) -> Dict: + states = { + 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; + 'sampler_type': self.__class__.__name__, + 'length': len(self.dataset), + } + return states + def load_state_dict(self, states: Dict): + # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; + assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ + "during an unfinished iteration." + + assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ + f"we cannot use {self.__class__.__name__} to load it." + + length = states['length'] + assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ + f"and current dataset({len(self.dataset)})." + self.num_consumed_samples = states['num_consumed_samples'] + if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 + self.num_consumed_samples = 0 + + +class SortedSampler(SequentialSampler): + def __init__(self, dataset, length:Union[str, List], **kwargs): + """ + 将 dataset 中的数据根据 length 从长到短进行迭代。在多卡情况下,由于padding 最后一个 sample 可能是最长的那个 sample。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 + DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 + :param seed: 设置的随机数种子 + :param kwargs: fastNLP 保留使用 + """ + super().__init__(dataset=dataset, **kwargs) + if isinstance(dataset, DataSet): + length = dataset.get_field(length) + if not isinstance(length[0], int): + length = list(map(len, length)) + else: + assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ + "the length parameter can only be List[int]" + + assert len(length) == len(dataset), "The length of `data` and `length` should be equal." + + self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 + self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 + + def generate_indices(self) -> List[int]: + return self.sorted_indices + + def __iter__(self): + if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 + self.num_consumed_samples = 0 + self.during_iter = True + indices = self.generate_indices() + + if self.pad: + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + + assert len(indices) == self.total_size + + # subsample + indices = indices[self.num_consumed_samples:] + indices = indices[self.rank:len(indices):self.num_replicas] + assert len(indices) == self.num_left_samples + + for index in indices: + self.num_consumed_samples += self.num_replicas + yield index + self.during_iter = False + self.num_consumed_samples = 0 diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index 18ae16db..d7913d20 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -1,6 +1,8 @@ __all__ = [ + 'UnrepeatedSampler', 'UnrepeatedSortedSampler', - 'UnrepeatedSampler' + 'UnrepeatedRandomSampler', + "UnrepeatedSequentialSampler" ] from typing import List, Union @@ -10,13 +12,21 @@ import numpy as np class UnrepeatedSampler: + """ + 在多卡场景下保证 indice 不重复的 sampler + """ + pass + + +class UnrepeatedRandomSampler(UnrepeatedSampler): def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): """ 考虑在多卡evaluate的场景下,不能重复sample。 - :param dataset: - :param shuffle: - :param seed: + :param dataset: 实现了 __len__ 方法的数据容器。 + :param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 + :param seed: 设置的随机数种子 + :param kwargs: fastNLP 保留使用 """ self.dataset = dataset self.shuffle = shuffle @@ -33,8 +43,8 @@ class UnrepeatedSampler: :return: """ num_common = len(self.dataset)//self.num_replicas - self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) - return self.num_samples + num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) + return num_samples def __iter__(self): indices = self.generate_indices() @@ -83,8 +93,8 @@ class UnrepeatedSampler: return self -class UnrepeatedSortedSampler(UnrepeatedSampler): - def __init__(self, dataset, length:Union[str, List], seed: int = 0): +class UnrepeatedSortedSampler(UnrepeatedRandomSampler): + def __init__(self, dataset, length:Union[str, List], **kwargs): """ 将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 batch 数量不完全一致。 @@ -92,11 +102,9 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): :param dataset: 实现了 __len__ 方法的数据容器。 :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 - :param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 - :param seed: 设置的随机数种子 :param kwargs: fastNLP 保留使用 """ - super().__init__(dataset=dataset, shuffle=False, seed=seed) + super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) if isinstance(dataset, DataSet): length = dataset.get_field(length) if not isinstance(length[0], int): @@ -107,8 +115,29 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): assert len(length) == len(dataset), "The length of `data` and `length` should be equal." - self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 - self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 + length = np.array(length, dtype=int) # 按照长到短排列的序号。 + self.sorted_indices = np.argsort(length)[::-1].tolist() # 按长度从高到低排序的 def generate_indices(self) -> List[int]: return self.sorted_indices + + +class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): + def __init__(self, dataset, **kwargs): + """ + 按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param kwargs: + """ + super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) + + def __iter__(self): + indices = self.generate_indices() + indices = indices[self.rank:len(indices):self.num_replicas] + for index in indices: + yield index + + def generate_indices(self) -> List[int]: + return list(range(len(self.dataset))) + diff --git a/fastNLP/core/samplers/utils.py b/fastNLP/core/samplers/utils.py new file mode 100644 index 00000000..dd90fe7c --- /dev/null +++ b/fastNLP/core/samplers/utils.py @@ -0,0 +1,42 @@ +__all__ = [ + 're_instantiate_sampler', + 'conversion_between_reproducible_and_unrepeated_sampler' +] + +from fastNLP.core.samplers.unrepeated_sampler import * +from fastNLP.core.samplers.reproducible_sampler import * + + +def conversion_between_reproducible_and_unrepeated_sampler(sampler): + """ + 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 + ReproducibleSampler, + + :param sampler: + :return: + """ + assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ + "The sampler must be UnrepeatedSampler or ReproducibleSampler" + if isinstance(sampler, UnrepeatedSampler): + if isinstance(sampler, UnrepeatedRandomSampler): + return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) + elif isinstance(sampler, UnrepeatedSequentialSampler): + return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) + elif isinstance(sampler, UnrepeatedSortedSampler): + return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) + raise TypeError(f"{sampler.__class__} has no unrepeated version.") + else: + if isinstance(sampler, RandomSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) + elif isinstance(sampler, SequentialSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) + elif isinstance(sampler, SortedSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) + raise TypeError(f"{sampler.__class__} has no reproducible version.") + + +def re_instantiate_sampler(sampler, new_sampler_class=None): + all_attributes = vars(sampler) + if new_sampler_class is not None: + return new_sampler_class(**all_attributes) + return type(sampler)(**all_attributes) \ No newline at end of file diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 33662d7f..b2f5864b 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -10,7 +10,7 @@ from paddle.io import DataLoader, BatchSampler from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver from fastNLP.core.samplers.reproducible_sampler import RandomSampler -from fastNLP.core.samplers import ReproducibleBatchSampler +from fastNLP.core.samplers import RandomBatchSampler from tests.helpers.models.paddle_model import PaddleNormalModel_Classification from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset from fastNLP.core import synchronize_safe_rm @@ -153,7 +153,7 @@ class TestSingleDeviceFunction: @pytest.mark.parametrize( "dist_sampler", - ["dist", ReproducibleBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] + ["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] ) @pytest.mark.parametrize( "reproducible", diff --git a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py index 81d693fc..161bbfe8 100644 --- a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py +++ b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py @@ -30,7 +30,7 @@ class SequenceDataSet: def check_replace_sampler(driver): - # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler + # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler # reproducible 是 True 和 False # 需要 check 返回的 sampler 和 dataloader 都不同了 diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index edc7b86b..d51dd912 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -4,7 +4,7 @@ import numpy as np import pytest from itertools import chain -from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler +from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -18,7 +18,7 @@ class TestReproducibleBatchSampler: before_batch_size = 7 dataset = TorchNormalDataset(num_of_data=100) dataloader = DataLoader(dataset, batch_size=before_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) dataloader = replace_batch_sampler(dataloader, re_batchsampler) forward_steps = 3 @@ -28,15 +28,15 @@ class TestReproducibleBatchSampler: # 1. 保存状态 _get_re_batchsampler = dataloader.batch_sampler - assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) + assert isinstance(_get_re_batchsampler, RandomBatchSampler) state = _get_re_batchsampler.state_dict() assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, - "sampler_type": "ReproducibleBatchSampler"} + "sampler_type": "RandomBatchSampler"} # 2. 断点重训,重新生成一个 dataloader; # 不改变 batch_size; dataloader = DataLoader(dataset, batch_size=before_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) re_batchsampler.load_state_dict(state) dataloader = replace_batch_sampler(dataloader, re_batchsampler) @@ -53,7 +53,7 @@ class TestReproducibleBatchSampler: # 改变 batch_size; after_batch_size = 3 dataloader = DataLoader(dataset, batch_size=after_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) re_batchsampler.load_state_dict(state) dataloader = replace_batch_sampler(dataloader, re_batchsampler) @@ -99,7 +99,7 @@ class TestReproducibleBatchSampler: dataset = TorchNormalDataset(num_of_data=100) # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) dataloader = replace_batch_sampler(dataloader, re_batchsampler) # 将一轮的所有数据保存下来,看是否恢复的是正确的; @@ -111,13 +111,13 @@ class TestReproducibleBatchSampler: # 1. 保存状态 _get_re_batchsampler = dataloader.batch_sampler - assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) + assert isinstance(_get_re_batchsampler, RandomBatchSampler) state = _get_re_batchsampler.state_dict() # 2. 断点重训,重新生成一个 dataloader; # 不改变 batch_size; dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) re_batchsampler.load_state_dict(state) dataloader = replace_batch_sampler(dataloader, re_batchsampler) diff --git a/tests/core/samplers/test_reproducible_sampler.py b/tests/core/samplers/test_reproducible_sampler.py index 0a3697d3..981d6a03 100644 --- a/tests/core/samplers/test_reproducible_sampler.py +++ b/tests/core/samplers/test_reproducible_sampler.py @@ -1,18 +1,14 @@ -import unittest - -from itertools import product import numpy as np +import pytest from functools import partial -from array import array +from itertools import chain -from fastNLP.core.samplers.reproducible_sampler import RandomSampler -from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler +from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler from tests.helpers.datasets.torch_data import TorchNormalDataset - -class TestRandomSamplerYh(unittest.TestCase): +class TestRandomSamplerYh: def test_init(self): # 测试能否正确初始化 dataset = TorchNormalDataset(num_of_data=100) @@ -24,7 +20,7 @@ class TestRandomSamplerYh(unittest.TestCase): dataset = TorchNormalDataset(num_of_data=100) sampler = RandomSampler(dataset) for i in sampler: - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): sampler.set_distributed(1, 0) break @@ -37,39 +33,39 @@ class TestRandomSamplerYh(unittest.TestCase): dataset = TorchNormalDataset(num_of_data=100) sampler = RandomSampler(dataset, shuffle=False) sampler.set_distributed(num_replicas=2, rank=0, pad=False) - self.assertEqual(len(sampler), 50) + assert len(sampler)==50 count = 0 for i in sampler: - self.assertEqual(i%2, 0) + assert i%2==0 count += 1 - self.assertEqual(count, 50) + assert count == 50 sampler.set_distributed(num_replicas=2, rank=1, pad=False) - self.assertEqual(len(sampler), 50) + assert len(sampler)==50 count = 0 for i in sampler: - self.assertEqual(i%2, 1) + assert i%2==1 count += 1 - self.assertEqual(count, 50) + assert count==50 dataset = TorchNormalDataset(num_of_data=101) sampler = RandomSampler(dataset, shuffle=False) sampler.set_distributed(num_replicas=2, rank=0, pad=True) - self.assertEqual(len(sampler), 51) + assert len(sampler)==51 count = 0 for i in sampler: - self.assertEqual(i%2, 0) + assert i%2==0 count += 1 - self.assertEqual(count, 51) + assert count == 51 sampler.set_distributed(num_replicas=2, rank=1, pad=True) - self.assertEqual(len(sampler), 51) + assert len(sampler) == 51 count = 0 for i in sampler: if i!=0: - self.assertEqual(i%2, 1) + assert i%2==1 count += 1 - self.assertEqual(count, 51) + assert count == 51 def test_state_dict_check_length(self): dataset = TorchNormalDataset(num_of_data=100) @@ -77,7 +73,7 @@ class TestRandomSamplerYh(unittest.TestCase): states = sampler.state_dict() new_ds = TorchNormalDataset(num_of_data=10) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): new_sampler = RandomSampler(new_ds) new_sampler.load_state_dict(states) @@ -85,99 +81,107 @@ class TestRandomSamplerYh(unittest.TestCase): new_sampler = RandomSampler(new_ds) new_sampler.load_state_dict(states) - def test_state_dict(self): + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('pre_shuffle', [True, False]) + @pytest.mark.parametrize('post_shuffle', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) + def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): num_samples = 100 dataset = TorchNormalDataset(num_of_data=num_samples) # 测试使用 前后shuffle不一致的load操作 - lst = [0]+np.random.randint(1, num_samples, size=3).tolist() - for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], - lst): - with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): - sampler = RandomSampler(dataset, shuffle=pre_shuffle) - sampler.set_epoch(0) - already_numbers = set() - if num_consumed_samples>0: - for i, j in enumerate(sampler, start=1): - already_numbers.add(j) - if i == num_consumed_samples: - break - self.assertEqual(len(already_numbers), num_consumed_samples) - - states = sampler.state_dict() - - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_epoch(0) - for i in new_sampler: - self.assertNotIn(i, already_numbers) - - # 测试切换成多卡也没有问题 - other_rank_number = set() - for rank in range(3): - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) - new_sampler.set_epoch(0) - count = 0 - for i in new_sampler: - self.assertNotIn(i, other_rank_number) - other_rank_number.add(i) - self.assertNotIn(i, already_numbers) - count += 1 - - def test_state_dict_2(self): + sampler = RandomSampler(dataset, shuffle=pre_shuffle) + sampler.set_epoch(0) + already_numbers = set() + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples + + states = sampler.state_dict() + + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + new_sampler.set_epoch(0) + count = 0 + seen = 0 + seen_in_other_rank = 0 + for i in new_sampler: + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('pre_shuffle', [True, False]) + @pytest.mark.parametrize('post_shuffle', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) + def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 num_samples = 100 dataset = TorchNormalDataset(num_of_data=num_samples) # 测试使用 前后shuffle不一致的load操作 - lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist() # lst = [30] - for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], - lst): - with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): - already_numbers = set() - sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) - sampler.set_distributed(num_replicas=2, rank=0) - sampler.set_epoch(0) - if num_consumed_samples>0: - for i, j in enumerate(sampler, start=1): - already_numbers.add(j) - if i == num_consumed_samples: - break - sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) - sampler.set_epoch(0) - sampler.set_distributed(num_replicas=2, rank=1) - if num_consumed_samples>0: - for i, j in enumerate(sampler, start=1): - already_numbers.add(j) - if i == num_consumed_samples: - break - self.assertEqual(len(already_numbers), num_consumed_samples*2) - - states = sampler.state_dict() - - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_epoch(0) - for i in new_sampler: - self.assertNotIn(i, already_numbers) - - # 测试切换成多卡也没有问题 - other_rank_number = set() - for rank in range(3): - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_epoch(0) - new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) - count = 0 - for i in new_sampler: - self.assertNotIn(i, other_rank_number) - other_rank_number.add(i) - self.assertNotIn(i, already_numbers) - count += 1 - - -class TestRandomSampler(unittest.TestCase): + already_numbers = set() + sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) + sampler.set_distributed(num_replicas=2, rank=0) + sampler.set_epoch(0) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=2, rank=1) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples*2 + + states = sampler.state_dict() + + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + count = 0 + seen = 0 + seen_in_other_rank = 0 + for i in new_sampler: + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + + +class TestRandomSampler: # 测试单卡; def test_seed_work_when_shuffle_is_true(self): data_length = 100 @@ -360,4 +364,324 @@ class TestRandomSampler(unittest.TestCase): ... +class DatasetWithVaryLength: + def __init__(self, num_of_data=100, reverse=False): + self.data = np.arange(num_of_data) + if reverse: + self.data = self.data[::-1] + + def __getitem__(self, item): + return self.data[item] + + def __len__(self): + return len(self.data) + + +class TestSortedSampler: + def test_single(self): + num_of_data = 100 + data = DatasetWithVaryLength(num_of_data) + sampler = SortedSampler(data, length=data.data) + indexes = list(sampler) + assert indexes==list(range(num_of_data-1, -1, -1)) + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) + def test_multi(self, pad, num_replica, num_of_data): + data = DatasetWithVaryLength(num_of_data=num_of_data) + samplers = [] + for i in range(num_replica): + sampler = SortedSampler(dataset=data, length=data.data) + sampler.set_distributed(num_replica, rank=i, pad=pad) + samplers.append(sampler) + + # 保证顺序是没乱的 + already_seen_index = set() + for sampler in samplers: + larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。 + prev_index = float('inf') + cur_set = set() + seen_in_other_rank = 0 + for index in sampler: + seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 + cur_set.add(index) + larger_count += int(index <= prev_index) + prev_index = index + assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 + assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0 + already_seen_index.update(cur_set) + + indexes = list(chain(*samplers)) + indexes = set(indexes) + if pad: + assert indexes == set(range(num_of_data)) + else: + assert len(indexes) <= num_of_data + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) + def test_state_dict(self, pad, num_consumed_samples): + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + sampler = SortedSampler(dataset, length=dataset.data) + sampler.set_epoch(0) + already_numbers = set() + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j= max(already_numbers)) + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + assert smaller<=1 if pad else smaller==0 + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) + def test_state_dict_2(self, pad, num_consumed_samples): + # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + # lst = [30] + already_numbers = set() + sampler = SortedSampler(dataset, length=dataset.data) + sampler.set_distributed(num_replicas=2, rank=0) + sampler.set_epoch(0) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j<=max(already_numbers) + already_numbers.add(j) + if i == num_consumed_samples: + break + sampler = SortedSampler(dataset, length=dataset.data) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=2, rank=1) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples*2 + + states = sampler.state_dict() + + new_sampler = SortedSampler(dataset, length=dataset.data) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + if already_numbers: + assert i < max(already_numbers) + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = SortedSampler(dataset, length=dataset.data) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + count = 0 + seen = 0 + seen_in_other_rank = 0 + smaller = 0 + for i in new_sampler: + if already_numbers: + smaller += int(i>=max(already_numbers)) + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + assert smaller <= 1 if pad else smaller == 0 + + +class TestSequentialSampler: + def test_single(self): + num_of_data = 100 + data = DatasetWithVaryLength(num_of_data) + sampler = SequentialSampler(data) + indexes = list(sampler) + assert indexes==list(range(num_of_data)) + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) + def test_multi(self, pad, num_replica, num_of_data): + data = DatasetWithVaryLength(num_of_data=num_of_data) + samplers = [] + for i in range(num_replica): + sampler = SequentialSampler(dataset=data) + sampler.set_distributed(num_replica, rank=i, pad=pad) + samplers.append(sampler) + + # 保证顺序是没乱的 + already_seen_index = set() + for idx, sampler in enumerate(samplers): + larger_count = 1 + prev_index = float('inf') + cur_set = set() + seen_in_other_rank = 0 + for index in sampler: + seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 + cur_set.add(index) + larger_count += int(index >= prev_index) + prev_index = index + assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 + assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0 + already_seen_index.update(cur_set) + + indexes = list(chain(*samplers)) + indexes = set(indexes) + if pad: + assert indexes == set(range(num_of_data)) + else: + assert len(indexes) <= num_of_data + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) + def test_state_dict(self, pad, num_consumed_samples): + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + sampler = SequentialSampler(dataset=dataset) + sampler.set_epoch(0) + already_numbers = set() + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j>max(already_numbers) + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples + + states = sampler.state_dict() + + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + if already_numbers: + assert i > max(already_numbers) + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + new_sampler.set_epoch(0) + count = 0 + seen = 0 + seen_in_other_rank = 0 + smaller = 0 + for i in new_sampler: + if already_numbers: + smaller += int(i <= max(already_numbers)) + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=rank # 因为pad可能重复 + assert smaller<=1 if pad else smaller==0 + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) + def test_state_dict_2(self, pad, num_consumed_samples): + # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + # lst = [30] + already_numbers = set() + sampler = SequentialSampler(dataset=dataset) + sampler.set_distributed(num_replicas=2, rank=0) + sampler.set_epoch(0) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j>max(already_numbers) + already_numbers.add(j) + if i == num_consumed_samples: + break + sampler = SequentialSampler(dataset=dataset) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=2, rank=1) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples*2 + + states = sampler.state_dict() + + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + if already_numbers: + assert i > max(already_numbers) + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + count = 0 + seen = 0 + seen_in_other_rank = 0 + smaller = 0 + for i in new_sampler: + if already_numbers: + smaller += int(i=prev_index + prev_index = index + + indexes = list(chain(*samplers)) + assert len(indexes) == num_of_data + indexes = set(indexes) + assert indexes == set(range(num_of_data))