Browse Source

1、调整单卡中 save_and_load_model 测试例,不再使用 pytest.fixture 2、添加 PaddleFleetDriver 中 broadcast 误删的设备转换

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
281a570b09
3 changed files with 12 additions and 19 deletions
  1. +4
    -1
      fastNLP/core/drivers/paddle_driver/fleet.py
  2. +4
    -9
      tests/core/drivers/paddle_driver/test_single_device.py
  3. +4
    -9
      tests/core/drivers/torch_driver/test_single_device.py

+ 4
- 1
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -1,6 +1,8 @@
import os
from typing import List, Union, Optional, Dict, Tuple, Callable

from fastNLP.core.utils.paddle_utils import get_device_from_visible

from .paddle_driver import PaddleDriver
from .fleet_launcher import FleetLauncher
from .utils import (
@@ -630,7 +632,8 @@ class PaddleFleetDriver(PaddleDriver):
接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。
"""
# 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误
return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group)
device = get_device_from_visible(self.data_device)
return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group)

def all_gather(self, obj, group=None) -> List:
"""


+ 4
- 9
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -552,22 +552,17 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"):

return driver

@pytest.fixture
def prepare_test_save_load():
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(dataset, batch_size=4)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader

@pytest.mark.paddle
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
def test_save_and_load_model(only_state_dict):
"""
测试 save_model 和 load_model 函数
"""
try:
path = "model"
driver1, driver2, dataloader = prepare_test_save_load
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(dataset, batch_size=4)
driver1, driver2 = generate_random_driver(10, 10, device="gpu"), generate_random_driver(10, 10, device="gpu")

if only_state_dict:
driver1.save_model(path, only_state_dict)


+ 4
- 9
tests/core/drivers/torch_driver/test_single_device.py View File

@@ -545,22 +545,17 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"):

return driver

@pytest.fixture
def prepare_test_save_load():
dataset = TorchArgMaxDataset(10, 40)
dataloader = DataLoader(dataset, batch_size=4)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader

@pytest.mark.torch
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
def test_save_and_load_model(only_state_dict):
"""
测试 save_model 和 load_model 函数
"""
try:
path = "model"
driver1, driver2, dataloader = prepare_test_save_load
dataset = TorchArgMaxDataset(10, 40)
dataloader = DataLoader(dataset, batch_size=4)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)

driver1.save_model(path, only_state_dict)
driver2.load_model(path, only_state_dict)


Loading…
Cancel
Save