@@ -0,0 +1,288 @@ | |||
import pytest | |||
import sys | |||
import os | |||
import numpy as np | |||
from fastNLP.envs.set_backend import set_env | |||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||
set_env_on_import_paddle() | |||
set_env("paddle") | |||
import paddle | |||
import paddle.distributed as dist | |||
from paddle.io import DataLoader | |||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.core import synchronize_safe_rm | |||
############################################################################ | |||
# | |||
# 测试PaddleFleetDriver的一些函数 | |||
# | |||
############################################################################ | |||
@magic_argv_env_context | |||
def test_move_data_to_device(): | |||
""" | |||
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||
就不重复测试了 | |||
""" | |||
try: | |||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
driver = PaddleFleetDriver( | |||
model=paddle_model, | |||
parallel_device=[0,1], | |||
) | |||
driver.set_optimizers(paddle_opt) | |||
# 区分launch和子进程setup的时候 | |||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
with pytest.raises(SystemExit) as e: | |||
driver.setup() | |||
assert e.value.code == 0 | |||
return | |||
else: | |||
driver.setup() | |||
driver.move_data_to_device(paddle.rand((32, 64))) | |||
finally: | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
@magic_argv_env_context | |||
def test_is_distributed(): | |||
print(os.getenv("CUDA_VISIBLE_DEVICES")) | |||
print(paddle.device.get_device()) | |||
try: | |||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
driver = PaddleFleetDriver( | |||
model=paddle_model, | |||
parallel_device=[0,1], | |||
) | |||
driver.set_optimizers(paddle_opt) | |||
# 区分launch和子进程setup的时候 | |||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
with pytest.raises(SystemExit) as e: | |||
driver.setup() | |||
assert e.value.code == 0 | |||
return | |||
else: | |||
driver.setup() | |||
assert driver.is_distributed() == True | |||
finally: | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
@magic_argv_env_context | |||
def test_get_no_sync_context(): | |||
""" | |||
测试能否运行 | |||
""" | |||
try: | |||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
driver = PaddleFleetDriver( | |||
model=paddle_model, | |||
parallel_device=[0,1], | |||
) | |||
driver.set_optimizers(paddle_opt) | |||
# 区分launch和子进程setup的时候 | |||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
with pytest.raises(SystemExit) as e: | |||
driver.setup() | |||
assert e.value.code == 0 | |||
return | |||
else: | |||
driver.setup() | |||
res = driver.get_no_sync_context() | |||
finally: | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
@magic_argv_env_context | |||
def test_is_global_zero(): | |||
try: | |||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
driver = PaddleFleetDriver( | |||
model=paddle_model, | |||
parallel_device=[0,1], | |||
) | |||
driver.set_optimizers(paddle_opt) | |||
# 区分launch和子进程setup的时候 | |||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
with pytest.raises(SystemExit) as e: | |||
driver.setup() | |||
assert e.value.code == 0 | |||
return | |||
else: | |||
driver.setup() | |||
driver.is_global_zero() | |||
finally: | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
@magic_argv_env_context | |||
def test_unwrap_model(): | |||
try: | |||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
driver = PaddleFleetDriver( | |||
model=paddle_model, | |||
parallel_device=[0,1], | |||
) | |||
driver.set_optimizers(paddle_opt) | |||
# 区分launch和子进程setup的时候 | |||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
with pytest.raises(SystemExit) as e: | |||
driver.setup() | |||
assert e.value.code == 0 | |||
return | |||
else: | |||
driver.setup() | |||
driver.unwrap_model() | |||
finally: | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
@magic_argv_env_context | |||
def test_get_local_rank(): | |||
try: | |||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
driver = PaddleFleetDriver( | |||
model=paddle_model, | |||
parallel_device=[0,1], | |||
) | |||
driver.set_optimizers(paddle_opt) | |||
# 区分launch和子进程setup的时候 | |||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
with pytest.raises(SystemExit) as e: | |||
driver.setup() | |||
assert e.value.code == 0 | |||
return | |||
else: | |||
driver.setup() | |||
driver.get_local_rank() | |||
finally: | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
@magic_argv_env_context | |||
@pytest.mark.parametrize( | |||
"dist_sampler", | |||
["dist", "unrepeatdist", RandomSampler(PaddleDataset_MNIST("train"))] | |||
) | |||
@pytest.mark.parametrize( | |||
"reproducible", | |||
[True, False] | |||
) | |||
def test_replace_sampler(dist_sampler, reproducible): | |||
""" | |||
测试replace_sampler | |||
""" | |||
try: | |||
paddle_model = PaddleNormalModel_Classification(10, 784) | |||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) | |||
driver = PaddleFleetDriver( | |||
model=paddle_model, | |||
parallel_device=[0,1], | |||
) | |||
driver.set_optimizers(paddle_opt) | |||
# 区分launch和子进程setup的时候 | |||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
with pytest.raises(SystemExit) as e: | |||
driver.setup() | |||
assert e.value.code == 0 | |||
return | |||
else: | |||
driver.setup() | |||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | |||
driver.replace_sampler(dataloader, dist_sampler, reproducible) | |||
finally: | |||
synchronize_safe_rm("log") | |||
dist.barrier() | |||
############################################################################ | |||
# | |||
# 测试单机多卡的训练情况 | |||
# | |||
############################################################################ | |||
@magic_argv_env_context | |||
class SingleMachineMultiGPUTrainingTestCase: | |||
""" | |||
测试在单机多卡上使用PaddleFleetDriver进行训练。 | |||
分布式训练用pytest会有些混乱 | |||
""" | |||
def test_case1(self): | |||
gpus = [0, 1] | |||
lr = 0.0003 | |||
epochs = 20 | |||
paddle_model = PaddleNormalModel_Classification() | |||
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=lr) | |||
train_dataset = PaddleDataset_MNIST("train") | |||
test_dataset = PaddleDataset_MNIST("test") | |||
loss_func = paddle.nn.CrossEntropyLoss() | |||
dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True) | |||
driver = PaddleFleetDriver( | |||
model=paddle_model, | |||
parallel_device=gpus, | |||
) | |||
driver.set_optimizers(paddle_opt) | |||
dataloader = driver.replace_sampler(dataloader) | |||
driver.setup() | |||
# 检查model_device | |||
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") | |||
driver.barrier() | |||
driver.zero_grad() | |||
current_epoch_idx = 0 | |||
while current_epoch_idx < epochs: | |||
epoch_loss, batch = 0, 0 | |||
driver.set_model_mode("train") | |||
driver.set_sampler_epoch(dataloader, current_epoch_idx) | |||
for batch, (img, label) in enumerate(dataloader): | |||
img = paddle.to_tensor(img) | |||
out = driver.train_step(img) | |||
label + 1 | |||
loss = loss_func(out, label) | |||
epoch_loss += loss.item() | |||
if batch % 50 == 0: | |||
print("epoch:{}, batch:{}, loss: {}, rank:{}".format(current_epoch_idx, batch, loss.item(), driver.local_rank)) | |||
driver.backward(loss) | |||
driver.step() | |||
driver.zero_grad() | |||
driver.barrier() | |||
current_epoch_idx += 1 | |||
# test | |||
correct = 0 | |||
driver.set_model_mode("eval") | |||
for img, label in test_dataset: | |||
img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1)) | |||
out = driver.test_step(img) | |||
res = paddle.nn.functional.softmax(out).argmax().item() | |||
label = label.item() | |||
if res == label: | |||
correct += 1 | |||
print("{} / {}, acc: {}".format(correct, len(test_dataset), correct / len(test_dataset))) |
@@ -0,0 +1,83 @@ | |||
import pytest | |||
from fastNLP.envs.set_backend import set_env | |||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||
set_env_on_import_paddle() | |||
set_env("paddle") | |||
import paddle | |||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | |||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||
def test_incorrect_driver(): | |||
with pytest.raises(ValueError): | |||
driver = initialize_paddle_driver("torch") | |||
@pytest.mark.parametrize( | |||
"device", | |||
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] | |||
) | |||
def test_get_single_device(device): | |||
""" | |||
测试正常情况下初始化PaddleSingleDriver的情况 | |||
""" | |||
model = PaddleNormalModel_Classification(2, 100) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
assert isinstance(driver, PaddleSingleDriver) | |||
@pytest.mark.parametrize( | |||
"device", | |||
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] | |||
) | |||
def test_get_single_device_with_visiblde_devices(device): | |||
""" | |||
测试 CUDA_VISIBLE_DEVICES 启动时初始化PaddleSingleDriver的情况 | |||
""" | |||
# TODO | |||
model = PaddleNormalModel_Classification(2, 100) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
assert isinstance(driver, PaddleSingleDriver) | |||
@pytest.mark.parametrize( | |||
"device", | |||
[[1, 2, 3]] | |||
) | |||
def test_get_fleet(device): | |||
""" | |||
测试 fleet 多卡的初始化情况 | |||
""" | |||
model = PaddleNormalModel_Classification(2, 100) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
assert isinstance(driver, PaddleFleetDriver) | |||
@pytest.mark.parametrize( | |||
"device", | |||
[[1,2,3]] | |||
) | |||
def test_get_fleet(device): | |||
""" | |||
测试 launch 启动 fleet 多卡的初始化情况 | |||
""" | |||
# TODO | |||
model = PaddleNormalModel_Classification(2, 100) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
assert isinstance(driver, PaddleFleetDriver) | |||
def test_device_out_of_range(device): | |||
""" | |||
测试传入的device超过范围的情况 | |||
""" | |||
pass |
@@ -0,0 +1,268 @@ | |||
import unittest | |||
import torch | |||
from fastNLP.envs.set_env import set_env | |||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||
set_env_on_import_paddle() | |||
set_env("paddle") | |||
import paddle | |||
from paddle.io import Dataset, DataLoader | |||
from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver | |||
class Net(paddle.nn.Layer): | |||
def __init__(self): | |||
super(Net, self).__init__() | |||
self.fc1 = paddle.nn.Linear(784, 64) | |||
self.fc2 = paddle.nn.Linear(64, 32) | |||
self.fc3 = paddle.nn.Linear(32, 10) | |||
self.fc4 = paddle.nn.Linear(10, 10) | |||
def forward(self, x): | |||
x = self.fc1(x) | |||
x = self.fc2(x) | |||
x = self.fc3(x) | |||
x = self.fc4(x) | |||
return x | |||
class PaddleDataset(Dataset): | |||
def __init__(self): | |||
super(PaddleDataset, self).__init__() | |||
self.items = [paddle.rand((3, 4)) for i in range(320)] | |||
def __len__(self): | |||
return len(self.items) | |||
def __getitem__(self, idx): | |||
return self.items[idx] | |||
class TorchNet(torch.nn.Module): | |||
def __init__(self): | |||
super(TorchNet, self).__init__() | |||
self.torch_fc1 = torch.nn.Linear(10, 10) | |||
self.torch_softmax = torch.nn.Softmax(0) | |||
self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3) | |||
self.torch_tensor = torch.ones(3, 3) | |||
self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) | |||
class TorchDataset(torch.utils.data.Dataset): | |||
def __init__(self): | |||
super(TorchDataset, self).__init__() | |||
self.items = [torch.ones(3, 4) for i in range(320)] | |||
def __len__(self): | |||
return len(self.items) | |||
def __getitem__(self, idx): | |||
return self.items[idx] | |||
class PaddleDriverTestCase(unittest.TestCase): | |||
""" | |||
PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试 | |||
""" | |||
def setUp(self): | |||
model = Net() | |||
self.driver = PaddleDriver(model) | |||
def test_check_single_optimizer_legacy(self): | |||
""" | |||
测试传入单个optimizer时的表现 | |||
""" | |||
optimizer = paddle.optimizer.Adam( | |||
parameters=self.driver.model.parameters(), | |||
learning_rate=0.01 | |||
) | |||
self.driver.set_optimizers(optimizer) | |||
optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01) | |||
# 传入torch的optimizer时,应该报错ValueError | |||
with self.assertRaises(ValueError) as cm: | |||
self.driver.set_optimizers(optimizer) | |||
def test_check_optimizers_legacy(self): | |||
""" | |||
测试传入optimizer list的表现 | |||
""" | |||
optimizers = [ | |||
paddle.optimizer.Adam( | |||
parameters=self.driver.model.parameters(), | |||
learning_rate=0.01 | |||
) for i in range(10) | |||
] | |||
self.driver.set_optimizers(optimizers) | |||
optimizers += [ | |||
torch.optim.Adam(TorchNet().parameters(), 0.01) | |||
] | |||
with self.assertRaises(ValueError) as cm: | |||
self.driver.set_optimizers(optimizers) | |||
def test_check_dataloader_legacy_in_train(self): | |||
""" | |||
测试is_train参数为True时,_check_dataloader_legality函数的表现 | |||
""" | |||
dataloader = paddle.io.DataLoader(PaddleDataset()) | |||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
# 创建torch的dataloader | |||
dataloader = torch.utils.data.DataLoader( | |||
TorchDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
with self.assertRaises(ValueError) as cm: | |||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True) | |||
def test_check_dataloader_legacy_in_test(self): | |||
""" | |||
测试is_train参数为False时,_check_dataloader_legality函数的表现 | |||
""" | |||
# 此时传入的应该是dict | |||
dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())} | |||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
# 传入的不是dict,应该报错 | |||
dataloader = paddle.io.DataLoader(PaddleDataset()) | |||
with self.assertRaises(ValueError) as cm: | |||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
# 创建torch的dataloader | |||
train_loader = torch.utils.data.DataLoader( | |||
TorchDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
test_loader = torch.utils.data.DataLoader( | |||
TorchDataset(), | |||
batch_size=32, shuffle=True | |||
) | |||
dataloader = {"train": train_loader, "test": test_loader} | |||
with self.assertRaises(ValueError) as cm: | |||
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False) | |||
def test_tensor_to_numeric(self): | |||
""" | |||
测试tensor_to_numeric函数 | |||
""" | |||
# 单个张量 | |||
tensor = paddle.to_tensor(3) | |||
res = PaddleDriver.tensor_to_numeric(tensor) | |||
self.assertEqual(res, 3) | |||
tensor = paddle.rand((3, 4)) | |||
res = PaddleDriver.tensor_to_numeric(tensor) | |||
self.assertListEqual(res, tensor.tolist()) | |||
# 张量list | |||
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||
res = PaddleDriver.tensor_to_numeric(tensor_list) | |||
self.assertTrue(res, list) | |||
tensor_list = [t.tolist() for t in tensor_list] | |||
self.assertListEqual(res, tensor_list) | |||
# 张量tuple | |||
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) | |||
res = PaddleDriver.tensor_to_numeric(tensor_tuple) | |||
self.assertTrue(res, tuple) | |||
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) | |||
self.assertTupleEqual(res, tensor_tuple) | |||
# 张量dict | |||
tensor_dict = { | |||
"tensor": paddle.rand((3, 4)), | |||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||
"dict":{ | |||
"list": [paddle.rand((6, 4, 2)) for i in range(10)], | |||
"tensor": paddle.rand((3, 4)) | |||
}, | |||
"int": 2, | |||
"string": "test string" | |||
} | |||
res = PaddleDriver.tensor_to_numeric(tensor_dict) | |||
self.assertIsInstance(res, dict) | |||
self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist()) | |||
self.assertIsInstance(res["list"], list) | |||
for r, d in zip(res["list"], tensor_dict["list"]): | |||
self.assertListEqual(r, d.tolist()) | |||
self.assertIsInstance(res["int"], int) | |||
self.assertIsInstance(res["string"], str) | |||
self.assertIsInstance(res["dict"], dict) | |||
self.assertIsInstance(res["dict"]["list"], list) | |||
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): | |||
self.assertListEqual(r, d.tolist()) | |||
self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist()) | |||
def test_set_model_mode(self): | |||
""" | |||
测试set_model_mode函数 | |||
""" | |||
self.driver.set_model_mode("train") | |||
self.assertTrue(self.driver.model.training) | |||
self.driver.set_model_mode("eval") | |||
self.assertFalse(self.driver.model.training) | |||
# 应该报错 | |||
with self.assertRaises(AssertionError) as cm: | |||
self.driver.set_model_mode("test") | |||
def test_move_model_to_device_cpu(self): | |||
""" | |||
测试move_model_to_device函数 | |||
""" | |||
PaddleDriver.move_model_to_device(self.driver.model, "cpu") | |||
self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place()) | |||
def test_move_model_to_device_gpu(self): | |||
""" | |||
测试move_model_to_device函数 | |||
""" | |||
PaddleDriver.move_model_to_device(self.driver.model, "gpu:0") | |||
self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place()) | |||
self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0) | |||
def test_worker_init_function(self): | |||
""" | |||
测试worker_init_function | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
PaddleDriver.worker_init_function(0) | |||
def test_set_deterministic_dataloader(self): | |||
""" | |||
测试set_deterministic_dataloader | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
dataloader = DataLoader(PaddleDataset()) | |||
self.driver.set_deterministic_dataloader(dataloader) | |||
def test_set_sampler_epoch(self): | |||
""" | |||
测试set_sampler_epoch | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
dataloader = DataLoader(PaddleDataset()) | |||
self.driver.set_sampler_epoch(dataloader, 0) | |||
def test_get_dataloader_args(self): | |||
""" | |||
测试get_dataloader_args | |||
""" | |||
# 先确保不影响运行 | |||
# TODO:正确性 | |||
dataloader = DataLoader(PaddleDataset()) | |||
res = PaddleDriver.get_dataloader_args(dataloader) |
@@ -0,0 +1,167 @@ | |||
import pytest | |||
from fastNLP.envs.set_backend import set_env | |||
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle | |||
set_env_on_import_paddle() | |||
set_env("paddle") | |||
import paddle | |||
from paddle.io import DataLoader, BatchSampler | |||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | |||
from fastNLP.core.samplers.reproducible_sampler import ReproducibleBatchSampler, RandomSampler | |||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification | |||
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset | |||
from fastNLP.core import synchronize_safe_rm | |||
############################################################################ | |||
# | |||
# 测试save和load相关的功能 | |||
# | |||
############################################################################ | |||
def generate_random_driver(features, labels): | |||
""" | |||
生成driver | |||
""" | |||
model = PaddleNormalModel_Classification(labels, features) | |||
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) | |||
driver = PaddleSingleDriver(model) | |||
driver.set_optimizers(opt) | |||
return driver | |||
@pytest.fixture | |||
def prepare_test_save_load(): | |||
dataset = PaddleRandomDataset(num_of_data=320, features=64, labels=8) | |||
dataloader = DataLoader(dataset, batch_size=32) | |||
driver1, driver2 = generate_random_driver(64, 8), generate_random_driver(64, 8) | |||
return driver1, driver2, dataloader | |||
def test_save_and_load(prepare_test_save_load): | |||
""" | |||
测试save和load函数 | |||
TODO optimizer的state_dict为空,暂时不测试 | |||
""" | |||
try: | |||
path = "model.pdparams" | |||
driver1, driver2, dataloader = prepare_test_save_load | |||
driver1.save(path, {}) | |||
driver2.load(path) | |||
for batch in dataloader: | |||
res1 = driver1.validate_step(batch) | |||
res2 = driver2.validate_step(batch) | |||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
finally: | |||
synchronize_safe_rm(path) | |||
def test_save_and_load_state_dict(prepare_test_save_load): | |||
""" | |||
测试save和load函数 | |||
TODO optimizer的state_dict为空,暂时不测试 | |||
""" | |||
try: | |||
path = "model.pdparams" | |||
driver1, driver2, dataloader = prepare_test_save_load | |||
driver1.save_model(path) | |||
driver2.model.load_dict(driver2.load_model(path)) | |||
for batch in dataloader: | |||
res1 = driver1.validate_step(batch) | |||
res2 = driver2.validate_step(batch) | |||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
finally: | |||
synchronize_safe_rm(path) | |||
def test_save_and_load_whole_model(prepare_test_save_load): | |||
""" | |||
测试save和load函数 | |||
TODO optimizer的state_dict为空,暂时不测试 | |||
""" | |||
try: | |||
path = "model.pdparams" | |||
driver1, driver2, dataloader = prepare_test_save_load | |||
driver1.save_model(path, only_state_dict=False, input_spec=[next(iter(dataloader))["x"]]) | |||
driver2.model = driver2.load_model(path, load_dict=False) | |||
for batch in dataloader: | |||
res1 = driver1.validate_step(batch) | |||
res2 = driver2.validate_step(batch) | |||
assert paddle.equal_all(res1["pred"], res2["pred"]) | |||
finally: | |||
synchronize_safe_rm(path) | |||
class TestSingleDeviceFunction: | |||
""" | |||
测试其它函数的测试例 | |||
""" | |||
@classmethod | |||
def setup_class(cls): | |||
model = PaddleNormalModel_Classification(10, 784) | |||
cls.driver = PaddleSingleDriver(model) | |||
def test_unwrap_model(self): | |||
""" | |||
测试能否运行 | |||
""" | |||
res = self.driver.unwrap_model() | |||
def test_check_evaluator_mode(self): | |||
""" | |||
这两个函数没有返回值和抛出异常,仅检查是否有import错误等影响运行的因素 | |||
""" | |||
self.driver.check_evaluator_mode("validate") | |||
self.driver.check_evaluator_mode("test") | |||
def test_get_model_device_cpu(self): | |||
""" | |||
测试get_model_device | |||
""" | |||
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "cpu") | |||
device = self.driver.get_model_device() | |||
assert device == "cpu", device | |||
def test_get_model_device_gpu(self): | |||
""" | |||
测试get_model_device | |||
""" | |||
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "gpu:0") | |||
device = self.driver.get_model_device() | |||
assert device == "gpu:0", device | |||
def test_is_distributed(self): | |||
assert self.driver.is_distributed() == False | |||
def test_move_data_to_device(self): | |||
""" | |||
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 | |||
就不重复测试了 | |||
""" | |||
self.driver.move_data_to_device(paddle.rand((32, 64))) | |||
@pytest.mark.parametrize( | |||
"dist_sampler", | |||
["dist", ReproducibleBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))] | |||
) | |||
@pytest.mark.parametrize( | |||
"reproducible", | |||
[True, False] | |||
) | |||
def test_repalce_sampler(self, dist_sampler, reproducible): | |||
""" | |||
测试replace_sampler函数 | |||
""" | |||
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) | |||
res = self.driver.replace_sampler(dataloader, dist_sampler, reproducible) |
@@ -0,0 +1,4 @@ | |||
import unittest | |||
import paddle | |||
from paddle.io import Dataset, DataLoader, DistributedBatchSampler |