diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index 0333df16..2228b240 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -397,12 +397,13 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: """ # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + if device is None: + device = torch.cuda.current_device() if _TORCH_GREATER_EQUAL_1_8: objs = [None for _ in range(dist.get_world_size(group))] dist.all_gather_object(objs, obj) + apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 return objs - if device is None: - device = torch.cuda.current_device() group = group if group is not None else torch.distributed.group.WORLD data = convert_to_tensors(obj, device=device) data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index 42b86dcd..edc7b86b 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -416,7 +416,6 @@ class TestBucketedBatchSampler: @pytest.mark.parametrize('num_replica', [2, 3]) def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): - # TODO 两个 rank 上的长度是要在同一个bucket的 dataset = DatasetWithVaryLength(num_of_data=num_samples) batch_size = 6 if num_replica*batch_size > num_samples: