Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
b117f6170c
6 changed files with 377 additions and 298 deletions
  1. +139
    -235
      fastNLP/core/drivers/paddle_driver/dist_utils.py
  2. +6
    -8
      fastNLP/core/drivers/paddle_driver/fleet.py
  3. +10
    -47
      fastNLP/core/metrics/backend/paddle_backend/backend.py
  4. +1
    -1
      tests/core/controllers/test_trainer_fleet_outside.py
  5. +185
    -0
      tests/core/drivers/paddle_driver/test_dist_utils.py
  6. +36
    -7
      tests/core/drivers/paddle_driver/test_fleet.py

+ 139
- 235
fastNLP/core/drivers/paddle_driver/dist_utils.py View File

@@ -1,27 +1,25 @@
import io
import os
import pickle
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
import os
from typing import Any, List

from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.env import FASTNLP_NO_SYNC
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed as dist
if _TORCH_GREATER_EQUAL_1_8:
try:
from torch._C._distributed_c10d import ProcessGroupGloo
from torch._C._distributed_c10d import _ProcessGroupWrapper
except ImportError:
pass


from fastNLP.core.utils import apply_to_collection

from fastNLP.core.utils import paddle_move_data_to_device

if _NEED_IMPORT_PADDLE:
import paddle
import paddle.distributed as dist
from paddle.framework.io import (
_is_state_dict,
_build_saved_state_dict,
_unpack_saved_dict,
_pickle_save,
_pack_loaded_dict,
_ndarray_to_tensor,
_parse_load_result,
)

def _validate_output_list_for_rank(my_rank, dst, gather_list):
if dst == my_rank:
@@ -35,48 +33,65 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
"on non-destination ranks."
)


def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP):
def paddle_pickle_dump(obj, stream, protocol):
"""
从其它 rank gather 东西到 dst rank 。
Reference to `paddle.save`
"""
if _is_state_dict(obj):
saved_obj = _build_saved_state_dict(obj)
saved_obj = _unpack_saved_dict(saved_obj, protocol)
pickle.dump(saved_obj, stream, protocol=protocol)
else:
_pickle_save(obj, stream, protocol)

Gathers picklable objects from the whole group in a single process.
Similar to :func:`gather`, but Python objects can be passed in. Note that the
object must be picklable in order to be gathered.
def paddle_pickle_load(stream):
"""
Reference to `paddle.load`
"""
load_result = pickle.load(stream)
if isinstance(load_result, dict):
load_result = _pack_loaded_dict(load_result)
if "StructuredToParameterName@@" in load_result:

for key in load_result["StructuredToParameterName@@"]:
if isinstance(load_result[key], np.ndarray):
load_result[key] = _ndarray_to_tensor(
load_result[key], return_numpy=False)

if "StructuredToParameterName@@" in load_result:
del load_result["StructuredToParameterName@@"]
else:
load_result = _parse_load_result(load_result, return_numpy=False)

Args:
obj (Any): Input object. Must be picklable.
object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
should be correctly sized as the size of the group for this
collective and will contain the output. Must be ``None`` on non-dst
ranks. (default is ``None``)
dst (int, optional): Destination rank. (default is 0)
group: (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.
else:
load_result = _parse_load_result(load_result, return_numpy=False)

Returns:
None. On the ``dst`` rank, ``object_gather_list`` will contain the
output of the collective.
return load_result

.. note:: Note that this API differs slightly from the gather collective
since it does not provide an async_op handle and thus will be a blocking
call.
def _object_to_tensor(obj, device=None):
f = io.BytesIO()
paddle_pickle_dump(obj, f, protocol=2)
byte_data = list(f.getvalue())
byte_tensor = paddle.to_tensor(byte_data, dtype=paddle.int32)
local_size = paddle.to_tensor([byte_tensor.numel()])
if device is not None:
byte_tensor = paddle_move_data_to_device(byte_tensor, device)
local_size = paddle_move_data_to_device(local_size, device)
return byte_tensor, local_size

.. note:: Note that this API is not supported when using the NCCL backend.
def _tensor_to_object(tensor, tensor_size):
buf = tensor.astype(paddle.uint8).detach().cpu().numpy().tobytes()[:tensor_size]
return paddle_pickle_load(io.BytesIO(buf))

.. warning::
:func:`gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.
def fastnlp_paddle_gather_object(obj, dst=0, group=None):
"""
从其它 rank gather 东西到 dst rank 。

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.gather_object(
>>> fastnlp_paddle_gather_object(
gather_objects[dist.get_rank()],
output if dist.get_rank() == 0 else None,
dst=0
@@ -84,99 +99,58 @@ def fastnlp_paddle_gather_object(obj, object_gather_list=None, dst=0, group=DEFA
>>> # On rank 0
>>> output
['foo', 12, {1: 2}]

:param obj: 需要发送的 obj 对象,需要是可以 pickable 的对象
:param dst: 目标的 rank 。
:param group: 在哪个 group 执行该函数。
:return: 在 dst 上面返回 world_size 的 list,依次为 rank 0;rank 1...上 obj
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]

if dist.get_rank() == dst:
object_gather_list = [None for _ in range(dist.get_world_size(group))]
object_gather_list = [None for _ in range(dist.get_world_size())]
else:
object_gather_list = None

if group is None:
group = DEFAULT_TORCH_GROUP
# if group is None:
# TODO 2.2 版本存在 bug
# group = dist.collective._get_global_group()

if dist.distributed_c10d._rank_not_in_group(group):
if group is not None and not group.is_member():
return

# Ensure object_gather_list is specified appopriately.
my_rank = dist.get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
# 防止 unpickle 的时候出现在了发送的 gpu 上。
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
obj = paddle_move_data_to_device(obj, device="cpu")
input_tensor, local_size = _object_to_tensor(obj)
group_backend = dist.get_backend(group)
current_device = torch.device("cpu")
is_nccl_backend = group_backend == dist.Backend.NCCL
if is_nccl_backend:
current_device = torch.device('cuda', torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(group_size, dtype=torch.long, device=current_device)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# Allgather tensor sizes. An all-gather is needed here despite this being a
# gather, since each rank needs to broadcast a tensor of the same (maximal)
# size.
# 目前 paddle 的 group 仅支持 nccl
input_tensor = paddle_move_data_to_device(input_tensor, device=paddle.device.get_device())
local_size = paddle_move_data_to_device(local_size, device=paddle.device.get_device())

# 收集所有的 local_size,找到最大的 size
object_size_list = []
dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
# Avoid populating output tensors if the result won't be gathered on this rank.
if my_rank == dst:
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
# All ranks call gather with equal-sized tensors.
dist.gather(
input_tensor,
gather_list=output_tensors if my_rank == dst else None,
dst=dst,
group=group,
)
input_tensor.reshape_(max_object_size)
# TODO 暂时没有在 paddle 中发现类似 torch.distributed.gather 的函数
output_tensors = []
dist.all_gather(output_tensors, input_tensor, group)
if my_rank != dst:
return
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8) # type: ignore[call-overload]
tensor = tensor.astype(paddle.uint8)
tensor_size = object_size_list[i]
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)


def _object_to_tensor(obj, device=None):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
if device is not None:
byte_tensor = byte_tensor.to(device)
local_size = local_size.to(device)
return byte_tensor, local_size


def _tensor_to_object(tensor, tensor_size):
buf = tensor.detach().cpu().numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()


def send_recv_object(obj, src, cur_rank, device, group=None, tag=0):
def send_recv_object(obj, src, cur_rank, device, group=None, use_calc_stream=True):
# src rank send to all other ranks
size = torch.LongTensor([0]).to(device)
size = paddle_move_data_to_device(paddle.to_tensor([0]), device)

if cur_rank == src:
world_size = dist.get_world_size(group=group)
world_size = dist.get_world_size()
tensor, size = _object_to_tensor(obj)
tensor = tensor.to(device)
size = size.to(device)
@@ -185,15 +159,15 @@ def send_recv_object(obj, src, cur_rank, device, group=None, tag=0):
dist.broadcast(size, src, group=group)
for subrank in range(world_size):
if subrank != src:
dist.send(tensor=tensor, dst=subrank, group=group, tag=tag)
dist.send(tensor=tensor, dst=subrank, group=group, use_calc_stream=use_calc_stream)
else:
dist.broadcast(size, src, group=group)
tensor = torch.ByteTensor([0] * size).to(device)
dist.recv(tensor=tensor, src=src, group=group, tag=tag)
tensor = paddle_move_data_to_device(paddle.to_tensor([0] * size), device)
dist.recv(tensor=tensor, src=src, group=group, use_calc_stream=use_calc_stream)

return _tensor_to_object(tensor.cpu(), size)

def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List:
def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List:
"""
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。

@@ -220,178 +194,108 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP)
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]

if group is None:
group = DEFAULT_TORCH_GROUP
if isinstance(obj, torch.Tensor):
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))]
# if group is None:
# TODO 2.2 版本存在 bug
# group = dist.collective._get_global_group()
if isinstance(obj, paddle.Tensor):
objs = []
dist.all_gather(objs, obj, group=group)
else:
objs = [None for _ in range(dist.get_world_size(group))]
objs = [None for _ in range(dist.get_world_size())]
# 防止 unpickle 的时候弄到发送的 gpu 上了
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if _TORCH_GREATER_EQUAL_1_8:
dist.all_gather_object(objs, obj, group=group)
else:
objs = all_gather_object(objs, obj, group=group)
obj = paddle_move_data_to_device(obj, "cpu")
objs = all_gather_object(objs, obj, group=group)

return objs


def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP):
def fastnlp_paddle_broadcast_object(obj, src, device=None, group=None):
"""
将 src 上的 obj 对象广播到其它 rank 上。

:param obj:
:param src:
:param obj: 需要发送的对象
:param src: 从哪里发出。
:param device:
:param group:
:param group: 属于哪个通信 group
:return:
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
if src == dist.get_rank(group):
if src == dist.get_rank():
return obj
else:
return None

if group is None:
group = DEFAULT_TORCH_GROUP
cur_rank = dist.get_rank(group)
cur_rank = dist.get_rank()
if cur_rank == src:
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if _TORCH_GREATER_EQUAL_1_8:
if cur_rank!=src:
get_obj = [None]
dist.broadcast_object_list(get_obj, src=src, group=group)
return get_obj[0]
else:
dist.broadcast_object_list([obj], src=src, group=group)
return obj
obj = paddle_move_data_to_device(obj, "cpu")

if device is None:
device = torch.cuda.current_device()
device = paddle.device.get_device()

if cur_rank == src:
tensor, size = _object_to_tensor(obj, device=device)
else:
size = torch.LongTensor([0]).to(device)
size = paddle_move_data_to_device(paddle.to_tensor([0]), device)

dist.broadcast(size, src=src, group=group)
if cur_rank != src:
tensor = torch.empty(
size.int().item(), # type: ignore[arg-type]
dtype=torch.uint8,
device=device
tensor = paddle.empty(
size.astype(paddle.int32), # type: ignore[arg-type]
dtype=paddle.int32,
)
dist.broadcast(tensor, src=src, group=group)

return _tensor_to_object(tensor, tensor_size=size.item())


def _check_for_nccl_backend(group):
pg = group or dist.distributed_c10d._get_default_group()
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, _ProcessGroupWrapper):
pg = pg.wrapped_pg

return (
dist.is_nccl_available() and
isinstance(pg, dist.ProcessGroupNCCL)
)


def all_gather_object(object_list, obj, group=None):
"""
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。

Gathers picklable objects from the whole group into a list. Similar to
:func:`all_gather`, but Python objects can be passed in. Note that the object
must be picklable in order to be gathered.

Args:
object_list (list[Any]): Output list. It should be correctly sized as the
size of the group for this collective and will contain the output.
object (Any): Pickable Python object to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.

Returns:
None. If the calling rank is part of this group, the output of the
collective will be populated into the input ``object_list``. If the
calling rank is not part of the group, the passed in ``object_list`` will
be unmodified.

.. note:: Note that this API differs slightly from the :func:`all_gather`
collective since it does not provide an ``async_op`` handle and thus
will be a blocking call.

.. note:: For NCCL-based processed groups, internal tensor representations
of objects must be moved to the GPU device before communication takes
place. In this case, the device used is given by
``torch.cuda.current_device()`` and it is the user's responsiblity to
ensure that this is set so that each rank has an individual GPU, via
``torch.cuda.set_device()``.

.. warning::
:func:`all_gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]

:param object_list:
:param obj:
:param group:
:return:
"""
if int(os.environ.get(FASTNLP_NO_SYNC, '0')) == 2:
return [obj]

if dist.distributed_c10d._rank_not_in_group(group):
if group is not None and not group.is_member():
return
if _TORCH_GREATER_EQUAL_1_8:
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
else:
current_device = torch.cuda.current_device()
current_device = paddle.device.get_device()

input_tensor, local_size = _object_to_tensor(obj, device=current_device)

# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(
group_size, dtype=torch.long, device=current_device
)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# 聚合 tensor 的 size,找到最大的
object_size_list = []
# Allgather tensor sizes
dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
# 将张量进行 pad
pad_dims = []
pad_by = (max_object_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
tensor_padded = paddle.nn.functional.pad(input_tensor, pad_dims)

# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
dist.all_gather(output_tensors, input_tensor, group=group)
output_tensors = []
dist.all_gather(output_tensors, tensor_padded, group=group)
dist.barrier()
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
if tensor.device != torch.device("cpu"):
tensor = tensor.astype(paddle.uint8)
if not tensor.place.is_cpu_place():
tensor = tensor.cpu()
tensor_size = object_size_list[i]
object_list[i] = _tensor_to_object(tensor, tensor_size)

+ 6
- 8
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -11,6 +11,7 @@ from .utils import (
replace_sampler,
replace_batch_sampler,
)
from .dist_utils import fastnlp_paddle_all_gather, fastnlp_paddle_broadcast_object

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import (
@@ -451,12 +452,12 @@ class PaddleFleetDriver(PaddleDriver):
:return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
"""
return
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC == 2 直接返回。
return
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)
device = self.data_device
# 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误
device = get_device_from_visible(device)
return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group)

def all_gather(self, obj, group) -> List:
def all_gather(self, obj, group=None) -> List:
"""
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过
pickle 进行序列化,接收到之后再反序列化。
@@ -479,7 +480,4 @@ class PaddleFleetDriver(PaddleDriver):
:param group:
:return:
"""
return
if int(os.environ.get(FASTNLP_NO_SYNC, 0)) == 2: # 如果 FASTNLP_NO_SYNC 表示不执行
return [obj]
return fastnlp_paddle_all_gather(obj, group=group)

+ 10
- 47
fastNLP/core/metrics/backend/paddle_backend/backend.py View File

@@ -5,8 +5,8 @@ import numpy as np
from fastNLP.core.metrics.backend import Backend
from fastNLP.core.utils.paddle_utils import paddle_to
from fastNLP.core.metrics.utils import AggregateMethodError
from fastNLP.core.utils import is_in_paddle_dist
from fastNLP.core.drivers.paddle_driver.utils import get_device_from_visible
from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE

if _NEED_IMPORT_PADDLE:
@@ -34,7 +34,7 @@ class PaddleBackend(Backend):
if parallel_helper._is_parallel_ctx_initialized():
if method is None:
raise AggregateMethodError(should_have_aggregate_method=True)
tensor = self._gather_all(tensor)
tensor = self.all_gather_object(tensor)
if isinstance(tensor[0], paddle.Tensor):
tensor = paddle.stack(tensor)
# 第一步, aggregate结果
@@ -74,55 +74,18 @@ class PaddleBackend(Backend):
return tensor.cpu().detach().numpy()
elif isinstance(tensor, np.array):
return tensor
elif isinstance(tensor, (float, int)):
return tensor
else:
raise ValueError(f"tensor: {tensor} can not convert to ndarray!")

@staticmethod
def _gather_all(result, group: Optional[Any] = None) -> List:
"""
聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding
"""
# TODO check 正确性
# 有 paddle 那边的 bug,2.3 版本的时候修复了,到时候改一下
# if group is None:
# group = dist.get_group(0)

world_size = group.nranks if group is not None else dist.get_world_size()
dist.barrier(group=group)

# 张量为 标量的情况,简单地gather就好
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)

# 获得 result 的 shape
local_size = paddle.to_tensor(result.shape)
# 将 group 中所有 result 的大小聚合在一起
local_sizes = []
dist.all_gather(local_sizes, local_size, group=group)
# 堆叠后,计算出 shape 每一维度的最大值
max_size = paddle.stack(local_sizes).max(axis=0)
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)

# 如果所有的结果大小相同,那么可以直接聚合
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)

# 否则,padding 与最大的张量对齐
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = paddle.nn.functional.pad(result, pad_dims)
# 重新进行聚合
gathered_result = []
dist.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size.tolist()]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result

def move_tensor_to_device(self, tensor, device):
# TODO 如果在这里处理的话,会不会在别的地方引起bug?
device = get_device_from_visible(device)
return paddle_to(tensor, device)

def all_gather_object(self, obj, group=None) -> List:
if self.is_distributed():
obj_list = fastnlp_paddle_all_gather(obj, group=group)
return obj_list
return [obj]

+ 1
- 1
tests/core/controllers/test_trainer_fleet_outside.py View File

@@ -1,7 +1,7 @@
"""
这个文件测试用户以python -m paddle.distributed.launch 启动的情况
并且自己初始化了 fleet
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet_outside.py
"""
import os
os.environ["FASTNLP_BACKEND"] = "paddle"


+ 185
- 0
tests/core/drivers/paddle_driver/test_dist_utils.py View File

@@ -0,0 +1,185 @@
import os
import sys
import signal
import pytest
import traceback
os.environ["FASTNLP_BACKEND"] = "paddle"

import numpy as np

from fastNLP.core.drivers.paddle_driver.dist_utils import (
_tensor_to_object,
_object_to_tensor,
fastnlp_paddle_all_gather,
fastnlp_paddle_broadcast_object,
)
from fastNLP.core.drivers.paddle_driver.fleet_launcher import FleetLauncher
from tests.helpers.utils import magic_argv_env_context

import paddle
import paddle.distributed as dist

class TestDistUtilsTools:
"""
测试一些工具函数
"""

@pytest.mark.parametrize("device", (["cpu", 0]))
def test_tensor_object_transfer_tensor(self, device):
"""
测试 _tensor_to_object 和 _object_to_tensor 二者的结果能否互相转换
"""
# 张量
paddle_tensor = paddle.rand((3, 4, 5)).cpu()
obj_tensor, size = _object_to_tensor(paddle_tensor, device=device)
res = _tensor_to_object(obj_tensor, size)
assert paddle.equal_all(res, paddle_tensor)

# 列表
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)]
obj_tensor, size = _object_to_tensor(paddle_list, device=device)
res = _tensor_to_object(obj_tensor, size)
assert isinstance(res, list)
for before, after in zip(paddle_list, res):
assert paddle.equal_all(after, before)

# 元组
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)]
paddle_tuple = tuple(paddle_list)
obj_tensor, size = _object_to_tensor(paddle_tuple, device=device)
res = _tensor_to_object(obj_tensor, size)
assert isinstance(res, tuple)
for before, after in zip(paddle_list, res):
assert paddle.equal_all(after, before)
# 字典
paddle_dict = {
"tensor": paddle.rand((3, 4)),
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"dict":{
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"tensor": paddle.rand((3, 4))
},
"int": 2,
"string": "test string"
}
obj_tensor, size = _object_to_tensor(paddle_dict, device=device)
res = _tensor_to_object(obj_tensor, size)
assert isinstance(res, dict)
assert paddle.equal_all(res["tensor"], paddle_dict["tensor"])
assert isinstance(res["list"], list)
for before, after in zip(paddle_dict["list"], res["list"]):
assert paddle.equal_all(after, before)

assert isinstance(res["dict"], dict)
assert paddle.equal_all(res["dict"]["tensor"], paddle_dict["dict"]["tensor"])
for before, after in zip(paddle_dict["dict"]["list"], res["dict"]["list"]):
assert paddle.equal_all(after, before)
assert res["int"] == paddle_dict["int"]
assert res["string"] == paddle_dict["string"]


class TestAllGatherAndBroadCast:

@classmethod
def setup_class(cls):
devices = [0,1,2]
output_from_new_proc = "only_error"

launcher = FleetLauncher(devices=devices, output_from_new_proc=output_from_new_proc)
cls.local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", "0"))
if cls.local_rank == 0:
launcher = FleetLauncher(devices, output_from_new_proc)
launcher.launch()
dist.fleet.init(is_collective=True)
dist.barrier()

# cls._pids = []
# dist.all_gather(cls._pids, paddle.to_tensor(os.getpid(), dtype="int32"))
# local_world_size = paddle.to_tensor(cls.local_rank, dtype="int32")
# dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX)
# local_world_size = local_world_size.item() + 1

def on_exception(self):
if self._pids is not None:

exc_type, exc_value, exc_traceback_obj = sys.exc_info()
traceback.print_tb(exc_traceback_obj, file=sys.stderr)
sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n")
for pid in self._pids:
pid = pid.item()
if pid != os.getpid():
os.kill(pid, signal.SIGKILL)

@magic_argv_env_context
def test_fastnlp_paddle_all_gather(self):
obj = {
'tensor': paddle.full(shape=(2, ), fill_value=self.local_rank).cuda(),
'numpy': np.full(shape=(2, ), fill_value=self.local_rank),
'bool': self.local_rank % 2 == 0,
'float': self.local_rank + 0.1,
'int': self.local_rank,
'dict': {
'rank': self.local_rank
},
'list': [self.local_rank] * 2,
'str': f'{self.local_rank}',
'tensors': [paddle.full(shape=(2,), fill_value=self.local_rank).cuda(),
paddle.full(shape=(2,), fill_value=self.local_rank).cuda()]
}
data = fastnlp_paddle_all_gather(obj)
world_size = int(os.environ['PADDLE_TRAINERS_NUM'])
assert len(data) == world_size
for i in range(world_size):
assert (data[i]['tensor'] == i).sum() == 2
assert (data[i]['numpy'] == i).sum() == 2
assert data[i]['bool'] == (i % 2 == 0)
assert np.allclose(data[i]['float'], i + 0.1)
assert data[i]['int'] == i
assert data[i]['dict']['rank'] == i
assert data[i]['list'][0] == i
assert data[i]['str'] == f'{i}'
assert data[i]['tensors'][0][0] == i

for obj in [1, True, 'xxx']:
data = fastnlp_paddle_all_gather(obj)
assert len(data) == world_size
assert data[0] == data[1]

dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1, 2]))
def test_fastnlp_paddle_broadcast_object(self, src_rank):
if self.local_rank == src_rank:
obj = {
'tensor': paddle.full(shape=(2, ), fill_value=self.local_rank).cuda(),
'numpy': np.full(shape=(2, ), fill_value=self.local_rank),
'bool': self.local_rank % 2 == 0,
'float': self.local_rank + 0.1,
'int': self.local_rank,
'dict': {
'rank': self.local_rank
},
'list': [self.local_rank] * 2,
'str': f'{self.local_rank}',
'tensors': [paddle.full(shape=(2,), fill_value=self.local_rank).cuda(),
paddle.full(shape=(2,), fill_value=self.local_rank).cuda()]
}
else:
obj = None
data = fastnlp_paddle_broadcast_object(obj, src=src_rank, device=paddle.device.get_device())
assert data['tensor'][0] == src_rank
assert data['numpy'][0] == src_rank
assert data['bool'] == (src_rank % 2 == 0)
assert np.allclose(data['float'], src_rank + 0.1)
assert data['int'] == src_rank
assert data['dict']['rank'] == src_rank
assert data['list'][0] == src_rank
assert data['str'] == f'{src_rank}'
assert data['tensors'][0][0] == src_rank

for obj in [self.local_rank, bool(self.local_rank == 1), str(self.local_rank)]:
data = fastnlp_paddle_broadcast_object(obj, src=0, device=paddle.device.get_device())
assert int (data) == 0
dist.barrier()

+ 36
- 7
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -14,7 +14,7 @@ from fastNLP.core.samplers import (
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm
from fastNLP.core import rank_zero_rm

import paddle
import paddle.distributed as dist
@@ -112,6 +112,35 @@ class TestFleetDriverFunction:
self.driver.get_local_rank()
dist.barrier()

@magic_argv_env_context
def test_all_gather(self):
"""
测试 all_gather 函数
详细的测试在 test_dist_utils.py 中完成
"""
obj = {
"rank": self.driver.global_rank
}
obj_list = self.driver.all_gather(obj, group=None)
for i, res in enumerate(obj_list):
assert res["rank"] == i

@magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_broadcast_object(self, src_rank):
"""
测试 broadcast_object 函数
详细的函数在 test_dist_utils.py 中完成
"""
if self.driver.global_rank == src_rank:
obj = {
"rank": self.driver.global_rank
}
else:
obj = None
res = self.driver.broadcast_object(obj, src=src_rank)
assert res["rank"] == src_rank

############################################################################
#
# 测试 set_dist_repro_dataloader 函数
@@ -543,11 +572,11 @@ class TestSaveLoad:
assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
rank_zero_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")
rank_zero_rm(path + ".pdiparams")
rank_zero_rm(path + ".pdiparams.info")
rank_zero_rm(path + ".pdmodel")

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@@ -658,7 +687,7 @@ class TestSaveLoad:
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas
finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@@ -769,4 +798,4 @@ class TestSaveLoad:
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas

finally:
synchronize_safe_rm(path)
rank_zero_rm(path)

Loading…
Cancel
Save