From fcd27cfc3f88ec4a6154d12d3bb0d8bdb3c44ac5 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 16 Apr 2022 05:50:53 +0000 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0FASTNLP=5FNO=5FSYNC=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E7=9A=84=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/drivers/paddle_driver/dist_utils.py | 22 +++++++++++++++++++ fastNLP/core/drivers/paddle_driver/fleet.py | 9 ++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/dist_utils.py b/fastNLP/core/drivers/paddle_driver/dist_utils.py index 3bfbbd4f..4d9ae5f0 100644 --- a/fastNLP/core/drivers/paddle_driver/dist_utils.py +++ b/fastNLP/core/drivers/paddle_driver/dist_utils.py @@ -1,4 +1,5 @@ import io +import os import pickle _pickler = pickle.Pickler _unpickler = pickle.Unpickler @@ -7,6 +8,7 @@ from typing import Any, List from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.envs.env import FASTNLP_NO_SYNC if _NEED_IMPORT_TORCH: import torch from torch import distributed as dist @@ -83,6 +85,14 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA >>> output ['foo', 12, {1: 2}] """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + + if dist.get_rank() == dst: + object_gather_list = [None for _ in range(dist.get_world_size(group))] + else: + object_gather_list = None + if group is None: group = DEFAULT_TORCH_GROUP @@ -207,6 +217,9 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) :param group: :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + if group is None: group = DEFAULT_TORCH_GROUP if isinstance(obj, torch.Tensor): @@ -233,6 +246,12 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR :param group: :return: """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + if src == dist.get_rank(group): + return obj + else: + return None + if group is None: group = DEFAULT_TORCH_GROUP cur_rank = dist.get_rank(group) @@ -328,6 +347,9 @@ def all_gather_object(object_list, obj, group=None): >>> output ['foo', 12, {1: 2}] """ + if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: + return [obj] + if dist.distributed_c10d._rank_not_in_group(group): return if _TORCH_GREATER_EQUAL_1_8: diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index ad07da8b..c407ab9f 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -29,7 +29,7 @@ from fastNLP.core.samplers import ( re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler, ) -from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED +from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -234,7 +234,8 @@ class PaddleFleetDriver(PaddleDriver): self.global_rank = paddledist.get_rank() def barrier(self): - paddledist.barrier() + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 + paddledist.barrier() def configure_fleet(self): if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): @@ -451,6 +452,8 @@ class PaddleFleetDriver(PaddleDriver): 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 """ return + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。 + return return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) def all_gather(self, obj, group) -> List: @@ -477,4 +480,6 @@ class PaddleFleetDriver(PaddleDriver): :return: """ return + if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行 + return [obj] return fastnlp_paddle_all_gather(obj, group=group)