|
@@ -1,21 +1,35 @@ |
|
|
|
|
|
from dataclasses import replace |
|
|
import pytest |
|
|
import pytest |
|
|
import os |
|
|
import os |
|
|
import numpy as np |
|
|
|
|
|
from fastNLP.envs.set_env_on_import import set_env_on_import_paddle |
|
|
|
|
|
|
|
|
|
|
|
set_env_on_import_paddle() |
|
|
|
|
|
|
|
|
os.environ["FASTNLP_BACKEND"] = "paddle" |
|
|
|
|
|
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver |
|
|
|
|
|
from fastNLP.core.samplers import ( |
|
|
|
|
|
RandomSampler, |
|
|
|
|
|
UnrepeatedSampler, |
|
|
|
|
|
BucketedBatchSampler, |
|
|
|
|
|
UnrepeatedRandomSampler, |
|
|
|
|
|
UnrepeatedSequentialSampler, |
|
|
|
|
|
) |
|
|
|
|
|
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 |
|
|
|
|
|
from tests.helpers.datasets.paddle_data import PaddleNormalDataset |
|
|
|
|
|
from tests.helpers.utils import magic_argv_env_context |
|
|
|
|
|
|
|
|
import paddle |
|
|
import paddle |
|
|
import paddle.distributed as dist |
|
|
import paddle.distributed as dist |
|
|
from paddle.io import DataLoader |
|
|
|
|
|
|
|
|
from paddle.io import DataLoader, BatchSampler |
|
|
|
|
|
|
|
|
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver |
|
|
|
|
|
from fastNLP.core.samplers.reproducible_sampler import RandomSampler |
|
|
|
|
|
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK |
|
|
|
|
|
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification |
|
|
|
|
|
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST |
|
|
|
|
|
from tests.helpers.utils import magic_argv_env_context |
|
|
|
|
|
from fastNLP.core import synchronize_safe_rm |
|
|
|
|
|
|
|
|
def generate_driver(num_labels, feature_dimension): |
|
|
|
|
|
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) |
|
|
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) |
|
|
|
|
|
driver = PaddleFleetDriver( |
|
|
|
|
|
model=paddle_model, |
|
|
|
|
|
parallel_device=[0,1], |
|
|
|
|
|
) |
|
|
|
|
|
driver.set_optimizers(paddle_opt) |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
|
|
|
|
|
|
return driver |
|
|
|
|
|
|
|
|
############################################################################ |
|
|
############################################################################ |
|
|
# |
|
|
# |
|
@@ -23,269 +37,340 @@ from fastNLP.core import synchronize_safe_rm |
|
|
# |
|
|
# |
|
|
############################################################################ |
|
|
############################################################################ |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_move_data_to_device(): |
|
|
|
|
|
""" |
|
|
|
|
|
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 |
|
|
|
|
|
就不重复测试了 |
|
|
|
|
|
""" |
|
|
|
|
|
try: |
|
|
|
|
|
paddle_model = PaddleNormalModel_Classification(10, 784) |
|
|
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) |
|
|
|
|
|
driver = PaddleFleetDriver( |
|
|
|
|
|
model=paddle_model, |
|
|
|
|
|
parallel_device=[0,1], |
|
|
|
|
|
) |
|
|
|
|
|
driver.set_optimizers(paddle_opt) |
|
|
|
|
|
# 区分launch和子进程setup的时候 |
|
|
|
|
|
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: |
|
|
|
|
|
with pytest.raises(SystemExit) as e: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
assert e.value.code == 0 |
|
|
|
|
|
return |
|
|
|
|
|
else: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
driver.move_data_to_device(paddle.rand((32, 64))) |
|
|
|
|
|
finally: |
|
|
|
|
|
synchronize_safe_rm("log") |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_is_distributed(): |
|
|
|
|
|
print(os.getenv("CUDA_VISIBLE_DEVICES")) |
|
|
|
|
|
print(paddle.device.get_device()) |
|
|
|
|
|
try: |
|
|
|
|
|
paddle_model = PaddleNormalModel_Classification(10, 784) |
|
|
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) |
|
|
|
|
|
driver = PaddleFleetDriver( |
|
|
|
|
|
model=paddle_model, |
|
|
|
|
|
parallel_device=[0,1], |
|
|
|
|
|
output_from_new_proc='all' |
|
|
|
|
|
) |
|
|
|
|
|
driver.set_optimizers(paddle_opt) |
|
|
|
|
|
# 区分launch和子进程setup的时候 |
|
|
|
|
|
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: |
|
|
|
|
|
with pytest.raises(SystemExit) as e: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
assert e.value.code == 0 |
|
|
|
|
|
return |
|
|
|
|
|
else: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
assert driver.is_distributed() == True |
|
|
|
|
|
finally: |
|
|
|
|
|
synchronize_safe_rm("log") |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_get_no_sync_context(): |
|
|
|
|
|
|
|
|
class TestFleetDriverFunction: |
|
|
""" |
|
|
""" |
|
|
测试能否运行 |
|
|
|
|
|
|
|
|
测试 PaddleFleetDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 |
|
|
""" |
|
|
""" |
|
|
try: |
|
|
|
|
|
paddle_model = PaddleNormalModel_Classification(10, 784) |
|
|
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) |
|
|
|
|
|
driver = PaddleFleetDriver( |
|
|
|
|
|
model=paddle_model, |
|
|
|
|
|
parallel_device=[0,1], |
|
|
|
|
|
) |
|
|
|
|
|
driver.set_optimizers(paddle_opt) |
|
|
|
|
|
# 区分launch和子进程setup的时候 |
|
|
|
|
|
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: |
|
|
|
|
|
with pytest.raises(SystemExit) as e: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
assert e.value.code == 0 |
|
|
|
|
|
return |
|
|
|
|
|
else: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
res = driver.get_no_sync_context() |
|
|
|
|
|
finally: |
|
|
|
|
|
synchronize_safe_rm("log") |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_is_global_zero(): |
|
|
|
|
|
try: |
|
|
|
|
|
paddle_model = PaddleNormalModel_Classification(10, 784) |
|
|
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) |
|
|
|
|
|
driver = PaddleFleetDriver( |
|
|
|
|
|
model=paddle_model, |
|
|
|
|
|
parallel_device=[0,1], |
|
|
|
|
|
) |
|
|
|
|
|
driver.set_optimizers(paddle_opt) |
|
|
|
|
|
# 区分launch和子进程setup的时候 |
|
|
|
|
|
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: |
|
|
|
|
|
with pytest.raises(SystemExit) as e: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
assert e.value.code == 0 |
|
|
|
|
|
return |
|
|
|
|
|
else: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
driver.is_global_zero() |
|
|
|
|
|
finally: |
|
|
|
|
|
synchronize_safe_rm("log") |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_unwrap_model(): |
|
|
|
|
|
try: |
|
|
|
|
|
paddle_model = PaddleNormalModel_Classification(10, 784) |
|
|
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) |
|
|
|
|
|
driver = PaddleFleetDriver( |
|
|
|
|
|
model=paddle_model, |
|
|
|
|
|
parallel_device=[0,1], |
|
|
|
|
|
) |
|
|
|
|
|
driver.set_optimizers(paddle_opt) |
|
|
|
|
|
# 区分launch和子进程setup的时候 |
|
|
|
|
|
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: |
|
|
|
|
|
with pytest.raises(SystemExit) as e: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
assert e.value.code == 0 |
|
|
|
|
|
return |
|
|
|
|
|
else: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
driver.unwrap_model() |
|
|
|
|
|
finally: |
|
|
|
|
|
synchronize_safe_rm("log") |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_get_local_rank(): |
|
|
|
|
|
try: |
|
|
|
|
|
paddle_model = PaddleNormalModel_Classification(10, 784) |
|
|
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) |
|
|
|
|
|
driver = PaddleFleetDriver( |
|
|
|
|
|
model=paddle_model, |
|
|
|
|
|
parallel_device=[0,1], |
|
|
|
|
|
) |
|
|
|
|
|
driver.set_optimizers(paddle_opt) |
|
|
|
|
|
# 区分launch和子进程setup的时候 |
|
|
|
|
|
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: |
|
|
|
|
|
with pytest.raises(SystemExit) as e: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
assert e.value.code == 0 |
|
|
|
|
|
return |
|
|
|
|
|
else: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
driver.get_local_rank() |
|
|
|
|
|
finally: |
|
|
|
|
|
synchronize_safe_rm("log") |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
|
|
"dist_sampler", |
|
|
|
|
|
["dist", "unrepeatdist", RandomSampler(PaddleDataset_MNIST("train"))] |
|
|
|
|
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
|
|
"reproducible", |
|
|
|
|
|
[True, False] |
|
|
|
|
|
) |
|
|
|
|
|
def test_replace_sampler(dist_sampler, reproducible): |
|
|
|
|
|
""" |
|
|
|
|
|
测试replace_sampler |
|
|
|
|
|
""" |
|
|
|
|
|
try: |
|
|
|
|
|
paddle_model = PaddleNormalModel_Classification(10, 784) |
|
|
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) |
|
|
|
|
|
driver = PaddleFleetDriver( |
|
|
|
|
|
model=paddle_model, |
|
|
|
|
|
parallel_device=[0,1], |
|
|
|
|
|
) |
|
|
|
|
|
driver.set_optimizers(paddle_opt) |
|
|
|
|
|
# 区分launch和子进程setup的时候 |
|
|
|
|
|
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: |
|
|
|
|
|
with pytest.raises(SystemExit) as e: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
assert e.value.code == 0 |
|
|
|
|
|
return |
|
|
|
|
|
else: |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) |
|
|
|
|
|
driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) |
|
|
|
|
|
finally: |
|
|
|
|
|
synchronize_safe_rm("log") |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
|
|
|
def setup_class(cls): |
|
|
|
|
|
cls.driver = generate_driver(10, 10) |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_move_data_to_device(self): |
|
|
|
|
|
""" |
|
|
|
|
|
这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 |
|
|
|
|
|
就不重复测试了 |
|
|
|
|
|
""" |
|
|
|
|
|
self.driver.move_data_to_device(paddle.rand((32, 64))) |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_is_distributed(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 is_distributed 函数 |
|
|
|
|
|
""" |
|
|
|
|
|
assert self.driver.is_distributed() == True |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_get_no_sync_context(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 get_no_sync_context 函数 |
|
|
|
|
|
""" |
|
|
|
|
|
res = self.driver.get_no_sync_context() |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_is_global_zero(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 is_global_zero 函数 |
|
|
|
|
|
""" |
|
|
|
|
|
self.driver.is_global_zero() |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_unwrap_model(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 unwrap_model 函数 |
|
|
|
|
|
""" |
|
|
|
|
|
self.driver.unwrap_model() |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_get_local_rank(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 get_local_rank 函数 |
|
|
|
|
|
""" |
|
|
|
|
|
self.driver.get_local_rank() |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
############################################################################ |
|
|
############################################################################ |
|
|
# |
|
|
# |
|
|
# 测试单机多卡的训练情况 |
|
|
|
|
|
|
|
|
# 测试 set_dist_repro_dataloader 函数 |
|
|
# |
|
|
# |
|
|
############################################################################ |
|
|
############################################################################ |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
class SingleMachineMultiGPUTrainingTestCase: |
|
|
|
|
|
|
|
|
class TestSetDistReproDataloader: |
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
|
|
|
|
def setup_class(cls): |
|
|
|
|
|
cls.driver = generate_driver(10, 10) |
|
|
|
|
|
|
|
|
|
|
|
def setup_method(self): |
|
|
|
|
|
self.dataset = PaddleNormalDataset(20) |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
测试在单机多卡上使用PaddleFleetDriver进行训练。 |
|
|
|
|
|
分布式训练用pytest会有些混乱 |
|
|
|
|
|
|
|
|
传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 |
|
|
|
|
|
此时对应 driver.load 中的情况 |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def test_case1(self): |
|
|
|
|
|
|
|
|
|
|
|
gpus = [0, 1] |
|
|
|
|
|
lr = 0.0003 |
|
|
|
|
|
epochs = 20 |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler is batch_sampler |
|
|
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler) |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
sampler = RandomSampler(self.dataset, shuffle=True) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, BatchSampler) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler is sampler |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size |
|
|
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` |
|
|
|
|
|
参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。 |
|
|
|
|
|
当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定 |
|
|
|
|
|
是否重新实例化 dataloader |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
paddle_model = PaddleNormalModel_Classification() |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
with pytest.raises(RuntimeError): |
|
|
|
|
|
# 应当抛出 RuntimeError |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler |
|
|
|
|
|
时的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
self.dataset, |
|
|
|
|
|
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4), |
|
|
|
|
|
) |
|
|
|
|
|
dataloader.batch_sampler.set_distributed( |
|
|
|
|
|
num_replicas=self.driver.world_size, |
|
|
|
|
|
rank=self.driver.global_rank, |
|
|
|
|
|
pad=True |
|
|
|
|
|
) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == 4 |
|
|
|
|
|
self.check_distributed_sampler(dataloader.batch_sampler) |
|
|
|
|
|
|
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
|
|
|
batch_sampler.sampler = RandomSampler(self.dataset, True) |
|
|
|
|
|
batch_sampler.sampler.set_distributed( |
|
|
|
|
|
num_replicas=self.driver.world_size, |
|
|
|
|
|
rank=self.driver.global_rank |
|
|
|
|
|
) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
self.dataset, |
|
|
|
|
|
batch_sampler=batch_sampler |
|
|
|
|
|
) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, BatchSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == 2 |
|
|
|
|
|
assert replaced_loader.batch_sampler.drop_last == False |
|
|
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) |
|
|
|
|
|
|
|
|
|
|
|
assert replaced_loader is dataloader |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=lr) |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 |
|
|
|
|
|
为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
train_dataset = PaddleDataset_MNIST("train") |
|
|
|
|
|
test_dataset = PaddleDataset_MNIST("test") |
|
|
|
|
|
loss_func = paddle.nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler |
|
|
|
|
|
的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
dataset=self.dataset, |
|
|
|
|
|
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) |
|
|
|
|
|
) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == 4 |
|
|
|
|
|
assert replaced_loader.drop_last == dataloader.drop_last |
|
|
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler |
|
|
|
|
|
的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
|
|
|
batch_sampler.sampler = RandomSampler(self.dataset, True) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
self.dataset, |
|
|
|
|
|
batch_sampler=batch_sampler |
|
|
|
|
|
) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == 2 |
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == True |
|
|
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, BatchSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size |
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == True |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True) |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 |
|
|
|
|
|
为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
driver = PaddleFleetDriver( |
|
|
|
|
|
model=paddle_model, |
|
|
|
|
|
parallel_device=gpus, |
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler |
|
|
|
|
|
的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
|
|
|
batch_sampler.sampler = RandomSampler(self.dataset, True) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
self.dataset, |
|
|
|
|
|
batch_sampler=batch_sampler |
|
|
) |
|
|
) |
|
|
driver.set_optimizers(paddle_opt) |
|
|
|
|
|
dataloader = driver.set_dist_repro_dataloader(dataloader, ) |
|
|
|
|
|
driver.setup() |
|
|
|
|
|
# 检查model_device |
|
|
|
|
|
self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") |
|
|
|
|
|
|
|
|
|
|
|
driver.barrier() |
|
|
|
|
|
|
|
|
|
|
|
driver.zero_grad() |
|
|
|
|
|
current_epoch_idx = 0 |
|
|
|
|
|
while current_epoch_idx < epochs: |
|
|
|
|
|
epoch_loss, batch = 0, 0 |
|
|
|
|
|
driver.set_model_mode("train") |
|
|
|
|
|
driver.set_sampler_epoch(dataloader, current_epoch_idx) |
|
|
|
|
|
for batch, (img, label) in enumerate(dataloader): |
|
|
|
|
|
|
|
|
|
|
|
img = paddle.to_tensor(img) |
|
|
|
|
|
out = driver.train_step(img) |
|
|
|
|
|
label + 1 |
|
|
|
|
|
loss = loss_func(out, label) |
|
|
|
|
|
epoch_loss += loss.item() |
|
|
|
|
|
|
|
|
|
|
|
if batch % 50 == 0: |
|
|
|
|
|
print("epoch:{}, batch:{}, loss: {}, rank:{}".format(current_epoch_idx, batch, loss.item(), driver.local_rank)) |
|
|
|
|
|
|
|
|
|
|
|
driver.backward(loss) |
|
|
|
|
|
driver.step() |
|
|
|
|
|
driver.zero_grad() |
|
|
|
|
|
driver.barrier() |
|
|
|
|
|
current_epoch_idx += 1 |
|
|
|
|
|
|
|
|
|
|
|
# test |
|
|
|
|
|
correct = 0 |
|
|
|
|
|
driver.set_model_mode("eval") |
|
|
|
|
|
for img, label in test_dataset: |
|
|
|
|
|
|
|
|
|
|
|
img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1)) |
|
|
|
|
|
out = driver.test_step(img) |
|
|
|
|
|
res = paddle.nn.functional.softmax(out).argmax().item() |
|
|
|
|
|
label = label.item() |
|
|
|
|
|
if res == label: |
|
|
|
|
|
correct += 1 |
|
|
|
|
|
|
|
|
|
|
|
print("{} / {}, acc: {}".format(correct, len(test_dataset), correct / len(test_dataset))) |
|
|
|
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, BatchSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == 2 |
|
|
|
|
|
assert replaced_loader.batch_sampler.sampler.shuffle == True |
|
|
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler |
|
|
|
|
|
的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) |
|
|
|
|
|
batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True) |
|
|
|
|
|
dataloader = DataLoader( |
|
|
|
|
|
self.dataset, |
|
|
|
|
|
batch_sampler=batch_sampler |
|
|
|
|
|
) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, BatchSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == 2 |
|
|
|
|
|
assert replaced_loader.drop_last == dataloader.drop_last |
|
|
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
@magic_argv_env_context |
|
|
|
|
|
def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self): |
|
|
|
|
|
""" |
|
|
|
|
|
测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 |
|
|
|
|
|
""" |
|
|
|
|
|
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) |
|
|
|
|
|
|
|
|
|
|
|
assert not (replaced_loader is dataloader) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler, BatchSampler) |
|
|
|
|
|
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) |
|
|
|
|
|
assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler) |
|
|
|
|
|
assert replaced_loader.batch_sampler.batch_size == 4 |
|
|
|
|
|
assert replaced_loader.drop_last == dataloader.drop_last |
|
|
|
|
|
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
def check_distributed_sampler(self, sampler): |
|
|
|
|
|
""" |
|
|
|
|
|
测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确 |
|
|
|
|
|
""" |
|
|
|
|
|
assert sampler.num_replicas == dist.get_world_size() |
|
|
|
|
|
assert sampler.rank == dist.get_rank() |
|
|
|
|
|
if not isinstance(sampler, UnrepeatedSampler): |
|
|
|
|
|
assert sampler.pad == True |
|
|
|
|
|
|