| @@ -245,10 +245,15 @@ class OneflowDDPDriver(OneflowDriver): | |||||
| # evaluator | # evaluator | ||||
| elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
| args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
| if type(args.batch_sampler) != BatchSampler: | |||||
| # TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善 | |||||
| logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \ | |||||
| "train dataloader while testing ``overfit_batches``, which may cause that" \ | |||||
| "the data for distributed evaluation is not unrepeated.") | |||||
| if isinstance(args.sampler, ReproducibleSampler): | if isinstance(args.sampler, ReproducibleSampler): | ||||
| sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | ||||
| elif not isinstance(args.sampler, UnrepeatedSampler): | elif not isinstance(args.sampler, UnrepeatedSampler): | ||||
| _check_dataloader_args_for_distributed(args, controller="Evaluator") | |||||
| _check_dataloader_args_for_distributed(args, controller='Evaluator') | |||||
| sampler = UnrepeatedSequentialSampler( | sampler = UnrepeatedSequentialSampler( | ||||
| dataset=args.dataset | dataset=args.dataset | ||||
| ) | ) | ||||
| @@ -258,6 +263,7 @@ class OneflowDDPDriver(OneflowDriver): | |||||
| num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
| rank=self.global_rank | rank=self.global_rank | ||||
| ) | ) | ||||
| # TODO 这里暂时统一替换为 BatchSampler | |||||
| batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | ||||
| return replace_batch_sampler(dataloader, batch_sampler) | return replace_batch_sampler(dataloader, batch_sampler) | ||||
| else: | else: | ||||
| @@ -43,7 +43,6 @@ def initialize_oneflow_driver(driver: str, device: Optional[Union[str, "oneflow. | |||||
| raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
| device = [oneflow.device(f"cuda:{w}") for w in range(_could_use_device_num)] | device = [oneflow.device(f"cuda:{w}") for w in range(_could_use_device_num)] | ||||
| elif device >= _could_use_device_num: | elif device >= _could_use_device_num: | ||||
| print(device, _could_use_device_num) | |||||
| raise ValueError("The gpu device that parameter `device` specifies is not existed.") | raise ValueError("The gpu device that parameter `device` specifies is not existed.") | ||||
| else: | else: | ||||
| device = oneflow.device(f"cuda:{device}") | device = oneflow.device(f"cuda:{device}") | ||||
| @@ -280,12 +280,23 @@ def optimizer_state_to_device(state, device): | |||||
| def _check_dataloader_args_for_distributed(args, controller='Trainer'): | def _check_dataloader_args_for_distributed(args, controller='Trainer'): | ||||
| if type(args.batch_sampler) is not oneflowBatchSampler or (type(args.sampler) not in {oneflowRandomSampler, | |||||
| oneflowSequentialSampler}): | |||||
| mode = 'training' if controller == 'Trainer' else 'evaluation' | |||||
| substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||||
| """ | |||||
| 检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止 | |||||
| 在分布式训练中出现错误会报错。 | |||||
| """ | |||||
| error_flag = (type(args.sampler) not in {oneflowRandomSampler, oneflowSequentialSampler}) | |||||
| if controller == 'Trainer': | |||||
| mode = 'training' | |||||
| substitution = 'fastNLP.RandomSampler' | |||||
| error_flag = (type(args.batch_sampler) != oneflowBatchSampler) or error_flag | |||||
| else: # Evaluator | |||||
| mode = 'evaluation' | |||||
| substitution = 'fastNLP.UnrepeatedSequentialSampler' | |||||
| if error_flag: | |||||
| raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | ||||
| f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | ||||
| f"``{substitution}``. The customized sampler should set for distributed running " | f"``{substitution}``. The customized sampler should set for distributed running " | ||||
| f"before initializing ``{controller}`` , and then set the " | f"before initializing ``{controller}`` , and then set the " | ||||
| f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") | |||||
| f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``." | |||||
| f"\n Current batch_sampler: {type(args.batch_sampler)}" | |||||
| f"\n Current sampler: {type(args.sampler)}") | |||||
| @@ -112,6 +112,7 @@ if _NEED_IMPORT_PADDLE: | |||||
| from paddle.optimizer import Optimizer | from paddle.optimizer import Optimizer | ||||
| from paddle.fluid.reader import _DatasetKind | from paddle.fluid.reader import _DatasetKind | ||||
| from paddle.fluid.dygraph import parallel_helper | from paddle.fluid.dygraph import parallel_helper | ||||
| from paddle.io import BatchSampler | |||||
| __all__ = [ | __all__ = [ | ||||
| "PaddleFleetDriver", | "PaddleFleetDriver", | ||||
| @@ -471,9 +472,15 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| # evaluator | # evaluator | ||||
| elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
| args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
| if type(args.batch_sampler) != BatchSampler: | |||||
| # TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善 | |||||
| logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \ | |||||
| "train dataloader while testing ``overfit_batches``, which may cause that" \ | |||||
| "the data for distributed evaluation is not unrepeated.") | |||||
| if isinstance(args.sampler, ReproducibleSampler): | if isinstance(args.sampler, ReproducibleSampler): | ||||
| sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | ||||
| elif not isinstance(args.sampler, UnrepeatedSampler): | elif not isinstance(args.sampler, UnrepeatedSampler): | ||||
| _check_dataloader_args_for_distributed(args, controller='Evaluator') | |||||
| sampler = UnrepeatedSequentialSampler( | sampler = UnrepeatedSequentialSampler( | ||||
| dataset=args.dataset | dataset=args.dataset | ||||
| ) | ) | ||||
| @@ -483,7 +490,9 @@ class PaddleFleetDriver(PaddleDriver): | |||||
| num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
| rank=self.global_rank | rank=self.global_rank | ||||
| ) | ) | ||||
| return replace_sampler(dataloader, sampler) | |||||
| # TODO 这里暂时统一替换为 BatchSampler | |||||
| batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | |||||
| return replace_batch_sampler(dataloader, batch_sampler) | |||||
| else: | else: | ||||
| raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
| @@ -266,12 +266,23 @@ def optimizer_state_to_device(state, device): | |||||
| return new_state | return new_state | ||||
| def _check_dataloader_args_for_distributed(args, controller='Trainer'): | def _check_dataloader_args_for_distributed(args, controller='Trainer'): | ||||
| if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler, | |||||
| SequenceSampler}): | |||||
| mode = 'training' if controller == 'Trainer' else 'evaluation' | |||||
| substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' | |||||
| """ | |||||
| 检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止 | |||||
| 在分布式训练中出现错误会报错。 | |||||
| """ | |||||
| error_flag = (type(args.sampler) not in {RandomSampler, SequenceSampler}) | |||||
| if controller == 'Trainer': | |||||
| mode = 'training' | |||||
| substitution = 'fastNLP.RandomSampler' | |||||
| error_flag = (type(args.batch_sampler) != BatchSampler) or error_flag | |||||
| else: # Evaluator | |||||
| mode = 'evaluation' | |||||
| substitution = 'fastNLP.UnrepeatedSequentialSampler' | |||||
| if error_flag: | |||||
| raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " | ||||
| f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " | ||||
| f"``{substitution}``. The customized sampler should set for distributed running " | f"``{substitution}``. The customized sampler should set for distributed running " | ||||
| f"before initializing ``{controller}`` , and then set the " | f"before initializing ``{controller}`` , and then set the " | ||||
| f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") | |||||
| f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``." | |||||
| f"\n Current batch_sampler: {type(args.batch_sampler)}" | |||||
| f"\n Current sampler: {type(args.sampler)}") | |||||
| @@ -617,7 +617,6 @@ class TorchDDPDriver(TorchDriver): | |||||
| if isinstance(args.sampler, ReproducibleSampler): | if isinstance(args.sampler, ReproducibleSampler): | ||||
| sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) | ||||
| elif not isinstance(args.sampler, UnrepeatedSampler): | elif not isinstance(args.sampler, UnrepeatedSampler): | ||||
| # TODO 避开 batch_sampler 的情况 | |||||
| _check_dataloader_args_for_distributed(args, controller='Evaluator') | _check_dataloader_args_for_distributed(args, controller='Evaluator') | ||||
| sampler = UnrepeatedSequentialSampler( | sampler = UnrepeatedSequentialSampler( | ||||
| dataset=args.dataset | dataset=args.dataset | ||||