|
|
@@ -397,12 +397,13 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: |
|
|
|
""" |
|
|
|
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 |
|
|
|
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) |
|
|
|
if device is None: |
|
|
|
device = torch.cuda.current_device() |
|
|
|
if _TORCH_GREATER_EQUAL_1_8: |
|
|
|
objs = [None for _ in range(dist.get_world_size(group))] |
|
|
|
dist.all_gather_object(objs, obj) |
|
|
|
apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 |
|
|
|
return objs |
|
|
|
if device is None: |
|
|
|
device = torch.cuda.current_device() |
|
|
|
group = group if group is not None else torch.distributed.group.WORLD |
|
|
|
data = convert_to_tensors(obj, device=device) |
|
|
|
data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) |
|
|
|