diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index e20866b3..434e9e5b 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -1,21 +1,35 @@ +from dataclasses import replace import pytest import os -import numpy as np -from fastNLP.envs.set_env_on_import import set_env_on_import_paddle -set_env_on_import_paddle() +os.environ["FASTNLP_BACKEND"] = "paddle" +from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver +from fastNLP.core.samplers import ( + RandomSampler, + UnrepeatedSampler, + BucketedBatchSampler, + UnrepeatedRandomSampler, + UnrepeatedSequentialSampler, +) +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from tests.helpers.datasets.paddle_data import PaddleNormalDataset +from tests.helpers.utils import magic_argv_env_context + import paddle import paddle.distributed as dist -from paddle.io import DataLoader +from paddle.io import DataLoader, BatchSampler -from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver -from fastNLP.core.samplers.reproducible_sampler import RandomSampler -from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK -from tests.helpers.models.paddle_model import PaddleNormalModel_Classification -from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST -from tests.helpers.utils import magic_argv_env_context -from fastNLP.core import synchronize_safe_rm +def generate_driver(num_labels, feature_dimension): + paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) + paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) + driver = PaddleFleetDriver( + model=paddle_model, + parallel_device=[0,1], + ) + driver.set_optimizers(paddle_opt) + driver.setup() + return driver ############################################################################ # @@ -23,269 +37,340 @@ from fastNLP.core import synchronize_safe_rm # ############################################################################ -@magic_argv_env_context -def test_move_data_to_device(): - """ - 这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 - 就不重复测试了 - """ - try: - paddle_model = PaddleNormalModel_Classification(10, 784) - paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) - driver = PaddleFleetDriver( - model=paddle_model, - parallel_device=[0,1], - ) - driver.set_optimizers(paddle_opt) - # 区分launch和子进程setup的时候 - if FASTNLP_DISTRIBUTED_CHECK not in os.environ: - with pytest.raises(SystemExit) as e: - driver.setup() - assert e.value.code == 0 - return - else: - driver.setup() - driver.move_data_to_device(paddle.rand((32, 64))) - finally: - synchronize_safe_rm("log") - - dist.barrier() - - -@magic_argv_env_context -def test_is_distributed(): - print(os.getenv("CUDA_VISIBLE_DEVICES")) - print(paddle.device.get_device()) - try: - paddle_model = PaddleNormalModel_Classification(10, 784) - paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) - driver = PaddleFleetDriver( - model=paddle_model, - parallel_device=[0,1], - output_from_new_proc='all' - ) - driver.set_optimizers(paddle_opt) - # 区分launch和子进程setup的时候 - if FASTNLP_DISTRIBUTED_CHECK not in os.environ: - with pytest.raises(SystemExit) as e: - driver.setup() - assert e.value.code == 0 - return - else: - driver.setup() - assert driver.is_distributed() == True - finally: - synchronize_safe_rm("log") - dist.barrier() - - -@magic_argv_env_context -def test_get_no_sync_context(): +class TestFleetDriverFunction: """ - 测试能否运行 + 测试 PaddleFleetDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 """ - try: - paddle_model = PaddleNormalModel_Classification(10, 784) - paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) - driver = PaddleFleetDriver( - model=paddle_model, - parallel_device=[0,1], - ) - driver.set_optimizers(paddle_opt) - # 区分launch和子进程setup的时候 - if FASTNLP_DISTRIBUTED_CHECK not in os.environ: - with pytest.raises(SystemExit) as e: - driver.setup() - assert e.value.code == 0 - return - else: - driver.setup() - res = driver.get_no_sync_context() - finally: - synchronize_safe_rm("log") - dist.barrier() - - -@magic_argv_env_context -def test_is_global_zero(): - try: - paddle_model = PaddleNormalModel_Classification(10, 784) - paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) - driver = PaddleFleetDriver( - model=paddle_model, - parallel_device=[0,1], - ) - driver.set_optimizers(paddle_opt) - # 区分launch和子进程setup的时候 - if FASTNLP_DISTRIBUTED_CHECK not in os.environ: - with pytest.raises(SystemExit) as e: - driver.setup() - assert e.value.code == 0 - return - else: - driver.setup() - driver.is_global_zero() - finally: - synchronize_safe_rm("log") - dist.barrier() - - - -@magic_argv_env_context -def test_unwrap_model(): - try: - paddle_model = PaddleNormalModel_Classification(10, 784) - paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) - driver = PaddleFleetDriver( - model=paddle_model, - parallel_device=[0,1], - ) - driver.set_optimizers(paddle_opt) - # 区分launch和子进程setup的时候 - if FASTNLP_DISTRIBUTED_CHECK not in os.environ: - with pytest.raises(SystemExit) as e: - driver.setup() - assert e.value.code == 0 - return - else: - driver.setup() - driver.unwrap_model() - finally: - synchronize_safe_rm("log") - dist.barrier() - -@magic_argv_env_context -def test_get_local_rank(): - try: - paddle_model = PaddleNormalModel_Classification(10, 784) - paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) - driver = PaddleFleetDriver( - model=paddle_model, - parallel_device=[0,1], - ) - driver.set_optimizers(paddle_opt) - # 区分launch和子进程setup的时候 - if FASTNLP_DISTRIBUTED_CHECK not in os.environ: - with pytest.raises(SystemExit) as e: - driver.setup() - assert e.value.code == 0 - return - else: - driver.setup() - driver.get_local_rank() - finally: - synchronize_safe_rm("log") - dist.barrier() - -@magic_argv_env_context -@pytest.mark.parametrize( - "dist_sampler", - ["dist", "unrepeatdist", RandomSampler(PaddleDataset_MNIST("train"))] -) -@pytest.mark.parametrize( - "reproducible", - [True, False] -) -def test_replace_sampler(dist_sampler, reproducible): - """ - 测试replace_sampler - """ - try: - paddle_model = PaddleNormalModel_Classification(10, 784) - paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=0.01) - driver = PaddleFleetDriver( - model=paddle_model, - parallel_device=[0,1], - ) - driver.set_optimizers(paddle_opt) - # 区分launch和子进程setup的时候 - if FASTNLP_DISTRIBUTED_CHECK not in os.environ: - with pytest.raises(SystemExit) as e: - driver.setup() - assert e.value.code == 0 - return - else: - driver.setup() - dataloader = DataLoader(PaddleDataset_MNIST("train"), batch_size=100, shuffle=True) - driver.set_dist_repro_dataloader(dataloader, dist_sampler, reproducible) - finally: - synchronize_safe_rm("log") - dist.barrier() + + @classmethod + def setup_class(cls): + cls.driver = generate_driver(10, 10) + + @magic_argv_env_context + def test_move_data_to_device(self): + """ + 这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 + 就不重复测试了 + """ + self.driver.move_data_to_device(paddle.rand((32, 64))) + + dist.barrier() + + @magic_argv_env_context + def test_is_distributed(self): + """ + 测试 is_distributed 函数 + """ + assert self.driver.is_distributed() == True + dist.barrier() + + @magic_argv_env_context + def test_get_no_sync_context(self): + """ + 测试 get_no_sync_context 函数 + """ + res = self.driver.get_no_sync_context() + dist.barrier() + + @magic_argv_env_context + def test_is_global_zero(self): + """ + 测试 is_global_zero 函数 + """ + self.driver.is_global_zero() + dist.barrier() + + @magic_argv_env_context + def test_unwrap_model(self): + """ + 测试 unwrap_model 函数 + """ + self.driver.unwrap_model() + dist.barrier() + + @magic_argv_env_context + def test_get_local_rank(self): + """ + 测试 get_local_rank 函数 + """ + self.driver.get_local_rank() + dist.barrier() ############################################################################ # -# 测试单机多卡的训练情况 +# 测试 set_dist_repro_dataloader 函数 # ############################################################################ -@magic_argv_env_context -class SingleMachineMultiGPUTrainingTestCase: +class TestSetDistReproDataloader: + + @classmethod + def setup_class(cls): + cls.driver = generate_driver(10, 10) + + def setup_method(self): + self.dataset = PaddleNormalDataset(20) + """ - 测试在单机多卡上使用PaddleFleetDriver进行训练。 - 分布式训练用pytest会有些混乱 + 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 + 此时对应 driver.load 中的情况 """ - def test_case1(self): - - gpus = [0, 1] - lr = 0.0003 - epochs = 20 + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_batch_sampler(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert replaced_loader.batch_sampler is batch_sampler + self.check_distributed_sampler(replaced_loader.batch_sampler) + + dist.barrier() + + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_sampler(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + sampler = RandomSampler(self.dataset, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.sampler is sampler + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + + dist.barrier() + + """ + 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` + 参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。 + 当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定 + 是否重新实例化 dataloader + """ - paddle_model = PaddleNormalModel_Classification() + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_none_reproducible_true(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + with pytest.raises(RuntimeError): + # 应当抛出 RuntimeError + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) + + dist.barrier() + + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler + 时的表现 + """ + dataloader = DataLoader( + self.dataset, + batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4), + ) + dataloader.batch_sampler.set_distributed( + num_replicas=self.driver.world_size, + rank=self.driver.global_rank, + pad=True + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert replaced_loader.batch_sampler.batch_size == 4 + self.check_distributed_sampler(dataloader.batch_sampler) + + dist.barrier() + + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_reproducible_smpler(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 + """ + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(self.dataset, True) + batch_sampler.sampler.set_distributed( + num_replicas=self.driver.world_size, + rank=self.driver.global_rank + ) + dataloader = DataLoader( + self.dataset, + batch_sampler=batch_sampler + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.drop_last == False + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + dist.barrier() + + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_none_reproducible_false_dataloader_normal(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + + assert replaced_loader is dataloader + dist.barrier() - paddle_opt = paddle.optimizer.Adam(parameters=paddle_model.parameters(), learning_rate=lr) + """ + 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 + 为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader + """ - train_dataset = PaddleDataset_MNIST("train") - test_dataset = PaddleDataset_MNIST("test") - loss_func = paddle.nn.CrossEntropyLoss() + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_batch_sampler(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler + 的表现 + """ + dataloader = DataLoader( + dataset=self.dataset, + batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4) + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.drop_last == dataloader.drop_last + self.check_distributed_sampler(replaced_loader.batch_sampler) + dist.barrier() + + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_dist_dataloader_reproducible_sampler(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler + 的表现 + """ + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(self.dataset, True) + dataloader = DataLoader( + self.dataset, + batch_sampler=batch_sampler + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.sampler.shuffle == True + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + dist.barrier() + + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_dist_dataloader_normal(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.batch_sampler.sampler.shuffle == True + dist.barrier() - dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True) + """ + 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 + 为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader + """ - driver = PaddleFleetDriver( - model=paddle_model, - parallel_device=gpus, + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_reproducible_sampler(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler + 的表现 + """ + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler.sampler = RandomSampler(self.dataset, True) + dataloader = DataLoader( + self.dataset, + batch_sampler=batch_sampler ) - driver.set_optimizers(paddle_opt) - dataloader = driver.set_dist_repro_dataloader(dataloader, ) - driver.setup() - # 检查model_device - self.assertEqual(driver.model_device, f"gpu:{os.environ['PADDLE_LOCAL_DEVICE_IDS']}") - - driver.barrier() - - driver.zero_grad() - current_epoch_idx = 0 - while current_epoch_idx < epochs: - epoch_loss, batch = 0, 0 - driver.set_model_mode("train") - driver.set_sampler_epoch(dataloader, current_epoch_idx) - for batch, (img, label) in enumerate(dataloader): - - img = paddle.to_tensor(img) - out = driver.train_step(img) - label + 1 - loss = loss_func(out, label) - epoch_loss += loss.item() - - if batch % 50 == 0: - print("epoch:{}, batch:{}, loss: {}, rank:{}".format(current_epoch_idx, batch, loss.item(), driver.local_rank)) - - driver.backward(loss) - driver.step() - driver.zero_grad() - driver.barrier() - current_epoch_idx += 1 - - # test - correct = 0 - driver.set_model_mode("eval") - for img, label in test_dataset: - - img = paddle.to_tensor(np.array(img).astype('float32').reshape(1, -1)) - out = driver.test_step(img) - res = paddle.nn.functional.softmax(out).argmax().item() - label = label.item() - if res == label: - correct += 1 - - print("{} / {}, acc: {}".format(correct, len(test_dataset), correct / len(test_dataset))) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) + assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.sampler.shuffle == True + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + dist.barrier() + + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_unrepreated_sampler(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler + 的表现 + """ + batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) + batch_sampler.sampler = UnrepeatedRandomSampler(self.dataset, True) + dataloader = DataLoader( + self.dataset, + batch_sampler=batch_sampler + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.drop_last == dataloader.drop_last + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + dist.barrier() + + @magic_argv_env_context + def test_set_dist_repro_dataloader_with_dist_unrepeat_dataloader_normal(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.drop_last == dataloader.drop_last + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + dist.barrier() + + def check_distributed_sampler(self, sampler): + """ + 测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确 + """ + assert sampler.num_replicas == dist.get_world_size() + assert sampler.rank == dist.get_rank() + if not isinstance(sampler, UnrepeatedSampler): + assert sampler.pad == True +