@@ -23,7 +23,7 @@ from fastNLP.core.drivers import Driver | |||
from fastNLP.core.drivers.utils import choose_driver | |||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||
@@ -610,7 +610,7 @@ class Trainer(TrainerEventTrigger): | |||
r""" | |||
用于断点重训的加载函数; | |||
注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 | |||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; | |||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; | |||
注意我们目前不支持单卡到多卡的断点重训; | |||
@@ -49,13 +49,13 @@ class Driver(ABC): | |||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
注意当 dist 为 ReproducibleIterator, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
注意当 dist 为 ReproducibleIterator, RandomBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
可以可以加载。 | |||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||
如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的 | |||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||
""" | |||
if dist is None and reproducible is False: | |||
@@ -3,7 +3,7 @@ from typing import Optional, Union | |||
from .jittor_driver import JittorDriver | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.samplers import ReproducibleIterator | |||
from fastNLP.core.samplers import ReproducibleSampler | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
@@ -70,7 +70,7 @@ class JittorMPIDriver(JittorDriver): | |||
def test_step(self, batch): | |||
return self._test_step(batch) | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
pass | |||
@@ -3,7 +3,7 @@ from typing import Dict, Union | |||
from .jittor_driver import JittorDriver | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
@@ -99,25 +99,25 @@ class JittorSingleDriver(JittorDriver): | |||
def is_distributed(self): | |||
return False | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# reproducible 的相关功能暂时没有实现 | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
if isinstance(dist, RandomBatchSampler): | |||
raise NotImplementedError | |||
dataloader.batch_sampler = dist_sample | |||
if isinstance(dist, ReproducibleIterator): | |||
if isinstance(dist, ReproducibleSampler): | |||
raise NotImplementedError | |||
dataloader.batch_sampler.sampler = dist | |||
if reproducible: | |||
raise NotImplementedError | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
return dataloader | |||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||
return dataloader | |||
else: | |||
# TODO | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler=dataloader.batch_sampler, | |||
batch_size=dataloader.batch_sampler.batch_size, | |||
drop_last=dataloader.drop_last | |||
@@ -19,7 +19,7 @@ from fastNLP.core.utils import ( | |||
paddle_move_data_to_device, | |||
is_in_paddle_dist, | |||
) | |||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler | |||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | |||
from fastNLP.core.log import logger | |||
@@ -312,13 +312,13 @@ class PaddleFleetDriver(PaddleDriver): | |||
def test_step(self, batch): | |||
return self._test_step(batch) | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# 暂时不支持iterableDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
if isinstance(dist, ReproducibleIterator): | |||
if isinstance(dist, ReproducibleSampler): | |||
dataloader.batch_sampler.sampler = dist | |||
return dataloader | |||
@@ -340,7 +340,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
# trainer | |||
elif dist == "dist": | |||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
dataloader.batch_sampler.sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
@@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver): | |||
return dataloader | |||
# evaluator | |||
elif dist == "unrepeatdist": | |||
sampler = UnrepeatedSampler( | |||
sampler = UnrepeatedRandomSampler( | |||
dataset=dataloader.dataset, | |||
shuffle=shuffle, | |||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||
@@ -10,7 +10,7 @@ from fastNLP.core.utils import ( | |||
get_paddle_device_id, | |||
paddle_move_data_to_device, | |||
) | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
@@ -139,26 +139,26 @@ class PaddleSingleDriver(PaddleDriver): | |||
""" | |||
return paddle_move_data_to_device(batch, "gpu:0") | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# 暂时不支持IteratorDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
if isinstance(dist, RandomBatchSampler): | |||
dataloader.batch_sampler = dist | |||
return dataloader | |||
if isinstance(dist, ReproducibleIterator): | |||
if isinstance(dist, ReproducibleSampler): | |||
dataloader.batch_sampler.sampler = dist | |||
return dataloader | |||
if reproducible: | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
return dataloader | |||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||
return dataloader | |||
else: | |||
# TODO | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler=dataloader.batch_sampler, | |||
batch_size=dataloader.batch_sampler.batch_size, | |||
drop_last=dataloader.drop_last | |||
@@ -28,11 +28,11 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||
) | |||
from fastNLP.core.drivers.utils import distributed_open_proc | |||
from fastNLP.core.utils import auto_param_call, check_user_specific_params | |||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \ | |||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | |||
from fastNLP.core.samplers import re_instantiate_sampler | |||
class TorchDDPDriver(TorchDriver): | |||
@@ -446,13 +446,23 @@ class TorchDDPDriver(TorchDriver): | |||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
return self._test_step(batch) | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None, | |||
reproducible: bool = False): | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
if isinstance(dist, RandomBatchSampler): | |||
dist.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
return replace_batch_sampler(dataloader, dist) | |||
if isinstance(dist, ReproducibleIterator): | |||
if isinstance(dist, ReproducibleSampler): | |||
dist.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
return replace_sampler(dataloader, dist) | |||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
@@ -462,10 +472,10 @@ class TorchDDPDriver(TorchDriver): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||
"control.") | |||
else: | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
if isinstance(dist, RandomBatchSampler): | |||
dist = re_instantiate_sampler(dist) | |||
return replace_batch_sampler(dataloader, dist) | |||
if isinstance(dist, ReproducibleIterator): | |||
if isinstance(dist, ReproducibleSampler): | |||
dist = re_instantiate_sampler(dist) | |||
return replace_sampler(dataloader, dist) | |||
return dataloader | |||
@@ -473,7 +483,7 @@ class TorchDDPDriver(TorchDriver): | |||
elif dist == "dist": | |||
args = self.get_dataloader_args(dataloader) | |||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
if isinstance(args.batch_sampler, RandomBatchSampler): | |||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
batch_sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
@@ -481,7 +491,7 @@ class TorchDDPDriver(TorchDriver): | |||
pad=True | |||
) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
elif isinstance(args.sampler, ReproducibleIterator): | |||
elif isinstance(args.sampler, ReproducibleSampler): | |||
sampler = re_instantiate_sampler(args.sampler) | |||
sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
@@ -503,14 +513,15 @@ class TorchDDPDriver(TorchDriver): | |||
return replace_sampler(dataloader, sampler) | |||
# evaluator | |||
elif dist == "unrepeatdist": | |||
# todo @yh,补充 unrepeatdist 相关内容; | |||
args = self.get_dataloader_args(dataloader) | |||
# todo 判断 batch_sampler; | |||
sampler = UnrepeatedSampler( | |||
dataset=args.dataset, | |||
shuffle=args.shuffle, | |||
) | |||
if isinstance(args.sampler, ReproducibleSampler): | |||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||
sampler = UnrepeatedSequentialSampler( | |||
dataset=args.dataset | |||
) | |||
else: | |||
sampler = re_instantiate_sampler(args.sampler) | |||
sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank | |||
@@ -13,9 +13,8 @@ __all__ = [ | |||
from .torch_driver import TorchDriver | |||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import re_instantiate_sampler | |||
class TorchSingleDriver(TorchDriver): | |||
@@ -130,26 +129,26 @@ class TorchSingleDriver(TorchDriver): | |||
else: | |||
return self._test_step(batch) | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None, | |||
reproducible: bool = False): | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
if isinstance(dist, RandomBatchSampler): | |||
return replace_batch_sampler(dataloader, dist) | |||
elif isinstance(dist, ReproducibleIterator): | |||
elif isinstance(dist, ReproducibleSampler): | |||
return replace_sampler(dataloader, dist) | |||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
if isinstance(args.batch_sampler, RandomBatchSampler): | |||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
elif isinstance(args.sampler, ReproducibleIterator): | |||
elif isinstance(args.sampler, ReproducibleSampler): | |||
sampler = re_instantiate_sampler(args.sampler) | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator | |||
class TorchDriver(Driver): | |||
@@ -182,10 +182,10 @@ class TorchDriver(Driver): | |||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 | |||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 | |||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`; | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): | |||
sampler = dataloader_args.batch_sampler | |||
elif dataloader_args.sampler: | |||
sampler = dataloader_args.sampler | |||
@@ -245,15 +245,15 @@ class TorchDriver(Driver): | |||
# 3. 恢复 sampler 的状态; | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): | |||
sampler = dataloader_args.batch_sampler | |||
elif isinstance(dataloader_args.sampler, ReproducibleIterator): | |||
sampler = dataloader_args.sampler | |||
elif self.is_distributed(): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " | |||
"`ReproducibleBatchSampler` or `ReproducibleIterator`.") | |||
"`RandomBatchSampler` or `ReproducibleIterator`.") | |||
else: | |||
sampler = ReproducibleBatchSampler( | |||
sampler = RandomBatchSampler( | |||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||
batch_size=dataloader_args.batch_size, | |||
drop_last=dataloader_args.drop_last | |||
@@ -263,7 +263,7 @@ class TorchDriver(Driver): | |||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||
if not isinstance(sampler, ReproducibleBatchSampler): | |||
if not isinstance(sampler, RandomBatchSampler): | |||
if dataloader_args.drop_last: | |||
batch_idx_in_epoch = len( | |||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||
@@ -291,7 +291,7 @@ class TorchDriver(Driver): | |||
@staticmethod | |||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed | |||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed | |||
with ``seed_everything(seed, workers=True)``. | |||
See also the PyTorch documentation on | |||
@@ -9,18 +9,24 @@ __all__ = [ | |||
'MixSequentialSampler', | |||
'PollingSampler', | |||
'ReproducibleIterator', | |||
'ReproducibleSampler', | |||
'RandomSampler', | |||
're_instantiate_sampler', | |||
"SequentialSampler", | |||
"SortedSampler", | |||
'UnrepeatedSampler', | |||
"UnrepeatedSortedSampler" | |||
'UnrepeatedRandomSampler', | |||
"UnrepeatedSortedSampler", | |||
"UnrepeatedSequentialSampler", | |||
"re_instantiate_sampler", | |||
"conversion_between_reproducible_and_unrepeated_sampler" | |||
] | |||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | |||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler | |||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler | |||
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler | |||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | |||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler | |||
@@ -1,6 +1,6 @@ | |||
__all__ = [ | |||
'BucketedBatchSampler', | |||
"ReproducibleBatchSampler" | |||
"RandomBatchSampler" | |||
] | |||
import math | |||
@@ -16,7 +16,7 @@ from fastNLP.core.log import logger | |||
from abc import abstractmethod | |||
class ReproducibleBatchIterator: | |||
class ReproducibleBatchSampler: | |||
@abstractmethod | |||
def set_distributed(self, num_replicas, rank, pad=True): | |||
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | |||
@@ -42,13 +42,13 @@ class ReproducibleBatchIterator: | |||
pass | |||
class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||
class RandomBatchSampler(ReproducibleBatchSampler): | |||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | |||
""" | |||
可以使得 batch_sampler 对象状态恢复的 wrapper 。 | |||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 | |||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 | |||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | |||
:param batch_size: 每个 batch 的大小是多少。 | |||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | |||
@@ -138,7 +138,7 @@ class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | |||
class BucketedBatchSampler(ReproducibleBatchIterator): | |||
class BucketedBatchSampler(ReproducibleBatchSampler): | |||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | |||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | |||
""" | |||
@@ -1,24 +1,21 @@ | |||
from typing import Dict, List | |||
from typing import Dict, List, Union | |||
import math | |||
import numpy as np | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.dataset import DataSet | |||
__all__ = [ | |||
'ReproducibleIterator', | |||
'ReproducibleSampler', | |||
'RandomSampler', | |||
're_instantiate_sampler' | |||
"SortedSampler", | |||
"SequentialSampler" | |||
] | |||
def re_instantiate_sampler(sampler): | |||
all_attributes = vars(sampler) | |||
return type(sampler)(**all_attributes) | |||
class ReproducibleIterator: | |||
class ReproducibleSampler: | |||
""" | |||
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||
注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | |||
""" | |||
@@ -46,7 +43,7 @@ class ReproducibleIterator: | |||
pass | |||
class RandomSampler(ReproducibleIterator): | |||
class RandomSampler(ReproducibleSampler): | |||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | |||
""" | |||
@@ -156,8 +153,8 @@ class RandomSampler(ReproducibleIterator): | |||
f"we cannot use {self.__class__.__name__} to load it." | |||
length = states['length'] | |||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||
"and current dataset." | |||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||
f"and current dataset({len(self.dataset)})." | |||
self.seed = states['seed'] | |||
self.epoch = states['epoch'] | |||
self.num_consumed_samples = states['num_consumed_samples'] | |||
@@ -214,9 +211,132 @@ class RandomSampler(ReproducibleIterator): | |||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||
class SequentialSampler(RandomSampler): | |||
def __init__(self, dataset, dist_mode:str='interval', **kwargs): | |||
""" | |||
按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||
:param kwargs: | |||
""" | |||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||
def __iter__(self): | |||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||
self.num_consumed_samples = 0 | |||
self.during_iter = True | |||
indices = self.generate_indices() | |||
if self.pad: | |||
# add extra samples to make it evenly divisible | |||
padding_size = self.total_size - len(indices) | |||
if padding_size <= len(indices): | |||
indices += indices[:padding_size] | |||
else: | |||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |||
else: | |||
# remove tail of data to make it evenly divisible. | |||
indices = indices[:self.total_size] | |||
assert len(indices) == self.total_size | |||
# subsample | |||
indices = indices[self.num_consumed_samples:] | |||
indices = indices[self.rank:len(indices):self.num_replicas] | |||
assert len(indices) == self.num_left_samples | |||
for index in indices: | |||
self.num_consumed_samples += self.num_replicas | |||
yield index | |||
self.during_iter = False | |||
self.num_consumed_samples = 0 | |||
def generate_indices(self) -> List[int]: | |||
""" | |||
生成随机序列 | |||
:return: | |||
""" | |||
return list(range(len(self.dataset))) | |||
def state_dict(self) -> Dict: | |||
states = { | |||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||
'sampler_type': self.__class__.__name__, | |||
'length': len(self.dataset), | |||
} | |||
return states | |||
def load_state_dict(self, states: Dict): | |||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||
"during an unfinished iteration." | |||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||
f"we cannot use {self.__class__.__name__} to load it." | |||
length = states['length'] | |||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||
f"and current dataset({len(self.dataset)})." | |||
self.num_consumed_samples = states['num_consumed_samples'] | |||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||
self.num_consumed_samples = 0 | |||
class SortedSampler(SequentialSampler): | |||
def __init__(self, dataset, length:Union[str, List], **kwargs): | |||
""" | |||
将 dataset 中的数据根据 length 从长到短进行迭代。在多卡情况下,由于padding 最后一个 sample 可能是最长的那个 sample。 | |||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||
:param seed: 设置的随机数种子 | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
super().__init__(dataset=dataset, **kwargs) | |||
if isinstance(dataset, DataSet): | |||
length = dataset.get_field(length) | |||
if not isinstance(length[0], int): | |||
length = list(map(len, length)) | |||
else: | |||
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||
"the length parameter can only be List[int]" | |||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||
def generate_indices(self) -> List[int]: | |||
return self.sorted_indices | |||
def __iter__(self): | |||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||
self.num_consumed_samples = 0 | |||
self.during_iter = True | |||
indices = self.generate_indices() | |||
if self.pad: | |||
padding_size = self.total_size - len(indices) | |||
if padding_size <= len(indices): | |||
indices += indices[:padding_size] | |||
else: | |||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |||
else: | |||
# remove tail of data to make it evenly divisible. | |||
indices = indices[:self.total_size] | |||
assert len(indices) == self.total_size | |||
# subsample | |||
indices = indices[self.num_consumed_samples:] | |||
indices = indices[self.rank:len(indices):self.num_replicas] | |||
assert len(indices) == self.num_left_samples | |||
for index in indices: | |||
self.num_consumed_samples += self.num_replicas | |||
yield index | |||
self.during_iter = False | |||
self.num_consumed_samples = 0 | |||
@@ -1,6 +1,8 @@ | |||
__all__ = [ | |||
'UnrepeatedSampler', | |||
'UnrepeatedSortedSampler', | |||
'UnrepeatedSampler' | |||
'UnrepeatedRandomSampler', | |||
"UnrepeatedSequentialSampler" | |||
] | |||
from typing import List, Union | |||
@@ -10,13 +12,21 @@ import numpy as np | |||
class UnrepeatedSampler: | |||
""" | |||
在多卡场景下保证 indice 不重复的 sampler | |||
""" | |||
pass | |||
class UnrepeatedRandomSampler(UnrepeatedSampler): | |||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | |||
""" | |||
考虑在多卡evaluate的场景下,不能重复sample。 | |||
:param dataset: | |||
:param shuffle: | |||
:param seed: | |||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||
:param seed: 设置的随机数种子 | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
self.dataset = dataset | |||
self.shuffle = shuffle | |||
@@ -33,8 +43,8 @@ class UnrepeatedSampler: | |||
:return: | |||
""" | |||
num_common = len(self.dataset)//self.num_replicas | |||
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||
return self.num_samples | |||
num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||
return num_samples | |||
def __iter__(self): | |||
indices = self.generate_indices() | |||
@@ -83,8 +93,8 @@ class UnrepeatedSampler: | |||
return self | |||
class UnrepeatedSortedSampler(UnrepeatedSampler): | |||
def __init__(self, dataset, length:Union[str, List], seed: int = 0): | |||
class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||
def __init__(self, dataset, length:Union[str, List], **kwargs): | |||
""" | |||
将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 | |||
batch 数量不完全一致。 | |||
@@ -92,11 +102,9 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): | |||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||
:param seed: 设置的随机数种子 | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
super().__init__(dataset=dataset, shuffle=False, seed=seed) | |||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||
if isinstance(dataset, DataSet): | |||
length = dataset.get_field(length) | |||
if not isinstance(length[0], int): | |||
@@ -107,8 +115,29 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): | |||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||
length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
self.sorted_indices = np.argsort(length)[::-1].tolist() # 按长度从高到低排序的 | |||
def generate_indices(self) -> List[int]: | |||
return self.sorted_indices | |||
class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||
def __init__(self, dataset, **kwargs): | |||
""" | |||
按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||
:param kwargs: | |||
""" | |||
super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) | |||
def __iter__(self): | |||
indices = self.generate_indices() | |||
indices = indices[self.rank:len(indices):self.num_replicas] | |||
for index in indices: | |||
yield index | |||
def generate_indices(self) -> List[int]: | |||
return list(range(len(self.dataset))) | |||
@@ -0,0 +1,42 @@ | |||
__all__ = [ | |||
're_instantiate_sampler', | |||
'conversion_between_reproducible_and_unrepeated_sampler' | |||
] | |||
from fastNLP.core.samplers.unrepeated_sampler import * | |||
from fastNLP.core.samplers.reproducible_sampler import * | |||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||
""" | |||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||
ReproducibleSampler, | |||
:param sampler: | |||
:return: | |||
""" | |||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||
if isinstance(sampler, UnrepeatedSampler): | |||
if isinstance(sampler, UnrepeatedRandomSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||
else: | |||
if isinstance(sampler, RandomSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||
elif isinstance(sampler, SequentialSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||
elif isinstance(sampler, SortedSampler): | |||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||
raise TypeError(f"{sampler.__class__} has no reproducible version.") | |||
def re_instantiate_sampler(sampler, new_sampler_class=None): | |||
all_attributes = vars(sampler) | |||
if new_sampler_class is not None: | |||
return new_sampler_class(**all_attributes) | |||
return type(sampler)(**all_attributes) |
@@ -10,7 +10,7 @@ from paddle.io import DataLoader, BatchSampler | |||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||
from fastNLP.core.samplers import RandomBatchSampler | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | |||
from fastNLP.core import synchronize_safe_rm | |||
@@ -153,7 +153,7 @@ class TestSingleDeviceFunction: | |||
@pytest.mark.parametrize( | |||
"dist_sampler", | |||
["dist", ReproducibleBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||
["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||
) | |||
@pytest.mark.parametrize( | |||
"reproducible", | |||
@@ -30,7 +30,7 @@ class SequenceDataSet: | |||
def check_replace_sampler(driver): | |||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler | |||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler | |||
# reproducible 是 True 和 False | |||
# 需要 check 返回的 sampler 和 dataloader 都不同了 | |||
@@ -4,7 +4,7 @@ import numpy as np | |||
import pytest | |||
from itertools import chain | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler | |||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
@@ -18,7 +18,7 @@ class TestReproducibleBatchSampler: | |||
before_batch_size = 7 | |||
dataset = TorchNormalDataset(num_of_data=100) | |||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
forward_steps = 3 | |||
@@ -28,15 +28,15 @@ class TestReproducibleBatchSampler: | |||
# 1. 保存状态 | |||
_get_re_batchsampler = dataloader.batch_sampler | |||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
state = _get_re_batchsampler.state_dict() | |||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||
"sampler_type": "ReproducibleBatchSampler"} | |||
"sampler_type": "RandomBatchSampler"} | |||
# 2. 断点重训,重新生成一个 dataloader; | |||
# 不改变 batch_size; | |||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
re_batchsampler.load_state_dict(state) | |||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
@@ -53,7 +53,7 @@ class TestReproducibleBatchSampler: | |||
# 改变 batch_size; | |||
after_batch_size = 3 | |||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
re_batchsampler.load_state_dict(state) | |||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
@@ -99,7 +99,7 @@ class TestReproducibleBatchSampler: | |||
dataset = TorchNormalDataset(num_of_data=100) | |||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||
@@ -111,13 +111,13 @@ class TestReproducibleBatchSampler: | |||
# 1. 保存状态 | |||
_get_re_batchsampler = dataloader.batch_sampler | |||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||
state = _get_re_batchsampler.state_dict() | |||
# 2. 断点重训,重新生成一个 dataloader; | |||
# 不改变 batch_size; | |||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||
re_batchsampler.load_state_dict(state) | |||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||
@@ -1,18 +1,14 @@ | |||
import unittest | |||
from itertools import product | |||
import numpy as np | |||
import pytest | |||
from functools import partial | |||
from array import array | |||
from itertools import chain | |||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
class TestRandomSamplerYh(unittest.TestCase): | |||
class TestRandomSamplerYh: | |||
def test_init(self): | |||
# 测试能否正确初始化 | |||
dataset = TorchNormalDataset(num_of_data=100) | |||
@@ -24,7 +20,7 @@ class TestRandomSamplerYh(unittest.TestCase): | |||
dataset = TorchNormalDataset(num_of_data=100) | |||
sampler = RandomSampler(dataset) | |||
for i in sampler: | |||
with self.assertRaises(AssertionError): | |||
with pytest.raises(AssertionError): | |||
sampler.set_distributed(1, 0) | |||
break | |||
@@ -37,39 +33,39 @@ class TestRandomSamplerYh(unittest.TestCase): | |||
dataset = TorchNormalDataset(num_of_data=100) | |||
sampler = RandomSampler(dataset, shuffle=False) | |||
sampler.set_distributed(num_replicas=2, rank=0, pad=False) | |||
self.assertEqual(len(sampler), 50) | |||
assert len(sampler)==50 | |||
count = 0 | |||
for i in sampler: | |||
self.assertEqual(i%2, 0) | |||
assert i%2==0 | |||
count += 1 | |||
self.assertEqual(count, 50) | |||
assert count == 50 | |||
sampler.set_distributed(num_replicas=2, rank=1, pad=False) | |||
self.assertEqual(len(sampler), 50) | |||
assert len(sampler)==50 | |||
count = 0 | |||
for i in sampler: | |||
self.assertEqual(i%2, 1) | |||
assert i%2==1 | |||
count += 1 | |||
self.assertEqual(count, 50) | |||
assert count==50 | |||
dataset = TorchNormalDataset(num_of_data=101) | |||
sampler = RandomSampler(dataset, shuffle=False) | |||
sampler.set_distributed(num_replicas=2, rank=0, pad=True) | |||
self.assertEqual(len(sampler), 51) | |||
assert len(sampler)==51 | |||
count = 0 | |||
for i in sampler: | |||
self.assertEqual(i%2, 0) | |||
assert i%2==0 | |||
count += 1 | |||
self.assertEqual(count, 51) | |||
assert count == 51 | |||
sampler.set_distributed(num_replicas=2, rank=1, pad=True) | |||
self.assertEqual(len(sampler), 51) | |||
assert len(sampler) == 51 | |||
count = 0 | |||
for i in sampler: | |||
if i!=0: | |||
self.assertEqual(i%2, 1) | |||
assert i%2==1 | |||
count += 1 | |||
self.assertEqual(count, 51) | |||
assert count == 51 | |||
def test_state_dict_check_length(self): | |||
dataset = TorchNormalDataset(num_of_data=100) | |||
@@ -77,7 +73,7 @@ class TestRandomSamplerYh(unittest.TestCase): | |||
states = sampler.state_dict() | |||
new_ds = TorchNormalDataset(num_of_data=10) | |||
with self.assertRaises(AssertionError): | |||
with pytest.raises(AssertionError): | |||
new_sampler = RandomSampler(new_ds) | |||
new_sampler.load_state_dict(states) | |||
@@ -85,99 +81,107 @@ class TestRandomSamplerYh(unittest.TestCase): | |||
new_sampler = RandomSampler(new_ds) | |||
new_sampler.load_state_dict(states) | |||
def test_state_dict(self): | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('pre_shuffle', [True, False]) | |||
@pytest.mark.parametrize('post_shuffle', [True, False]) | |||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||
def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): | |||
num_samples = 100 | |||
dataset = TorchNormalDataset(num_of_data=num_samples) | |||
# 测试使用 前后shuffle不一致的load操作 | |||
lst = [0]+np.random.randint(1, num_samples, size=3).tolist() | |||
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], | |||
lst): | |||
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): | |||
sampler = RandomSampler(dataset, shuffle=pre_shuffle) | |||
sampler.set_epoch(0) | |||
already_numbers = set() | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
self.assertEqual(len(already_numbers), num_consumed_samples) | |||
states = sampler.state_dict() | |||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
for i in new_sampler: | |||
self.assertNotIn(i, already_numbers) | |||
# 测试切换成多卡也没有问题 | |||
other_rank_number = set() | |||
for rank in range(3): | |||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) | |||
new_sampler.set_epoch(0) | |||
count = 0 | |||
for i in new_sampler: | |||
self.assertNotIn(i, other_rank_number) | |||
other_rank_number.add(i) | |||
self.assertNotIn(i, already_numbers) | |||
count += 1 | |||
def test_state_dict_2(self): | |||
sampler = RandomSampler(dataset, shuffle=pre_shuffle) | |||
sampler.set_epoch(0) | |||
already_numbers = set() | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
assert len(already_numbers) == num_consumed_samples | |||
states = sampler.state_dict() | |||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
for i in new_sampler: | |||
assert i not in already_numbers | |||
# 测试切换成多卡也没有问题 | |||
other_rank_number = set() | |||
for rank in range(3): | |||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
new_sampler.set_epoch(0) | |||
count = 0 | |||
seen = 0 | |||
seen_in_other_rank = 0 | |||
for i in new_sampler: | |||
seen_in_other_rank += int(i in other_rank_number) | |||
other_rank_number.add(i) | |||
seen += int(i in already_numbers) | |||
count += 1 | |||
assert seen <= 1 if pad else seen == 0 | |||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('pre_shuffle', [True, False]) | |||
@pytest.mark.parametrize('post_shuffle', [True, False]) | |||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||
def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): | |||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||
num_samples = 100 | |||
dataset = TorchNormalDataset(num_of_data=num_samples) | |||
# 测试使用 前后shuffle不一致的load操作 | |||
lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist() | |||
# lst = [30] | |||
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], | |||
lst): | |||
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): | |||
already_numbers = set() | |||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||
sampler.set_distributed(num_replicas=2, rank=0) | |||
sampler.set_epoch(0) | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||
sampler.set_epoch(0) | |||
sampler.set_distributed(num_replicas=2, rank=1) | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
self.assertEqual(len(already_numbers), num_consumed_samples*2) | |||
states = sampler.state_dict() | |||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
for i in new_sampler: | |||
self.assertNotIn(i, already_numbers) | |||
# 测试切换成多卡也没有问题 | |||
other_rank_number = set() | |||
for rank in range(3): | |||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) | |||
count = 0 | |||
for i in new_sampler: | |||
self.assertNotIn(i, other_rank_number) | |||
other_rank_number.add(i) | |||
self.assertNotIn(i, already_numbers) | |||
count += 1 | |||
class TestRandomSampler(unittest.TestCase): | |||
already_numbers = set() | |||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||
sampler.set_distributed(num_replicas=2, rank=0) | |||
sampler.set_epoch(0) | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||
sampler.set_epoch(0) | |||
sampler.set_distributed(num_replicas=2, rank=1) | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
assert len(already_numbers) == num_consumed_samples*2 | |||
states = sampler.state_dict() | |||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
for i in new_sampler: | |||
assert i not in already_numbers | |||
# 测试切换成多卡也没有问题 | |||
other_rank_number = set() | |||
for rank in range(3): | |||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
count = 0 | |||
seen = 0 | |||
seen_in_other_rank = 0 | |||
for i in new_sampler: | |||
seen_in_other_rank += int(i in other_rank_number) | |||
other_rank_number.add(i) | |||
seen += int(i in already_numbers) | |||
count += 1 | |||
assert seen <= 1 if pad else seen == 0 | |||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
class TestRandomSampler: | |||
# 测试单卡; | |||
def test_seed_work_when_shuffle_is_true(self): | |||
data_length = 100 | |||
@@ -360,4 +364,324 @@ class TestRandomSampler(unittest.TestCase): | |||
... | |||
class DatasetWithVaryLength: | |||
def __init__(self, num_of_data=100, reverse=False): | |||
self.data = np.arange(num_of_data) | |||
if reverse: | |||
self.data = self.data[::-1] | |||
def __getitem__(self, item): | |||
return self.data[item] | |||
def __len__(self): | |||
return len(self.data) | |||
class TestSortedSampler: | |||
def test_single(self): | |||
num_of_data = 100 | |||
data = DatasetWithVaryLength(num_of_data) | |||
sampler = SortedSampler(data, length=data.data) | |||
indexes = list(sampler) | |||
assert indexes==list(range(num_of_data-1, -1, -1)) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, pad, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
samplers = [] | |||
for i in range(num_replica): | |||
sampler = SortedSampler(dataset=data, length=data.data) | |||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||
samplers.append(sampler) | |||
# 保证顺序是没乱的 | |||
already_seen_index = set() | |||
for sampler in samplers: | |||
larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。 | |||
prev_index = float('inf') | |||
cur_set = set() | |||
seen_in_other_rank = 0 | |||
for index in sampler: | |||
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 | |||
cur_set.add(index) | |||
larger_count += int(index <= prev_index) | |||
prev_index = index | |||
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 | |||
assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0 | |||
already_seen_index.update(cur_set) | |||
indexes = list(chain(*samplers)) | |||
indexes = set(indexes) | |||
if pad: | |||
assert indexes == set(range(num_of_data)) | |||
else: | |||
assert len(indexes) <= num_of_data | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||
def test_state_dict(self, pad, num_consumed_samples): | |||
num_samples = 100 | |||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
# 测试使用 前后shuffle不一致的load操作 | |||
sampler = SortedSampler(dataset, length=dataset.data) | |||
sampler.set_epoch(0) | |||
already_numbers = set() | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
if already_numbers: | |||
assert j<max(already_numbers) | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
assert len(already_numbers) == num_consumed_samples | |||
states = sampler.state_dict() | |||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
for i in new_sampler: | |||
if already_numbers: | |||
assert i < max(already_numbers) | |||
assert i not in already_numbers | |||
# 测试切换成多卡也没有问题 | |||
other_rank_number = set() | |||
for rank in range(3): | |||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
new_sampler.set_epoch(0) | |||
count = 0 | |||
seen = 0 | |||
seen_in_other_rank = 0 | |||
smaller = 0 | |||
for i in new_sampler: | |||
if already_numbers: | |||
smaller += int(i >= max(already_numbers)) | |||
seen_in_other_rank += int(i in other_rank_number) | |||
other_rank_number.add(i) | |||
seen += int(i in already_numbers) | |||
count += 1 | |||
assert seen <= 1 if pad else seen == 0 | |||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
assert smaller<=1 if pad else smaller==0 | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||
def test_state_dict_2(self, pad, num_consumed_samples): | |||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||
num_samples = 100 | |||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
# 测试使用 前后shuffle不一致的load操作 | |||
# lst = [30] | |||
already_numbers = set() | |||
sampler = SortedSampler(dataset, length=dataset.data) | |||
sampler.set_distributed(num_replicas=2, rank=0) | |||
sampler.set_epoch(0) | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
if already_numbers: | |||
assert j<=max(already_numbers) | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
sampler = SortedSampler(dataset, length=dataset.data) | |||
sampler.set_epoch(0) | |||
sampler.set_distributed(num_replicas=2, rank=1) | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
assert len(already_numbers) == num_consumed_samples*2 | |||
states = sampler.state_dict() | |||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
for i in new_sampler: | |||
if already_numbers: | |||
assert i < max(already_numbers) | |||
assert i not in already_numbers | |||
# 测试切换成多卡也没有问题 | |||
other_rank_number = set() | |||
for rank in range(3): | |||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
count = 0 | |||
seen = 0 | |||
seen_in_other_rank = 0 | |||
smaller = 0 | |||
for i in new_sampler: | |||
if already_numbers: | |||
smaller += int(i>=max(already_numbers)) | |||
seen_in_other_rank += int(i in other_rank_number) | |||
other_rank_number.add(i) | |||
seen += int(i in already_numbers) | |||
count += 1 | |||
assert seen <= 1 if pad else seen == 0 | |||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
assert smaller <= 1 if pad else smaller == 0 | |||
class TestSequentialSampler: | |||
def test_single(self): | |||
num_of_data = 100 | |||
data = DatasetWithVaryLength(num_of_data) | |||
sampler = SequentialSampler(data) | |||
indexes = list(sampler) | |||
assert indexes==list(range(num_of_data)) | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, pad, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
samplers = [] | |||
for i in range(num_replica): | |||
sampler = SequentialSampler(dataset=data) | |||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||
samplers.append(sampler) | |||
# 保证顺序是没乱的 | |||
already_seen_index = set() | |||
for idx, sampler in enumerate(samplers): | |||
larger_count = 1 | |||
prev_index = float('inf') | |||
cur_set = set() | |||
seen_in_other_rank = 0 | |||
for index in sampler: | |||
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 | |||
cur_set.add(index) | |||
larger_count += int(index >= prev_index) | |||
prev_index = index | |||
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 | |||
assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0 | |||
already_seen_index.update(cur_set) | |||
indexes = list(chain(*samplers)) | |||
indexes = set(indexes) | |||
if pad: | |||
assert indexes == set(range(num_of_data)) | |||
else: | |||
assert len(indexes) <= num_of_data | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||
def test_state_dict(self, pad, num_consumed_samples): | |||
num_samples = 100 | |||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
# 测试使用 前后shuffle不一致的load操作 | |||
sampler = SequentialSampler(dataset=dataset) | |||
sampler.set_epoch(0) | |||
already_numbers = set() | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
if already_numbers: | |||
assert j>max(already_numbers) | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
assert len(already_numbers) == num_consumed_samples | |||
states = sampler.state_dict() | |||
new_sampler = SequentialSampler(dataset=dataset) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
for i in new_sampler: | |||
if already_numbers: | |||
assert i > max(already_numbers) | |||
assert i not in already_numbers | |||
# 测试切换成多卡也没有问题 | |||
other_rank_number = set() | |||
for rank in range(3): | |||
new_sampler = SequentialSampler(dataset=dataset) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
new_sampler.set_epoch(0) | |||
count = 0 | |||
seen = 0 | |||
seen_in_other_rank = 0 | |||
smaller = 0 | |||
for i in new_sampler: | |||
if already_numbers: | |||
smaller += int(i <= max(already_numbers)) | |||
seen_in_other_rank += int(i in other_rank_number) | |||
other_rank_number.add(i) | |||
seen += int(i in already_numbers) | |||
count += 1 | |||
assert seen <= 1 if pad else seen == 0 | |||
assert seen_in_other_rank<=rank # 因为pad可能重复 | |||
assert smaller<=1 if pad else smaller==0 | |||
@pytest.mark.parametrize('pad', [True, False]) | |||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||
def test_state_dict_2(self, pad, num_consumed_samples): | |||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||
num_samples = 100 | |||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||
# 测试使用 前后shuffle不一致的load操作 | |||
# lst = [30] | |||
already_numbers = set() | |||
sampler = SequentialSampler(dataset=dataset) | |||
sampler.set_distributed(num_replicas=2, rank=0) | |||
sampler.set_epoch(0) | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
if already_numbers: | |||
assert j>max(already_numbers) | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
sampler = SequentialSampler(dataset=dataset) | |||
sampler.set_epoch(0) | |||
sampler.set_distributed(num_replicas=2, rank=1) | |||
if num_consumed_samples>0: | |||
for i, j in enumerate(sampler, start=1): | |||
already_numbers.add(j) | |||
if i == num_consumed_samples: | |||
break | |||
assert len(already_numbers) == num_consumed_samples*2 | |||
states = sampler.state_dict() | |||
new_sampler = SequentialSampler(dataset=dataset) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
for i in new_sampler: | |||
if already_numbers: | |||
assert i > max(already_numbers) | |||
assert i not in already_numbers | |||
# 测试切换成多卡也没有问题 | |||
other_rank_number = set() | |||
for rank in range(3): | |||
new_sampler = SequentialSampler(dataset=dataset) | |||
new_sampler.load_state_dict(states) | |||
new_sampler.set_epoch(0) | |||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||
count = 0 | |||
seen = 0 | |||
seen_in_other_rank = 0 | |||
smaller = 0 | |||
for i in new_sampler: | |||
if already_numbers: | |||
smaller += int(i<max(already_numbers)) | |||
seen_in_other_rank += int(i in other_rank_number) | |||
other_rank_number.add(i) | |||
seen += int(i in already_numbers) | |||
count += 1 | |||
assert seen <= 1 if pad else seen == 0 | |||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||
assert smaller <= rank if pad else smaller == 0 | |||
@@ -2,7 +2,7 @@ from itertools import chain | |||
import pytest | |||
from fastNLP.core.samplers import UnrepeatedSampler, UnrepeatedSortedSampler | |||
from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||
class DatasetWithVaryLength: | |||
@@ -21,7 +21,7 @@ class TestUnrepeatedSampler: | |||
def test_single(self, shuffle): | |||
num_of_data = 100 | |||
data = DatasetWithVaryLength(num_of_data) | |||
sampler = UnrepeatedSampler(data, shuffle) | |||
sampler = UnrepeatedRandomSampler(data, shuffle) | |||
indexes = set(sampler) | |||
assert indexes==set(range(num_of_data)) | |||
@@ -32,17 +32,18 @@ class TestUnrepeatedSampler: | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
samplers = [] | |||
for i in range(num_replica): | |||
sampler = UnrepeatedSampler(dataset=data, shuffle=shuffle) | |||
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) | |||
sampler.set_distributed(num_replica, rank=i) | |||
samplers.append(sampler) | |||
indexes = set(chain(*samplers)) | |||
indexes = list(chain(*samplers)) | |||
assert len(indexes) == num_of_data | |||
indexes = set(indexes) | |||
assert indexes==set(range(num_of_data)) | |||
class TestUnrepeatedSortedSampler: | |||
@pytest.mark.parametrize('shuffle', [True, False]) | |||
def test_single(self, shuffle): | |||
def test_single(self): | |||
num_of_data = 100 | |||
data = DatasetWithVaryLength(num_of_data) | |||
sampler = UnrepeatedSortedSampler(data, length=data.data) | |||
@@ -51,8 +52,7 @@ class TestUnrepeatedSortedSampler: | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
@pytest.mark.parametrize('shuffle', [False, True]) | |||
def test_multi(self, num_replica, num_of_data, shuffle): | |||
def test_multi(self, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
samplers = [] | |||
for i in range(num_replica): | |||
@@ -60,5 +60,45 @@ class TestUnrepeatedSortedSampler: | |||
sampler.set_distributed(num_replica, rank=i) | |||
samplers.append(sampler) | |||
indexes = set(chain(*samplers)) | |||
# 保证顺序是没乱的 | |||
for sampler in samplers: | |||
prev_index = float('inf') | |||
for index in sampler: | |||
assert index <= prev_index | |||
prev_index = index | |||
indexes = list(chain(*samplers)) | |||
assert len(indexes) == num_of_data # 不同卡之间没有交叉 | |||
indexes = set(indexes) | |||
assert indexes==set(range(num_of_data)) | |||
class TestUnrepeatedSequentialSampler: | |||
def test_single(self): | |||
num_of_data = 100 | |||
data = DatasetWithVaryLength(num_of_data) | |||
sampler = UnrepeatedSequentialSampler(data, length=data.data) | |||
indexes = list(sampler) | |||
assert indexes==list(range(num_of_data)) | |||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||
def test_multi(self, num_replica, num_of_data): | |||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||
samplers = [] | |||
for i in range(num_replica): | |||
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) | |||
sampler.set_distributed(num_replica, rank=i) | |||
samplers.append(sampler) | |||
# 保证顺序是没乱的 | |||
for sampler in samplers: | |||
prev_index = float('-inf') | |||
for index in sampler: | |||
assert index>=prev_index | |||
prev_index = index | |||
indexes = list(chain(*samplers)) | |||
assert len(indexes) == num_of_data | |||
indexes = set(indexes) | |||
assert indexes == set(range(num_of_data)) |