diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index e19aa648..44cabcf4 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -599,7 +599,7 @@ class TorchDDPDriver(TorchDriver): :param group: :return: """ - return fastnlp_torch_all_gather(obj, device=self.data_device, group=group) + return fastnlp_torch_all_gather(obj, group=group) def find_free_network_port() -> str: diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index 37717f54..5e3819e7 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -1,11 +1,8 @@ import io import pickle -from typing import Mapping _pickler = pickle.Pickler _unpickler = pickle.Unpickler -from abc import ABC -from typing import Any, Union, List -import numpy as np +from typing import Any, List from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 @@ -13,103 +10,25 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch from torch import distributed as dist + try: + from torch._C._distributed_c10d import ProcessGroupMPI + except ImportError: + _MPI_AVAILABLE = False + + try: + from torch._C._distributed_c10d import ProcessGroupNCCL + except ImportError: + _NCCL_AVAILABLE = False + + try: + from torch._C._distributed_c10d import ProcessGroupGloo + from torch._C._distributed_c10d import _ProcessGroupWrapper + except ImportError: + _GLOO_AVAILABLE = False from fastNLP.core.utils import apply_to_collection - -def all_gather_object(object_list, obj, group=None): - """ - Gathers picklable objects from the whole group into a list. Similar to - :func:`all_gather`, but Python objects can be passed in. Note that the object - must be picklable in order to be gathered. - - Args: - object_list (list[Any]): Output list. It should be correctly sized as the - size of the group for this collective and will contain the output. - object (Any): Pickable Python object to be broadcast from current process. - group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. Default is ``None``. - - Returns: - None. If the calling rank is part of this group, the output of the - collective will be populated into the input ``object_list``. If the - calling rank is not part of the group, the passed in ``object_list`` will - be unmodified. - - .. note:: Note that this API differs slightly from the :func:`all_gather` - collective since it does not provide an ``async_op`` handle and thus - will be a blocking call. - - .. note:: For NCCL-based processed groups, internal tensor representations - of objects must be moved to the GPU device before communication takes - place. In this case, the device used is given by - ``torch.cuda.current_device()`` and it is the user's responsiblity to - ensure that this is set so that each rank has an individual GPU, via - ``torch.cuda.set_device()``. - - .. warning:: - :func:`all_gather_object` uses ``pickle`` module implicitly, which is - known to be insecure. It is possible to construct malicious pickle data - which will execute arbitrary code during unpickling. Only call this - function with data you trust. - - Example:: - >>> # Note: Process group initialization omitted on each rank. - >>> import torch.distributed as dist - >>> # Assumes world_size of 3. - >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object - >>> output = [None for _ in gather_objects] - >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) - >>> output - ['foo', 12, {1: 2}] - """ - if dist.distributed_c10d._rank_not_in_group(group): - return - - input_tensor, local_size = _object_to_tensor(obj) - current_device = torch.device("cpu") - if dist.is_nccl_available() and isinstance( - group or dist.distributed_c10d._get_default_group(), dist.ProcessGroupNCCL - ): - # See note about using torch.cuda.current_device() here in docstring. - # We cannot simply use my_rank since rank == device is not necessarily - # true. - current_device = torch.device("cuda", torch.cuda.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) - # Gather all local sizes. This is so that we can find the max size, and index - # until the correct size when deserializing the tensors. - group_size = dist.get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=current_device - ) - object_size_list = [ - object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) - ] - # Allgather tensor sizes - dist.all_gather(object_size_list, local_size, group=group) - max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] - # Resize tensor to max size across all ranks. - input_tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8, device=current_device - ) - # Output tensors are nonoverlapping views of coalesced_output_tensor - output_tensors = [ - coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] - for i in range(group_size) - ] - dist.all_gather(output_tensors, input_tensor, group=group) - # Deserialize outputs back to object. - for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.uint8) - if tensor.device != torch.device("cpu"): - tensor = tensor.cpu() - tensor_size = object_size_list[i] - object_list[i] = _tensor_to_object(tensor, tensor_size) - - def _validate_output_list_for_rank(my_rank, dst, gather_list): if dst == my_rank: if not gather_list: @@ -123,8 +42,10 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): ) -def gather_object(obj, object_gather_list=None, dst=0, group=None): +def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): """ + 从其它 rank gather 东西到 dst rank 。 + Gathers picklable objects from the whole group in a single process. Similar to :func:`gather`, but Python objects can be passed in. Note that the object must be picklable in order to be gathered. @@ -176,6 +97,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None): # Ensure object_gather_list is specified appopriately. my_rank = dist.get_rank() _validate_output_list_for_rank(my_rank, dst, object_gather_list) + # 防止 unpickle 的时候出现在了发送的 gpu 上。 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) input_tensor, local_size = _object_to_tensor(obj) group_backend = dist.get_backend(group) current_device = torch.device("cpu") @@ -266,113 +189,11 @@ def send_recv_object(obj, src, cur_rank, device, group=None, tag=0): return _tensor_to_object(tensor.cpu(), size) -def _all_gather(obj, **kwargs): - group = kwargs.get('group', None) - if isinstance(obj, torch.Tensor): - gathered_tensor = [torch.zeros_like(obj) for _ in - range(torch.distributed.get_world_size(group=group))] - - torch.distributed.all_gather(gathered_tensor, obj, group=group) - - return gathered_tensor - - elif isinstance(obj, tuple) and isinstance(obj[1], torch.Tensor): - tensor, size = obj - # 首先需要同步 size 吧? - group_size = dist.get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=tensor.device - ) - object_size_list = [ - object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) - ] - dist.all_gather(object_size_list, size, group=group) - max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] - # Resize tensor to max size across all ranks. - tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8, device=tensor.device - ) - - # Output tensors are nonoverlapping views of coalesced_output_tensor - output_tensors = [ - coalesced_output_tensor[max_object_size * i: max_object_size * (i + 1)] - for i in range(group_size) - ] - dist.all_gather(output_tensors, tensor, group=group) - object_list = [] - for i, tensor in enumerate(output_tensors): - tensor = tensor.type(torch.uint8) - tensor_size = object_size_list[i] - object_list.append(_tensor_to_object(tensor, tensor_size)) - return object_list - elif isinstance(obj, tuple) and len(obj) == 2: - obj, _type = obj - gathered_tensor = [torch.zeros_like(obj) for _ in - range(torch.distributed.get_world_size(group=group))] - - torch.distributed.all_gather(gathered_tensor, obj, group=group) - - if _type == np.ndarray: - gathered_tensor = [t.detach().cpu().numpy() for t in gathered_tensor] - else: - gathered_tensor = [_type(t.item()) for t in gathered_tensor] - - return gathered_tensor - else: - raise RuntimeError("Unsupported types to implement all_gather.") - - -class CanTransferDataType(ABC): - """ - 检测可以进行传输的对象。 - - """ - - @classmethod - def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: - if cls is CanTransferDataType: - if issubclass(subclass, Mapping): - return False - if subclass in (torch.Tensor, tuple, list, str, int, float, bool, np.ndarray): - return True - return False - return NotImplemented - - -def _tensorize(obj, device=None): - if isinstance(obj, torch.Tensor): - return obj - if isinstance(obj, bool): - return torch.tensor(obj, dtype=torch.uint8, device=device), bool - if isinstance(obj, float): - return torch.tensor(obj, dtype=torch.float, device=device), float - if isinstance(obj, int): - return torch.tensor(obj, dtype=torch.int, device=device), int - if isinstance(obj, np.ndarray): - return torch.from_numpy(obj), np.ndarray - return _object_to_tensor(obj, device) - - def _to_device(tensor, device): return tensor.contiguous().to(device) -def convert_to_tensors(data: Any, device=None) -> Any: - data = apply_to_collection(data, CanTransferDataType, _tensorize) - def _move_to_device_and_make_contiguous(t: Union[torch.Tensor, tuple], device: Union[str, torch.device]): - if isinstance(t, tuple): - if isinstance(t[1], torch.Tensor): # 说明是 object 转的 - return t[0].to(device).contiguous(), t[1].to(device) - else: # 说明第二个元素是type,见 to_dtype_tensor 函数 - return t[0].to(device).contiguous(), t[1] - return t.to(device).contiguous() - - data = apply_to_collection(data, (torch.Tensor, tuple), _move_to_device_and_make_contiguous, device=device) - return data - - -def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: +def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: """ 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 @@ -390,36 +211,28 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: {'a': 1, 'b':[1, 2], 'c':{'d': 2}} ] - :param obj: 任意结构的数据,所有的 value 都会变成 list ,其长度为 world_size ,依次为每个 rank 上的对象值 - :param device: 当前 rank 使用的 device 是哪个。为 None 的话默认使用 torch.cuda.current_device() 获取。 + :param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行 + 序列化之后进行传输。 + :param device: 当前该参数无意义。 :param group: :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 """ # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 - # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) - if device is None: - device = torch.cuda.current_device() - if _TORCH_GREATER_EQUAL_1_8: + if isinstance(obj, torch.Tensor): + objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] + dist.all_gather(objs, obj, group=group) + else: objs = [None for _ in range(dist.get_world_size(group))] - dist.all_gather_object(objs, obj) - objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 - return objs - group = group if group is not None else torch.distributed.group.WORLD - data = convert_to_tensors(obj, device=device) - data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) - - objs = [] - - def _get_obj_on_idx(obj, idx): - return obj[idx] - - for i in range(dist.get_world_size(group)): - objs.append(apply_to_collection(data, dtype=list, function=_get_obj_on_idx, idx=i)) - + # 防止 unpickle 的时候弄到发送的 gpu 上了 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + if _TORCH_GREATER_EQUAL_1_8: + dist.all_gather_object(objs, obj, group=group) + else: + objs = all_gather_object(objs, obj, group=group) return objs -def fastnlp_torch_broadcast_object(obj, src, device, group=None): +def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): """ 将 src 上的 obj 对象广播到其它 rank 上。 @@ -430,10 +243,9 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): :return: """ cur_rank = dist.get_rank(group) - # if cur_rank == src: - # # 如果有 tensor 全部移动到 cpu 上,方便 pickle - # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) - + if cur_rank == src: + # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 + obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) if _TORCH_GREATER_EQUAL_1_8: if cur_rank!=src: get_obj = [None] @@ -442,6 +254,8 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): else: dist.broadcast_object_list([obj], src=src, group=group) return obj + if device is None: + device = torch.cuda.current_device() if cur_rank == src: tensor, size = _object_to_tensor(obj, device=device) @@ -460,3 +274,107 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None): return _tensor_to_object(tensor, tensor_size=size.item()) +def _check_for_nccl_backend(group): + pg = group or dist.distributed_c10d._get_default_group() + # It is not expected for PG to be wrapped many times, but support it just + # in case + while isinstance(pg, _ProcessGroupWrapper): + pg = pg.wrapped_pg + + return ( + dist.is_nccl_available() and + isinstance(pg, dist.ProcessGroupNCCL) + ) + + +def all_gather_object(object_list, obj, group=None): + """ + 复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。 + + Gathers picklable objects from the whole group into a list. Similar to + :func:`all_gather`, but Python objects can be passed in. Note that the object + must be picklable in order to be gathered. + + Args: + object_list (list[Any]): Output list. It should be correctly sized as the + size of the group for this collective and will contain the output. + object (Any): Pickable Python object to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. If the calling rank is part of this group, the output of the + collective will be populated into the input ``object_list``. If the + calling rank is not part of the group, the passed in ``object_list`` will + be unmodified. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsiblity to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + :func:`all_gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + Example:: + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] + """ + if dist._rank_not_in_group(group): + return + + input_tensor, local_size = _object_to_tensor(obj) + current_device = torch.device("cpu") + is_nccl_backend = _check_for_nccl_backend(group) + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device("cuda", torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = dist.get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + dist.all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + dist.all_gather(output_tensors, input_tensor, group=group) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + if tensor.device != torch.device("cpu"): + tensor = tensor.cpu() + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size) diff --git a/tests/core/drivers/torch_driver/test_dist_utils.py b/tests/core/drivers/torch_driver/test_dist_utils.py index 8fb7eb34..2d2145c8 100644 --- a/tests/core/drivers/torch_driver/test_dist_utils.py +++ b/tests/core/drivers/torch_driver/test_dist_utils.py @@ -7,38 +7,10 @@ import numpy as np # print(isinstance((1,), tuple)) # exit() -from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, convert_to_tensors, fastnlp_torch_broadcast_object +from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context - -def test_convert_to_tensors(): - local_rank = 0 - obj = { - 'tensor': torch.full(size=(2,), fill_value=local_rank), - 'numpy': np.full(shape=(1,), fill_value=local_rank), - 'bool': local_rank % 2 == 0, - 'float': local_rank + 0.1, - 'int': local_rank, - 'dict': { - 'rank': local_rank - }, - 'list': [local_rank] * 2, - 'str': 'xxx' - } - data = convert_to_tensors(obj) - assert len(data) == len(obj) - assert (data['tensor'] == obj['tensor']).sum() == 2 - for name in ['list', 'str']: - assert len(data[name])==2 and isinstance(data[name][0], torch.Tensor) and \ - isinstance(data[name][1], torch.Tensor) and data[name][1].ndim==1 - - for name in ['numpy', 'bool', 'float', 'int']: - assert isinstance(data[name][0], torch.Tensor) and data[name][0].numel()==1 - - assert isinstance(data['dict']['rank'][0], torch.Tensor) and data[name][0].numel() == 1 - - @magic_argv_env_context def test_fastnlp_torch_all_gather(): os.environ['MASTER_ADDR'] = '127.0.0.1' @@ -66,7 +38,7 @@ def test_fastnlp_torch_all_gather(): 'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(), torch.full(size=(2,), fill_value=local_rank).cuda()] } - data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) + data = fastnlp_torch_all_gather(obj) world_size = int(os.environ['WORLD_SIZE']) assert len(data) == world_size for i in range(world_size): @@ -81,10 +53,12 @@ def test_fastnlp_torch_all_gather(): assert data[i]['tensors'][0][0] == i for obj in [1, True, 'xxx']: - data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device()) + data = fastnlp_torch_all_gather(obj) assert len(data)==world_size assert data[0]==data[1] + dist.destroy_process_group() + @magic_argv_env_context def test_fastnlp_torch_broadcast_object(): os.environ['MASTER_ADDR'] = '127.0.0.1' @@ -130,3 +104,4 @@ def test_fastnlp_torch_broadcast_object(): for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]: data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device()) assert int(data)==0 + dist.destroy_process_group()