Browse Source

整理PaddleSingleDriver的部分测试例

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
00b5baf67a
5 changed files with 119 additions and 120 deletions
  1. +34
    -32
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  2. +12
    -12
      fastNLP/core/drivers/paddle_driver/single_device.py
  3. +12
    -7
      fastNLP/core/drivers/paddle_driver/utils.py
  4. +4
    -1
      fastNLP/core/utils/paddle_utils.py
  5. +57
    -68
      tests/core/drivers/paddle_driver/test_single_device.py

+ 34
- 32
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -11,8 +11,13 @@ from .utils import _build_fp16_env, optimizer_state_to_device
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.drivers.driver import Driver
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.envs import (
FASTNLP_SEED_WORKERS,
FASTNLP_MODEL_FILENAME,
FASTNLP_CHECKPOINT_FILENAME,
FASTNLP_GLOBAL_RANK,
rank_zero_call,
)
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler

@@ -91,7 +96,7 @@ class PaddleDriver(Driver):
f"type, not {type(each_dataloader)}.")
if isinstance(each_dataloader.dataset, IterableDataset):
raise TypeError("`IterableDataset` is not allowed.")
if dataloader.batch_sampler is None and dataloader.batch_size is None:
if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None:
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of "
f"`batch_sampler` and `batch_size` should be set.")

@@ -171,56 +176,45 @@ class PaddleDriver(Driver):
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
r"""
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数;
如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数;

:param filepath: 保存文件的文件位置(需要包括文件名);
:param only_state_dict: 是否只保存模型的 `state_dict`;
:param only_state_dict: 是否只保存模型的 `state_dict`;如果为 False,则会调用 `paddle.jit.save` 函数
保存整个模型的参数,此时需要传入 `input_spec` 参数,否则在 load 时会报错。
:param kwargs:
input_spec: 描述存储模型 forward 方法的输入,当 `only_state_dict` 为 False时必须传入,否则加载时会报错。
可以通过 InputSpec 或者示例 Tensor 进行描述。详细的可以参考 paddle 关于`paddle.jit.save`
的文档:
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save
:return:
"""
debug = kwargs.get("debug", False)
model = self.unwrap_model()
if only_state_dict:
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()}
paddle.save(states, filepath)
if debug:
logger.debug("Save model state dict.")
else:
# paddle 在保存整个模型时需要传入额外参数
input_spec = kwargs.get("input_spec", None)
if input_spec is None:
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.")
if self.model_device is not None:
if not self.is_distributed():
self.move_model_to_device(model, "cpu")
paddle.jit.save(model, filepath, input_spec)
if not self.is_distributed():
self.move_model_to_device(model, self.model_device)
else:
paddle.jit.save(model, filepath, input_spec)
if debug:
logger.debug("Save model.")
paddle.jit.save(model, filepath, input_spec)

def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs):
r"""
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数;

:param filepath: 需要被加载的对象的文件位置(需要包括文件名);
:param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时,
即保存了整个模型时,这个参数必须也为False
:param only_state_dict: 是否加载state_dict,默认为True。
:param kwargs:
:return:
"""
debug = kwargs.get("debug", False)
model = self.unwrap_model()
if only_state_dict:
model.load_dict(paddle.load(filepath))
if debug:
logger.debug("Load model state dict.")
else:
model.load_dict(paddle.jit.load(filepath).state_dict())
if debug:
logger.debug("Load model.")
# paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict
# 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。
dirname, filename = os.path.split(filepath)
if not only_state_dict and dirname == "":
# 如果传入的是单个文件,则加上相对路径
filepath = os.path.join(".", filepath)
model.load_dict(paddle.load(filepath))

@rank_zero_call
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
@@ -261,7 +255,11 @@ class PaddleDriver(Driver):

# 2. 保存模型的状态;
if should_save_model:
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True, **kwargs)
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs)
if only_state_dict:
logger.debug("Save model state dict.")
else:
logger.debug("Save model.")

# 3. 保存 optimizers 的状态;
optimizers_state_dict = {}
@@ -288,7 +286,11 @@ class PaddleDriver(Driver):

# 2. 加载模型状态;
if should_load_model:
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True)
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict)
if only_state_dict:
logger.debug("Load model state dict.")
else:
logger.debug("Load model.")

# 3. 恢复 sampler 的状态;
dataloader_args = self.get_dataloader_args(dataloader)
@@ -359,7 +361,7 @@ class PaddleDriver(Driver):
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_.
"""
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
global_rank = rank if rank is not None else rank_zero_call.rank
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))
# TODO gpu
process_seed = paddle.fluid.core.default_cpu_generator().initial_seed()
# back out the base seed so we can use all the bits


+ 12
- 12
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -2,7 +2,7 @@ import os
from typing import Optional, Dict, Union

from .paddle_driver import PaddleDriver
from .utils import replace_batch_sampler, replace_sampler
from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.utils import (
@@ -29,10 +29,7 @@ class PaddleSingleDriver(PaddleDriver):
if device is None:
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.")

if isinstance(device, int):
self.model_device = get_paddle_gpu_str(device)
else:
self.model_device = device
self.model_device = get_paddle_gpu_str(device)

self.local_rank = 0
self.global_rank = 0
@@ -94,11 +91,14 @@ class PaddleSingleDriver(PaddleDriver):
self._test_signature_fn = model.forward

def setup(self):
device_id = get_paddle_device_id(self.model_device)
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
paddle.device.set_device("gpu:0")
self.model.to("gpu:0")
device = self.model_device
if device != "cpu":
device_id = get_paddle_device_id(device)
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id]
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
device = get_device_from_visible(device, output_type=str)
paddle.device.set_device(device)
self.model.to(device)

def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
@@ -131,11 +131,11 @@ class PaddleSingleDriver(PaddleDriver):
r"""
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。
在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0`

:return: 将移动到指定机器上的 batch 对象返回;
"""
return paddle_move_data_to_device(batch, "gpu:0")
device = get_device_from_visible(self.data_device)
return paddle_move_data_to_device(batch, device)

def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
reproducible: bool = False):


+ 12
- 7
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -255,20 +255,23 @@ def get_host_name_ip():
except:
return None

def get_device_from_visible(device: Union[str, int]):
def get_device_from_visible(device: Union[str, int], output_type=int):
"""
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。
:param devices:未转化的设备名
:param devices: 未转化的设备名
:param output_type: 返回值的类型
:return: 转化后的设备id
"""
if output_type not in [int, str]:
raise ValueError("Parameter `output_type` should be one of these types: [int, str]")
if device == "cpu":
return device
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
idx = get_paddle_device_id(device)
if cuda_visible_devices is None or cuda_visible_devices == "":
# 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES
return idx
raise RuntimeError("This situation should not happen, please report us this bug.")
else:
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
@@ -277,11 +280,13 @@ def get_device_from_visible(device: Union[str, int]):
idx = user_visible_devices.split(",")[idx]

cuda_visible_devices_list = cuda_visible_devices.split(',')
assert idx in cuda_visible_devices_list, "Can't find "\
"your devices %s in CUDA_VISIBLE_DEVICES[%s]."\
% (idx, cuda_visible_devices)
if idx not in cuda_visible_devices_list:
raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}].")
res = cuda_visible_devices_list.index(idx)
return res
if output_type == int:
return res
else:
return f"gpu:{res}"

def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"):
"""


+ 4
- 1
fastNLP/core/utils/paddle_utils.py View File

@@ -46,11 +46,14 @@ def get_paddle_device_id(device: Union[str, int]):
device = device.lower()
if device == "cpu":
raise ValueError("Cannot get device id from `cpu`.")
elif device == "gpu":
return 0

match_res = re.match(r"gpu:\d+", device)
if not match_res:
raise ValueError(
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x'"
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x', "
f"not '{device}'"
)
device_id = device.split(':', 1)[1]
device_id = int(device_id)


+ 57
- 68
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -1,10 +1,11 @@
import os
from numpy import isin
os.environ["FASTNLP_BACKEND"] = "paddle"
import pytest

from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.core.samplers import ReproducibleBatchSampler
from fastNLP.core.samplers import RandomBatchSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset
@@ -30,6 +31,7 @@ def generate_random_driver(features, labels):
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01)
driver = PaddleSingleDriver(model, device="cpu")
driver.set_optimizers(opt)
driver.setup()

return driver

@@ -77,6 +79,7 @@ def test_save_and_load_state_dict(prepare_test_save_load):
driver2.load_model(path)

for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)

@@ -93,10 +96,11 @@ def test_save_and_load_whole_model(prepare_test_save_load):
path = "model"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save_model(path, only_state_dict=False, input_spec=[next(iter(dataloader))["x"]])
driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict=False)

for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)

@@ -115,7 +119,7 @@ class TestSingleDeviceFunction:
@classmethod
def setup_class(cls):
model = PaddleNormalModel_Classification_1(10, 784)
cls.driver = PaddleSingleDriver(model, device="gpu")
cls.driver = PaddleSingleDriver(model, device="cpu")

def test_unwrap_model(self):
"""
@@ -130,22 +134,6 @@ class TestSingleDeviceFunction:
self.driver.check_evaluator_mode("validate")
self.driver.check_evaluator_mode("test")

def test_get_model_device_cpu(self):
"""
测试get_model_device
"""
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification_1(10, 784), "cpu")
device = self.driver.get_model_device()
assert device == "cpu", device

def test_get_model_device_gpu(self):
"""
测试get_model_device
"""
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification_1(10, 784), "gpu:0")
device = self.driver.get_model_device()
assert device == "gpu:0", device

def test_is_distributed(self):
assert self.driver.is_distributed() == False

@@ -156,24 +144,24 @@ class TestSingleDeviceFunction:
"""
self.driver.move_data_to_device(paddle.rand((32, 64)))

@pytest.mark.parametrize(
"dist_sampler", [
"dist",
ReproducibleBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False),
RandomSampler(PaddleRandomMaxDataset(320, 10))
]
)
@pytest.mark.parametrize(
"reproducible",
[True, False]
)
def test_repalce_sampler(self, dist_sampler, reproducible):
"""
测试set_dist_repro_dataloader函数
"""
dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True)
res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)
# @pytest.mark.parametrize(
# "dist_sampler", [
# "dist",
# RandomBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False),
# RandomSampler(PaddleRandomMaxDataset(320, 10))
# ]
# )
# @pytest.mark.parametrize(
# "reproducible",
# [True, False]
# )
# def test_set_dist_repro_dataloader(self, dist_sampler, reproducible):
# """
# 测试set_dist_repro_dataloader函数
# """
# dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True)
# res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)

class TestPaddleDriverFunctions:
"""
@@ -183,7 +171,7 @@ class TestPaddleDriverFunctions:
@classmethod
def setup_class(self):
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="gpu")
self.driver = PaddleSingleDriver(model, device="cpu")

def test_check_single_optimizer_legality(self):
"""
@@ -198,7 +186,7 @@ class TestPaddleDriverFunctions:

optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
# 传入torch的optimizer时,应该报错ValueError
with self.assertRaises(ValueError) as cm:
with pytest.raises(ValueError):
self.driver.set_optimizers(optimizer)

def test_check_optimizers_legality(self):
@@ -218,7 +206,7 @@ class TestPaddleDriverFunctions:
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
]

with self.assertRaises(ValueError) as cm:
with pytest.raises(ValueError):
self.driver.set_optimizers(optimizers)

def test_check_dataloader_legality_in_train(self):
@@ -230,7 +218,7 @@ class TestPaddleDriverFunctions:

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
with self.assertRaises(ValueError) as cm:
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)

# 创建torch的dataloader
@@ -238,7 +226,7 @@ class TestPaddleDriverFunctions:
TorchNormalDataset(),
batch_size=32, shuffle=True
)
with self.assertRaises(ValueError) as cm:
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True)

def test_check_dataloader_legacy_in_test(self):
@@ -257,11 +245,12 @@ class TestPaddleDriverFunctions:
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
}
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# 传入的不是dict,应该报错
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
with self.assertRaises(ValueError) as cm:
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# 创建torch的dataloader
@@ -274,7 +263,7 @@ class TestPaddleDriverFunctions:
batch_size=32, shuffle=True
)
dataloader = {"train": train_loader, "test": test_loader}
with self.assertRaises(ValueError) as cm:
with pytest.raises(ValueError):
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False)

def test_tensor_to_numeric(self):
@@ -284,25 +273,25 @@ class TestPaddleDriverFunctions:
# 单个张量
tensor = paddle.to_tensor(3)
res = PaddleSingleDriver.tensor_to_numeric(tensor)
self.assertEqual(res, 3)
assert res == 3

tensor = paddle.rand((3, 4))
res = PaddleSingleDriver.tensor_to_numeric(tensor)
self.assertListEqual(res, tensor.tolist())
assert res == tensor.tolist()

# 张量list
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)]
res = PaddleSingleDriver.tensor_to_numeric(tensor_list)
self.assertTrue(res, list)
assert isinstance(res, list)
tensor_list = [t.tolist() for t in tensor_list]
self.assertListEqual(res, tensor_list)
assert res == tensor_list

# 张量tuple
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)])
res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple)
self.assertTrue(res, tuple)
assert isinstance(res, tuple)
tensor_tuple = tuple([t.tolist() for t in tensor_tuple])
self.assertTupleEqual(res, tensor_tuple)
assert res == tensor_tuple

# 张量dict
tensor_dict = {
@@ -317,29 +306,29 @@ class TestPaddleDriverFunctions:
}

res = PaddleSingleDriver.tensor_to_numeric(tensor_dict)
self.assertIsInstance(res, dict)
self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist())
self.assertIsInstance(res["list"], list)
assert isinstance(res, dict)
assert res["tensor"] == tensor_dict["tensor"].tolist()
assert isinstance(res["list"], list)
for r, d in zip(res["list"], tensor_dict["list"]):
self.assertListEqual(r, d.tolist())
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
assert r == d.tolist()
assert isinstance(res["int"], int)
assert isinstance(res["string"], str)
assert isinstance(res["dict"], dict)
assert isinstance(res["dict"]["list"], list)
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]):
self.assertListEqual(r, d.tolist())
self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist())
assert r == d.tolist()
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()

def test_set_model_mode(self):
"""
测试set_model_mode函数
"""
self.driver.set_model_mode("train")
self.assertTrue(self.driver.model.training)
assert self.driver.model.training
self.driver.set_model_mode("eval")
self.assertFalse(self.driver.model.training)
assert not self.driver.model.training
# 应该报错
with self.assertRaises(AssertionError) as cm:
with pytest.raises(AssertionError):
self.driver.set_model_mode("test")

def test_move_model_to_device_cpu(self):
@@ -347,15 +336,15 @@ class TestPaddleDriverFunctions:
测试move_model_to_device函数
"""
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu")
self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place())
assert self.driver.model.linear1.weight.place.is_cpu_place()

def test_move_model_to_device_gpu(self):
"""
测试move_model_to_device函数
"""
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu:0")
self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place())
self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0)
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu")
assert self.driver.model.linear1.weight.place.is_gpu_place()
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0

def test_worker_init_function(self):
"""


Loading…
Cancel
Save