|
|
@@ -1,27 +1,25 @@ |
|
|
|
import io |
|
|
|
import os |
|
|
|
import pickle |
|
|
|
_pickler = pickle.Pickler |
|
|
|
_unpickler = pickle.Unpickler |
|
|
|
import os |
|
|
|
from typing import Any, List |
|
|
|
|
|
|
|
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 |
|
|
|
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_TORCH |
|
|
|
import numpy as np |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE |
|
|
|
from fastNLP.envs.env import FASTNLP_NO_SYNC |
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
import torch |
|
|
|
from torch import distributed as dist |
|
|
|
if _TORCH_GREATER_EQUAL_1_8: |
|
|
|
try: |
|
|
|
from torch._C._distributed_c10d import ProcessGroupGloo |
|
|
|
from torch._C._distributed_c10d import _ProcessGroupWrapper |
|
|
|
except ImportError: |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.core.utils import apply_to_collection |
|
|
|
|
|
|
|
from fastNLP.core.utils import paddle_move_data_to_device |
|
|
|
|
|
|
|
if _NEED_IMPORT_PADDLE: |
|
|
|
import paddle |
|
|
|
import paddle.distributed as dist |
|
|
|
from paddle.framework.io import ( |
|
|
|
_is_state_dict, |
|
|
|
_build_saved_state_dict, |
|
|
|
_unpack_saved_dict, |
|
|
|
_pickle_save, |
|
|
|
_pack_loaded_dict, |
|
|
|
_ndarray_to_tensor, |
|
|
|
_parse_load_result, |
|
|
|
) |
|
|
|
|
|
|
|
def _validate_output_list_for_rank(my_rank, dst, gather_list): |
|
|
|
if dst == my_rank: |
|
|
@@ -35,48 +33,65 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): |
|
|
|
"on non-destination ranks." |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): |
|
|
|
def paddle_pickle_dump(obj, stream, protocol): |
|
|
|
""" |
|
|
|
从其它 rank gather 东西到 dst rank 。 |
|
|
|
Reference to `paddle.save` |
|
|
|
""" |
|
|
|
if _is_state_dict(obj): |
|
|
|
saved_obj = _build_saved_state_dict(obj) |
|
|
|
saved_obj = _unpack_saved_dict(saved_obj, protocol) |
|
|
|
pickle.dump(saved_obj, stream, protocol=protocol) |
|
|
|
else: |
|
|
|
_pickle_save(obj, stream, protocol) |
|
|
|
|
|
|
|
Gathers picklable objects from the whole group in a single process. |
|
|
|
Similar to :func:`gather`, but Python objects can be passed in. Note that the |
|
|
|
object must be picklable in order to be gathered. |
|
|
|
def paddle_pickle_load(stream): |
|
|
|
""" |
|
|
|
Reference to `paddle.load` |
|
|
|
""" |
|
|
|
load_result = pickle.load(stream) |
|
|
|
if isinstance(load_result, dict): |
|
|
|
load_result = _pack_loaded_dict(load_result) |
|
|
|
if "StructuredToParameterName@@" in load_result: |
|
|
|
|
|
|
|
for key in load_result["StructuredToParameterName@@"]: |
|
|
|
if isinstance(load_result[key], np.ndarray): |
|
|
|
load_result[key] = _ndarray_to_tensor( |
|
|
|
load_result[key], return_numpy=False) |
|
|
|
|
|
|
|
if "StructuredToParameterName@@" in load_result: |
|
|
|
del load_result["StructuredToParameterName@@"] |
|
|
|
else: |
|
|
|
load_result = _parse_load_result(load_result, return_numpy=False) |
|
|
|
|
|
|
|
Args: |
|
|
|
obj (Any): Input object. Must be picklable. |
|
|
|
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it |
|
|
|
should be correctly sized as the size of the group for this |
|
|
|
collective and will contain the output. Must be ``None`` on non-dst |
|
|
|
ranks. (default is ``None``) |
|
|
|
dst (int, optional): Destination rank. (default is 0) |
|
|
|
group: (ProcessGroup, optional): The process group to work on. If None, |
|
|
|
the default process group will be used. Default is ``None``. |
|
|
|
else: |
|
|
|
load_result = _parse_load_result(load_result, return_numpy=False) |
|
|
|
|
|
|
|
Returns: |
|
|
|
None. On the ``dst`` rank, ``object_gather_list`` will contain the |
|
|
|
output of the collective. |
|
|
|
return load_result |
|
|
|
|
|
|
|
.. note:: Note that this API differs slightly from the gather collective |
|
|
|
since it does not provide an async_op handle and thus will be a blocking |
|
|
|
call. |
|
|
|
def _object_to_tensor(obj, device=None): |
|
|
|
f = io.BytesIO() |
|
|
|
paddle_pickle_dump(obj, f, protocol=2) |
|
|
|
byte_data = list(f.getvalue()) |
|
|
|
byte_tensor = paddle.to_tensor(byte_data, dtype=paddle.int32) |
|
|
|
local_size = paddle.to_tensor([byte_tensor.numel()]) |
|
|
|
if device is not None: |
|
|
|
byte_tensor = paddle_move_data_to_device(byte_tensor, device) |
|
|
|
local_size = paddle_move_data_to_device(local_size, device) |
|
|
|
return byte_tensor, local_size |
|
|
|
|
|
|
|
.. note:: Note that this API is not supported when using the NCCL backend. |
|
|
|
def _tensor_to_object(tensor, tensor_size): |
|
|
|
buf = tensor.astype(paddle.uint8).detach().cpu().numpy().tobytes()[:tensor_size] |
|
|
|
return paddle_pickle_load(io.BytesIO(buf)) |
|
|
|
|
|
|
|
.. warning:: |
|
|
|
:func:`gather_object` uses ``pickle`` module implicitly, which is |
|
|
|
known to be insecure. It is possible to construct malicious pickle data |
|
|
|
which will execute arbitrary code during unpickling. Only call this |
|
|
|
function with data you trust. |
|
|
|
def fastnlp_paddle_gather_object(obj, dst=0, group=None): |
|
|
|
""" |
|
|
|
从其它 rank gather 东西到 dst rank 。 |
|
|
|
|
|
|
|
Example:: |
|
|
|
>>> # Note: Process group initialization omitted on each rank. |
|
|
|
>>> import torch.distributed as dist |
|
|
|
>>> # Assumes world_size of 3. |
|
|
|
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object |
|
|
|
>>> output = [None for _ in gather_objects] |
|
|
|
>>> dist.gather_object( |
|
|
|
>>> fastnlp_paddle_gather_object( |
|
|
|
gather_objects[dist.get_rank()], |
|
|
|
output if dist.get_rank() == 0 else None, |
|
|
|
dst=0 |
|
|
@@ -84,99 +99,58 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA |
|
|
|
>>> # On rank 0 |
|
|
|
>>> output |
|
|
|
['foo', 12, {1: 2}] |
|
|
|
|
|
|
|
:param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象 |
|
|
|
:param dst: 目标的 rank 。 |
|
|
|
:param group: 在哪个 group 执行该函数。 |
|
|
|
:return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj |
|
|
|
""" |
|
|
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: |
|
|
|
return [obj] |
|
|
|
|
|
|
|
if dist.get_rank() == dst: |
|
|
|
object_gather_list = [None for _ in range(dist.get_world_size(group))] |
|
|
|
object_gather_list = [None for _ in range(dist.get_world_size())] |
|
|
|
else: |
|
|
|
object_gather_list = None |
|
|
|
|
|
|
|
if group is None: |
|
|
|
group = DEFAULT_TORCH_GROUP |
|
|
|
# if group is None: |
|
|
|
# TODO 2.2 版本存在 bug |
|
|
|
# group = dist.collective._get_global_group() |
|
|
|
|
|
|
|
if dist.distributed_c10d._rank_not_in_group(group): |
|
|
|
if group is not None and not group.is_member(): |
|
|
|
return |
|
|
|
|
|
|
|
# Ensure object_gather_list is specified appopriately. |
|
|
|
my_rank = dist.get_rank() |
|
|
|
_validate_output_list_for_rank(my_rank, dst, object_gather_list) |
|
|
|
# 防止 unpickle 的时候出现在了发送的 gpu 上。 |
|
|
|
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) |
|
|
|
obj = paddle_move_data_to_device(obj, device="cpu") |
|
|
|
input_tensor, local_size = _object_to_tensor(obj) |
|
|
|
group_backend = dist.get_backend(group) |
|
|
|
current_device = torch.device("cpu") |
|
|
|
is_nccl_backend = group_backend == dist.Backend.NCCL |
|
|
|
if is_nccl_backend: |
|
|
|
current_device = torch.device('cuda', torch.cuda.current_device()) |
|
|
|
input_tensor = input_tensor.to(current_device) |
|
|
|
local_size = local_size.to(current_device) |
|
|
|
# Gather all local sizes. This is so that we can find the max size, and index |
|
|
|
# until the correct size when deserializing the tensors. |
|
|
|
group_size = dist.get_world_size(group=group) |
|
|
|
object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device) |
|
|
|
object_size_list = [ |
|
|
|
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) |
|
|
|
] |
|
|
|
# Allgather tensor sizes. An all-gather is needed here despite this being a |
|
|
|
# gather, since each rank needs to broadcast a tensor of the same (maximal) |
|
|
|
# size. |
|
|
|
# 目前 paddle 的 group 仅支持 nccl |
|
|
|
input_tensor = paddle_move_data_to_device(input_tensor, device=paddle.device.get_device()) |
|
|
|
local_size = paddle_move_data_to_device(local_size, device=paddle.device.get_device()) |
|
|
|
|
|
|
|
# 收集所有的 local_size,找到最大的 size |
|
|
|
object_size_list = [] |
|
|
|
dist.all_gather(object_size_list, local_size, group=group) |
|
|
|
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] |
|
|
|
# Resize tensor to max size across all ranks. |
|
|
|
input_tensor.resize_(max_object_size) |
|
|
|
# Avoid populating output tensors if the result won't be gathered on this rank. |
|
|
|
if my_rank == dst: |
|
|
|
coalesced_output_tensor = torch.empty( |
|
|
|
max_object_size * group_size, dtype=torch.uint8, device=current_device |
|
|
|
) |
|
|
|
# Output tensors are nonoverlapping views of coalesced_output_tensor |
|
|
|
output_tensors = [ |
|
|
|
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] |
|
|
|
for i in range(group_size) |
|
|
|
] |
|
|
|
# All ranks call gather with equal-sized tensors. |
|
|
|
dist.gather( |
|
|
|
input_tensor, |
|
|
|
gather_list=output_tensors if my_rank == dst else None, |
|
|
|
dst=dst, |
|
|
|
group=group, |
|
|
|
) |
|
|
|
input_tensor.reshape_(max_object_size) |
|
|
|
# TODO 暂时没有在 paddle 中发现类似 torch.distributed.gather 的函数 |
|
|
|
output_tensors = [] |
|
|
|
dist.all_gather(output_tensors, input_tensor, group) |
|
|
|
if my_rank != dst: |
|
|
|
return |
|
|
|
for i, tensor in enumerate(output_tensors): |
|
|
|
tensor = tensor.type(torch.uint8) # type: ignore[call-overload] |
|
|
|
tensor = tensor.astype(paddle.uint8) |
|
|
|
tensor_size = object_size_list[i] |
|
|
|
object_gather_list[i] = _tensor_to_object(tensor, tensor_size) |
|
|
|
|
|
|
|
|
|
|
|
def _object_to_tensor(obj, device=None): |
|
|
|
f = io.BytesIO() |
|
|
|
_pickler(f).dump(obj) |
|
|
|
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined] |
|
|
|
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. |
|
|
|
# Otherwise, it will casue 100X slowdown. |
|
|
|
# See: https://github.com/pytorch/pytorch/issues/65696 |
|
|
|
byte_tensor = torch.ByteTensor(byte_storage) |
|
|
|
local_size = torch.LongTensor([byte_tensor.numel()]) |
|
|
|
if device is not None: |
|
|
|
byte_tensor = byte_tensor.to(device) |
|
|
|
local_size = local_size.to(device) |
|
|
|
return byte_tensor, local_size |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_to_object(tensor, tensor_size): |
|
|
|
buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size] |
|
|
|
return _unpickler(io.BytesIO(buf)).load() |
|
|
|
|
|
|
|
|
|
|
|
def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): |
|
|
|
def send_recv_object(obj, src, cur_rank, device, group=None, use_calc_stream=True): |
|
|
|
# src rank send to all other ranks |
|
|
|
size = torch.LongTensor([0]).to(device) |
|
|
|
size = paddle_move_data_to_device(paddle.to_tensor([0]), device) |
|
|
|
|
|
|
|
if cur_rank == src: |
|
|
|
world_size = dist.get_world_size(group=group) |
|
|
|
world_size = dist.get_world_size() |
|
|
|
tensor, size = _object_to_tensor(obj) |
|
|
|
tensor = tensor.to(device) |
|
|
|
size = size.to(device) |
|
|
@@ -185,15 +159,15 @@ def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): |
|
|
|
dist.broadcast(size, src, group=group) |
|
|
|
for subrank in range(world_size): |
|
|
|
if subrank != src: |
|
|
|
dist.send(tensor=tensor, dst=subrank, group=group, tag=tag) |
|
|
|
dist.send(tensor=tensor, dst=subrank, group=group, use_calc_stream=use_calc_stream) |
|
|
|
else: |
|
|
|
dist.broadcast(size, src, group=group) |
|
|
|
tensor = torch.ByteTensor([0] * size).to(device) |
|
|
|
dist.recv(tensor=tensor, src=src, group=group, tag=tag) |
|
|
|
tensor = paddle_move_data_to_device(paddle.to_tensor([0] * size), device) |
|
|
|
dist.recv(tensor=tensor, src=src, group=group, use_calc_stream=use_calc_stream) |
|
|
|
|
|
|
|
return _tensor_to_object(tensor.cpu(), size) |
|
|
|
|
|
|
|
def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: |
|
|
|
def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List: |
|
|
|
""" |
|
|
|
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 |
|
|
|
|
|
|
@@ -220,178 +194,108 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) |
|
|
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: |
|
|
|
return [obj] |
|
|
|
|
|
|
|
if group is None: |
|
|
|
group = DEFAULT_TORCH_GROUP |
|
|
|
if isinstance(obj, torch.Tensor): |
|
|
|
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] |
|
|
|
# if group is None: |
|
|
|
# TODO 2.2 版本存在 bug |
|
|
|
# group = dist.collective._get_global_group() |
|
|
|
if isinstance(obj, paddle.Tensor): |
|
|
|
objs = [] |
|
|
|
dist.all_gather(objs, obj, group=group) |
|
|
|
else: |
|
|
|
objs = [None for _ in range(dist.get_world_size(group))] |
|
|
|
objs = [None for _ in range(dist.get_world_size())] |
|
|
|
# 防止 unpickle 的时候弄到发送的 gpu 上了 |
|
|
|
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) |
|
|
|
if _TORCH_GREATER_EQUAL_1_8: |
|
|
|
dist.all_gather_object(objs, obj, group=group) |
|
|
|
else: |
|
|
|
objs = all_gather_object(objs, obj, group=group) |
|
|
|
obj = paddle_move_data_to_device(obj, "cpu") |
|
|
|
objs = all_gather_object(objs, obj, group=group) |
|
|
|
|
|
|
|
return objs |
|
|
|
|
|
|
|
|
|
|
|
def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): |
|
|
|
def fastnlp_paddle_broadcast_object(obj, src, device=None, group=None): |
|
|
|
""" |
|
|
|
将 src 上的 obj 对象广播到其它 rank 上。 |
|
|
|
|
|
|
|
:param obj: |
|
|
|
:param src: |
|
|
|
:param obj: 需要发送的对象 |
|
|
|
:param src: 从哪里发出。 |
|
|
|
:param device: |
|
|
|
:param group: |
|
|
|
:param group: 属于哪个通信 group |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: |
|
|
|
if src == dist.get_rank(group): |
|
|
|
if src == dist.get_rank(): |
|
|
|
return obj |
|
|
|
else: |
|
|
|
return None |
|
|
|
|
|
|
|
if group is None: |
|
|
|
group = DEFAULT_TORCH_GROUP |
|
|
|
cur_rank = dist.get_rank(group) |
|
|
|
cur_rank = dist.get_rank() |
|
|
|
if cur_rank == src: |
|
|
|
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 |
|
|
|
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) |
|
|
|
if _TORCH_GREATER_EQUAL_1_8: |
|
|
|
if cur_rank!=src: |
|
|
|
get_obj = [None] |
|
|
|
dist.broadcast_object_list(get_obj, src=src, group=group) |
|
|
|
return get_obj[0] |
|
|
|
else: |
|
|
|
dist.broadcast_object_list([obj], src=src, group=group) |
|
|
|
return obj |
|
|
|
obj = paddle_move_data_to_device(obj, "cpu") |
|
|
|
|
|
|
|
if device is None: |
|
|
|
device = torch.cuda.current_device() |
|
|
|
device = paddle.device.get_device() |
|
|
|
|
|
|
|
if cur_rank == src: |
|
|
|
tensor, size = _object_to_tensor(obj, device=device) |
|
|
|
else: |
|
|
|
size = torch.LongTensor([0]).to(device) |
|
|
|
size = paddle_move_data_to_device(paddle.to_tensor([0]), device) |
|
|
|
|
|
|
|
dist.broadcast(size, src=src, group=group) |
|
|
|
if cur_rank != src: |
|
|
|
tensor = torch.empty( |
|
|
|
size.int().item(), # type: ignore[arg-type] |
|
|
|
dtype=torch.uint8, |
|
|
|
device=device |
|
|
|
tensor = paddle.empty( |
|
|
|
size.astype(paddle.int32), # type: ignore[arg-type] |
|
|
|
dtype=paddle.int32, |
|
|
|
) |
|
|
|
dist.broadcast(tensor, src=src, group=group) |
|
|
|
|
|
|
|
return _tensor_to_object(tensor, tensor_size=size.item()) |
|
|
|
|
|
|
|
|
|
|
|
def _check_for_nccl_backend(group): |
|
|
|
pg = group or dist.distributed_c10d._get_default_group() |
|
|
|
# It is not expected for PG to be wrapped many times, but support it just |
|
|
|
# in case |
|
|
|
while isinstance(pg, _ProcessGroupWrapper): |
|
|
|
pg = pg.wrapped_pg |
|
|
|
|
|
|
|
return ( |
|
|
|
dist.is_nccl_available() and |
|
|
|
isinstance(pg, dist.ProcessGroupNCCL) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def all_gather_object(object_list, obj, group=None): |
|
|
|
""" |
|
|
|
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 |
|
|
|
|
|
|
|
Gathers picklable objects from the whole group into a list. Similar to |
|
|
|
:func:`all_gather`, but Python objects can be passed in. Note that the object |
|
|
|
must be picklable in order to be gathered. |
|
|
|
|
|
|
|
Args: |
|
|
|
object_list (list[Any]): Output list. It should be correctly sized as the |
|
|
|
size of the group for this collective and will contain the output. |
|
|
|
object (Any): Pickable Python object to be broadcast from current process. |
|
|
|
group (ProcessGroup, optional): The process group to work on. If None, |
|
|
|
the default process group will be used. Default is ``None``. |
|
|
|
|
|
|
|
Returns: |
|
|
|
None. If the calling rank is part of this group, the output of the |
|
|
|
collective will be populated into the input ``object_list``. If the |
|
|
|
calling rank is not part of the group, the passed in ``object_list`` will |
|
|
|
be unmodified. |
|
|
|
|
|
|
|
.. note:: Note that this API differs slightly from the :func:`all_gather` |
|
|
|
collective since it does not provide an ``async_op`` handle and thus |
|
|
|
will be a blocking call. |
|
|
|
|
|
|
|
.. note:: For NCCL-based processed groups, internal tensor representations |
|
|
|
of objects must be moved to the GPU device before communication takes |
|
|
|
place. In this case, the device used is given by |
|
|
|
``torch.cuda.current_device()`` and it is the user's responsiblity to |
|
|
|
ensure that this is set so that each rank has an individual GPU, via |
|
|
|
``torch.cuda.set_device()``. |
|
|
|
|
|
|
|
.. warning:: |
|
|
|
:func:`all_gather_object` uses ``pickle`` module implicitly, which is |
|
|
|
known to be insecure. It is possible to construct malicious pickle data |
|
|
|
which will execute arbitrary code during unpickling. Only call this |
|
|
|
function with data you trust. |
|
|
|
|
|
|
|
Example:: |
|
|
|
>>> # Note: Process group initialization omitted on each rank. |
|
|
|
>>> import torch.distributed as dist |
|
|
|
>>> # Assumes world_size of 3. |
|
|
|
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object |
|
|
|
>>> output = [None for _ in gather_objects] |
|
|
|
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) |
|
|
|
>>> all_gather_object(output, gather_objects[dist.get_rank()]) |
|
|
|
>>> output |
|
|
|
['foo', 12, {1: 2}] |
|
|
|
|
|
|
|
:param object_list: |
|
|
|
:param obj: |
|
|
|
:param group: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2: |
|
|
|
return [obj] |
|
|
|
|
|
|
|
if dist.distributed_c10d._rank_not_in_group(group): |
|
|
|
if group is not None and not group.is_member(): |
|
|
|
return |
|
|
|
if _TORCH_GREATER_EQUAL_1_8: |
|
|
|
current_device = torch.device("cpu") |
|
|
|
is_nccl_backend = _check_for_nccl_backend(group) |
|
|
|
if is_nccl_backend: |
|
|
|
# See note about using torch.cuda.current_device() here in docstring. |
|
|
|
# We cannot simply use my_rank since rank == device is not necessarily |
|
|
|
# true. |
|
|
|
current_device = torch.device("cuda", torch.cuda.current_device()) |
|
|
|
else: |
|
|
|
current_device = torch.cuda.current_device() |
|
|
|
|
|
|
|
current_device = paddle.device.get_device() |
|
|
|
|
|
|
|
input_tensor, local_size = _object_to_tensor(obj, device=current_device) |
|
|
|
|
|
|
|
# Gather all local sizes. This is so that we can find the max size, and index |
|
|
|
# until the correct size when deserializing the tensors. |
|
|
|
group_size = dist.get_world_size(group=group) |
|
|
|
object_sizes_tensor = torch.zeros( |
|
|
|
group_size, dtype=torch.long, device=current_device |
|
|
|
) |
|
|
|
object_size_list = [ |
|
|
|
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) |
|
|
|
] |
|
|
|
# 聚合 tensor 的 size,找到最大的 |
|
|
|
object_size_list = [] |
|
|
|
# Allgather tensor sizes |
|
|
|
dist.all_gather(object_size_list, local_size, group=group) |
|
|
|
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] |
|
|
|
# Resize tensor to max size across all ranks. |
|
|
|
input_tensor.resize_(max_object_size) |
|
|
|
coalesced_output_tensor = torch.empty( |
|
|
|
max_object_size * group_size, dtype=torch.uint8, device=current_device |
|
|
|
) |
|
|
|
# 将张量进行 pad |
|
|
|
pad_dims = [] |
|
|
|
pad_by = (max_object_size - local_size).detach().cpu() |
|
|
|
for val in reversed(pad_by): |
|
|
|
pad_dims.append(0) |
|
|
|
pad_dims.append(val.item()) |
|
|
|
tensor_padded = paddle.nn.functional.pad(input_tensor, pad_dims) |
|
|
|
|
|
|
|
# Output tensors are nonoverlapping views of coalesced_output_tensor |
|
|
|
output_tensors = [ |
|
|
|
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] |
|
|
|
for i in range(group_size) |
|
|
|
] |
|
|
|
dist.all_gather(output_tensors, input_tensor, group=group) |
|
|
|
output_tensors = [] |
|
|
|
dist.all_gather(output_tensors, tensor_padded, group=group) |
|
|
|
dist.barrier() |
|
|
|
# Deserialize outputs back to object. |
|
|
|
for i, tensor in enumerate(output_tensors): |
|
|
|
tensor = tensor.type(torch.uint8) |
|
|
|
if tensor.device != torch.device("cpu"): |
|
|
|
tensor = tensor.astype(paddle.uint8) |
|
|
|
if not tensor.place.is_cpu_place(): |
|
|
|
tensor = tensor.cpu() |
|
|
|
tensor_size = object_size_list[i] |
|
|
|
object_list[i] = _tensor_to_object(tensor, tensor_size) |