From 281a570b0902b485e16b9dddead17e31364ad377 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 3 May 2022 15:34:17 +0000 Subject: [PATCH] =?UTF-8?q?1=E3=80=81=E8=B0=83=E6=95=B4=E5=8D=95=E5=8D=A1?= =?UTF-8?q?=E4=B8=AD=20save=5Fand=5Fload=5Fmodel=20=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BE=8B=EF=BC=8C=E4=B8=8D=E5=86=8D=E4=BD=BF=E7=94=A8=20pytest?= =?UTF-8?q?.fixture=202=E3=80=81=E6=B7=BB=E5=8A=A0=20PaddleFleetDriver=20?= =?UTF-8?q?=E4=B8=AD=20broadcast=20=E8=AF=AF=E5=88=A0=E7=9A=84=E8=AE=BE?= =?UTF-8?q?=E5=A4=87=E8=BD=AC=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/fleet.py | 5 ++++- .../drivers/paddle_driver/test_single_device.py | 13 ++++--------- .../core/drivers/torch_driver/test_single_device.py | 13 ++++--------- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 7611a581..01b61afa 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -1,6 +1,8 @@ import os from typing import List, Union, Optional, Dict, Tuple, Callable +from fastNLP.core.utils.paddle_utils import get_device_from_visible + from .paddle_driver import PaddleDriver from .fleet_launcher import FleetLauncher from .utils import ( @@ -630,7 +632,8 @@ class PaddleFleetDriver(PaddleDriver): 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 """ # 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误 - return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) + device = get_device_from_visible(self.data_device) + return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group) def all_gather(self, obj, group=None) -> List: """ diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index ba243106..ffcb35e7 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -552,22 +552,17 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"): return driver -@pytest.fixture -def prepare_test_save_load(): - dataset = PaddleRandomMaxDataset(40, 10) - dataloader = DataLoader(dataset, batch_size=4) - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) - return driver1, driver2, dataloader - @pytest.mark.paddle @pytest.mark.parametrize("only_state_dict", ([True, False])) -def test_save_and_load_model(prepare_test_save_load, only_state_dict): +def test_save_and_load_model(only_state_dict): """ 测试 save_model 和 load_model 函数 """ try: path = "model" - driver1, driver2, dataloader = prepare_test_save_load + dataset = PaddleRandomMaxDataset(40, 10) + dataloader = DataLoader(dataset, batch_size=4) + driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu") if only_state_dict: driver1.save_model(path, only_state_dict) diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 9115ed19..086f4251 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -545,22 +545,17 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"): return driver -@pytest.fixture -def prepare_test_save_load(): - dataset = TorchArgMaxDataset(10, 40) - dataloader = DataLoader(dataset, batch_size=4) - driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) - return driver1, driver2, dataloader - @pytest.mark.torch @pytest.mark.parametrize("only_state_dict", ([True, False])) -def test_save_and_load_model(prepare_test_save_load, only_state_dict): +def test_save_and_load_model(only_state_dict): """ 测试 save_model 和 load_model 函数 """ try: path = "model" - driver1, driver2, dataloader = prepare_test_save_load + dataset = TorchArgMaxDataset(10, 40) + dataloader = DataLoader(dataset, batch_size=4) + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) driver1.save_model(path, only_state_dict) driver2.load_model(path, only_state_dict)