|
|
@@ -1,4 +1,5 @@ |
|
|
|
import io |
|
|
|
import os |
|
|
|
import pickle |
|
|
|
_pickler = pickle.Pickler |
|
|
|
_unpickler = pickle.Unpickler |
|
|
@@ -7,6 +8,7 @@ from typing import Any, List |
|
|
|
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 |
|
|
|
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_TORCH |
|
|
|
from fastNLP.envs.env import FASTNLP_NO_SYNC |
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
import torch |
|
|
|
from torch import distributed as dist |
|
|
@@ -83,6 +85,14 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA |
|
|
|
>>> output |
|
|
|
['foo', 12, {1: 2}] |
|
|
|
""" |
|
|
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: |
|
|
|
return [obj] |
|
|
|
|
|
|
|
if dist.get_rank() == dst: |
|
|
|
object_gather_list = [None for _ in range(dist.get_world_size(group))] |
|
|
|
else: |
|
|
|
object_gather_list = None |
|
|
|
|
|
|
|
if group is None: |
|
|
|
group = DEFAULT_TORCH_GROUP |
|
|
|
|
|
|
@@ -207,6 +217,9 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) |
|
|
|
:param group: |
|
|
|
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 |
|
|
|
""" |
|
|
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: |
|
|
|
return [obj] |
|
|
|
|
|
|
|
if group is None: |
|
|
|
group = DEFAULT_TORCH_GROUP |
|
|
|
if isinstance(obj, torch.Tensor): |
|
|
@@ -233,6 +246,12 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GR |
|
|
|
:param group: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: |
|
|
|
if src == dist.get_rank(group): |
|
|
|
return obj |
|
|
|
else: |
|
|
|
return None |
|
|
|
|
|
|
|
if group is None: |
|
|
|
group = DEFAULT_TORCH_GROUP |
|
|
|
cur_rank = dist.get_rank(group) |
|
|
@@ -328,6 +347,9 @@ def all_gather_object(object_list, obj, group=None): |
|
|
|
>>> output |
|
|
|
['foo', 12, {1: 2}] |
|
|
|
""" |
|
|
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: |
|
|
|
return [obj] |
|
|
|
|
|
|
|
if dist.distributed_c10d._rank_not_in_group(group): |
|
|
|
return |
|
|
|
if _TORCH_GREATER_EQUAL_1_8: |
|
|
|