From 817f6d6ad62fa81a5f4d580636455352b178f43a Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 10 Apr 2022 22:28:48 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9fastnlp=5Ftorch=5Fall=5Fgathe?= =?UTF-8?q?r=E5=87=BD=E6=95=B0=EF=BC=8C=E4=BD=BF=E5=BE=97=E5=AE=83gather?= =?UTF-8?q?=E5=90=8E=E7=9A=84tensor=E9=83=BD=E5=9C=A8=E5=BD=93=E5=89=8Ddev?= =?UTF-8?q?ice?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/dist_utils.py | 5 +++-- tests/core/samplers/test_reproducible_batch_sampler.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index 0333df16..2228b240 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -397,12 +397,13 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: """ # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + if device is None: + device = torch.cuda.current_device() if _TORCH_GREATER_EQUAL_1_8: objs = [None for _ in range(dist.get_world_size(group))] dist.all_gather_object(objs, obj) + apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 return objs - if device is None: - device = torch.cuda.current_device() group = group if group is not None else torch.distributed.group.WORLD data = convert_to_tensors(obj, device=device) data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index 42b86dcd..edc7b86b 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -416,7 +416,6 @@ class TestBucketedBatchSampler: @pytest.mark.parametrize('num_replica', [2, 3]) def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): - # TODO 两个 rank 上的长度是要在同一个bucket的 dataset = DatasetWithVaryLength(num_of_data=num_samples) batch_size = 6 if num_replica*batch_size > num_samples: