|
|
@@ -1,3 +1,4 @@ |
|
|
|
from dataclasses import replace |
|
|
|
import os |
|
|
|
from re import S |
|
|
|
os.environ["FASTNLP_BACKEND"] = "paddle" |
|
|
@@ -16,203 +17,303 @@ import paddle |
|
|
|
from paddle.io import DataLoader, BatchSampler |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
############################################################################ |
|
|
|
# |
|
|
|
# 测试save和load相关的功能 |
|
|
|
# 测试基类 PaddleDrvier 中的一些简单函数 |
|
|
|
# |
|
|
|
############################################################################ |
|
|
|
|
|
|
|
def generate_random_driver(features, labels): |
|
|
|
""" |
|
|
|
生成driver |
|
|
|
""" |
|
|
|
model = PaddleNormalModel_Classification_1(labels, features) |
|
|
|
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) |
|
|
|
driver = PaddleSingleDriver(model, device="cpu") |
|
|
|
driver.set_optimizers(opt) |
|
|
|
driver.setup() |
|
|
|
|
|
|
|
return driver |
|
|
|
|
|
|
|
@pytest.fixture |
|
|
|
def prepare_test_save_load(): |
|
|
|
dataset = PaddleRandomMaxDataset(320, 10) |
|
|
|
dataloader = DataLoader(dataset, batch_size=32) |
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
return driver1, driver2, dataloader |
|
|
|
|
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
|
def test_save_and_load_with_randombatchsampler(only_state_dict): |
|
|
|
class TestPaddleDriverFunctions: |
|
|
|
""" |
|
|
|
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 |
|
|
|
使用 PaddleSingleDriver 测试基类的函数 |
|
|
|
""" |
|
|
|
|
|
|
|
try: |
|
|
|
path = "model.ckp" |
|
|
|
@classmethod |
|
|
|
def setup_class(self): |
|
|
|
model = PaddleNormalModel_Classification_1(10, 32) |
|
|
|
self.driver = PaddleSingleDriver(model, device="cpu") |
|
|
|
|
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
dataset = PaddleRandomMaxDataset(80, 10) |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset=dataset, |
|
|
|
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) |
|
|
|
def test_check_single_optimizer_legality(self): |
|
|
|
""" |
|
|
|
测试传入单个optimizer时的表现 |
|
|
|
""" |
|
|
|
optimizer = paddle.optimizer.Adam( |
|
|
|
parameters=self.driver.model.parameters(), |
|
|
|
learning_rate=0.01 |
|
|
|
) |
|
|
|
num_consumed_batches = 2 |
|
|
|
|
|
|
|
# TODO 断点重训完善后在这里迭代几次 |
|
|
|
already_seen_set = set() |
|
|
|
for idx, batch in enumerate(dataloader): |
|
|
|
if idx >= num_consumed_batches: |
|
|
|
break |
|
|
|
already_seen_set.update(batch) |
|
|
|
self.driver.set_optimizers(optimizer) |
|
|
|
|
|
|
|
sampler_states = dataloader.batch_sampler.state_dict() |
|
|
|
save_states = {"num_consumed_batches": num_consumed_batches} |
|
|
|
if only_state_dict: |
|
|
|
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) |
|
|
|
else: |
|
|
|
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) |
|
|
|
|
|
|
|
# 加载 |
|
|
|
# 更改 batch_size |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset=dataset, |
|
|
|
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False) |
|
|
|
) |
|
|
|
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) |
|
|
|
replaced_loader = load_states.pop("dataloader") |
|
|
|
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) |
|
|
|
# 传入torch的optimizer时,应该报错ValueError |
|
|
|
with pytest.raises(ValueError): |
|
|
|
self.driver.set_optimizers(optimizer) |
|
|
|
|
|
|
|
# 1. 检查 optimizer 的状态 |
|
|
|
# TODO optimizer 的 state_dict 总是为空 |
|
|
|
def test_check_optimizers_legality(self): |
|
|
|
""" |
|
|
|
测试传入optimizer list的表现 |
|
|
|
""" |
|
|
|
optimizers = [ |
|
|
|
paddle.optimizer.Adam( |
|
|
|
parameters=self.driver.model.parameters(), |
|
|
|
learning_rate=0.01 |
|
|
|
) for i in range(10) |
|
|
|
] |
|
|
|
|
|
|
|
# 2. 检查 batch_sampler 是否被正确地加载和替换 |
|
|
|
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) |
|
|
|
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] |
|
|
|
assert replaced_loader.batch_sampler.data_idx == sampler_states["data_idx"] |
|
|
|
self.driver.set_optimizers(optimizers) |
|
|
|
|
|
|
|
# 3. 检查 model 的参数是否被正确加载 |
|
|
|
for batch in dataloader: |
|
|
|
res1 = driver1.model.evaluate_step(**batch) |
|
|
|
res2 = driver2.model.evaluate_step(**batch) |
|
|
|
optimizers += [ |
|
|
|
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) |
|
|
|
] |
|
|
|
|
|
|
|
assert paddle.equal_all(res1["pred"], res2["pred"]) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
self.driver.set_optimizers(optimizers) |
|
|
|
|
|
|
|
# 4. 检查 batch_idx |
|
|
|
start_batch = load_states.pop('batch_idx_in_epoch') |
|
|
|
assert start_batch == 2 * num_consumed_batches |
|
|
|
left_batches = set() |
|
|
|
for idx, batch in enumerate(replaced_loader): |
|
|
|
left_batches.update(batch) |
|
|
|
def test_check_dataloader_legality_in_train(self): |
|
|
|
""" |
|
|
|
测试is_train参数为True时,_check_dataloader_legality函数的表现 |
|
|
|
""" |
|
|
|
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) |
|
|
|
|
|
|
|
assert len(left_batches) + len(already_seen_set) == len(dataset) |
|
|
|
assert len(left_batches | already_seen_set) == len(dataset) |
|
|
|
# batch_size 和 batch_sampler 均为 None 的情形 |
|
|
|
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) |
|
|
|
|
|
|
|
# 创建torch的dataloader |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
|
TorchNormalDataset(), |
|
|
|
batch_size=32, shuffle=True |
|
|
|
) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) |
|
|
|
|
|
|
|
finally: |
|
|
|
synchronize_safe_rm(path) |
|
|
|
def test_check_dataloader_legality_in_test(self): |
|
|
|
""" |
|
|
|
测试is_train参数为False时,_check_dataloader_legality函数的表现 |
|
|
|
""" |
|
|
|
# 此时传入的应该是dict |
|
|
|
dataloader = { |
|
|
|
"train": paddle.io.DataLoader(PaddleNormalDataset()), |
|
|
|
"test":paddle.io.DataLoader(PaddleNormalDataset()) |
|
|
|
} |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
|
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
|
def test_save_and_load_with_randomsampler(only_state_dict): |
|
|
|
""" |
|
|
|
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 |
|
|
|
""" |
|
|
|
# batch_size 和 batch_sampler 均为 None 的情形 |
|
|
|
dataloader = { |
|
|
|
"train": paddle.io.DataLoader(PaddleNormalDataset()), |
|
|
|
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) |
|
|
|
} |
|
|
|
with pytest.raises(ValueError): |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
|
|
|
|
try: |
|
|
|
path = "model.ckp" |
|
|
|
# 传入的不是dict,应该报错 |
|
|
|
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
|
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
dataset = PaddleRandomMaxDataset(80, 10) |
|
|
|
batch_sampler = BatchSampler(dataset=dataset, batch_size=2) |
|
|
|
batch_sampler.sampler = RandomSampler(dataset, True) |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
batch_sampler=batch_sampler |
|
|
|
# 创建torch的dataloader |
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
|
|
TorchNormalDataset(), |
|
|
|
batch_size=32, shuffle=True |
|
|
|
) |
|
|
|
num_consumed_batches = 2 |
|
|
|
|
|
|
|
# TODO 断点重训完善后在这里迭代几次 |
|
|
|
already_seen_set = set() |
|
|
|
for idx, batch in enumerate(dataloader): |
|
|
|
if idx >= num_consumed_batches: |
|
|
|
break |
|
|
|
already_seen_set.update(batch) |
|
|
|
|
|
|
|
sampler_states = dataloader.batch_sampler.sampler.state_dict() |
|
|
|
save_states = {"num_consumed_batches": num_consumed_batches} |
|
|
|
if only_state_dict: |
|
|
|
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) |
|
|
|
else: |
|
|
|
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) |
|
|
|
|
|
|
|
# 加载 |
|
|
|
# 更改 batch_size |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset=dataset, |
|
|
|
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2), 2, False) |
|
|
|
test_loader = torch.utils.data.DataLoader( |
|
|
|
TorchNormalDataset(), |
|
|
|
batch_size=32, shuffle=True |
|
|
|
) |
|
|
|
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) |
|
|
|
replaced_loader = load_states.pop("dataloader") |
|
|
|
dataloader = {"train": train_loader, "test": test_loader} |
|
|
|
with pytest.raises(ValueError): |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
|
|
|
|
# 1. 检查 optimizer 的状态 |
|
|
|
# TODO optimizer 的 state_dict 总是为空 |
|
|
|
def test_tensor_to_numeric(self): |
|
|
|
""" |
|
|
|
测试tensor_to_numeric函数 |
|
|
|
""" |
|
|
|
# 单个张量 |
|
|
|
tensor = paddle.to_tensor(3) |
|
|
|
res = PaddleSingleDriver.tensor_to_numeric(tensor) |
|
|
|
assert res == 3 |
|
|
|
|
|
|
|
# 2. 检查 sampler 是否被正确地加载和替换 |
|
|
|
replaced_loader = load_states["dataloader"] |
|
|
|
tensor = paddle.rand((3, 4)) |
|
|
|
res = PaddleSingleDriver.tensor_to_numeric(tensor) |
|
|
|
assert res == tensor.tolist() |
|
|
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
|
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] |
|
|
|
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] |
|
|
|
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == sampler_states["num_consumed_samples"] |
|
|
|
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] |
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] |
|
|
|
# 张量list |
|
|
|
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] |
|
|
|
res = PaddleSingleDriver.tensor_to_numeric(tensor_list) |
|
|
|
assert isinstance(res, list) |
|
|
|
tensor_list = [t.tolist() for t in tensor_list] |
|
|
|
assert res == tensor_list |
|
|
|
|
|
|
|
# 3. 检查 model 的参数是否被正确加载 |
|
|
|
for batch in dataloader: |
|
|
|
res1 = driver1.model.evaluate_step(**batch) |
|
|
|
res2 = driver2.model.evaluate_step(**batch) |
|
|
|
# 张量tuple |
|
|
|
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) |
|
|
|
res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple) |
|
|
|
assert isinstance(res, tuple) |
|
|
|
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) |
|
|
|
assert res == tensor_tuple |
|
|
|
|
|
|
|
assert paddle.equal_all(res1["pred"], res2["pred"]) |
|
|
|
# 张量dict |
|
|
|
tensor_dict = { |
|
|
|
"tensor": paddle.rand((3, 4)), |
|
|
|
"list": [paddle.rand((6, 4, 2)) for i in range(10)], |
|
|
|
"dict":{ |
|
|
|
"list": [paddle.rand((6, 4, 2)) for i in range(10)], |
|
|
|
"tensor": paddle.rand((3, 4)) |
|
|
|
}, |
|
|
|
"int": 2, |
|
|
|
"string": "test string" |
|
|
|
} |
|
|
|
|
|
|
|
# 4. 检查 batch_idx |
|
|
|
start_batch = load_states.pop('batch_idx_in_epoch') |
|
|
|
assert start_batch == 2 * num_consumed_batches |
|
|
|
left_batches = set() |
|
|
|
for idx, batch in enumerate(replaced_loader): |
|
|
|
left_batches.update(batch) |
|
|
|
res = PaddleSingleDriver.tensor_to_numeric(tensor_dict) |
|
|
|
assert isinstance(res, dict) |
|
|
|
assert res["tensor"] == tensor_dict["tensor"].tolist() |
|
|
|
assert isinstance(res["list"], list) |
|
|
|
for r, d in zip(res["list"], tensor_dict["list"]): |
|
|
|
assert r == d.tolist() |
|
|
|
assert isinstance(res["int"], int) |
|
|
|
assert isinstance(res["string"], str) |
|
|
|
assert isinstance(res["dict"], dict) |
|
|
|
assert isinstance(res["dict"]["list"], list) |
|
|
|
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): |
|
|
|
assert r == d.tolist() |
|
|
|
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() |
|
|
|
|
|
|
|
assert len(left_batches) + len(already_seen_set) == len(dataset) |
|
|
|
assert len(left_batches | already_seen_set) == len(dataset) |
|
|
|
finally: |
|
|
|
synchronize_safe_rm(path) |
|
|
|
def test_set_model_mode(self): |
|
|
|
""" |
|
|
|
测试set_model_mode函数 |
|
|
|
""" |
|
|
|
self.driver.set_model_mode("train") |
|
|
|
assert self.driver.model.training |
|
|
|
self.driver.set_model_mode("eval") |
|
|
|
assert not self.driver.model.training |
|
|
|
# 应该报错 |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
self.driver.set_model_mode("test") |
|
|
|
|
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
|
def test_save_and_load_model(prepare_test_save_load, only_state_dict): |
|
|
|
""" |
|
|
|
测试 save_model 和 load_model 函数 |
|
|
|
""" |
|
|
|
try: |
|
|
|
path = "model" |
|
|
|
driver1, driver2, dataloader = prepare_test_save_load |
|
|
|
def test_move_model_to_device_cpu(self): |
|
|
|
""" |
|
|
|
测试move_model_to_device函数 |
|
|
|
""" |
|
|
|
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") |
|
|
|
assert self.driver.model.linear1.weight.place.is_cpu_place() |
|
|
|
|
|
|
|
if only_state_dict: |
|
|
|
driver1.save_model(path, only_state_dict) |
|
|
|
def test_move_model_to_device_gpu(self): |
|
|
|
""" |
|
|
|
测试move_model_to_device函数 |
|
|
|
""" |
|
|
|
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") |
|
|
|
assert self.driver.model.linear1.weight.place.is_gpu_place() |
|
|
|
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0 |
|
|
|
|
|
|
|
def test_worker_init_function(self): |
|
|
|
""" |
|
|
|
测试worker_init_function |
|
|
|
""" |
|
|
|
# 先确保不影响运行 |
|
|
|
# TODO:正确性 |
|
|
|
PaddleSingleDriver.worker_init_function(0) |
|
|
|
|
|
|
|
def test_set_deterministic_dataloader(self): |
|
|
|
""" |
|
|
|
测试set_deterministic_dataloader |
|
|
|
""" |
|
|
|
# 先确保不影响运行 |
|
|
|
# TODO:正确性 |
|
|
|
dataloader = DataLoader(PaddleNormalDataset()) |
|
|
|
self.driver.set_deterministic_dataloader(dataloader) |
|
|
|
|
|
|
|
def test_set_sampler_epoch(self): |
|
|
|
""" |
|
|
|
测试set_sampler_epoch |
|
|
|
""" |
|
|
|
# 先确保不影响运行 |
|
|
|
# TODO:正确性 |
|
|
|
dataloader = DataLoader(PaddleNormalDataset()) |
|
|
|
self.driver.set_sampler_epoch(dataloader, 0) |
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16]) |
|
|
|
@pytest.mark.parametrize("shuffle", [True, False]) |
|
|
|
@pytest.mark.parametrize("drop_last", [True, False]) |
|
|
|
def test_get_dataloader_args(self, batch_size, shuffle, drop_last): |
|
|
|
""" |
|
|
|
测试正常情况下 get_dataloader_args 的表现 |
|
|
|
""" |
|
|
|
dataloader = DataLoader( |
|
|
|
PaddleNormalDataset(), |
|
|
|
batch_size=batch_size, |
|
|
|
shuffle=shuffle, |
|
|
|
drop_last=drop_last, |
|
|
|
) |
|
|
|
res = PaddleSingleDriver.get_dataloader_args(dataloader) |
|
|
|
|
|
|
|
assert isinstance(res.dataset, PaddleNormalDataset) |
|
|
|
assert isinstance(res.batch_sampler, BatchSampler) |
|
|
|
if shuffle: |
|
|
|
assert isinstance(res.sampler, paddle.io.RandomSampler) |
|
|
|
else: |
|
|
|
driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))]) |
|
|
|
driver2.load_model(path, only_state_dict) |
|
|
|
assert isinstance(res.sampler, paddle.io.SequenceSampler) |
|
|
|
assert res.shuffle == shuffle |
|
|
|
assert res.batch_size == batch_size |
|
|
|
assert res.drop_last == drop_last |
|
|
|
|
|
|
|
for batch in dataloader: |
|
|
|
batch = driver1.move_data_to_device(batch) |
|
|
|
res1 = driver1.model.evaluate_step(**batch) |
|
|
|
res2 = driver2.model.evaluate_step(**batch) |
|
|
|
@pytest.mark.parametrize("batch_size", [16]) |
|
|
|
@pytest.mark.parametrize("shuffle", [True, False]) |
|
|
|
@pytest.mark.parametrize("drop_last", [True, False]) |
|
|
|
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): |
|
|
|
""" |
|
|
|
测试替换了 batch_sampler 后 get_dataloader_args 的表现 |
|
|
|
""" |
|
|
|
dataset = PaddleNormalDataset() |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
batch_sampler=RandomBatchSampler( |
|
|
|
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), |
|
|
|
batch_size, |
|
|
|
drop_last, |
|
|
|
) |
|
|
|
) |
|
|
|
res = PaddleSingleDriver.get_dataloader_args(dataloader) |
|
|
|
|
|
|
|
assert paddle.equal_all(res1["pred"], res2["pred"]) |
|
|
|
finally: |
|
|
|
if only_state_dict: |
|
|
|
synchronize_safe_rm(path) |
|
|
|
assert isinstance(res.dataset, PaddleNormalDataset) |
|
|
|
assert isinstance(res.batch_sampler, RandomBatchSampler) |
|
|
|
if shuffle: |
|
|
|
assert isinstance(res.sampler, paddle.io.RandomSampler) |
|
|
|
else: |
|
|
|
synchronize_safe_rm(path + ".pdiparams") |
|
|
|
synchronize_safe_rm(path + ".pdiparams.info") |
|
|
|
synchronize_safe_rm(path + ".pdmodel") |
|
|
|
assert isinstance(res.sampler, paddle.io.SequenceSampler) |
|
|
|
assert res.shuffle == shuffle |
|
|
|
assert res.batch_size == batch_size |
|
|
|
assert res.drop_last == drop_last |
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16]) |
|
|
|
@pytest.mark.parametrize("shuffle", [True, False]) |
|
|
|
@pytest.mark.parametrize("drop_last", [True, False]) |
|
|
|
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): |
|
|
|
""" |
|
|
|
测试替换了 sampler 后 get_dataloader_args 的表现 |
|
|
|
""" |
|
|
|
dataset = PaddleNormalDataset() |
|
|
|
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last) |
|
|
|
batch_sampler.sampler = RandomSampler(dataset, shuffle) |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
batch_sampler=batch_sampler, |
|
|
|
) |
|
|
|
res = PaddleSingleDriver.get_dataloader_args(dataloader) |
|
|
|
|
|
|
|
assert isinstance(res.dataset, PaddleNormalDataset) |
|
|
|
assert isinstance(res.batch_sampler, BatchSampler) |
|
|
|
assert isinstance(res.sampler, RandomSampler) |
|
|
|
assert res.shuffle == shuffle |
|
|
|
assert res.batch_size == batch_size |
|
|
|
assert res.drop_last == drop_last |
|
|
|
|
|
|
|
|
|
|
|
############################################################################ |
|
|
|
# |
|
|
|
# 测试 PaddleSingleDrvier 中的一些简单函数 |
|
|
|
# |
|
|
|
############################################################################ |
|
|
|
|
|
|
|
class TestSingleDeviceFunction: |
|
|
|
""" |
|
|
@@ -242,6 +343,12 @@ class TestSingleDeviceFunction: |
|
|
|
self.driver.move_data_to_device(paddle.rand((32, 64))) |
|
|
|
|
|
|
|
|
|
|
|
############################################################################ |
|
|
|
# |
|
|
|
# 测试 set_dist_repro_dataloader 函数 |
|
|
|
# |
|
|
|
############################################################################ |
|
|
|
|
|
|
|
class TestSetDistReproDataloder: |
|
|
|
""" |
|
|
|
专门测试 set_dist_repro_dataloader 函数的类 |
|
|
@@ -423,287 +530,204 @@ class TestSetDistReproDataloder: |
|
|
|
assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) |
|
|
|
assert len(left_idxes | already_seen_idx) == len(self.dataset) |
|
|
|
|
|
|
|
class TestPaddleDriverFunctions: |
|
|
|
############################################################################ |
|
|
|
# |
|
|
|
# 测试 save 和 load 相关的功能 |
|
|
|
# |
|
|
|
############################################################################ |
|
|
|
|
|
|
|
def generate_random_driver(features, labels): |
|
|
|
""" |
|
|
|
使用 PaddleSingleDriver 测试基类的函数 |
|
|
|
生成driver |
|
|
|
""" |
|
|
|
model = PaddleNormalModel_Classification_1(labels, features) |
|
|
|
opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.01) |
|
|
|
driver = PaddleSingleDriver(model, device="cpu") |
|
|
|
driver.set_optimizers(opt) |
|
|
|
driver.setup() |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def setup_class(self): |
|
|
|
model = PaddleNormalModel_Classification_1(10, 32) |
|
|
|
self.driver = PaddleSingleDriver(model, device="cpu") |
|
|
|
|
|
|
|
def test_check_single_optimizer_legality(self): |
|
|
|
""" |
|
|
|
测试传入单个optimizer时的表现 |
|
|
|
""" |
|
|
|
optimizer = paddle.optimizer.Adam( |
|
|
|
parameters=self.driver.model.parameters(), |
|
|
|
learning_rate=0.01 |
|
|
|
) |
|
|
|
|
|
|
|
self.driver.set_optimizers(optimizer) |
|
|
|
return driver |
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) |
|
|
|
# 传入torch的optimizer时,应该报错ValueError |
|
|
|
with pytest.raises(ValueError): |
|
|
|
self.driver.set_optimizers(optimizer) |
|
|
|
@pytest.fixture |
|
|
|
def prepare_test_save_load(): |
|
|
|
dataset = PaddleRandomMaxDataset(320, 10) |
|
|
|
dataloader = DataLoader(dataset, batch_size=32) |
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
return driver1, driver2, dataloader |
|
|
|
|
|
|
|
def test_check_optimizers_legality(self): |
|
|
|
""" |
|
|
|
测试传入optimizer list的表现 |
|
|
|
""" |
|
|
|
optimizers = [ |
|
|
|
paddle.optimizer.Adam( |
|
|
|
parameters=self.driver.model.parameters(), |
|
|
|
learning_rate=0.01 |
|
|
|
) for i in range(10) |
|
|
|
] |
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
|
def test_save_and_load_model(prepare_test_save_load, only_state_dict): |
|
|
|
""" |
|
|
|
测试 save_model 和 load_model 函数 |
|
|
|
""" |
|
|
|
try: |
|
|
|
path = "model" |
|
|
|
driver1, driver2, dataloader = prepare_test_save_load |
|
|
|
|
|
|
|
self.driver.set_optimizers(optimizers) |
|
|
|
if only_state_dict: |
|
|
|
driver1.save_model(path, only_state_dict) |
|
|
|
else: |
|
|
|
driver1.save_model(path, only_state_dict, input_spec=[paddle.ones((32, 10))]) |
|
|
|
driver2.load_model(path, only_state_dict) |
|
|
|
|
|
|
|
optimizers += [ |
|
|
|
torch.optim.Adam(TorchNormalModel_Classification_1(10, 32).parameters(), 0.01) |
|
|
|
] |
|
|
|
for batch in dataloader: |
|
|
|
batch = driver1.move_data_to_device(batch) |
|
|
|
res1 = driver1.model.evaluate_step(**batch) |
|
|
|
res2 = driver2.model.evaluate_step(**batch) |
|
|
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
|
self.driver.set_optimizers(optimizers) |
|
|
|
assert paddle.equal_all(res1["pred"], res2["pred"]) |
|
|
|
finally: |
|
|
|
if only_state_dict: |
|
|
|
synchronize_safe_rm(path) |
|
|
|
else: |
|
|
|
synchronize_safe_rm(path + ".pdiparams") |
|
|
|
synchronize_safe_rm(path + ".pdiparams.info") |
|
|
|
synchronize_safe_rm(path + ".pdmodel") |
|
|
|
|
|
|
|
def test_check_dataloader_legality_in_train(self): |
|
|
|
""" |
|
|
|
测试is_train参数为True时,_check_dataloader_legality函数的表现 |
|
|
|
""" |
|
|
|
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) |
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
|
def test_save_and_load_with_randombatchsampler(only_state_dict): |
|
|
|
""" |
|
|
|
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 |
|
|
|
""" |
|
|
|
|
|
|
|
# batch_size 和 batch_sampler 均为 None 的情形 |
|
|
|
dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) |
|
|
|
try: |
|
|
|
path = "model.ckp" |
|
|
|
|
|
|
|
# 创建torch的dataloader |
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
|
TorchNormalDataset(), |
|
|
|
batch_size=32, shuffle=True |
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
dataset = PaddleRandomMaxDataset(40, 10) |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset=dataset, |
|
|
|
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) |
|
|
|
) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) |
|
|
|
num_consumed_batches = 2 |
|
|
|
|
|
|
|
def test_check_dataloader_legality_in_test(self): |
|
|
|
""" |
|
|
|
测试is_train参数为False时,_check_dataloader_legality函数的表现 |
|
|
|
""" |
|
|
|
# 此时传入的应该是dict |
|
|
|
dataloader = { |
|
|
|
"train": paddle.io.DataLoader(PaddleNormalDataset()), |
|
|
|
"test":paddle.io.DataLoader(PaddleNormalDataset()) |
|
|
|
} |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
|
|
|
|
# batch_size 和 batch_sampler 均为 None 的情形 |
|
|
|
dataloader = { |
|
|
|
"train": paddle.io.DataLoader(PaddleNormalDataset()), |
|
|
|
"test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) |
|
|
|
} |
|
|
|
with pytest.raises(ValueError): |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
|
|
|
|
# 传入的不是dict,应该报错 |
|
|
|
dataloader = paddle.io.DataLoader(PaddleNormalDataset()) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
already_seen_x_set = set() |
|
|
|
already_seen_y_set = set() |
|
|
|
for idx, batch in enumerate(dataloader): |
|
|
|
if idx >= num_consumed_batches: |
|
|
|
break |
|
|
|
already_seen_x_set.update(batch["x"]) |
|
|
|
already_seen_y_set.update(batch["y"]) |
|
|
|
|
|
|
|
# 创建torch的dataloader |
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
|
|
TorchNormalDataset(), |
|
|
|
batch_size=32, shuffle=True |
|
|
|
) |
|
|
|
test_loader = torch.utils.data.DataLoader( |
|
|
|
TorchNormalDataset(), |
|
|
|
batch_size=32, shuffle=True |
|
|
|
sampler_states = dataloader.batch_sampler.state_dict() |
|
|
|
save_states = {"num_consumed_batches": num_consumed_batches} |
|
|
|
if only_state_dict: |
|
|
|
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) |
|
|
|
else: |
|
|
|
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) |
|
|
|
# 加载 |
|
|
|
# 更改 batch_size |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset=dataset, |
|
|
|
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) |
|
|
|
) |
|
|
|
dataloader = {"train": train_loader, "test": test_loader} |
|
|
|
with pytest.raises(ValueError): |
|
|
|
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) |
|
|
|
|
|
|
|
def test_tensor_to_numeric(self): |
|
|
|
""" |
|
|
|
测试tensor_to_numeric函数 |
|
|
|
""" |
|
|
|
# 单个张量 |
|
|
|
tensor = paddle.to_tensor(3) |
|
|
|
res = PaddleSingleDriver.tensor_to_numeric(tensor) |
|
|
|
assert res == 3 |
|
|
|
|
|
|
|
tensor = paddle.rand((3, 4)) |
|
|
|
res = PaddleSingleDriver.tensor_to_numeric(tensor) |
|
|
|
assert res == tensor.tolist() |
|
|
|
|
|
|
|
# 张量list |
|
|
|
tensor_list = [paddle.rand((6, 4, 2)) for i in range(10)] |
|
|
|
res = PaddleSingleDriver.tensor_to_numeric(tensor_list) |
|
|
|
assert isinstance(res, list) |
|
|
|
tensor_list = [t.tolist() for t in tensor_list] |
|
|
|
assert res == tensor_list |
|
|
|
|
|
|
|
# 张量tuple |
|
|
|
tensor_tuple = tuple([paddle.rand((6, 4, 2)) for i in range(10)]) |
|
|
|
res = PaddleSingleDriver.tensor_to_numeric(tensor_tuple) |
|
|
|
assert isinstance(res, tuple) |
|
|
|
tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) |
|
|
|
assert res == tensor_tuple |
|
|
|
|
|
|
|
# 张量dict |
|
|
|
tensor_dict = { |
|
|
|
"tensor": paddle.rand((3, 4)), |
|
|
|
"list": [paddle.rand((6, 4, 2)) for i in range(10)], |
|
|
|
"dict":{ |
|
|
|
"list": [paddle.rand((6, 4, 2)) for i in range(10)], |
|
|
|
"tensor": paddle.rand((3, 4)) |
|
|
|
}, |
|
|
|
"int": 2, |
|
|
|
"string": "test string" |
|
|
|
} |
|
|
|
|
|
|
|
res = PaddleSingleDriver.tensor_to_numeric(tensor_dict) |
|
|
|
assert isinstance(res, dict) |
|
|
|
assert res["tensor"] == tensor_dict["tensor"].tolist() |
|
|
|
assert isinstance(res["list"], list) |
|
|
|
for r, d in zip(res["list"], tensor_dict["list"]): |
|
|
|
assert r == d.tolist() |
|
|
|
assert isinstance(res["int"], int) |
|
|
|
assert isinstance(res["string"], str) |
|
|
|
assert isinstance(res["dict"], dict) |
|
|
|
assert isinstance(res["dict"]["list"], list) |
|
|
|
for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): |
|
|
|
assert r == d.tolist() |
|
|
|
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() |
|
|
|
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) |
|
|
|
replaced_loader = load_states.pop("dataloader") |
|
|
|
# 1. 检查 optimizer 的状态 |
|
|
|
# TODO optimizer 的 state_dict 总是为空 |
|
|
|
|
|
|
|
def test_set_model_mode(self): |
|
|
|
""" |
|
|
|
测试set_model_mode函数 |
|
|
|
""" |
|
|
|
self.driver.set_model_mode("train") |
|
|
|
assert self.driver.model.training |
|
|
|
self.driver.set_model_mode("eval") |
|
|
|
assert not self.driver.model.training |
|
|
|
# 应该报错 |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
self.driver.set_model_mode("test") |
|
|
|
# 2. 检查 batch_sampler 是否被正确地加载和替换 |
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
assert replaced_loader.batch_sampler is dataloader.batch_sampler |
|
|
|
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) |
|
|
|
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] |
|
|
|
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 |
|
|
|
|
|
|
|
def test_move_model_to_device_cpu(self): |
|
|
|
""" |
|
|
|
测试move_model_to_device函数 |
|
|
|
""" |
|
|
|
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") |
|
|
|
assert self.driver.model.linear1.weight.place.is_cpu_place() |
|
|
|
# 3. 检查 model 的参数是否正确 |
|
|
|
# 4. 检查 batch_idx |
|
|
|
start_batch = load_states.pop('batch_idx_in_epoch') |
|
|
|
assert start_batch == 2 * num_consumed_batches |
|
|
|
left_x_batches = set() |
|
|
|
left_y_batches = set() |
|
|
|
for idx, batch in enumerate(replaced_loader): |
|
|
|
|
|
|
|
def test_move_model_to_device_gpu(self): |
|
|
|
""" |
|
|
|
测试move_model_to_device函数 |
|
|
|
""" |
|
|
|
PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") |
|
|
|
assert self.driver.model.linear1.weight.place.is_gpu_place() |
|
|
|
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0 |
|
|
|
left_x_batches.update(batch["x"]) |
|
|
|
left_y_batches.update(batch["y"]) |
|
|
|
res1 = driver1.model.evaluate_step(**batch) |
|
|
|
res2 = driver2.model.evaluate_step(**batch) |
|
|
|
assert paddle.equal_all(res1["pred"], res2["pred"]) |
|
|
|
|
|
|
|
def test_worker_init_function(self): |
|
|
|
""" |
|
|
|
测试worker_init_function |
|
|
|
""" |
|
|
|
# 先确保不影响运行 |
|
|
|
# TODO:正确性 |
|
|
|
PaddleSingleDriver.worker_init_function(0) |
|
|
|
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) |
|
|
|
assert len(left_x_batches | already_seen_x_set) == len(dataset) |
|
|
|
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) |
|
|
|
assert len(left_y_batches | already_seen_y_set) == len(dataset) |
|
|
|
finally: |
|
|
|
synchronize_safe_rm(path) |
|
|
|
|
|
|
|
def test_set_deterministic_dataloader(self): |
|
|
|
""" |
|
|
|
测试set_deterministic_dataloader |
|
|
|
""" |
|
|
|
# 先确保不影响运行 |
|
|
|
# TODO:正确性 |
|
|
|
dataloader = DataLoader(PaddleNormalDataset()) |
|
|
|
self.driver.set_deterministic_dataloader(dataloader) |
|
|
|
@pytest.mark.parametrize("only_state_dict", ([True, False])) |
|
|
|
def test_save_and_load_with_randomsampler(only_state_dict): |
|
|
|
""" |
|
|
|
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 |
|
|
|
""" |
|
|
|
|
|
|
|
def test_set_sampler_epoch(self): |
|
|
|
""" |
|
|
|
测试set_sampler_epoch |
|
|
|
""" |
|
|
|
# 先确保不影响运行 |
|
|
|
# TODO:正确性 |
|
|
|
dataloader = DataLoader(PaddleNormalDataset()) |
|
|
|
self.driver.set_sampler_epoch(dataloader, 0) |
|
|
|
try: |
|
|
|
path = "model.ckp" |
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16]) |
|
|
|
@pytest.mark.parametrize("shuffle", [True, False]) |
|
|
|
@pytest.mark.parametrize("drop_last", [True, False]) |
|
|
|
def test_get_dataloader_args(self, batch_size, shuffle, drop_last): |
|
|
|
""" |
|
|
|
测试正常情况下 get_dataloader_args 的表现 |
|
|
|
""" |
|
|
|
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) |
|
|
|
dataset = PaddleRandomMaxDataset(40, 10) |
|
|
|
batch_sampler = BatchSampler(dataset=dataset, batch_size=4) |
|
|
|
batch_sampler.sampler = RandomSampler(dataset, True) |
|
|
|
dataloader = DataLoader( |
|
|
|
PaddleNormalDataset(), |
|
|
|
batch_size=batch_size, |
|
|
|
shuffle=shuffle, |
|
|
|
drop_last=drop_last, |
|
|
|
dataset, |
|
|
|
batch_sampler=batch_sampler |
|
|
|
) |
|
|
|
res = PaddleSingleDriver.get_dataloader_args(dataloader) |
|
|
|
num_consumed_batches = 2 |
|
|
|
|
|
|
|
assert isinstance(res.dataset, PaddleNormalDataset) |
|
|
|
assert isinstance(res.batch_sampler, BatchSampler) |
|
|
|
if shuffle: |
|
|
|
assert isinstance(res.sampler, paddle.io.RandomSampler) |
|
|
|
else: |
|
|
|
assert isinstance(res.sampler, paddle.io.SequenceSampler) |
|
|
|
assert res.shuffle == shuffle |
|
|
|
assert res.batch_size == batch_size |
|
|
|
assert res.drop_last == drop_last |
|
|
|
already_seen_x_set = set() |
|
|
|
already_seen_y_set = set() |
|
|
|
for idx, batch in enumerate(dataloader): |
|
|
|
if idx >= num_consumed_batches: |
|
|
|
break |
|
|
|
already_seen_x_set.update(batch["x"]) |
|
|
|
already_seen_y_set.update(batch["y"]) |
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16]) |
|
|
|
@pytest.mark.parametrize("shuffle", [True, False]) |
|
|
|
@pytest.mark.parametrize("drop_last", [True, False]) |
|
|
|
def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): |
|
|
|
""" |
|
|
|
测试替换了 batch_sampler 后 get_dataloader_args 的表现 |
|
|
|
""" |
|
|
|
dataset = PaddleNormalDataset() |
|
|
|
sampler_states = dataloader.batch_sampler.sampler.state_dict() |
|
|
|
save_states = {"num_consumed_batches": num_consumed_batches} |
|
|
|
if only_state_dict: |
|
|
|
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) |
|
|
|
else: |
|
|
|
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) |
|
|
|
|
|
|
|
# 加载 |
|
|
|
# 更改 batch_size |
|
|
|
batch_sampler = BatchSampler(dataset=dataset, batch_size=2) |
|
|
|
batch_sampler.sampler = RandomSampler(dataset, True) |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
batch_sampler=RandomBatchSampler( |
|
|
|
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), |
|
|
|
batch_size, |
|
|
|
drop_last, |
|
|
|
) |
|
|
|
batch_sampler=batch_sampler |
|
|
|
) |
|
|
|
res = PaddleSingleDriver.get_dataloader_args(dataloader) |
|
|
|
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) |
|
|
|
replaced_loader = load_states.pop("dataloader") |
|
|
|
|
|
|
|
assert isinstance(res.dataset, PaddleNormalDataset) |
|
|
|
assert isinstance(res.batch_sampler, RandomBatchSampler) |
|
|
|
if shuffle: |
|
|
|
assert isinstance(res.sampler, paddle.io.RandomSampler) |
|
|
|
else: |
|
|
|
assert isinstance(res.sampler, paddle.io.SequenceSampler) |
|
|
|
assert res.shuffle == shuffle |
|
|
|
assert res.batch_size == batch_size |
|
|
|
assert res.drop_last == drop_last |
|
|
|
# 1. 检查 optimizer 的状态 |
|
|
|
# TODO optimizer 的 state_dict 总是为空 |
|
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16]) |
|
|
|
@pytest.mark.parametrize("shuffle", [True, False]) |
|
|
|
@pytest.mark.parametrize("drop_last", [True, False]) |
|
|
|
def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): |
|
|
|
""" |
|
|
|
测试替换了 sampler 后 get_dataloader_args 的表现 |
|
|
|
""" |
|
|
|
dataset = PaddleNormalDataset() |
|
|
|
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=drop_last) |
|
|
|
batch_sampler.sampler = RandomSampler(dataset, shuffle) |
|
|
|
dataloader = DataLoader( |
|
|
|
dataset, |
|
|
|
batch_sampler=batch_sampler, |
|
|
|
) |
|
|
|
res = PaddleSingleDriver.get_dataloader_args(dataloader) |
|
|
|
# 2. 检查 sampler 是否被正确地加载和替换 |
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
|
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] |
|
|
|
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] |
|
|
|
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches |
|
|
|
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] |
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] |
|
|
|
|
|
|
|
assert isinstance(res.dataset, PaddleNormalDataset) |
|
|
|
assert isinstance(res.batch_sampler, BatchSampler) |
|
|
|
assert isinstance(res.sampler, RandomSampler) |
|
|
|
assert res.shuffle == shuffle |
|
|
|
assert res.batch_size == batch_size |
|
|
|
assert res.drop_last == drop_last |
|
|
|
# 3. 检查 model 的参数是否正确 |
|
|
|
# 4. 检查 batch_idx |
|
|
|
start_batch = load_states.pop('batch_idx_in_epoch') |
|
|
|
assert start_batch == 2 * num_consumed_batches |
|
|
|
left_x_batches = set() |
|
|
|
left_y_batches = set() |
|
|
|
for idx, batch in enumerate(replaced_loader): |
|
|
|
|
|
|
|
left_x_batches.update(batch["x"]) |
|
|
|
left_y_batches.update(batch["y"]) |
|
|
|
res1 = driver1.model.evaluate_step(**batch) |
|
|
|
res2 = driver2.model.evaluate_step(**batch) |
|
|
|
assert paddle.equal_all(res1["pred"], res2["pred"]) |
|
|
|
|
|
|
|
assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) |
|
|
|
assert len(left_x_batches | already_seen_x_set) == len(dataset) |
|
|
|
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) |
|
|
|
assert len(left_y_batches | already_seen_y_set) == len(dataset) |
|
|
|
finally: |
|
|
|
synchronize_safe_rm(path) |