From fe449bd5438faf97678619bc1573c6354858a3ab Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 9 May 2022 08:59:03 +0000 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=20initialize=5Fpaddle=5Fdriv?= =?UTF-8?q?er=20=E5=AF=B9=E5=8F=82=E6=95=B0=20driver=E7=9A=84=E9=99=90?= =?UTF-8?q?=E5=88=B6=EF=BC=8C=E4=BD=BF=E5=BE=97=E5=AE=83=E8=83=BD=E5=A4=9F?= =?UTF-8?q?=E6=A0=B9=E6=8D=AE=20device=20=E5=88=9D=E5=A7=8B=E5=8C=96=20dri?= =?UTF-8?q?ver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/fleet.py | 8 ++- .../paddle_driver/initialize_paddle_driver.py | 29 +++------- .../drivers/paddle_driver/single_device.py | 4 +- tests/core/controllers/_test_trainer_fleet.py | 2 +- .../_test_trainer_fleet_outside.py | 2 +- tests/core/controllers/test_trainer_paddle.py | 5 +- .../test_initialize_paddle_driver.py | 58 ++----------------- 7 files changed, 25 insertions(+), 83 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index b94e7bde..e5b2a06f 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -238,14 +238,16 @@ class PaddleFleetDriver(PaddleDriver): self.gloo_rendezvous_dir = None # 分布式环境的其它参数设置 - self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) + paddle_kwargs = kwargs.get("paddle_kwargs", {}) + + self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {}) check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) # fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) - self.is_collective = self._fleet_kwargs.get("is_collective", True) + self.is_collective = self._fleet_kwargs.pop("is_collective", True) if not self.is_collective: raise NotImplementedError("FastNLP only support `collective` for distributed training now.") - self.role_maker = self._fleet_kwargs.get("role_maker", None) + self.role_maker = self._fleet_kwargs.pop("role_maker", None) if self.local_rank == 0 and not is_in_paddle_dist(): # 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有 diff --git a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py index f07cd47e..60e8afc0 100644 --- a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py @@ -22,12 +22,14 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ 2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver 3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver - :param driver: 该参数的值应为以下之一:["paddle", "fleet"]; + :param driver: 使用的 ``driver`` 类型,在这个函数中仅支持 ``paddle`` :param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; :param model: 训练或者评测的具体的模型; :return: 返回构造的 `Driver` 实例。 """ + if driver != "paddle": + raise ValueError("When initialize PaddleDriver, parameter `driver` must be 'paddle'.") if is_in_paddle_launch_dist(): if device is not None: logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " @@ -37,9 +39,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ # TODO 目前一个进程仅对应一个卡,所以暂时传入一个 int return PaddleFleetDriver(model, device[0], True, **kwargs) - if driver not in {"paddle", "fleet"}: - raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") - user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") if user_visible_devices is None: raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set " @@ -64,22 +63,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ " the available gpu number.") elif device is not None and not isinstance(device, str): raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") + if isinstance(device, List): + return PaddleFleetDriver(model, device, **kwargs) + else: + return PaddleSingleDriver(model, device, **kwargs) - if driver == "paddle": - if not isinstance(device, List): - return PaddleSingleDriver(model, device, **kwargs) - else: - logger.info("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" - "`PaddleFleetDriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" - "`driver` as `fleet`.") - return PaddleFleetDriver(model, device, **kwargs) - elif driver == "fleet": - if not isinstance(device, List): - if device == "cpu": - raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.") - logger.info("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" - "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " - "`driver` as `paddle`.") - return PaddleFleetDriver(model, [device], **kwargs) - else: - return PaddleFleetDriver(model, device, **kwargs) diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 630c03ee..69b58954 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -1,4 +1,5 @@ import os +import contextlib from typing import Optional, Dict, Union, Callable, Tuple from .paddle_driver import PaddleDriver @@ -70,7 +71,8 @@ class PaddleSingleDriver(PaddleDriver): """ device = get_device_from_visible(self.model_device, output_type=str) paddle.device.set_device(device) - self.model.to(device) + with contextlib.redirect_stdout(None): + self.model.to(device) def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: if isinstance(batch, Dict) and not self.wo_auto_param_call: diff --git a/tests/core/controllers/_test_trainer_fleet.py b/tests/core/controllers/_test_trainer_fleet.py index 309e6eb4..1a01bb5d 100644 --- a/tests/core/controllers/_test_trainer_fleet.py +++ b/tests/core/controllers/_test_trainer_fleet.py @@ -76,7 +76,7 @@ def test_trainer_fleet( trainer.run() if __name__ == "__main__": - driver = "fleet" + driver = "paddle" device = [0,2,3] # driver = "paddle" # device = 2 diff --git a/tests/core/controllers/_test_trainer_fleet_outside.py b/tests/core/controllers/_test_trainer_fleet_outside.py index d2bcbc41..1ab2e624 100644 --- a/tests/core/controllers/_test_trainer_fleet_outside.py +++ b/tests/core/controllers/_test_trainer_fleet_outside.py @@ -83,7 +83,7 @@ def test_trainer_fleet( trainer.run() if __name__ == "__main__": - driver = "fleet" + driver = "paddle" device = [0,2,3] callbacks = [ # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index 895e8517..d7bfaeaf 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -24,13 +24,12 @@ class TrainPaddleConfig: shuffle: bool = True evaluate_every = 2 -@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])]) +@pytest.mark.parametrize("device", ["cpu", 1, [0, 1]]) # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) @pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) @pytest.mark.paddledist @magic_argv_env_context def test_trainer_paddle( - driver, device, callbacks, n_epochs=2, @@ -56,7 +55,7 @@ def test_trainer_paddle( metrics = {"acc": Accuracy(backend="paddle")} trainer = Trainer( model=model, - driver=driver, + driver="paddle", device=device, optimizers=optimizers, train_dataloader=train_dataloader, diff --git a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py index e339bbcc..ad99d4a8 100644 --- a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py +++ b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py @@ -21,87 +21,41 @@ def test_incorrect_driver(): "device", ["cpu", "gpu:0", 0] ) -@pytest.mark.parametrize( - "driver", - ["paddle"] -) -def test_get_single_device(driver, device): +def test_get_single_device(device): """ 测试正常情况下初始化 PaddleSingleDriver 的情况 """ model = PaddleNormalModel_Classification_1(2, 100) - driver = initialize_paddle_driver(driver, device, model) + driver = initialize_paddle_driver("paddle", device, model) assert isinstance(driver, PaddleSingleDriver) -@pytest.mark.paddle -@pytest.mark.parametrize( - "device", - [0, 1, [1]] -) -@pytest.mark.parametrize( - "driver", - ["fleet"] -) -@magic_argv_env_context -def test_get_fleet_2(driver, device): - """ - 测试 fleet 多卡的初始化情况,但传入了单个 gpu - """ - - model = PaddleNormalModel_Classification_1(64, 10) - driver = initialize_paddle_driver(driver, device, model) - - assert isinstance(driver, PaddleFleetDriver) - @pytest.mark.paddle @pytest.mark.parametrize( "device", [[0, 2, 3], -1] ) -@pytest.mark.parametrize( - "driver", - ["paddle", "fleet"] -) @magic_argv_env_context -def test_get_fleet(driver, device): +def test_get_fleet(device): """ 测试 fleet 多卡的初始化情况 """ model = PaddleNormalModel_Classification_1(64, 10) - driver = initialize_paddle_driver(driver, device, model) + driver = initialize_paddle_driver("paddle", device, model) assert isinstance(driver, PaddleFleetDriver) -@pytest.mark.paddle -@pytest.mark.parametrize( - ("driver", "device"), - [("fleet", "cpu")] -) -@magic_argv_env_context -def test_get_fleet_cpu(driver, device): - """ - 测试试图在 cpu 上初始化分布式训练的情况 - """ - model = PaddleNormalModel_Classification_1(64, 10) - with pytest.raises(ValueError): - driver = initialize_paddle_driver(driver, device, model) - @pytest.mark.paddle @pytest.mark.parametrize( "device", [-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] ) -@pytest.mark.parametrize( - "driver", - ["paddle", "fleet"] -) @magic_argv_env_context -def test_device_out_of_range(driver, device): +def test_device_out_of_range(device): """ 测试传入的device超过范围的情况 """ model = PaddleNormalModel_Classification_1(2, 100) with pytest.raises(ValueError): - driver = initialize_paddle_driver(driver, device, model) + driver = initialize_paddle_driver("paddle", device, model)