Browse Source

PaddleSingleDriver的save load函数测试

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
9e97155312
1 changed files with 435 additions and 411 deletions
  1. +435
    -411
      tests/core/drivers/paddle_driver/test_single_device.py

+ 435
- 411
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -1,3 +1,4 @@
from dataclasses import replace
import os
from re import S
os.environ["FASTNLP_BACKEND"] = "paddle"
@@ -16,203 +17,303 @@ import paddle
from paddle.io import DataLoader, BatchSampler
import torch


############################################################################
#
# 测试save和load相关的功能
# 测试基类 PaddleDrvier 中的一些简单函数
#
############################################################################

def generate_random_driver(features, labels):
"""
生成driver
"""
model = PaddleNormalModel_Classification_1(labels, features)
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01)
driver = PaddleSingleDriver(model, device="cpu")
driver.set_optimizers(opt)
driver.setup()

return driver

@pytest.fixture
def prepare_test_save_load():
dataset = PaddleRandomMaxDataset(320, 10)
dataloader = DataLoader(dataset, batch_size=32)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader

@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict):
class TestPaddleDriverFunctions:
"""
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
使用 PaddleSingleDriver 测试基类的函数
"""

try:
path = "model.ckp"
@classmethod
def setup_class(self):
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(80, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
def test_check_single_optimizer_legality(self):
"""
测试传入单个optimizer时的表现
"""
optimizer = paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
)
num_consumed_batches = 2

# TODO 断点重训完善后在这里迭代几次
already_seen_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_set.update(batch)
self.driver.set_optimizers(optimizer)

sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
# 传入torch的optimizer时,应该报错ValueError
with pytest.raises(ValueError):
self.driver.set_optimizers(optimizer)

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
def test_check_optimizers_legality(self):
"""
测试传入optimizer list的表现
"""
optimizers = [
paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
) for i in range(10)
]

# 2. 检查 batch_sampler 是否被正确地加载和替换
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"]
self.driver.set_optimizers(optimizers)

# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
optimizers += [
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
]

assert paddle.equal_all(res1["pred"], res2["pred"])
with pytest.raises(ValueError):
self.driver.set_optimizers(optimizers)

# 4. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_batches = set()
for idx, batch in enumerate(replaced_loader):
left_batches.update(batch)
def test_check_dataloader_legality_in_train(self):
"""
测试is_train参数为True时,_check_dataloader_legality函数的表现
"""
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)

assert len(left_batches) + len(already_seen_set) == len(dataset)
assert len(left_batches | already_seen_set) == len(dataset)
# batch_size 和 batch_sampler 均为 None 的情形
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)

# 创建torch的dataloader
dataloader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)

finally:
synchronize_safe_rm(path)
def test_check_dataloader_legality_in_test(self):
"""
测试is_train参数为False时,_check_dataloader_legality函数的表现
"""
# 此时传入的应该是dict
dataloader = {
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset())
}
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict):
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""
# batch_size 和 batch_sampler 均为 None 的情形
dataloader = {
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
}
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

try:
path = "model.ckp"
# 传入的不是dict,应该报错
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(80, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler
# 创建torch的dataloader
train_loader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
num_consumed_batches = 2

# TODO 断点重训完善后在这里迭代几次
already_seen_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_set.update(batch)

sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])

# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False)
test_loader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
def test_tensor_to_numeric(self):
"""
测试tensor_to_numeric函数
"""
# 单个张量
tensor = paddle.to_tensor(3)
res = PaddleSingleDriver.tensor_to_numeric(tensor)
assert res == 3

# 2. 检查 sampler 是否被正确地加载和替换
replaced_loader = load_states["dataloader"]
tensor = paddle.rand((3, 4))
res = PaddleSingleDriver.tensor_to_numeric(tensor)
assert res == tensor.tolist()

assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"]
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 张量list
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)]
res = PaddleSingleDriver.tensor_to_numeric(tensor_list)
assert isinstance(res, list)
tensor_list = [t.tolist() for t in tensor_list]
assert res == tensor_list

# 3. 检查 model 的参数是否被正确加载
for batch in dataloader:
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
# 张量tuple
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)])
res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple)
assert isinstance(res, tuple)
tensor_tuple = tuple([t.tolist() for t in tensor_tuple])
assert res == tensor_tuple

assert paddle.equal_all(res1["pred"], res2["pred"])
# 张量dict
tensor_dict = {
"tensor": paddle.rand((3, 4)),
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"dict":{
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"tensor": paddle.rand((3, 4))
},
"int": 2,
"string": "test string"
}

# 4. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_batches = set()
for idx, batch in enumerate(replaced_loader):
left_batches.update(batch)
res = PaddleSingleDriver.tensor_to_numeric(tensor_dict)
assert isinstance(res, dict)
assert res["tensor"] == tensor_dict["tensor"].tolist()
assert isinstance(res["list"], list)
for r, d in zip(res["list"], tensor_dict["list"]):
assert r == d.tolist()
assert isinstance(res["int"], int)
assert isinstance(res["string"], str)
assert isinstance(res["dict"], dict)
assert isinstance(res["dict"]["list"], list)
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]):
assert r == d.tolist()
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()

assert len(left_batches) + len(already_seen_set) == len(dataset)
assert len(left_batches | already_seen_set) == len(dataset)
finally:
synchronize_safe_rm(path)
def test_set_model_mode(self):
"""
测试set_model_mode函数
"""
self.driver.set_model_mode("train")
assert self.driver.model.training
self.driver.set_model_mode("eval")
assert not self.driver.model.training
# 应该报错
with pytest.raises(AssertionError):
self.driver.set_model_mode("test")

@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
"""
测试 save_model 和 load_model 函数
"""
try:
path = "model"
driver1, driver2, dataloader = prepare_test_save_load
def test_move_model_to_device_cpu(self):
"""
测试move_model_to_device函数
"""
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu")
assert self.driver.model.linear1.weight.place.is_cpu_place()

if only_state_dict:
driver1.save_model(path, only_state_dict)
def test_move_model_to_device_gpu(self):
"""
测试move_model_to_device函数
"""
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu")
assert self.driver.model.linear1.weight.place.is_gpu_place()
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0

def test_worker_init_function(self):
"""
测试worker_init_function
"""
# 先确保不影响运行
# TODO:正确性
PaddleSingleDriver.worker_init_function(0)

def test_set_deterministic_dataloader(self):
"""
测试set_deterministic_dataloader
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleNormalDataset())
self.driver.set_deterministic_dataloader(dataloader)

def test_set_sampler_epoch(self):
"""
测试set_sampler_epoch
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleNormalDataset())
self.driver.set_sampler_epoch(dataloader, 0)

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args(self, batch_size, shuffle, drop_last):
"""
测试正常情况下 get_dataloader_args 的表现
"""
dataloader = DataLoader(
PaddleNormalDataset(),
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, BatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict)
assert isinstance(res.sampler, paddle.io.SequenceSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last

for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last):
"""
测试替换了 batch_sampler 后 get_dataloader_args 的表现
"""
dataset = PaddleNormalDataset()
dataloader = DataLoader(
dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle),
batch_size,
drop_last,
)
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)

assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")
assert isinstance(res.sampler, paddle.io.SequenceSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last):
"""
测试替换了 sampler 后 get_dataloader_args 的表现
"""
dataset = PaddleNormalDataset()
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last)
batch_sampler.sampler = RandomSampler(dataset, shuffle)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler,
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, BatchSampler)
assert isinstance(res.sampler, RandomSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last


############################################################################
#
# 测试 PaddleSingleDrvier 中的一些简单函数
#
############################################################################

class TestSingleDeviceFunction:
"""
@@ -242,6 +343,12 @@ class TestSingleDeviceFunction:
self.driver.move_data_to_device(paddle.rand((32, 64)))


############################################################################
#
# 测试 set_dist_repro_dataloader 函数
#
############################################################################

class TestSetDistReproDataloder:
"""
专门测试 set_dist_repro_dataloader 函数的类
@@ -423,287 +530,204 @@ class TestSetDistReproDataloder:
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset)
assert len(left_idxes | already_seen_idx) == len(self.dataset)

class TestPaddleDriverFunctions:
############################################################################
#
# 测试 save 和 load 相关的功能
#
############################################################################

def generate_random_driver(features, labels):
"""
使用 PaddleSingleDriver 测试基类的函数
生成driver
"""
model = PaddleNormalModel_Classification_1(labels, features)
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01)
driver = PaddleSingleDriver(model, device="cpu")
driver.set_optimizers(opt)
driver.setup()

@classmethod
def setup_class(self):
model = PaddleNormalModel_Classification_1(10, 32)
self.driver = PaddleSingleDriver(model, device="cpu")

def test_check_single_optimizer_legality(self):
"""
测试传入单个optimizer时的表现
"""
optimizer = paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
)

self.driver.set_optimizers(optimizer)
return driver

optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
# 传入torch的optimizer时,应该报错ValueError
with pytest.raises(ValueError):
self.driver.set_optimizers(optimizer)
@pytest.fixture
def prepare_test_save_load():
dataset = PaddleRandomMaxDataset(320, 10)
dataloader = DataLoader(dataset, batch_size=32)
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
return driver1, driver2, dataloader

def test_check_optimizers_legality(self):
"""
测试传入optimizer list的表现
"""
optimizers = [
paddle.optimizer.Adam(
parameters=self.driver.model.parameters(),
learning_rate=0.01
) for i in range(10)
]
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_model(prepare_test_save_load, only_state_dict):
"""
测试 save_model 和 load_model 函数
"""
try:
path = "model"
driver1, driver2, dataloader = prepare_test_save_load

self.driver.set_optimizers(optimizers)
if only_state_dict:
driver1.save_model(path, only_state_dict)
else:
driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))])
driver2.load_model(path, only_state_dict)

optimizers += [
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01)
]
for batch in dataloader:
batch = driver1.move_data_to_device(batch)
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)

with pytest.raises(ValueError):
self.driver.set_optimizers(optimizers)
assert paddle.equal_all(res1["pred"], res2["pred"])
finally:
if only_state_dict:
synchronize_safe_rm(path)
else:
synchronize_safe_rm(path + ".pdiparams")
synchronize_safe_rm(path + ".pdiparams.info")
synchronize_safe_rm(path + ".pdmodel")

def test_check_dataloader_legality_in_train(self):
"""
测试is_train参数为True时,_check_dataloader_legality函数的表现
"""
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randombatchsampler(only_state_dict):
"""
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
"""

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
try:
path = "model.ckp"

# 创建torch的dataloader
dataloader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
)
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
num_consumed_batches = 2

def test_check_dataloader_legality_in_test(self):
"""
测试is_train参数为False时,_check_dataloader_legality函数的表现
"""
# 此时传入的应该是dict
dataloader = {
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset())
}
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

# batch_size 和 batch_sampler 均为 None 的情形
dataloader = {
"train": paddle.io.DataLoader(PaddleNormalDataset()),
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None)
}
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

# 传入的不是dict,应该报错
dataloader = paddle.io.DataLoader(PaddleNormalDataset())
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

# 创建torch的dataloader
train_loader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
TorchNormalDataset(),
batch_size=32, shuffle=True
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
)
dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)

def test_tensor_to_numeric(self):
"""
测试tensor_to_numeric函数
"""
# 单个张量
tensor = paddle.to_tensor(3)
res = PaddleSingleDriver.tensor_to_numeric(tensor)
assert res == 3

tensor = paddle.rand((3, 4))
res = PaddleSingleDriver.tensor_to_numeric(tensor)
assert res == tensor.tolist()

# 张量list
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)]
res = PaddleSingleDriver.tensor_to_numeric(tensor_list)
assert isinstance(res, list)
tensor_list = [t.tolist() for t in tensor_list]
assert res == tensor_list

# 张量tuple
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)])
res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple)
assert isinstance(res, tuple)
tensor_tuple = tuple([t.tolist() for t in tensor_tuple])
assert res == tensor_tuple

# 张量dict
tensor_dict = {
"tensor": paddle.rand((3, 4)),
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"dict":{
"list": [paddle.rand((6, 4, 2)) for i in range(10)],
"tensor": paddle.rand((3, 4))
},
"int": 2,
"string": "test string"
}

res = PaddleSingleDriver.tensor_to_numeric(tensor_dict)
assert isinstance(res, dict)
assert res["tensor"] == tensor_dict["tensor"].tolist()
assert isinstance(res["list"], list)
for r, d in zip(res["list"], tensor_dict["list"]):
assert r == d.tolist()
assert isinstance(res["int"], int)
assert isinstance(res["string"], str)
assert isinstance(res["dict"], dict)
assert isinstance(res["dict"]["list"], list)
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]):
assert r == d.tolist()
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

def test_set_model_mode(self):
"""
测试set_model_mode函数
"""
self.driver.set_model_mode("train")
assert self.driver.model.training
self.driver.set_model_mode("eval")
assert not self.driver.model.training
# 应该报错
with pytest.raises(AssertionError):
self.driver.set_model_mode("test")
# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4

def test_move_model_to_device_cpu(self):
"""
测试move_model_to_device函数
"""
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu")
assert self.driver.model.linear1.weight.place.is_cpu_place()
# 3. 检查 model 的参数是否正确
# 4. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

def test_move_model_to_device_gpu(self):
"""
测试move_model_to_device函数
"""
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu")
assert self.driver.model.linear1.weight.place.is_gpu_place()
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0
left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])

def test_worker_init_function(self):
"""
测试worker_init_function
"""
# 先确保不影响运行
# TODO:正确性
PaddleSingleDriver.worker_init_function(0)
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
synchronize_safe_rm(path)

def test_set_deterministic_dataloader(self):
"""
测试set_deterministic_dataloader
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleNormalDataset())
self.driver.set_deterministic_dataloader(dataloader)
@pytest.mark.parametrize("only_state_dict", ([True, False]))
def test_save_and_load_with_randomsampler(only_state_dict):
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""

def test_set_sampler_epoch(self):
"""
测试set_sampler_epoch
"""
# 先确保不影响运行
# TODO:正确性
dataloader = DataLoader(PaddleNormalDataset())
self.driver.set_sampler_epoch(dataloader, 0)
try:
path = "model.ckp"

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args(self, batch_size, shuffle, drop_last):
"""
测试正常情况下 get_dataloader_args 的表现
"""
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
PaddleNormalDataset(),
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
dataset,
batch_sampler=batch_sampler
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)
num_consumed_batches = 2

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, BatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
assert isinstance(res.sampler, paddle.io.SequenceSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last
already_seen_x_set = set()
already_seen_y_set = set()
for idx, batch in enumerate(dataloader):
if idx >= num_consumed_batches:
break
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last):
"""
测试替换了 batch_sampler 后 get_dataloader_args 的表现
"""
dataset = PaddleNormalDataset()
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
batch_sampler = BatchSampler(dataset=dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(dataset, True)
dataloader = DataLoader(
dataset,
batch_sampler=RandomBatchSampler(
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle),
batch_size,
drop_last,
)
batch_sampler=batch_sampler
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler)
else:
assert isinstance(res.sampler, paddle.io.SequenceSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空

@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last):
"""
测试替换了 sampler 后 get_dataloader_args 的表现
"""
dataset = PaddleNormalDataset()
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last)
batch_sampler.sampler = RandomSampler(dataset, shuffle)
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler,
)
res = PaddleSingleDriver.get_dataloader_args(dataloader)
# 2. 检查 sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]

assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, BatchSampler)
assert isinstance(res.sampler, RandomSampler)
assert res.shuffle == shuffle
assert res.batch_size == batch_size
assert res.drop_last == drop_last
# 3. 检查 model 的参数是否正确
# 4. 检查 batch_idx
start_batch = load_states.pop('batch_idx_in_epoch')
assert start_batch == 2 * num_consumed_batches
left_x_batches = set()
left_y_batches = set()
for idx, batch in enumerate(replaced_loader):

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
finally:
synchronize_safe_rm(path)

Loading…
Cancel
Save