Browse Source

提交tests/core/drivers/paddle_driver

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
23502a0753
6 changed files with 810 additions and 0 deletions
  1. +0
    -0
      tests/core/drivers/paddle_driver/__init__.py
  2. +288
    -0
      tests/core/drivers/paddle_driver/test_fleet.py
  3. +83
    -0
      tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py
  4. +268
    -0
      tests/core/drivers/paddle_driver/test_paddle_driver.py
  5. +167
    -0
      tests/core/drivers/paddle_driver/test_single_device.py
  6. +4
    -0
      tests/core/drivers/paddle_driver/test_utils.py

+ 0
- 0
tests/core/drivers/paddle_driver/__init__.py View File


+ 288
- 0
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -0,0 +1,288 @@
import pytest
import sys
import os
import numpy as np
from fastNLP.envs.set_backend import set_env
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle

set_env_on_import_paddle()
set_env("paddle")
import paddle
import paddle.distributed as dist
from paddle.io import DataLoader

from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm


############################################################################
#
# 测试PaddleFleetDriver的一些函数
#
############################################################################

@magic_argv_env_context
def test_move_data_to_device():
"""
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中
就不重复测试了
"""
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.move_data_to_device(paddle.rand((32, 64)))
finally:
synchronize_safe_rm("log")

dist.barrier()

@magic_argv_env_context
def test_is_distributed():
print(os.getenv("CUDA_VISIBLE_DEVICES"))
print(paddle.device.get_device())
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
assert driver.is_distributed() == True
finally:
synchronize_safe_rm("log")
dist.barrier()

@magic_argv_env_context
def test_get_no_sync_context():
"""
测试能否运行
"""
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
res = driver.get_no_sync_context()
finally:
synchronize_safe_rm("log")
dist.barrier()

@magic_argv_env_context
def test_is_global_zero():
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.is_global_zero()
finally:
synchronize_safe_rm("log")
dist.barrier()

@magic_argv_env_context
def test_unwrap_model():
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.unwrap_model()
finally:
synchronize_safe_rm("log")
dist.barrier()

@magic_argv_env_context
def test_get_local_rank():
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.get_local_rank()
finally:
synchronize_safe_rm("log")
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize(
"dist_sampler",
["dist", "unrepeatdist", RandomSampler(PaddleDataset_MNIST("train"))]
)
@pytest.mark.parametrize(
"reproducible",
[True, False]
)
def test_replace_sampler(dist_sampler, reproducible):
"""
测试replace_sampler
"""
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
driver.replace_sampler(dataloader, dist_sampler, reproducible)
finally:
synchronize_safe_rm("log")
dist.barrier()

############################################################################
#
# 测试单机多卡的训练情况
#
############################################################################

@magic_argv_env_context
class SingleMachineMultiGPUTrainingTestCase:
"""
测试在单机多卡上使用PaddleFleetDriver进行训练。
分布式训练用pytest会有些混乱
"""

def test_case1(self):

gpus = [0, 1]
lr = 0.0003
epochs = 20

paddle_model = PaddleNormalModel_Classification()

paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=lr)

train_dataset = PaddleDataset_MNIST("train")
test_dataset = PaddleDataset_MNIST("test")
loss_func = paddle.nn.CrossEntropyLoss()

dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True)

driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=gpus,
)
driver.set_optimizers(paddle_opt)
dataloader = driver.replace_sampler(dataloader)
driver.setup()
# 检查model_device
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}")

driver.barrier()

driver.zero_grad()
current_epoch_idx = 0
while current_epoch_idx < epochs:
epoch_loss, batch = 0, 0
driver.set_model_mode("train")
driver.set_sampler_epoch(dataloader, current_epoch_idx)
for batch, (img, label) in enumerate(dataloader):

img = paddle.to_tensor(img)
out = driver.train_step(img)
label + 1
loss = loss_func(out, label)
epoch_loss += loss.item()

if batch % 50 == 0:
print("epoch:{}, batch:{}, loss: {}, rank:{}".format(current_epoch_idx, batch, loss.item(), driver.local_rank))

driver.backward(loss)
driver.step()
driver.zero_grad()
driver.barrier()
current_epoch_idx += 1

# test
correct = 0
driver.set_model_mode("eval")
for img, label in test_dataset:

img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1))
out = driver.test_step(img)
res = paddle.nn.functional.softmax(out).argmax().item()
label = label.item()
if res == label:
correct += 1

print("{} / {}, acc: {}".format(correct, len(test_dataset), correct / len(test_dataset)))

+ 83
- 0
tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py View File

@@ -0,0 +1,83 @@
import pytest

from fastNLP.envs.set_backend import set_env
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle

set_env_on_import_paddle()
set_env("paddle")
import paddle

from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification


def test_incorrect_driver():

with pytest.raises(ValueError):
driver = initialize_paddle_driver("torch")

@pytest.mark.parametrize(
"device",
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"]
)
def test_get_single_device(device):
"""
测试正常情况下初始化PaddleSingleDriver的情况
"""

model = PaddleNormalModel_Classification(2, 100)
driver = initialize_paddle_driver("paddle", device, model)

assert isinstance(driver, PaddleSingleDriver)

@pytest.mark.parametrize(
"device",
["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"]
)
def test_get_single_device_with_visiblde_devices(device):
"""
测试 CUDA_VISIBLE_DEVICES 启动时初始化PaddleSingleDriver的情况
"""
# TODO

model = PaddleNormalModel_Classification(2, 100)
driver = initialize_paddle_driver("paddle", device, model)

assert isinstance(driver, PaddleSingleDriver)

@pytest.mark.parametrize(
"device",
[[1, 2, 3]]
)
def test_get_fleet(device):
"""
测试 fleet 多卡的初始化情况
"""

model = PaddleNormalModel_Classification(2, 100)
driver = initialize_paddle_driver("paddle", device, model)

assert isinstance(driver, PaddleFleetDriver)

@pytest.mark.parametrize(
"device",
[[1,2,3]]
)
def test_get_fleet(device):
"""
测试 launch 启动 fleet 多卡的初始化情况
"""
# TODO

model = PaddleNormalModel_Classification(2, 100)
driver = initialize_paddle_driver("paddle", device, model)

assert isinstance(driver, PaddleFleetDriver)

def test_device_out_of_range(device):
"""
测试传入的device超过范围的情况
"""
pass

+ 268
- 0
tests/core/drivers/paddle_driver/test_paddle_driver.py View File

@@ -0,0 +1,268 @@
import unittest

import torch
from fastNLP.envs.set_env import set_env
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle

set_env_on_import_paddle()
set_env("paddle")
import paddle
from paddle.io import Dataset, DataLoader

from fastNLP.core.drivers.paddle_driver.paddle_driver import PaddleDriver


class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()

self.fc1 = paddle.nn.Linear(784, 64)
self.fc2 = paddle.nn.Linear(64, 32)
self.fc3 = paddle.nn.Linear(32, 10)
self.fc4 = paddle.nn.Linear(10, 10)

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)

return x


class PaddleDataset(Dataset):
def __init__(self):
super(PaddleDataset, self).__init__()
self.items = [paddle.rand((3, 4)) for i in range(320)]

def __len__(self):
return len(self.items)

def __getitem__(self, idx):
return self.items[idx]


class TorchNet(torch.nn.Module):
def __init__(self):
super(TorchNet, self).__init__()

self.torch_fc1 = torch.nn.Linear(10, 10)
self.torch_softmax = torch.nn.Softmax(0)
self.torch_conv2d1 = torch.nn.Conv2d(10, 10, 3)
self.torch_tensor = torch.ones(3, 3)
self.torch_param = torch.nn.Parameter(torch.ones(4, 4))


class TorchDataset(torch.utils.data.Dataset):
def __init__(self):
super(TorchDataset, self).__init__()
self.items = [torch.ones(3, 4) for i in range(320)]

def __len__(self):
return len(self.items)

def __getitem__(self, idx):
return self.items[idx]


class PaddleDriverTestCase(unittest.TestCase):
"""
PaddleDriver的测试类,由于类的特殊性仅测试部分函数,其它的由PaddleSingleDriver和PaddleFleetDriver完成测试
"""

def setUp(self):
model = Net()
self.driver = PaddleDriver(model)

def test_check_single_optimizer_legacy(self):
"""
测试传入单个optimizer时的表现
"""
optimizer = paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
)

self.driver.set_optimizers(optimizer)

optimizer = torch.optim.Adam(TorchNet().parameters(), 0.01)
# 传入torch的optimizer时,应该报错ValueError
with self.assertRaises(ValueError) as cm:
self.driver.set_optimizers(optimizer)

def test_check_optimizers_legacy(self):
"""
测试传入optimizer list的表现
"""
optimizers = [
paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
) for i in range(10)
]

self.driver.set_optimizers(optimizers)

optimizers += [
torch.optim.Adam(TorchNet().parameters(), 0.01)
]

with self.assertRaises(ValueError) as cm:
self.driver.set_optimizers(optimizers)

def test_check_dataloader_legacy_in_train(self):
"""
测试is_train参数为True时,_check_dataloader_legality函数的表现
"""
dataloader = paddle.io.DataLoader(PaddleDataset())
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True)

# 创建torch的dataloader
dataloader = torch.utils.data.DataLoader(
TorchDataset(),
batch_size=32, shuffle=True
)
with self.assertRaises(ValueError) as cm:
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", True)

def test_check_dataloader_legacy_in_test(self):
"""
测试is_train参数为False时,_check_dataloader_legality函数的表现
"""
# 此时传入的应该是dict
dataloader = {"train": paddle.io.DataLoader(PaddleDataset()), "test":paddle.io.DataLoader(PaddleDataset())}
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# 传入的不是dict,应该报错
dataloader = paddle.io.DataLoader(PaddleDataset())
with self.assertRaises(ValueError) as cm:
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False)

# 创建torch的dataloader
train_loader = torch.utils.data.DataLoader(
TorchDataset(),
batch_size=32, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
TorchDataset(),
batch_size=32, shuffle=True
)
dataloader = {"train": train_loader, "test": test_loader}
with self.assertRaises(ValueError) as cm:
PaddleDriver._check_dataloader_legality(dataloader, "dataloader", False)

def test_tensor_to_numeric(self):
"""
测试tensor_to_numeric函数
"""
# 单个张量
tensor = paddle.to_tensor(3)
res = PaddleDriver.tensor_to_numeric(tensor)
self.assertEqual(res, 3)

tensor = paddle.rand((3, 4))
res = PaddleDriver.tensor_to_numeric(tensor)
self.assertListEqual(res, tensor.tolist())

# 张量list
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)]
res = PaddleDriver.tensor_to_numeric(tensor_list)
self.assertTrue(res, list)
tensor_list = [t.tolist() for t in tensor_list]
self.assertListEqual(res, tensor_list)

# 张量tuple
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)])
res = PaddleDriver.tensor_to_numeric(tensor_tuple)
self.assertTrue(res, tuple)
tensor_tuple = tuple([t.tolist() for t in tensor_tuple])
self.assertTupleEqual(res, tensor_tuple)

# 张量dict
tensor_dict = {
"tensor": paddle.rand((3, 4)),
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"dict":{
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"tensor": paddle.rand((3, 4))
},
"int": 2,
"string": "test string"
}

res = PaddleDriver.tensor_to_numeric(tensor_dict)
self.assertIsInstance(res, dict)
self.assertListEqual(res["tensor"], tensor_dict["tensor"].tolist())
self.assertIsInstance(res["list"], list)
for r, d in zip(res["list"], tensor_dict["list"]):
self.assertListEqual(r, d.tolist())
self.assertIsInstance(res["int"], int)
self.assertIsInstance(res["string"], str)
self.assertIsInstance(res["dict"], dict)
self.assertIsInstance(res["dict"]["list"], list)
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]):
self.assertListEqual(r, d.tolist())
self.assertListEqual(res["dict"]["tensor"], tensor_dict["dict"]["tensor"].tolist())

def test_set_model_mode(self):
"""
测试set_model_mode函数
"""
self.driver.set_model_mode("train")
self.assertTrue(self.driver.model.training)
self.driver.set_model_mode("eval")
self.assertFalse(self.driver.model.training)
# 应该报错
with self.assertRaises(AssertionError) as cm:
self.driver.set_model_mode("test")

def test_move_model_to_device_cpu(self):
"""
测试move_model_to_device函数
"""
PaddleDriver.move_model_to_device(self.driver.model, "cpu")
self.assertTrue(self.driver.model.fc1.weight.place.is_cpu_place())

def test_move_model_to_device_gpu(self):
"""
测试move_model_to_device函数
"""
PaddleDriver.move_model_to_device(self.driver.model, "gpu:0")
self.assertTrue(self.driver.model.fc1.weight.place.is_gpu_place())
self.assertEqual(self.driver.model.fc1.weight.place.gpu_device_id(), 0)

def test_worker_init_function(self):
"""
测试worker_init_function
"""
# 先确保不影响运行
# TODO:正确性
PaddleDriver.worker_init_function(0)

def test_set_deterministic_dataloader(self):
"""
测试set_deterministic_dataloader
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleDataset())
self.driver.set_deterministic_dataloader(dataloader)

def test_set_sampler_epoch(self):
"""
测试set_sampler_epoch
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleDataset())
self.driver.set_sampler_epoch(dataloader, 0)

def test_get_dataloader_args(self):
"""
测试get_dataloader_args
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleDataset())
res = PaddleDriver.get_dataloader_args(dataloader)

+ 167
- 0
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -0,0 +1,167 @@
import pytest

from fastNLP.envs.set_backend import set_env
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle

set_env_on_import_paddle()
set_env("paddle")
import paddle
from paddle.io import DataLoader, BatchSampler

from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers.reproducible_sampler import ReproducibleBatchSampler, RandomSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset
from fastNLP.core import synchronize_safe_rm


############################################################################
#
# 测试save和load相关的功能
#
############################################################################

def generate_random_driver(features, labels):
"""
生成driver
"""
model = PaddleNormalModel_Classification(labels, features)
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01)
driver = PaddleSingleDriver(model)
driver.set_optimizers(opt)

return driver

@pytest.fixture
def prepare_test_save_load():
dataset = PaddleRandomDataset(num_of_data=320, features=64, labels=8)
dataloader = DataLoader(dataset, batch_size=32)
driver1, driver2 = generate_random_driver(64, 8), generate_random_driver(64, 8)
return driver1, driver2, dataloader

def test_save_and_load(prepare_test_save_load):
"""
测试save和load函数
TODO optimizer的state_dict为空,暂时不测试
"""

try:
path = "model.pdparams"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save(path, {})
driver2.load(path)

for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
synchronize_safe_rm(path)

def test_save_and_load_state_dict(prepare_test_save_load):
"""
测试save和load函数
TODO optimizer的state_dict为空,暂时不测试
"""
try:
path = "model.pdparams"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save_model(path)
driver2.model.load_dict(driver2.load_model(path))

for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
synchronize_safe_rm(path)

def test_save_and_load_whole_model(prepare_test_save_load):
"""
测试save和load函数
TODO optimizer的state_dict为空,暂时不测试
"""
try:
path = "model.pdparams"
driver1, driver2, dataloader = prepare_test_save_load

driver1.save_model(path, only_state_dict=False, input_spec=[next(iter(dataloader))["x"]])
driver2.model = driver2.load_model(path, load_dict=False)

for batch in dataloader:
res1 = driver1.validate_step(batch)
res2 = driver2.validate_step(batch)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
synchronize_safe_rm(path)


class TestSingleDeviceFunction:
"""
测试其它函数的测试例
"""

@classmethod
def setup_class(cls):
model = PaddleNormalModel_Classification(10, 784)
cls.driver = PaddleSingleDriver(model)

def test_unwrap_model(self):
"""
测试能否运行
"""
res = self.driver.unwrap_model()

def test_check_evaluator_mode(self):
"""
这两个函数没有返回值和抛出异常,仅检查是否有import错误等影响运行的因素
"""
self.driver.check_evaluator_mode("validate")
self.driver.check_evaluator_mode("test")

def test_get_model_device_cpu(self):
"""
测试get_model_device
"""
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "cpu")
device = self.driver.get_model_device()
assert device == "cpu", device

def test_get_model_device_gpu(self):
"""
测试get_model_device
"""
self.driver = PaddleSingleDriver(PaddleNormalModel_Classification(10, 784), "gpu:0")
device = self.driver.get_model_device()
assert device == "gpu:0", device

def test_is_distributed(self):
assert self.driver.is_distributed() == False

def test_move_data_to_device(self):
"""
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中
就不重复测试了
"""
self.driver.move_data_to_device(paddle.rand((32, 64)))

@pytest.mark.parametrize(
"dist_sampler",
["dist", ReproducibleBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))]
)
@pytest.mark.parametrize(
"reproducible",
[True, False]
)
def test_repalce_sampler(self, dist_sampler, reproducible):
"""
测试replace_sampler函数
"""
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)

res = self.driver.replace_sampler(dataloader, dist_sampler, reproducible)

+ 4
- 0
tests/core/drivers/paddle_driver/test_utils.py View File

@@ -0,0 +1,4 @@
import unittest

import paddle
from paddle.io import Dataset, DataLoader, DistributedBatchSampler

Loading…
Cancel
Save