diff --git a/fastNLP/core/drivers/oneflow_driver/ddp.py b/fastNLP/core/drivers/oneflow_driver/ddp.py index 681ff5d0..974c0b69 100644 --- a/fastNLP/core/drivers/oneflow_driver/ddp.py +++ b/fastNLP/core/drivers/oneflow_driver/ddp.py @@ -245,10 +245,15 @@ class OneflowDDPDriver(OneflowDriver): # evaluator elif dist == "unrepeatdist": args = self.get_dataloader_args(dataloader) + if type(args.batch_sampler) != BatchSampler: + # TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善 + logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \ + "train dataloader while testing ``overfit_batches``, which may cause that" \ + "the data for distributed evaluation is not unrepeated.") if isinstance(args.sampler, ReproducibleSampler): sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) elif not isinstance(args.sampler, UnrepeatedSampler): - _check_dataloader_args_for_distributed(args, controller="Evaluator") + _check_dataloader_args_for_distributed(args, controller='Evaluator') sampler = UnrepeatedSequentialSampler( dataset=args.dataset ) @@ -258,6 +263,7 @@ class OneflowDDPDriver(OneflowDriver): num_replicas=self.world_size, rank=self.global_rank ) + # TODO 这里暂时统一替换为 BatchSampler batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) return replace_batch_sampler(dataloader, batch_sampler) else: diff --git a/fastNLP/core/drivers/oneflow_driver/initialize_oneflow_driver.py b/fastNLP/core/drivers/oneflow_driver/initialize_oneflow_driver.py index 7475900a..dd0597b2 100644 --- a/fastNLP/core/drivers/oneflow_driver/initialize_oneflow_driver.py +++ b/fastNLP/core/drivers/oneflow_driver/initialize_oneflow_driver.py @@ -43,7 +43,6 @@ def initialize_oneflow_driver(driver: str, device: Optional[Union[str, "oneflow. raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") device = [oneflow.device(f"cuda:{w}") for w in range(_could_use_device_num)] elif device >= _could_use_device_num: - print(device, _could_use_device_num) raise ValueError("The gpu device that parameter `device` specifies is not existed.") else: device = oneflow.device(f"cuda:{device}") diff --git a/fastNLP/core/drivers/oneflow_driver/utils.py b/fastNLP/core/drivers/oneflow_driver/utils.py index f52cc13a..175c7714 100644 --- a/fastNLP/core/drivers/oneflow_driver/utils.py +++ b/fastNLP/core/drivers/oneflow_driver/utils.py @@ -280,12 +280,23 @@ def optimizer_state_to_device(state, device): def _check_dataloader_args_for_distributed(args, controller='Trainer'): - if type(args.batch_sampler) is not oneflowBatchSampler or (type(args.sampler) not in {oneflowRandomSampler, - oneflowSequentialSampler}): - mode = 'training' if controller == 'Trainer' else 'evaluation' - substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' + """ + 检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止 + 在分布式训练中出现错误会报错。 + """ + error_flag = (type(args.sampler) not in {oneflowRandomSampler, oneflowSequentialSampler}) + if controller == 'Trainer': + mode = 'training' + substitution = 'fastNLP.RandomSampler' + error_flag = (type(args.batch_sampler) != oneflowBatchSampler) or error_flag + else: # Evaluator + mode = 'evaluation' + substitution = 'fastNLP.UnrepeatedSequentialSampler' + if error_flag: raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " f"``{substitution}``. The customized sampler should set for distributed running " f"before initializing ``{controller}`` , and then set the " - f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") + f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``." + f"\n Current batch_sampler: {type(args.batch_sampler)}" + f"\n Current sampler: {type(args.sampler)}") diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 3458d340..fc8c0695 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -112,6 +112,7 @@ if _NEED_IMPORT_PADDLE: from paddle.optimizer import Optimizer from paddle.fluid.reader import _DatasetKind from paddle.fluid.dygraph import parallel_helper + from paddle.io import BatchSampler __all__ = [ "PaddleFleetDriver", @@ -471,9 +472,15 @@ class PaddleFleetDriver(PaddleDriver): # evaluator elif dist == "unrepeatdist": args = self.get_dataloader_args(dataloader) + if type(args.batch_sampler) != BatchSampler: + # TODO 这里的目的是判断用户的 batch_sampler 是定制的,可能需要完善 + logger.warning("Note that you are using customized ``batch_sampler`` in evaluate dataloader or" \ + "train dataloader while testing ``overfit_batches``, which may cause that" \ + "the data for distributed evaluation is not unrepeated.") if isinstance(args.sampler, ReproducibleSampler): sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) elif not isinstance(args.sampler, UnrepeatedSampler): + _check_dataloader_args_for_distributed(args, controller='Evaluator') sampler = UnrepeatedSequentialSampler( dataset=args.dataset ) @@ -483,7 +490,9 @@ class PaddleFleetDriver(PaddleDriver): num_replicas=self.world_size, rank=self.global_rank ) - return replace_sampler(dataloader, sampler) + # TODO 这里暂时统一替换为 BatchSampler + batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) + return replace_batch_sampler(dataloader, batch_sampler) else: raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 9220f3a3..296bcebe 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -266,12 +266,23 @@ def optimizer_state_to_device(state, device): return new_state def _check_dataloader_args_for_distributed(args, controller='Trainer'): - if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {RandomSampler, - SequenceSampler}): - mode = 'training' if controller == 'Trainer' else 'evaluation' - substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' + """ + 检查 dataloader 的 sampler 情况,如果用户替换了自己定制的 sampler ,为了防止 + 在分布式训练中出现错误会报错。 + """ + error_flag = (type(args.sampler) not in {RandomSampler, SequenceSampler}) + if controller == 'Trainer': + mode = 'training' + substitution = 'fastNLP.RandomSampler' + error_flag = (type(args.batch_sampler) != BatchSampler) or error_flag + else: # Evaluator + mode = 'evaluation' + substitution = 'fastNLP.UnrepeatedSequentialSampler' + if error_flag: raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " f"``{substitution}``. The customized sampler should set for distributed running " f"before initializing ``{controller}`` , and then set the " - f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") + f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``." + f"\n Current batch_sampler: {type(args.batch_sampler)}" + f"\n Current sampler: {type(args.sampler)}") diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 5dd8125a..2bf38a0d 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -617,7 +617,6 @@ class TorchDDPDriver(TorchDriver): if isinstance(args.sampler, ReproducibleSampler): sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) elif not isinstance(args.sampler, UnrepeatedSampler): - # TODO 避开 batch_sampler 的情况 _check_dataloader_args_for_distributed(args, controller='Evaluator') sampler = UnrepeatedSequentialSampler( dataset=args.dataset