Browse Source

完成了paddle fleet的save load函数测试

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
514415e9d4
1 changed files with 127 additions and 59 deletions
  1. +127
    -59
      tests/core/drivers/paddle_driver/test_fleet.py

+ 127
- 59
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -1,5 +1,6 @@
import pytest
import os
from pathlib import Path

os.environ["FASTNLP_BACKEND"] = "paddle"
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
@@ -33,20 +34,6 @@ def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, out

return driver

@magic_argv_env_context
def test_multi_drivers():
"""
测试使用了多个 PaddleFleetDriver 的情况。
"""
driver1 = generate_driver(10, 10)
driver2 = generate_driver(20, 10)

with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,2])

dist.barrier()

############################################################################
#
# 测试 PaddleFleetDriver 的一些函数
@@ -62,6 +49,19 @@ class TestFleetDriverFunction:
def setup_class(cls):
cls.driver = generate_driver(10, 10)

@magic_argv_env_context
def test_multi_drivers(self):
"""
测试使用了多个 PaddleFleetDriver 的情况。
"""
driver2 = generate_driver(20, 10)

with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,2])

dist.barrier()

@magic_argv_env_context
def test_move_data_to_device(self):
"""
@@ -494,9 +494,14 @@ class TestSaveLoad:
"""
测试多卡情况下 save 和 load 相关函数的表现
"""

@classmethod
def setup_class(cls):
# 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10)

def setup_method(self):
self.dataset = PaddleRandomMaxDataset(20, 10)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@@ -506,7 +511,9 @@ class TestSaveLoad:
"""
try:
path = "model"

dataloader = DataLoader(self.dataset, batch_size=2)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)

if only_state_dict:
self.driver1.save_model(path, only_state_dict)
@@ -545,20 +552,30 @@ class TestSaveLoad:
@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randombatchsampler(self, only_state_dict, fp16):
return
@pytest.mark.parametrize("device", ([[0,1]]))
def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况
"""

try:
path = "model.ckp"
num_replicas = len(device)

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
generate_driver(10, 10, device=device, fp16=False)
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(
self.dataset,
length=[10 for i in range(len(self.dataset))],
batch_size=4,
)
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
pad=True
)
num_consumed_batches = 2

@@ -570,19 +587,32 @@ class TestSaveLoad:
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

# 同步
dist.barrier()

# 保存状态
sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
dataloader = DataLoader(
dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(
self.dataset,
length=[10 for i in range(len(self.dataset))],
batch_size=4,
)
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
pad=True
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空
@@ -590,13 +620,13 @@ class TestSaveLoad:
# 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas

# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
@@ -608,22 +638,33 @@ class TestSaveLoad:

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
# Driver.model -> DataParallel._layers -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model._layers.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas
finally:
synchronize_safe_rm(path)

@magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False]))
def test_save_and_load_with_randomsampler(self, only_state_dict, fp16):
return
@pytest.mark.parametrize("device", ([[0,1]]))
def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16):
"""
测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况
"""
@@ -631,12 +672,19 @@ class TestSaveLoad:
try:
path = "model.ckp"

driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10)
dataset = PaddleRandomMaxDataset(40, 10)
batch_sampler = BatchSampler(dataset=dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(dataset, True)
num_replicas = len(device)

self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
self.driver2 = generate_driver(10, 10, device=device, fp16=False)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=4)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
pad=True
)
dataloader = DataLoader(
dataset,
self.dataset,
batch_sampler=batch_sampler
)
num_consumed_batches = 2
@@ -649,22 +697,30 @@ class TestSaveLoad:
already_seen_x_set.update(batch["x"])
already_seen_y_set.update(batch["y"])

# 同步
dist.barrier()

# 保存状态
sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else:
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))])
# 加载
# 更改 batch_size
batch_sampler = BatchSampler(dataset=dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(dataset, True)
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
pad=True
)
dataloader = DataLoader(
dataset,
self.dataset,
batch_sampler=batch_sampler
)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader")

# 1. 检查 optimizer 的状态
@@ -675,12 +731,12 @@ class TestSaveLoad:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"]
assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"]
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches
assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas
assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"]
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载
if fp16:
assert isinstance(driver2.grad_scaler, paddle.amp.GradScaler)
assert isinstance(self.driver2.grad_scaler, paddle.amp.GradScaler)

# 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx
@@ -692,13 +748,25 @@ class TestSaveLoad:

left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"])
res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch)
res1 = self.driver1.model(
batch,
fastnlp_fn=self.driver1.model._layers.model.evaluate_step,
# Driver.model -> DataParallel._layers -> _FleetWrappingModel.model
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
res2 = self.driver2.model(
batch,
fastnlp_fn=self.driver2.model._layers.model.evaluate_step,
fastnlp_signature_fn=None,
wo_auto_param_call=False,
)
assert paddle.equal_all(res1["pred"], res2["pred"])

assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset)
assert len(left_y_batches) + len(already_seen_y_set) == len(dataset)
assert len(left_y_batches | already_seen_y_set) == len(dataset)
assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas
assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas

finally:
synchronize_safe_rm(path)
synchronize_safe_rm(path)

Loading…
Cancel
Save