@@ -1,6 +1,8 @@ | |||||
import os | import os | ||||
from typing import List, Union, Optional, Dict, Tuple, Callable | from typing import List, Union, Optional, Dict, Tuple, Callable | ||||
from fastNLP.core.utils.paddle_utils import get_device_from_visible | |||||
from .paddle_driver import PaddleDriver | from .paddle_driver import PaddleDriver | ||||
from .fleet_launcher import FleetLauncher | from .fleet_launcher import FleetLauncher | ||||
from .utils import ( | from .utils import ( | ||||
@@ -630,7 +632,8 @@ class PaddleFleetDriver(PaddleDriver): | |||||
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 | ||||
""" | """ | ||||
# 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误 | # 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误 | ||||
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) | |||||
device = get_device_from_visible(self.data_device) | |||||
return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group) | |||||
def all_gather(self, obj, group=None) -> List: | def all_gather(self, obj, group=None) -> List: | ||||
""" | """ | ||||
@@ -552,22 +552,17 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||||
return driver | return driver | ||||
@pytest.fixture | |||||
def prepare_test_save_load(): | |||||
dataset = PaddleRandomMaxDataset(40, 10) | |||||
dataloader = DataLoader(dataset, batch_size=4) | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
return driver1, driver2, dataloader | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
def test_save_and_load_model(prepare_test_save_load, only_state_dict): | |||||
def test_save_and_load_model(only_state_dict): | |||||
""" | """ | ||||
测试 save_model 和 load_model 函数 | 测试 save_model 和 load_model 函数 | ||||
""" | """ | ||||
try: | try: | ||||
path = "model" | path = "model" | ||||
driver1, driver2, dataloader = prepare_test_save_load | |||||
dataset = PaddleRandomMaxDataset(40, 10) | |||||
dataloader = DataLoader(dataset, batch_size=4) | |||||
driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu") | |||||
if only_state_dict: | if only_state_dict: | ||||
driver1.save_model(path, only_state_dict) | driver1.save_model(path, only_state_dict) | ||||
@@ -545,22 +545,17 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"): | |||||
return driver | return driver | ||||
@pytest.fixture | |||||
def prepare_test_save_load(): | |||||
dataset = TorchArgMaxDataset(10, 40) | |||||
dataloader = DataLoader(dataset, batch_size=4) | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
return driver1, driver2, dataloader | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
def test_save_and_load_model(prepare_test_save_load, only_state_dict): | |||||
def test_save_and_load_model(only_state_dict): | |||||
""" | """ | ||||
测试 save_model 和 load_model 函数 | 测试 save_model 和 load_model 函数 | ||||
""" | """ | ||||
try: | try: | ||||
path = "model" | path = "model" | ||||
driver1, driver2, dataloader = prepare_test_save_load | |||||
dataset = TorchArgMaxDataset(10, 40) | |||||
dataloader = DataLoader(dataset, batch_size=4) | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | |||||
driver1.save_model(path, only_state_dict) | driver1.save_model(path, only_state_dict) | ||||
driver2.load_model(path, only_state_dict) | driver2.load_model(path, only_state_dict) | ||||