@@ -23,7 +23,7 @@ from fastNLP.core.drivers import Driver | |||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | ||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | from fastNLP.envs import FASTNLP_MODEL_FILENAME | ||||
@@ -610,7 +610,7 @@ class Trainer(TrainerEventTrigger): | |||||
r""" | r""" | ||||
用于断点重训的加载函数; | 用于断点重训的加载函数; | ||||
注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 | 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 | ||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; | |||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; | |||||
注意我们目前不支持单卡到多卡的断点重训; | 注意我们目前不支持单卡到多卡的断点重训; | ||||
@@ -49,13 +49,13 @@ class Driver(ABC): | |||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | ||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | ||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | ||||
注意当 dist 为 ReproducibleIterator, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
注意当 dist 为 ReproducibleIterator, RandomBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | ||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | ||||
可以可以加载。 | 可以可以加载。 | ||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | ||||
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的 | |||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | ||||
""" | """ | ||||
if dist is None and reproducible is False: | if dist is None and reproducible is False: | ||||
@@ -3,7 +3,7 @@ from typing import Optional, Union | |||||
from .jittor_driver import JittorDriver | from .jittor_driver import JittorDriver | ||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
from fastNLP.core.samplers import ReproducibleIterator | |||||
from fastNLP.core.samplers import ReproducibleSampler | |||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor | import jittor | ||||
@@ -70,7 +70,7 @@ class JittorMPIDriver(JittorDriver): | |||||
def test_step(self, batch): | def test_step(self, batch): | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
pass | pass | ||||
@@ -3,7 +3,7 @@ from typing import Dict, Union | |||||
from .jittor_driver import JittorDriver | from .jittor_driver import JittorDriver | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler | |||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor | import jittor | ||||
@@ -99,25 +99,25 @@ class JittorSingleDriver(JittorDriver): | |||||
def is_distributed(self): | def is_distributed(self): | ||||
return False | return False | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# reproducible 的相关功能暂时没有实现 | # reproducible 的相关功能暂时没有实现 | ||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
if isinstance(dist, RandomBatchSampler): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
dataloader.batch_sampler = dist_sample | dataloader.batch_sampler = dist_sample | ||||
if isinstance(dist, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleSampler): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
dataloader.batch_sampler.sampler = dist | dataloader.batch_sampler.sampler = dist | ||||
if reproducible: | if reproducible: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
return dataloader | return dataloader | ||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||||
return dataloader | return dataloader | ||||
else: | else: | ||||
# TODO | # TODO | ||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=dataloader.batch_sampler, | batch_sampler=dataloader.batch_sampler, | ||||
batch_size=dataloader.batch_sampler.batch_size, | batch_size=dataloader.batch_sampler.batch_size, | ||||
drop_last=dataloader.drop_last | drop_last=dataloader.drop_last | ||||
@@ -19,7 +19,7 @@ from fastNLP.core.utils import ( | |||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
is_in_paddle_dist, | is_in_paddle_dist, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler | |||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler | |||||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -312,13 +312,13 @@ class PaddleFleetDriver(PaddleDriver): | |||||
def test_step(self, batch): | def test_step(self, batch): | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# 暂时不支持iterableDataset | # 暂时不支持iterableDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
if isinstance(dist, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler = dist | dataloader.batch_sampler.sampler = dist | ||||
return dataloader | return dataloader | ||||
@@ -340,7 +340,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# trainer | # trainer | ||||
elif dist == "dist": | elif dist == "dist": | ||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank, | rank=self.global_rank, | ||||
@@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return dataloader | return dataloader | ||||
# evaluator | # evaluator | ||||
elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
sampler = UnrepeatedSampler( | |||||
sampler = UnrepeatedRandomSampler( | |||||
dataset=dataloader.dataset, | dataset=dataloader.dataset, | ||||
shuffle=shuffle, | shuffle=shuffle, | ||||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | seed=int(os.environ.get("FASTNLP_SEED", 0)) | ||||
@@ -10,7 +10,7 @@ from fastNLP.core.utils import ( | |||||
get_paddle_device_id, | get_paddle_device_id, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -139,26 +139,26 @@ class PaddleSingleDriver(PaddleDriver): | |||||
""" | """ | ||||
return paddle_move_data_to_device(batch, "gpu:0") | return paddle_move_data_to_device(batch, "gpu:0") | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# 暂时不支持IteratorDataset | # 暂时不支持IteratorDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
if isinstance(dist, RandomBatchSampler): | |||||
dataloader.batch_sampler = dist | dataloader.batch_sampler = dist | ||||
return dataloader | return dataloader | ||||
if isinstance(dist, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dataloader.batch_sampler.sampler = dist | dataloader.batch_sampler.sampler = dist | ||||
return dataloader | return dataloader | ||||
if reproducible: | if reproducible: | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||||
return dataloader | return dataloader | ||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||||
return dataloader | return dataloader | ||||
else: | else: | ||||
# TODO | # TODO | ||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=dataloader.batch_sampler, | batch_sampler=dataloader.batch_sampler, | ||||
batch_size=dataloader.batch_sampler.batch_size, | batch_size=dataloader.batch_sampler.batch_size, | ||||
drop_last=dataloader.drop_last | drop_last=dataloader.drop_last | ||||
@@ -28,11 +28,11 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||||
) | ) | ||||
from fastNLP.core.drivers.utils import distributed_open_proc | from fastNLP.core.drivers.utils import distributed_open_proc | ||||
from fastNLP.core.utils import auto_param_call, check_user_specific_params | from fastNLP.core.utils import auto_param_call, check_user_specific_params | ||||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \ | |||||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | |||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object | ||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
class TorchDDPDriver(TorchDriver): | class TorchDDPDriver(TorchDriver): | ||||
@@ -446,13 +446,23 @@ class TorchDDPDriver(TorchDriver): | |||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None, | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | ||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
if isinstance(dist, RandomBatchSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
if isinstance(dist, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, dist) | return replace_sampler(dataloader, dist) | ||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | ||||
@@ -462,10 +472,10 @@ class TorchDDPDriver(TorchDriver): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | ||||
"control.") | "control.") | ||||
else: | else: | ||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
if isinstance(dist, RandomBatchSampler): | |||||
dist = re_instantiate_sampler(dist) | dist = re_instantiate_sampler(dist) | ||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
if isinstance(dist, ReproducibleIterator): | |||||
if isinstance(dist, ReproducibleSampler): | |||||
dist = re_instantiate_sampler(dist) | dist = re_instantiate_sampler(dist) | ||||
return replace_sampler(dataloader, dist) | return replace_sampler(dataloader, dist) | ||||
return dataloader | return dataloader | ||||
@@ -473,7 +483,7 @@ class TorchDDPDriver(TorchDriver): | |||||
elif dist == "dist": | elif dist == "dist": | ||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | ||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
if isinstance(args.batch_sampler, RandomBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | batch_sampler = re_instantiate_sampler(args.batch_sampler) | ||||
batch_sampler.set_distributed( | batch_sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
@@ -481,7 +491,7 @@ class TorchDDPDriver(TorchDriver): | |||||
pad=True | pad=True | ||||
) | ) | ||||
return replace_batch_sampler(dataloader, batch_sampler) | return replace_batch_sampler(dataloader, batch_sampler) | ||||
elif isinstance(args.sampler, ReproducibleIterator): | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | sampler = re_instantiate_sampler(args.sampler) | ||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
@@ -503,14 +513,15 @@ class TorchDDPDriver(TorchDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
# evaluator | # evaluator | ||||
elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
# todo @yh,补充 unrepeatdist 相关内容; | |||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
# todo 判断 batch_sampler; | |||||
sampler = UnrepeatedSampler( | |||||
dataset=args.dataset, | |||||
shuffle=args.shuffle, | |||||
) | |||||
if isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | |||||
elif not isinstance(args.sampler, UnrepeatedSampler): | |||||
sampler = UnrepeatedSequentialSampler( | |||||
dataset=args.dataset | |||||
) | |||||
else: | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
sampler.set_distributed( | sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank | rank=self.global_rank | ||||
@@ -13,9 +13,8 @@ __all__ = [ | |||||
from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
class TorchSingleDriver(TorchDriver): | class TorchSingleDriver(TorchDriver): | ||||
@@ -130,26 +129,26 @@ class TorchSingleDriver(TorchDriver): | |||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None, | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
if isinstance(dist, RandomBatchSampler): | |||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
elif isinstance(dist, ReproducibleIterator): | |||||
elif isinstance(dist, ReproducibleSampler): | |||||
return replace_sampler(dataloader, dist) | return replace_sampler(dataloader, dist) | ||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | ||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
if isinstance(args.batch_sampler, RandomBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | batch_sampler = re_instantiate_sampler(args.batch_sampler) | ||||
return replace_batch_sampler(dataloader, batch_sampler) | return replace_batch_sampler(dataloader, batch_sampler) | ||||
elif isinstance(args.sampler, ReproducibleIterator): | |||||
elif isinstance(args.sampler, ReproducibleSampler): | |||||
sampler = re_instantiate_sampler(args.sampler) | sampler = re_instantiate_sampler(args.sampler) | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
if reproducible: | if reproducible: | ||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
drop_last=args.drop_last | drop_last=args.drop_last | ||||
@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -182,10 +182,10 @@ class TorchDriver(Driver): | |||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | ||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | ||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 | |||||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 | |||||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | dataloader_args = self.get_dataloader_args(dataloader) | ||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
elif dataloader_args.sampler: | elif dataloader_args.sampler: | ||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
@@ -245,15 +245,15 @@ class TorchDriver(Driver): | |||||
# 3. 恢复 sampler 的状态; | # 3. 恢复 sampler 的状态; | ||||
dataloader_args = self.get_dataloader_args(dataloader) | dataloader_args = self.get_dataloader_args(dataloader) | ||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
elif isinstance(dataloader_args.sampler, ReproducibleIterator): | elif isinstance(dataloader_args.sampler, ReproducibleIterator): | ||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
elif self.is_distributed(): | elif self.is_distributed(): | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " | ||||
"`ReproducibleBatchSampler` or `ReproducibleIterator`.") | |||||
"`RandomBatchSampler` or `ReproducibleIterator`.") | |||||
else: | else: | ||||
sampler = ReproducibleBatchSampler( | |||||
sampler = RandomBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
@@ -263,7 +263,7 @@ class TorchDriver(Driver): | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | # 4. 修改 trainer_state.batch_idx_in_epoch | ||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | ||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if not isinstance(sampler, RandomBatchSampler): | |||||
if dataloader_args.drop_last: | if dataloader_args.drop_last: | ||||
batch_idx_in_epoch = len( | batch_idx_in_epoch = len( | ||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | ||||
@@ -291,7 +291,7 @@ class TorchDriver(Driver): | |||||
@staticmethod | @staticmethod | ||||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | ||||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed | |||||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed | |||||
with ``seed_everything(seed, workers=True)``. | with ``seed_everything(seed, workers=True)``. | ||||
See also the PyTorch documentation on | See also the PyTorch documentation on | ||||
@@ -9,18 +9,24 @@ __all__ = [ | |||||
'MixSequentialSampler', | 'MixSequentialSampler', | ||||
'PollingSampler', | 'PollingSampler', | ||||
'ReproducibleIterator', | |||||
'ReproducibleSampler', | |||||
'RandomSampler', | 'RandomSampler', | ||||
're_instantiate_sampler', | |||||
"SequentialSampler", | |||||
"SortedSampler", | |||||
'UnrepeatedSampler', | 'UnrepeatedSampler', | ||||
"UnrepeatedSortedSampler" | |||||
'UnrepeatedRandomSampler', | |||||
"UnrepeatedSortedSampler", | |||||
"UnrepeatedSequentialSampler", | |||||
"re_instantiate_sampler", | |||||
"conversion_between_reproducible_and_unrepeated_sampler" | |||||
] | ] | ||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler | ||||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler | |||||
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | ||||
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler | |||||
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler | |||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | |||||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler | |||||
@@ -1,6 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'BucketedBatchSampler', | 'BucketedBatchSampler', | ||||
"ReproducibleBatchSampler" | |||||
"RandomBatchSampler" | |||||
] | ] | ||||
import math | import math | ||||
@@ -16,7 +16,7 @@ from fastNLP.core.log import logger | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
class ReproducibleBatchIterator: | |||||
class ReproducibleBatchSampler: | |||||
@abstractmethod | @abstractmethod | ||||
def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | ||||
@@ -42,13 +42,13 @@ class ReproducibleBatchIterator: | |||||
pass | pass | ||||
class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||||
class RandomBatchSampler(ReproducibleBatchSampler): | |||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | ||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | ||||
""" | """ | ||||
可以使得 batch_sampler 对象状态恢复的 wrapper 。 | 可以使得 batch_sampler 对象状态恢复的 wrapper 。 | ||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | ||||
:param batch_size: 每个 batch 的大小是多少。 | :param batch_size: 每个 batch 的大小是多少。 | ||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | ||||
@@ -138,7 +138,7 @@ class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||||
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | ||||
class BucketedBatchSampler(ReproducibleBatchIterator): | |||||
class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | ||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
@@ -1,24 +1,21 @@ | |||||
from typing import Dict, List | |||||
from typing import Dict, List, Union | |||||
import math | import math | ||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.dataset import DataSet | |||||
__all__ = [ | __all__ = [ | ||||
'ReproducibleIterator', | |||||
'ReproducibleSampler', | |||||
'RandomSampler', | 'RandomSampler', | ||||
're_instantiate_sampler' | |||||
"SortedSampler", | |||||
"SequentialSampler" | |||||
] | ] | ||||
def re_instantiate_sampler(sampler): | |||||
all_attributes = vars(sampler) | |||||
return type(sampler)(**all_attributes) | |||||
class ReproducibleIterator: | |||||
class ReproducibleSampler: | |||||
""" | """ | ||||
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||||
注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | |||||
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | ||||
""" | """ | ||||
@@ -46,7 +43,7 @@ class ReproducibleIterator: | |||||
pass | pass | ||||
class RandomSampler(ReproducibleIterator): | |||||
class RandomSampler(ReproducibleSampler): | |||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
@@ -156,8 +153,8 @@ class RandomSampler(ReproducibleIterator): | |||||
f"we cannot use {self.__class__.__name__} to load it." | f"we cannot use {self.__class__.__name__} to load it." | ||||
length = states['length'] | length = states['length'] | ||||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||||
"and current dataset." | |||||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
f"and current dataset({len(self.dataset)})." | |||||
self.seed = states['seed'] | self.seed = states['seed'] | ||||
self.epoch = states['epoch'] | self.epoch = states['epoch'] | ||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
@@ -214,9 +211,132 @@ class RandomSampler(ReproducibleIterator): | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | ||||
class SequentialSampler(RandomSampler): | |||||
def __init__(self, dataset, dist_mode:str='interval', **kwargs): | |||||
""" | |||||
按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param kwargs: | |||||
""" | |||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||||
def __iter__(self): | |||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||||
self.num_consumed_samples = 0 | |||||
self.during_iter = True | |||||
indices = self.generate_indices() | |||||
if self.pad: | |||||
# add extra samples to make it evenly divisible | |||||
padding_size = self.total_size - len(indices) | |||||
if padding_size <= len(indices): | |||||
indices += indices[:padding_size] | |||||
else: | |||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |||||
else: | |||||
# remove tail of data to make it evenly divisible. | |||||
indices = indices[:self.total_size] | |||||
assert len(indices) == self.total_size | |||||
# subsample | |||||
indices = indices[self.num_consumed_samples:] | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
assert len(indices) == self.num_left_samples | |||||
for index in indices: | |||||
self.num_consumed_samples += self.num_replicas | |||||
yield index | |||||
self.during_iter = False | |||||
self.num_consumed_samples = 0 | |||||
def generate_indices(self) -> List[int]: | |||||
""" | |||||
生成随机序列 | |||||
:return: | |||||
""" | |||||
return list(range(len(self.dataset))) | |||||
def state_dict(self) -> Dict: | |||||
states = { | |||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | |||||
'sampler_type': self.__class__.__name__, | |||||
'length': len(self.dataset), | |||||
} | |||||
return states | |||||
def load_state_dict(self, states: Dict): | |||||
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||||
"during an unfinished iteration." | |||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||||
f"we cannot use {self.__class__.__name__} to load it." | |||||
length = states['length'] | |||||
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ | |||||
f"and current dataset({len(self.dataset)})." | |||||
self.num_consumed_samples = states['num_consumed_samples'] | |||||
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
self.num_consumed_samples = 0 | |||||
class SortedSampler(SequentialSampler): | |||||
def __init__(self, dataset, length:Union[str, List], **kwargs): | |||||
""" | |||||
将 dataset 中的数据根据 length 从长到短进行迭代。在多卡情况下,由于padding 最后一个 sample 可能是最长的那个 sample。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
super().__init__(dataset=dataset, **kwargs) | |||||
if isinstance(dataset, DataSet): | |||||
length = dataset.get_field(length) | |||||
if not isinstance(length[0], int): | |||||
length = list(map(len, length)) | |||||
else: | |||||
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||||
"the length parameter can only be List[int]" | |||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||||
def generate_indices(self) -> List[int]: | |||||
return self.sorted_indices | |||||
def __iter__(self): | |||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||||
self.num_consumed_samples = 0 | |||||
self.during_iter = True | |||||
indices = self.generate_indices() | |||||
if self.pad: | |||||
padding_size = self.total_size - len(indices) | |||||
if padding_size <= len(indices): | |||||
indices += indices[:padding_size] | |||||
else: | |||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |||||
else: | |||||
# remove tail of data to make it evenly divisible. | |||||
indices = indices[:self.total_size] | |||||
assert len(indices) == self.total_size | |||||
# subsample | |||||
indices = indices[self.num_consumed_samples:] | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
assert len(indices) == self.num_left_samples | |||||
for index in indices: | |||||
self.num_consumed_samples += self.num_replicas | |||||
yield index | |||||
self.during_iter = False | |||||
self.num_consumed_samples = 0 | |||||
@@ -1,6 +1,8 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'UnrepeatedSampler', | |||||
'UnrepeatedSortedSampler', | 'UnrepeatedSortedSampler', | ||||
'UnrepeatedSampler' | |||||
'UnrepeatedRandomSampler', | |||||
"UnrepeatedSequentialSampler" | |||||
] | ] | ||||
from typing import List, Union | from typing import List, Union | ||||
@@ -10,13 +12,21 @@ import numpy as np | |||||
class UnrepeatedSampler: | class UnrepeatedSampler: | ||||
""" | |||||
在多卡场景下保证 indice 不重复的 sampler | |||||
""" | |||||
pass | |||||
class UnrepeatedRandomSampler(UnrepeatedSampler): | |||||
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
考虑在多卡evaluate的场景下,不能重复sample。 | 考虑在多卡evaluate的场景下,不能重复sample。 | ||||
:param dataset: | |||||
:param shuffle: | |||||
:param seed: | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | """ | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
@@ -33,8 +43,8 @@ class UnrepeatedSampler: | |||||
:return: | :return: | ||||
""" | """ | ||||
num_common = len(self.dataset)//self.num_replicas | num_common = len(self.dataset)//self.num_replicas | ||||
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||||
return self.num_samples | |||||
num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) | |||||
return num_samples | |||||
def __iter__(self): | def __iter__(self): | ||||
indices = self.generate_indices() | indices = self.generate_indices() | ||||
@@ -83,8 +93,8 @@ class UnrepeatedSampler: | |||||
return self | return self | ||||
class UnrepeatedSortedSampler(UnrepeatedSampler): | |||||
def __init__(self, dataset, length:Union[str, List], seed: int = 0): | |||||
class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||||
def __init__(self, dataset, length:Union[str, List], **kwargs): | |||||
""" | """ | ||||
将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 | 将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 | ||||
batch 数量不完全一致。 | batch 数量不完全一致。 | ||||
@@ -92,11 +102,9 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | :param dataset: 实现了 __len__ 方法的数据容器。 | ||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | ||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | ||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
super().__init__(dataset=dataset, shuffle=False, seed=seed) | |||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||||
if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
length = dataset.get_field(length) | length = dataset.get_field(length) | ||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
@@ -107,8 +115,29 @@ class UnrepeatedSortedSampler(UnrepeatedSampler): | |||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | ||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||||
length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||||
self.sorted_indices = np.argsort(length)[::-1].tolist() # 按长度从高到低排序的 | |||||
def generate_indices(self) -> List[int]: | def generate_indices(self) -> List[int]: | ||||
return self.sorted_indices | return self.sorted_indices | ||||
class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): | |||||
def __init__(self, dataset, **kwargs): | |||||
""" | |||||
按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param kwargs: | |||||
""" | |||||
super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) | |||||
def __iter__(self): | |||||
indices = self.generate_indices() | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
for index in indices: | |||||
yield index | |||||
def generate_indices(self) -> List[int]: | |||||
return list(range(len(self.dataset))) | |||||
@@ -0,0 +1,42 @@ | |||||
__all__ = [ | |||||
're_instantiate_sampler', | |||||
'conversion_between_reproducible_and_unrepeated_sampler' | |||||
] | |||||
from fastNLP.core.samplers.unrepeated_sampler import * | |||||
from fastNLP.core.samplers.reproducible_sampler import * | |||||
def conversion_between_reproducible_and_unrepeated_sampler(sampler): | |||||
""" | |||||
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 | |||||
ReproducibleSampler, | |||||
:param sampler: | |||||
:return: | |||||
""" | |||||
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ | |||||
"The sampler must be UnrepeatedSampler or ReproducibleSampler" | |||||
if isinstance(sampler, UnrepeatedSampler): | |||||
if isinstance(sampler, UnrepeatedRandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) | |||||
elif isinstance(sampler, UnrepeatedSequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) | |||||
elif isinstance(sampler, UnrepeatedSortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no unrepeated version.") | |||||
else: | |||||
if isinstance(sampler, RandomSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) | |||||
elif isinstance(sampler, SequentialSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) | |||||
elif isinstance(sampler, SortedSampler): | |||||
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) | |||||
raise TypeError(f"{sampler.__class__} has no reproducible version.") | |||||
def re_instantiate_sampler(sampler, new_sampler_class=None): | |||||
all_attributes = vars(sampler) | |||||
if new_sampler_class is not None: | |||||
return new_sampler_class(**all_attributes) | |||||
return type(sampler)(**all_attributes) |
@@ -10,7 +10,7 @@ from paddle.io import DataLoader, BatchSampler | |||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | ||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | from fastNLP.core.samplers.reproducible_sampler import RandomSampler | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
from fastNLP.core.samplers import RandomBatchSampler | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | ||||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | ||||
from fastNLP.core import synchronize_safe_rm | from fastNLP.core import synchronize_safe_rm | ||||
@@ -153,7 +153,7 @@ class TestSingleDeviceFunction: | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"dist_sampler", | "dist_sampler", | ||||
["dist", ReproducibleBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||||
["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"reproducible", | "reproducible", | ||||
@@ -30,7 +30,7 @@ class SequenceDataSet: | |||||
def check_replace_sampler(driver): | def check_replace_sampler(driver): | ||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler | |||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler | |||||
# reproducible 是 True 和 False | # reproducible 是 True 和 False | ||||
# 需要 check 返回的 sampler 和 dataloader 都不同了 | # 需要 check 返回的 sampler 和 dataloader 都不同了 | ||||
@@ -4,7 +4,7 @@ import numpy as np | |||||
import pytest | import pytest | ||||
from itertools import chain | from itertools import chain | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler | |||||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -18,7 +18,7 @@ class TestReproducibleBatchSampler: | |||||
before_batch_size = 7 | before_batch_size = 7 | ||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | dataloader = DataLoader(dataset, batch_size=before_batch_size) | ||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
forward_steps = 3 | forward_steps = 3 | ||||
@@ -28,15 +28,15 @@ class TestReproducibleBatchSampler: | |||||
# 1. 保存状态 | # 1. 保存状态 | ||||
_get_re_batchsampler = dataloader.batch_sampler | _get_re_batchsampler = dataloader.batch_sampler | ||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | state = _get_re_batchsampler.state_dict() | ||||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | ||||
"sampler_type": "ReproducibleBatchSampler"} | |||||
"sampler_type": "RandomBatchSampler"} | |||||
# 2. 断点重训,重新生成一个 dataloader; | # 2. 断点重训,重新生成一个 dataloader; | ||||
# 不改变 batch_size; | # 不改变 batch_size; | ||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | dataloader = DataLoader(dataset, batch_size=before_batch_size) | ||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | re_batchsampler.load_state_dict(state) | ||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
@@ -53,7 +53,7 @@ class TestReproducibleBatchSampler: | |||||
# 改变 batch_size; | # 改变 batch_size; | ||||
after_batch_size = 3 | after_batch_size = 3 | ||||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | dataloader = DataLoader(dataset, batch_size=after_batch_size) | ||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | re_batchsampler.load_state_dict(state) | ||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
@@ -99,7 +99,7 @@ class TestReproducibleBatchSampler: | |||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | ||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | ||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | # 将一轮的所有数据保存下来,看是否恢复的是正确的; | ||||
@@ -111,13 +111,13 @@ class TestReproducibleBatchSampler: | |||||
# 1. 保存状态 | # 1. 保存状态 | ||||
_get_re_batchsampler = dataloader.batch_sampler | _get_re_batchsampler = dataloader.batch_sampler | ||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | state = _get_re_batchsampler.state_dict() | ||||
# 2. 断点重训,重新生成一个 dataloader; | # 2. 断点重训,重新生成一个 dataloader; | ||||
# 不改变 batch_size; | # 不改变 batch_size; | ||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | ||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | re_batchsampler.load_state_dict(state) | ||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | dataloader = replace_batch_sampler(dataloader, re_batchsampler) | ||||
@@ -1,18 +1,14 @@ | |||||
import unittest | |||||
from itertools import product | |||||
import numpy as np | import numpy as np | ||||
import pytest | |||||
from functools import partial | from functools import partial | ||||
from array import array | |||||
from itertools import chain | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
class TestRandomSamplerYh(unittest.TestCase): | |||||
class TestRandomSamplerYh: | |||||
def test_init(self): | def test_init(self): | ||||
# 测试能否正确初始化 | # 测试能否正确初始化 | ||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
@@ -24,7 +20,7 @@ class TestRandomSamplerYh(unittest.TestCase): | |||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
sampler = RandomSampler(dataset) | sampler = RandomSampler(dataset) | ||||
for i in sampler: | for i in sampler: | ||||
with self.assertRaises(AssertionError): | |||||
with pytest.raises(AssertionError): | |||||
sampler.set_distributed(1, 0) | sampler.set_distributed(1, 0) | ||||
break | break | ||||
@@ -37,39 +33,39 @@ class TestRandomSamplerYh(unittest.TestCase): | |||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
sampler = RandomSampler(dataset, shuffle=False) | sampler = RandomSampler(dataset, shuffle=False) | ||||
sampler.set_distributed(num_replicas=2, rank=0, pad=False) | sampler.set_distributed(num_replicas=2, rank=0, pad=False) | ||||
self.assertEqual(len(sampler), 50) | |||||
assert len(sampler)==50 | |||||
count = 0 | count = 0 | ||||
for i in sampler: | for i in sampler: | ||||
self.assertEqual(i%2, 0) | |||||
assert i%2==0 | |||||
count += 1 | count += 1 | ||||
self.assertEqual(count, 50) | |||||
assert count == 50 | |||||
sampler.set_distributed(num_replicas=2, rank=1, pad=False) | sampler.set_distributed(num_replicas=2, rank=1, pad=False) | ||||
self.assertEqual(len(sampler), 50) | |||||
assert len(sampler)==50 | |||||
count = 0 | count = 0 | ||||
for i in sampler: | for i in sampler: | ||||
self.assertEqual(i%2, 1) | |||||
assert i%2==1 | |||||
count += 1 | count += 1 | ||||
self.assertEqual(count, 50) | |||||
assert count==50 | |||||
dataset = TorchNormalDataset(num_of_data=101) | dataset = TorchNormalDataset(num_of_data=101) | ||||
sampler = RandomSampler(dataset, shuffle=False) | sampler = RandomSampler(dataset, shuffle=False) | ||||
sampler.set_distributed(num_replicas=2, rank=0, pad=True) | sampler.set_distributed(num_replicas=2, rank=0, pad=True) | ||||
self.assertEqual(len(sampler), 51) | |||||
assert len(sampler)==51 | |||||
count = 0 | count = 0 | ||||
for i in sampler: | for i in sampler: | ||||
self.assertEqual(i%2, 0) | |||||
assert i%2==0 | |||||
count += 1 | count += 1 | ||||
self.assertEqual(count, 51) | |||||
assert count == 51 | |||||
sampler.set_distributed(num_replicas=2, rank=1, pad=True) | sampler.set_distributed(num_replicas=2, rank=1, pad=True) | ||||
self.assertEqual(len(sampler), 51) | |||||
assert len(sampler) == 51 | |||||
count = 0 | count = 0 | ||||
for i in sampler: | for i in sampler: | ||||
if i!=0: | if i!=0: | ||||
self.assertEqual(i%2, 1) | |||||
assert i%2==1 | |||||
count += 1 | count += 1 | ||||
self.assertEqual(count, 51) | |||||
assert count == 51 | |||||
def test_state_dict_check_length(self): | def test_state_dict_check_length(self): | ||||
dataset = TorchNormalDataset(num_of_data=100) | dataset = TorchNormalDataset(num_of_data=100) | ||||
@@ -77,7 +73,7 @@ class TestRandomSamplerYh(unittest.TestCase): | |||||
states = sampler.state_dict() | states = sampler.state_dict() | ||||
new_ds = TorchNormalDataset(num_of_data=10) | new_ds = TorchNormalDataset(num_of_data=10) | ||||
with self.assertRaises(AssertionError): | |||||
with pytest.raises(AssertionError): | |||||
new_sampler = RandomSampler(new_ds) | new_sampler = RandomSampler(new_ds) | ||||
new_sampler.load_state_dict(states) | new_sampler.load_state_dict(states) | ||||
@@ -85,99 +81,107 @@ class TestRandomSamplerYh(unittest.TestCase): | |||||
new_sampler = RandomSampler(new_ds) | new_sampler = RandomSampler(new_ds) | ||||
new_sampler.load_state_dict(states) | new_sampler.load_state_dict(states) | ||||
def test_state_dict(self): | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('pre_shuffle', [True, False]) | |||||
@pytest.mark.parametrize('post_shuffle', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||||
def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): | |||||
num_samples = 100 | num_samples = 100 | ||||
dataset = TorchNormalDataset(num_of_data=num_samples) | dataset = TorchNormalDataset(num_of_data=num_samples) | ||||
# 测试使用 前后shuffle不一致的load操作 | # 测试使用 前后shuffle不一致的load操作 | ||||
lst = [0]+np.random.randint(1, num_samples, size=3).tolist() | |||||
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], | |||||
lst): | |||||
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle) | |||||
sampler.set_epoch(0) | |||||
already_numbers = set() | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
self.assertEqual(len(already_numbers), num_consumed_samples) | |||||
states = sampler.state_dict() | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, already_numbers) | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) | |||||
new_sampler.set_epoch(0) | |||||
count = 0 | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, other_rank_number) | |||||
other_rank_number.add(i) | |||||
self.assertNotIn(i, already_numbers) | |||||
count += 1 | |||||
def test_state_dict_2(self): | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle) | |||||
sampler.set_epoch(0) | |||||
already_numbers = set() | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples | |||||
states = sampler.state_dict() | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
new_sampler.set_epoch(0) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
for i in new_sampler: | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('pre_shuffle', [True, False]) | |||||
@pytest.mark.parametrize('post_shuffle', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||||
def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): | |||||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | ||||
num_samples = 100 | num_samples = 100 | ||||
dataset = TorchNormalDataset(num_of_data=num_samples) | dataset = TorchNormalDataset(num_of_data=num_samples) | ||||
# 测试使用 前后shuffle不一致的load操作 | # 测试使用 前后shuffle不一致的load操作 | ||||
lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist() | |||||
# lst = [30] | # lst = [30] | ||||
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], | |||||
lst): | |||||
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): | |||||
already_numbers = set() | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||||
sampler.set_distributed(num_replicas=2, rank=0) | |||||
sampler.set_epoch(0) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=2, rank=1) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
self.assertEqual(len(already_numbers), num_consumed_samples*2) | |||||
states = sampler.state_dict() | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, already_numbers) | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) | |||||
count = 0 | |||||
for i in new_sampler: | |||||
self.assertNotIn(i, other_rank_number) | |||||
other_rank_number.add(i) | |||||
self.assertNotIn(i, already_numbers) | |||||
count += 1 | |||||
class TestRandomSampler(unittest.TestCase): | |||||
already_numbers = set() | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||||
sampler.set_distributed(num_replicas=2, rank=0) | |||||
sampler.set_epoch(0) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=2, rank=1) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples*2 | |||||
states = sampler.state_dict() | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = RandomSampler(dataset, shuffle=post_shuffle) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
for i in new_sampler: | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||||
class TestRandomSampler: | |||||
# 测试单卡; | # 测试单卡; | ||||
def test_seed_work_when_shuffle_is_true(self): | def test_seed_work_when_shuffle_is_true(self): | ||||
data_length = 100 | data_length = 100 | ||||
@@ -360,4 +364,324 @@ class TestRandomSampler(unittest.TestCase): | |||||
... | ... | ||||
class DatasetWithVaryLength: | |||||
def __init__(self, num_of_data=100, reverse=False): | |||||
self.data = np.arange(num_of_data) | |||||
if reverse: | |||||
self.data = self.data[::-1] | |||||
def __getitem__(self, item): | |||||
return self.data[item] | |||||
def __len__(self): | |||||
return len(self.data) | |||||
class TestSortedSampler: | |||||
def test_single(self): | |||||
num_of_data = 100 | |||||
data = DatasetWithVaryLength(num_of_data) | |||||
sampler = SortedSampler(data, length=data.data) | |||||
indexes = list(sampler) | |||||
assert indexes==list(range(num_of_data-1, -1, -1)) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||||
def test_multi(self, pad, num_replica, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||||
samplers = [] | |||||
for i in range(num_replica): | |||||
sampler = SortedSampler(dataset=data, length=data.data) | |||||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
samplers.append(sampler) | |||||
# 保证顺序是没乱的 | |||||
already_seen_index = set() | |||||
for sampler in samplers: | |||||
larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。 | |||||
prev_index = float('inf') | |||||
cur_set = set() | |||||
seen_in_other_rank = 0 | |||||
for index in sampler: | |||||
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 | |||||
cur_set.add(index) | |||||
larger_count += int(index <= prev_index) | |||||
prev_index = index | |||||
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 | |||||
assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0 | |||||
already_seen_index.update(cur_set) | |||||
indexes = list(chain(*samplers)) | |||||
indexes = set(indexes) | |||||
if pad: | |||||
assert indexes == set(range(num_of_data)) | |||||
else: | |||||
assert len(indexes) <= num_of_data | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||||
def test_state_dict(self, pad, num_consumed_samples): | |||||
num_samples = 100 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
# 测试使用 前后shuffle不一致的load操作 | |||||
sampler = SortedSampler(dataset, length=dataset.data) | |||||
sampler.set_epoch(0) | |||||
already_numbers = set() | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
if already_numbers: | |||||
assert j<max(already_numbers) | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples | |||||
states = sampler.state_dict() | |||||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
assert i < max(already_numbers) | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
new_sampler.set_epoch(0) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
smaller = 0 | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
smaller += int(i >= max(already_numbers)) | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||||
assert smaller<=1 if pad else smaller==0 | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||||
def test_state_dict_2(self, pad, num_consumed_samples): | |||||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||||
num_samples = 100 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
# 测试使用 前后shuffle不一致的load操作 | |||||
# lst = [30] | |||||
already_numbers = set() | |||||
sampler = SortedSampler(dataset, length=dataset.data) | |||||
sampler.set_distributed(num_replicas=2, rank=0) | |||||
sampler.set_epoch(0) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
if already_numbers: | |||||
assert j<=max(already_numbers) | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
sampler = SortedSampler(dataset, length=dataset.data) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=2, rank=1) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples*2 | |||||
states = sampler.state_dict() | |||||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
assert i < max(already_numbers) | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = SortedSampler(dataset, length=dataset.data) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
smaller = 0 | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
smaller += int(i>=max(already_numbers)) | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||||
assert smaller <= 1 if pad else smaller == 0 | |||||
class TestSequentialSampler: | |||||
def test_single(self): | |||||
num_of_data = 100 | |||||
data = DatasetWithVaryLength(num_of_data) | |||||
sampler = SequentialSampler(data) | |||||
indexes = list(sampler) | |||||
assert indexes==list(range(num_of_data)) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||||
def test_multi(self, pad, num_replica, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||||
samplers = [] | |||||
for i in range(num_replica): | |||||
sampler = SequentialSampler(dataset=data) | |||||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
samplers.append(sampler) | |||||
# 保证顺序是没乱的 | |||||
already_seen_index = set() | |||||
for idx, sampler in enumerate(samplers): | |||||
larger_count = 1 | |||||
prev_index = float('inf') | |||||
cur_set = set() | |||||
seen_in_other_rank = 0 | |||||
for index in sampler: | |||||
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 | |||||
cur_set.add(index) | |||||
larger_count += int(index >= prev_index) | |||||
prev_index = index | |||||
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 | |||||
assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0 | |||||
already_seen_index.update(cur_set) | |||||
indexes = list(chain(*samplers)) | |||||
indexes = set(indexes) | |||||
if pad: | |||||
assert indexes == set(range(num_of_data)) | |||||
else: | |||||
assert len(indexes) <= num_of_data | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) | |||||
def test_state_dict(self, pad, num_consumed_samples): | |||||
num_samples = 100 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
# 测试使用 前后shuffle不一致的load操作 | |||||
sampler = SequentialSampler(dataset=dataset) | |||||
sampler.set_epoch(0) | |||||
already_numbers = set() | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
if already_numbers: | |||||
assert j>max(already_numbers) | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples | |||||
states = sampler.state_dict() | |||||
new_sampler = SequentialSampler(dataset=dataset) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
assert i > max(already_numbers) | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = SequentialSampler(dataset=dataset) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
new_sampler.set_epoch(0) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
smaller = 0 | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
smaller += int(i <= max(already_numbers)) | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=rank # 因为pad可能重复 | |||||
assert smaller<=1 if pad else smaller==0 | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) | |||||
def test_state_dict_2(self, pad, num_consumed_samples): | |||||
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 | |||||
num_samples = 100 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
# 测试使用 前后shuffle不一致的load操作 | |||||
# lst = [30] | |||||
already_numbers = set() | |||||
sampler = SequentialSampler(dataset=dataset) | |||||
sampler.set_distributed(num_replicas=2, rank=0) | |||||
sampler.set_epoch(0) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
if already_numbers: | |||||
assert j>max(already_numbers) | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
sampler = SequentialSampler(dataset=dataset) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replicas=2, rank=1) | |||||
if num_consumed_samples>0: | |||||
for i, j in enumerate(sampler, start=1): | |||||
already_numbers.add(j) | |||||
if i == num_consumed_samples: | |||||
break | |||||
assert len(already_numbers) == num_consumed_samples*2 | |||||
states = sampler.state_dict() | |||||
new_sampler = SequentialSampler(dataset=dataset) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
assert i > max(already_numbers) | |||||
assert i not in already_numbers | |||||
# 测试切换成多卡也没有问题 | |||||
other_rank_number = set() | |||||
for rank in range(3): | |||||
new_sampler = SequentialSampler(dataset=dataset) | |||||
new_sampler.load_state_dict(states) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) | |||||
count = 0 | |||||
seen = 0 | |||||
seen_in_other_rank = 0 | |||||
smaller = 0 | |||||
for i in new_sampler: | |||||
if already_numbers: | |||||
smaller += int(i<max(already_numbers)) | |||||
seen_in_other_rank += int(i in other_rank_number) | |||||
other_rank_number.add(i) | |||||
seen += int(i in already_numbers) | |||||
count += 1 | |||||
assert seen <= 1 if pad else seen == 0 | |||||
assert seen_in_other_rank<=1 # 因为pad可能重复 | |||||
assert smaller <= rank if pad else smaller == 0 | |||||
@@ -2,7 +2,7 @@ from itertools import chain | |||||
import pytest | import pytest | ||||
from fastNLP.core.samplers import UnrepeatedSampler, UnrepeatedSortedSampler | |||||
from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler | |||||
class DatasetWithVaryLength: | class DatasetWithVaryLength: | ||||
@@ -21,7 +21,7 @@ class TestUnrepeatedSampler: | |||||
def test_single(self, shuffle): | def test_single(self, shuffle): | ||||
num_of_data = 100 | num_of_data = 100 | ||||
data = DatasetWithVaryLength(num_of_data) | data = DatasetWithVaryLength(num_of_data) | ||||
sampler = UnrepeatedSampler(data, shuffle) | |||||
sampler = UnrepeatedRandomSampler(data, shuffle) | |||||
indexes = set(sampler) | indexes = set(sampler) | ||||
assert indexes==set(range(num_of_data)) | assert indexes==set(range(num_of_data)) | ||||
@@ -32,17 +32,18 @@ class TestUnrepeatedSampler: | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replica): | for i in range(num_replica): | ||||
sampler = UnrepeatedSampler(dataset=data, shuffle=shuffle) | |||||
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) | |||||
sampler.set_distributed(num_replica, rank=i) | sampler.set_distributed(num_replica, rank=i) | ||||
samplers.append(sampler) | samplers.append(sampler) | ||||
indexes = set(chain(*samplers)) | |||||
indexes = list(chain(*samplers)) | |||||
assert len(indexes) == num_of_data | |||||
indexes = set(indexes) | |||||
assert indexes==set(range(num_of_data)) | assert indexes==set(range(num_of_data)) | ||||
class TestUnrepeatedSortedSampler: | class TestUnrepeatedSortedSampler: | ||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
def test_single(self, shuffle): | |||||
def test_single(self): | |||||
num_of_data = 100 | num_of_data = 100 | ||||
data = DatasetWithVaryLength(num_of_data) | data = DatasetWithVaryLength(num_of_data) | ||||
sampler = UnrepeatedSortedSampler(data, length=data.data) | sampler = UnrepeatedSortedSampler(data, length=data.data) | ||||
@@ -51,8 +52,7 @@ class TestUnrepeatedSortedSampler: | |||||
@pytest.mark.parametrize('num_replica', [2, 3]) | @pytest.mark.parametrize('num_replica', [2, 3]) | ||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
@pytest.mark.parametrize('shuffle', [False, True]) | |||||
def test_multi(self, num_replica, num_of_data, shuffle): | |||||
def test_multi(self, num_replica, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replica): | for i in range(num_replica): | ||||
@@ -60,5 +60,45 @@ class TestUnrepeatedSortedSampler: | |||||
sampler.set_distributed(num_replica, rank=i) | sampler.set_distributed(num_replica, rank=i) | ||||
samplers.append(sampler) | samplers.append(sampler) | ||||
indexes = set(chain(*samplers)) | |||||
# 保证顺序是没乱的 | |||||
for sampler in samplers: | |||||
prev_index = float('inf') | |||||
for index in sampler: | |||||
assert index <= prev_index | |||||
prev_index = index | |||||
indexes = list(chain(*samplers)) | |||||
assert len(indexes) == num_of_data # 不同卡之间没有交叉 | |||||
indexes = set(indexes) | |||||
assert indexes==set(range(num_of_data)) | assert indexes==set(range(num_of_data)) | ||||
class TestUnrepeatedSequentialSampler: | |||||
def test_single(self): | |||||
num_of_data = 100 | |||||
data = DatasetWithVaryLength(num_of_data) | |||||
sampler = UnrepeatedSequentialSampler(data, length=data.data) | |||||
indexes = list(sampler) | |||||
assert indexes==list(range(num_of_data)) | |||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | |||||
def test_multi(self, num_replica, num_of_data): | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | |||||
samplers = [] | |||||
for i in range(num_replica): | |||||
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) | |||||
sampler.set_distributed(num_replica, rank=i) | |||||
samplers.append(sampler) | |||||
# 保证顺序是没乱的 | |||||
for sampler in samplers: | |||||
prev_index = float('-inf') | |||||
for index in sampler: | |||||
assert index>=prev_index | |||||
prev_index = index | |||||
indexes = list(chain(*samplers)) | |||||
assert len(indexes) == num_of_data | |||||
indexes = set(indexes) | |||||
assert indexes == set(range(num_of_data)) |