Browse Source

修改fastnlp_torch_all_gather函数,使得它gather后的tensor都在当前device

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
817f6d6ad6
2 changed files with 3 additions and 3 deletions
  1. +3
    -2
      fastNLP/core/drivers/torch_driver/dist_utils.py
  2. +0
    -1
      tests/core/samplers/test_reproducible_batch_sampler.py

+ 3
- 2
fastNLP/core/drivers/torch_driver/dist_utils.py View File

@@ -397,12 +397,13 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List:
"""
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if device is None:
device = torch.cuda.current_device()
if _TORCH_GREATER_EQUAL_1_8:
objs = [None for _ in range(dist.get_world_size(group))]
dist.all_gather_object(objs, obj)
apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上
return objs
if device is None:
device = torch.cuda.current_device()
group = group if group is not None else torch.distributed.group.WORLD
data = convert_to_tensors(obj, device=device)
data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group)


+ 0
- 1
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -416,7 +416,6 @@ class TestBucketedBatchSampler:
@pytest.mark.parametrize('num_replica', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2):
# TODO 两个 rank 上的长度是要在同一个bucket的
dataset = DatasetWithVaryLength(num_of_data=num_samples)
batch_size = 6
if num_replica*batch_size > num_samples:


Loading…
Cancel
Save