diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 7611a581..01b61afa 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -1,6 +1,8 @@ import os from typing import List, Union, Optional, Dict, Tuple, Callable +from fastNLP.core.utils.paddle_utils import get_device_from_visible + from .paddle_driver import PaddleDriver from .fleet_launcher import FleetLauncher from .utils import ( @@ -630,7 +632,8 @@ class PaddleFleetDriver(PaddleDriver): 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 """ # 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误 - return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) + device = get_device_from_visible(self.data_device) + return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group) def all_gather(self, obj, group=None) -> List: """ diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index ba243106..ffcb35e7 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -552,22 +552,17 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"): return driver -@pytest.fixture -def prepare_test_save_load(): - dataset = PaddleRandomMaxDataset(40, 10) - dataloader = DataLoader(dataset, batch_size=4) - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) - return driver1, driver2, dataloader - @pytest.mark.paddle @pytest.mark.parametrize("only_state_dict", ([True, False])) -def test_save_and_load_model(prepare_test_save_load, only_state_dict): +def test_save_and_load_model(only_state_dict): """ 测试 save_model 和 load_model 函数 """ try: path = "model" - driver1, driver2, dataloader = prepare_test_save_load + dataset = PaddleRandomMaxDataset(40, 10) + dataloader = DataLoader(dataset, batch_size=4) + driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu") if only_state_dict: driver1.save_model(path, only_state_dict) diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 9115ed19..086f4251 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -545,22 +545,17 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"): return driver -@pytest.fixture -def prepare_test_save_load(): - dataset = TorchArgMaxDataset(10, 40) - dataloader = DataLoader(dataset, batch_size=4) - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) - return driver1, driver2, dataloader - @pytest.mark.torch @pytest.mark.parametrize("only_state_dict", ([True, False])) -def test_save_and_load_model(prepare_test_save_load, only_state_dict): +def test_save_and_load_model(only_state_dict): """ 测试 save_model 和 load_model 函数 """ try: path = "model" - driver1, driver2, dataloader = prepare_test_save_load + dataset = TorchArgMaxDataset(10, 40) + dataloader = DataLoader(dataset, batch_size=4) + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) driver1.save_model(path, only_state_dict) driver2.load_model(path, only_state_dict)