Browse Source

多卡 set_dist_repro_dataloader 的测试例

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
f6f489dc90
1 changed files with 345 additions and 260 deletions
  1. +345
    -260
      tests/core/drivers/paddle_driver/test_fleet.py

+ 345
- 260
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -1,21 +1,35 @@
from dataclasses import replace
import pytest import pytest
import os import os
import numpy as np
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle


set_env_on_import_paddle()
os.environ["FASTNLP_BACKEND"] = "paddle"
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
from fastNLP.core.samplers import (
RandomSampler,
UnrepeatedSampler,
BucketedBatchSampler,
UnrepeatedRandomSampler,
UnrepeatedSequentialSampler,
)
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset
from tests.helpers.utils import magic_argv_env_context

import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.io import DataLoader
from paddle.io import DataLoader, BatchSampler


from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm
def generate_driver(num_labels, feature_dimension):
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
driver.setup()


return driver


############################################################################ ############################################################################
# #
@@ -23,269 +37,340 @@ from fastNLP.core import synchronize_safe_rm
# #
############################################################################ ############################################################################


@magic_argv_env_context
def test_move_data_to_device():
"""
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中
就不重复测试了
"""
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.move_data_to_device(paddle.rand((32, 64)))
finally:
synchronize_safe_rm("log")

dist.barrier()


@magic_argv_env_context
def test_is_distributed():
print(os.getenv("CUDA_VISIBLE_DEVICES"))
print(paddle.device.get_device())
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
output_from_new_proc='all'
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
assert driver.is_distributed() == True
finally:
synchronize_safe_rm("log")
dist.barrier()


@magic_argv_env_context
def test_get_no_sync_context():
class TestFleetDriverFunction:
""" """
测试能否运行
测试 PaddleFleetDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
""" """
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
res = driver.get_no_sync_context()
finally:
synchronize_safe_rm("log")
dist.barrier()


@magic_argv_env_context
def test_is_global_zero():
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.is_global_zero()
finally:
synchronize_safe_rm("log")
dist.barrier()



@magic_argv_env_context
def test_unwrap_model():
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.unwrap_model()
finally:
synchronize_safe_rm("log")
dist.barrier()

@magic_argv_env_context
def test_get_local_rank():
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
driver.get_local_rank()
finally:
synchronize_safe_rm("log")
dist.barrier()

@magic_argv_env_context
@pytest.mark.parametrize(
"dist_sampler",
["dist", "unrepeatdist", RandomSampler(PaddleDataset_MNIST("train"))]
)
@pytest.mark.parametrize(
"reproducible",
[True, False]
)
def test_replace_sampler(dist_sampler, reproducible):
"""
测试replace_sampler
"""
try:
paddle_model = PaddleNormalModel_Classification(10, 784)
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01)
driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=[0,1],
)
driver.set_optimizers(paddle_opt)
# 区分launch和子进程setup的时候
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
with pytest.raises(SystemExit) as e:
driver.setup()
assert e.value.code == 0
return
else:
driver.setup()
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True)
driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible)
finally:
synchronize_safe_rm("log")
dist.barrier()

@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)

@magic_argv_env_context
def test_move_data_to_device(self):
"""
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中
就不重复测试了
"""
self.driver.move_data_to_device(paddle.rand((32, 64)))

dist.barrier()

@magic_argv_env_context
def test_is_distributed(self):
"""
测试 is_distributed 函数
"""
assert self.driver.is_distributed() == True
dist.barrier()

@magic_argv_env_context
def test_get_no_sync_context(self):
"""
测试 get_no_sync_context 函数
"""
res = self.driver.get_no_sync_context()
dist.barrier()

@magic_argv_env_context
def test_is_global_zero(self):
"""
测试 is_global_zero 函数
"""
self.driver.is_global_zero()
dist.barrier()

@magic_argv_env_context
def test_unwrap_model(self):
"""
测试 unwrap_model 函数
"""
self.driver.unwrap_model()
dist.barrier()

@magic_argv_env_context
def test_get_local_rank(self):
"""
测试 get_local_rank 函数
"""
self.driver.get_local_rank()
dist.barrier()


############################################################################ ############################################################################
# #
# 测试单机多卡的训练情况
# 测试 set_dist_repro_dataloader 函数
# #
############################################################################ ############################################################################


@magic_argv_env_context
class SingleMachineMultiGPUTrainingTestCase:
class TestSetDistReproDataloader:

@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)

def setup_method(self):
self.dataset = PaddleNormalDataset(20)

""" """
测试在单机多卡上使用PaddleFleetDriver进行训练。
分布式训练用pytest会有些混乱
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况
此时对应 driver.load 中的情况
""" """


def test_case1(self):

gpus = [0, 1]
lr = 0.0003
epochs = 20
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_sampler(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
sampler = RandomSampler(self.dataset, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)

dist.barrier()
"""
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler`
参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。
当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定
是否重新实例化 dataloader
"""


paddle_model = PaddleNormalModel_Classification()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError):
# 应当抛出 RuntimeError
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True)

dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler
时的表现
"""
dataloader = DataLoader(
self.dataset,
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4),
)
dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
pad=True
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler)

dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)

assert replaced_loader is dataloader
dist.barrier()


paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=lr)
"""
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数
为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader
"""


train_dataset = PaddleDataset_MNIST("train")
test_dataset = PaddleDataset_MNIST("test")
loss_func = paddle.nn.CrossEntropyLoss()
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler
的表现
"""
dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4)
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.batch_sampler.sampler.shuffle == True
dist.barrier()


dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True)
"""
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数
为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader
"""


driver = PaddleFleetDriver(
model=paddle_model,
parallel_device=gpus,
@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = RandomSampler(self.dataset, True)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
) )
driver.set_optimizers(paddle_opt)
dataloader = driver.set_dist_repro_dataloader(dataloader, )
driver.setup()
# 检查model_device
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}")

driver.barrier()

driver.zero_grad()
current_epoch_idx = 0
while current_epoch_idx < epochs:
epoch_loss, batch = 0, 0
driver.set_model_mode("train")
driver.set_sampler_epoch(dataloader, current_epoch_idx)
for batch, (img, label) in enumerate(dataloader):

img = paddle.to_tensor(img)
out = driver.train_step(img)
label + 1
loss = loss_func(out, label)
epoch_loss += loss.item()

if batch % 50 == 0:
print("epoch:{}, batch:{}, loss: {}, rank:{}".format(current_epoch_idx, batch, loss.item(), driver.local_rank))

driver.backward(loss)
driver.step()
driver.zero_grad()
driver.barrier()
current_epoch_idx += 1

# test
correct = 0
driver.set_model_mode("eval")
for img, label in test_dataset:

img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1))
out = driver.test_step(img)
res = paddle.nn.functional.softmax(out).argmax().item()
label = label.item()
if res == label:
correct += 1

print("{} / {}, acc: {}".format(correct, len(test_dataset), correct / len(test_dataset)))
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.batch_sampler.sampler.shuffle == True
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler
的表现
"""
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2)
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True)
dataloader = DataLoader(
self.dataset,
batch_sampler=batch_sampler
)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler)
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler)
assert replaced_loader.batch_sampler.batch_size == 2
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

@magic_argv_env_context
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self):
"""
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现
"""
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)

assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler)
assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier()

def check_distributed_sampler(self, sampler):
"""
测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确
"""
assert sampler.num_replicas == dist.get_world_size()
assert sampler.rank == dist.get_rank()
if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True


Loading…
Cancel
Save