@@ -11,8 +11,13 @@ from .utils import _build_fp16_env, optimizer_state_to_device | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.core.drivers.driver import Driver | |||
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.envs import ( | |||
FASTNLP_SEED_WORKERS, | |||
FASTNLP_MODEL_FILENAME, | |||
FASTNLP_CHECKPOINT_FILENAME, | |||
FASTNLP_GLOBAL_RANK, | |||
rank_zero_call, | |||
) | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
@@ -91,7 +96,7 @@ class PaddleDriver(Driver): | |||
f"type, not {type(each_dataloader)}.") | |||
if isinstance(each_dataloader.dataset, IterableDataset): | |||
raise TypeError("`IterableDataset` is not allowed.") | |||
if dataloader.batch_sampler is None and dataloader.batch_size is None: | |||
if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None: | |||
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " | |||
f"`batch_sampler` and `batch_size` should be set.") | |||
@@ -171,56 +176,45 @@ class PaddleDriver(Driver): | |||
def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||
r""" | |||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | |||
如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; | |||
:param filepath: 保存文件的文件位置(需要包括文件名); | |||
:param only_state_dict: 是否只保存模型的 `state_dict`; | |||
:param only_state_dict: 是否只保存模型的 `state_dict`;如果为 False,则会调用 `paddle.jit.save` 函数 | |||
保存整个模型的参数,此时需要传入 `input_spec` 参数,否则在 load 时会报错。 | |||
:param kwargs: | |||
input_spec: 描述存储模型 forward 方法的输入,当 `only_state_dict` 为 False时必须传入,否则加载时会报错。 | |||
可以通过 InputSpec 或者示例 Tensor 进行描述。详细的可以参考 paddle 关于`paddle.jit.save` | |||
的文档: | |||
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/save_cn.html#save | |||
:return: | |||
""" | |||
debug = kwargs.get("debug", False) | |||
model = self.unwrap_model() | |||
if only_state_dict: | |||
states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||
paddle.save(states, filepath) | |||
if debug: | |||
logger.debug("Save model state dict.") | |||
else: | |||
# paddle 在保存整个模型时需要传入额外参数 | |||
input_spec = kwargs.get("input_spec", None) | |||
if input_spec is None: | |||
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") | |||
if self.model_device is not None: | |||
if not self.is_distributed(): | |||
self.move_model_to_device(model, "cpu") | |||
paddle.jit.save(model, filepath, input_spec) | |||
if not self.is_distributed(): | |||
self.move_model_to_device(model, self.model_device) | |||
else: | |||
paddle.jit.save(model, filepath, input_spec) | |||
if debug: | |||
logger.debug("Save model.") | |||
paddle.jit.save(model, filepath, input_spec) | |||
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): | |||
r""" | |||
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; | |||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名); | |||
:param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, | |||
即保存了整个模型时,这个参数必须也为False | |||
:param only_state_dict: 是否加载state_dict,默认为True。 | |||
:param kwargs: | |||
:return: | |||
""" | |||
debug = kwargs.get("debug", False) | |||
model = self.unwrap_model() | |||
if only_state_dict: | |||
model.load_dict(paddle.load(filepath)) | |||
if debug: | |||
logger.debug("Load model state dict.") | |||
else: | |||
model.load_dict(paddle.jit.load(filepath).state_dict()) | |||
if debug: | |||
logger.debug("Load model.") | |||
# paddle 中,通过 paddle.jit.save 函数保存的模型也可以通过 paddle.load 加载为相应的 state dict | |||
# 但是此时对输入的 path 有要求,必须是 dir/filename 的形式,否则会报错。 | |||
dirname, filename = os.path.split(filepath) | |||
if not only_state_dict and dirname == "": | |||
# 如果传入的是单个文件,则加上相对路径 | |||
filepath = os.path.join(".", filepath) | |||
model.load_dict(paddle.load(filepath)) | |||
@rank_zero_call | |||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
@@ -261,7 +255,11 @@ class PaddleDriver(Driver): | |||
# 2. 保存模型的状态; | |||
if should_save_model: | |||
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True, **kwargs) | |||
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) | |||
if only_state_dict: | |||
logger.debug("Save model state dict.") | |||
else: | |||
logger.debug("Save model.") | |||
# 3. 保存 optimizers 的状态; | |||
optimizers_state_dict = {} | |||
@@ -288,7 +286,11 @@ class PaddleDriver(Driver): | |||
# 2. 加载模型状态; | |||
if should_load_model: | |||
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True) | |||
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict) | |||
if only_state_dict: | |||
logger.debug("Load model state dict.") | |||
else: | |||
logger.debug("Load model.") | |||
# 3. 恢复 sampler 的状态; | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
@@ -359,7 +361,7 @@ class PaddleDriver(Driver): | |||
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. | |||
""" | |||
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 | |||
global_rank = rank if rank is not None else rank_zero_call.rank | |||
global_rank = rank if rank is not None else int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) | |||
# TODO gpu | |||
process_seed = paddle.fluid.core.default_cpu_generator().initial_seed() | |||
# back out the base seed so we can use all the bits | |||
@@ -2,7 +2,7 @@ import os | |||
from typing import Optional, Dict, Union | |||
from .paddle_driver import PaddleDriver | |||
from .utils import replace_batch_sampler, replace_sampler | |||
from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES | |||
from fastNLP.core.utils import ( | |||
@@ -29,10 +29,7 @@ class PaddleSingleDriver(PaddleDriver): | |||
if device is None: | |||
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | |||
if isinstance(device, int): | |||
self.model_device = get_paddle_gpu_str(device) | |||
else: | |||
self.model_device = device | |||
self.model_device = get_paddle_gpu_str(device) | |||
self.local_rank = 0 | |||
self.global_rank = 0 | |||
@@ -94,11 +91,14 @@ class PaddleSingleDriver(PaddleDriver): | |||
self._test_signature_fn = model.forward | |||
def setup(self): | |||
device_id = get_paddle_device_id(self.model_device) | |||
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] | |||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||
paddle.device.set_device("gpu:0") | |||
self.model.to("gpu:0") | |||
device = self.model_device | |||
if device != "cpu": | |||
device_id = get_paddle_device_id(device) | |||
device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] | |||
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) | |||
device = get_device_from_visible(device, output_type=str) | |||
paddle.device.set_device(device) | |||
self.model.to(device) | |||
def train_step(self, batch) -> Dict: | |||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | |||
@@ -131,11 +131,11 @@ class PaddleSingleDriver(PaddleDriver): | |||
r""" | |||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||
在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0` | |||
:return: 将移动到指定机器上的 batch 对象返回; | |||
""" | |||
return paddle_move_data_to_device(batch, "gpu:0") | |||
device = get_device_from_visible(self.data_device) | |||
return paddle_move_data_to_device(batch, device) | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||
reproducible: bool = False): | |||
@@ -255,20 +255,23 @@ def get_host_name_ip(): | |||
except: | |||
return None | |||
def get_device_from_visible(device: Union[str, int]): | |||
def get_device_from_visible(device: Union[str, int], output_type=int): | |||
""" | |||
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | |||
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | |||
:param devices:未转化的设备名 | |||
:param devices: 未转化的设备名 | |||
:param output_type: 返回值的类型 | |||
:return: 转化后的设备id | |||
""" | |||
if output_type not in [int, str]: | |||
raise ValueError("Parameter `output_type` should be one of these types: [int, str]") | |||
if device == "cpu": | |||
return device | |||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | |||
idx = get_paddle_device_id(device) | |||
if cuda_visible_devices is None or cuda_visible_devices == "": | |||
# 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES | |||
return idx | |||
raise RuntimeError("This situation should not happen, please report us this bug.") | |||
else: | |||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | |||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||
@@ -277,11 +280,13 @@ def get_device_from_visible(device: Union[str, int]): | |||
idx = user_visible_devices.split(",")[idx] | |||
cuda_visible_devices_list = cuda_visible_devices.split(',') | |||
assert idx in cuda_visible_devices_list, "Can't find "\ | |||
"your devices %s in CUDA_VISIBLE_DEVICES[%s]."\ | |||
% (idx, cuda_visible_devices) | |||
if idx not in cuda_visible_devices_list: | |||
raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}].") | |||
res = cuda_visible_devices_list.index(idx) | |||
return res | |||
if output_type == int: | |||
return res | |||
else: | |||
return f"gpu:{res}" | |||
def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): | |||
""" | |||
@@ -46,11 +46,14 @@ def get_paddle_device_id(device: Union[str, int]): | |||
device = device.lower() | |||
if device == "cpu": | |||
raise ValueError("Cannot get device id from `cpu`.") | |||
elif device == "gpu": | |||
return 0 | |||
match_res = re.match(r"gpu:\d+", device) | |||
if not match_res: | |||
raise ValueError( | |||
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x'" | |||
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x', " | |||
f"not '{device}'" | |||
) | |||
device_id = device.split(':', 1)[1] | |||
device_id = int(device_id) | |||
@@ -1,10 +1,11 @@ | |||
import os | |||
from numpy import isin | |||
os.environ["FASTNLP_BACKEND"] = "paddle" | |||
import pytest | |||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||
from fastNLP.core.samplers import RandomBatchSampler | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
@@ -30,6 +31,7 @@ def generate_random_driver(features, labels): | |||
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) | |||
driver = PaddleSingleDriver(model, device="cpu") | |||
driver.set_optimizers(opt) | |||
driver.setup() | |||
return driver | |||
@@ -77,6 +79,7 @@ def test_save_and_load_state_dict(prepare_test_save_load): | |||
driver2.load_model(path) | |||
for batch in dataloader: | |||
batch = driver1.move_data_to_device(batch) | |||
res1 = driver1.validate_step(batch) | |||
res2 = driver2.validate_step(batch) | |||
@@ -93,10 +96,11 @@ def test_save_and_load_whole_model(prepare_test_save_load): | |||
path = "model" | |||
driver1, driver2, dataloader = prepare_test_save_load | |||
driver1.save_model(path, only_state_dict=False, input_spec=[next(iter(dataloader))["x"]]) | |||
driver1.save_model(path, only_state_dict=False, input_spec=[paddle.ones((32, 10))]) | |||
driver2.load_model(path, only_state_dict=False) | |||
for batch in dataloader: | |||
batch = driver1.move_data_to_device(batch) | |||
res1 = driver1.validate_step(batch) | |||
res2 = driver2.validate_step(batch) | |||
@@ -115,7 +119,7 @@ class TestSingleDeviceFunction: | |||
@classmethod | |||
def setup_class(cls): | |||
model = PaddleNormalModel_Classification_1(10, 784) | |||
cls.driver = PaddleSingleDriver(model, device="gpu") | |||
cls.driver = PaddleSingleDriver(model, device="cpu") | |||
def test_unwrap_model(self): | |||
""" | |||
@@ -130,22 +134,6 @@ class TestSingleDeviceFunction: | |||
self.driver.check_evaluator_mode("validate") | |||
self.driver.check_evaluator_mode("test") | |||
def test_get_model_device_cpu(self): | |||
""" | |||
测试get_model_device | |||
""" | |||
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification_1(10, 784), "cpu") | |||
device = self.driver.get_model_device() | |||
assert device == "cpu", device | |||
def test_get_model_device_gpu(self): | |||
""" | |||
测试get_model_device | |||
""" | |||
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification_1(10, 784), "gpu:0") | |||
device = self.driver.get_model_device() | |||
assert device == "gpu:0", device | |||
def test_is_distributed(self): | |||
assert self.driver.is_distributed() == False | |||
@@ -156,24 +144,24 @@ class TestSingleDeviceFunction: | |||
""" | |||
self.driver.move_data_to_device(paddle.rand((32, 64))) | |||
@pytest.mark.parametrize( | |||
"dist_sampler", [ | |||
"dist", | |||
ReproducibleBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False), | |||
RandomSampler(PaddleRandomMaxDataset(320, 10)) | |||
] | |||
) | |||
@pytest.mark.parametrize( | |||
"reproducible", | |||
[True, False] | |||
) | |||
def test_repalce_sampler(self, dist_sampler, reproducible): | |||
""" | |||
测试set_dist_repro_dataloader函数 | |||
""" | |||
dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True) | |||
res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) | |||
# @pytest.mark.parametrize( | |||
# "dist_sampler", [ | |||
# "dist", | |||
# RandomBatchSampler(BatchSampler(PaddleRandomMaxDataset(320, 10)), 32, False), | |||
# RandomSampler(PaddleRandomMaxDataset(320, 10)) | |||
# ] | |||
# ) | |||
# @pytest.mark.parametrize( | |||
# "reproducible", | |||
# [True, False] | |||
# ) | |||
# def test_set_dist_repro_dataloader(self, dist_sampler, reproducible): | |||
# """ | |||
# 测试set_dist_repro_dataloader函数 | |||
# """ | |||
# dataloader = DataLoader(PaddleRandomMaxDataset(320, 10), batch_size=100, shuffle=True) | |||
# res = self.driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) | |||
class TestPaddleDriverFunctions: | |||
""" | |||
@@ -183,7 +171,7 @@ class TestPaddleDriverFunctions: | |||
@classmethod | |||
def setup_class(self): | |||
model = PaddleNormalModel_Classification_1(10, 32) | |||
self.driver = PaddleSingleDriver(model, device="gpu") | |||
self.driver = PaddleSingleDriver(model, device="cpu") | |||
def test_check_single_optimizer_legality(self): | |||
""" | |||
@@ -198,7 +186,7 @@ class TestPaddleDriverFunctions: | |||
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
# 传入torch的optimizer时,应该报错ValueError | |||
with self.assertRaises(ValueError) as cm: | |||
with pytest.raises(ValueError): | |||
self.driver.set_optimizers(optimizer) | |||
def test_check_optimizers_legality(self): | |||
@@ -218,7 +206,7 @@ class TestPaddleDriverFunctions: | |||
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) | |||
] | |||
with self.assertRaises(ValueError) as cm: | |||
with pytest.raises(ValueError): | |||
self.driver.set_optimizers(optimizers) | |||
def test_check_dataloader_legality_in_train(self): | |||
@@ -230,7 +218,7 @@ class TestPaddleDriverFunctions: | |||
# batch_size 和 batch_sampler 均为 None 的情形 | |||
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||
with self.assertRaises(ValueError) as cm: | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
# 创建torch的dataloader | |||
@@ -238,7 +226,7 @@ class TestPaddleDriverFunctions: | |||
TorchNormalDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
with self.assertRaises(ValueError) as cm: | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
def test_check_dataloader_legacy_in_test(self): | |||
@@ -257,11 +245,12 @@ class TestPaddleDriverFunctions: | |||
"train": paddle.io.DataLoader(PaddleNormalDataset()), | |||
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) | |||
} | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
# 传入的不是dict,应该报错 | |||
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) | |||
with self.assertRaises(ValueError) as cm: | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
# 创建torch的dataloader | |||
@@ -274,7 +263,7 @@ class TestPaddleDriverFunctions: | |||
batch_size=32, shuffle=True | |||
) | |||
dataloader = {"train": train_loader, "test": test_loader} | |||
with self.assertRaises(ValueError) as cm: | |||
with pytest.raises(ValueError): | |||
PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
def test_tensor_to_numeric(self): | |||
@@ -284,25 +273,25 @@ class TestPaddleDriverFunctions: | |||
# 单个张量 | |||
tensor = paddle.to_tensor(3) | |||
res = PaddleSingleDriver.tensor_to_numeric(tensor) | |||
self.assertEqual(res, 3) | |||
assert res == 3 | |||
tensor = paddle.rand((3, 4)) | |||
res = PaddleSingleDriver.tensor_to_numeric(tensor) | |||
self.assertListEqual(res, tensor.tolist()) | |||
assert res == tensor.tolist() | |||
# 张量list | |||
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||
res = PaddleSingleDriver.tensor_to_numeric(tensor_list) | |||
self.assertTrue(res, list) | |||
assert isinstance(res, list) | |||
tensor_list = [t.tolist() for t in tensor_list] | |||
self.assertListEqual(res, tensor_list) | |||
assert res == tensor_list | |||
# 张量tuple | |||
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) | |||
res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple) | |||
self.assertTrue(res, tuple) | |||
assert isinstance(res, tuple) | |||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||
self.assertTupleEqual(res, tensor_tuple) | |||
assert res == tensor_tuple | |||
# 张量dict | |||
tensor_dict = { | |||
@@ -317,29 +306,29 @@ class TestPaddleDriverFunctions: | |||
} | |||
res = PaddleSingleDriver.tensor_to_numeric(tensor_dict) | |||
self.assertIsInstance(res, dict) | |||
self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist()) | |||
self.assertIsInstance(res["list"], list) | |||
assert isinstance(res, dict) | |||
assert res["tensor"] == tensor_dict["tensor"].tolist() | |||
assert isinstance(res["list"], list) | |||
for r, d in zip(res["list"], tensor_dict["list"]): | |||
self.assertListEqual(r, d.tolist()) | |||
self.assertIsInstance(res["int"], int) | |||
self.assertIsInstance(res["string"], str) | |||
self.assertIsInstance(res["dict"], dict) | |||
self.assertIsInstance(res["dict"]["list"], list) | |||
assert r == d.tolist() | |||
assert isinstance(res["int"], int) | |||
assert isinstance(res["string"], str) | |||
assert isinstance(res["dict"], dict) | |||
assert isinstance(res["dict"]["list"], list) | |||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||
self.assertListEqual(r, d.tolist()) | |||
self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist()) | |||
assert r == d.tolist() | |||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | |||
def test_set_model_mode(self): | |||
""" | |||
测试set_model_mode函数 | |||
""" | |||
self.driver.set_model_mode("train") | |||
self.assertTrue(self.driver.model.training) | |||
assert self.driver.model.training | |||
self.driver.set_model_mode("eval") | |||
self.assertFalse(self.driver.model.training) | |||
assert not self.driver.model.training | |||
# 应该报错 | |||
with self.assertRaises(AssertionError) as cm: | |||
with pytest.raises(AssertionError): | |||
self.driver.set_model_mode("test") | |||
def test_move_model_to_device_cpu(self): | |||
@@ -347,15 +336,15 @@ class TestPaddleDriverFunctions: | |||
测试move_model_to_device函数 | |||
""" | |||
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | |||
self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place()) | |||
assert self.driver.model.linear1.weight.place.is_cpu_place() | |||
def test_move_model_to_device_gpu(self): | |||
""" | |||
测试move_model_to_device函数 | |||
""" | |||
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu:0") | |||
self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place()) | |||
self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0) | |||
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") | |||
assert self.driver.model.linear1.weight.place.is_gpu_place() | |||
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0 | |||
def test_worker_init_function(self): | |||
""" | |||