Browse Source

删除 initialize_paddle_driver 对参数 driver的限制,使得它能够根据 device 初始化 driver

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
fe449bd543
7 changed files with 25 additions and 83 deletions
  1. +5
    -3
      fastNLP/core/drivers/paddle_driver/fleet.py
  2. +7
    -22
      fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
  3. +3
    -1
      fastNLP/core/drivers/paddle_driver/single_device.py
  4. +1
    -1
      tests/core/controllers/_test_trainer_fleet.py
  5. +1
    -1
      tests/core/controllers/_test_trainer_fleet_outside.py
  6. +2
    -3
      tests/core/controllers/test_trainer_paddle.py
  7. +6
    -52
      tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py

+ 5
- 3
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -238,14 +238,16 @@ class PaddleFleetDriver(PaddleDriver):
self.gloo_rendezvous_dir = None

# 分布式环境的其它参数设置
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {})
paddle_kwargs = kwargs.get("paddle_kwargs", {})
self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {})
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__)
# fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy())
self.is_collective = self._fleet_kwargs.get("is_collective", True)
self.is_collective = self._fleet_kwargs.pop("is_collective", True)
if not self.is_collective:
raise NotImplementedError("FastNLP only support `collective` for distributed training now.")
self.role_maker = self._fleet_kwargs.get("role_maker", None)
self.role_maker = self._fleet_kwargs.pop("role_maker", None)

if self.local_rank == 0 and not is_in_paddle_dist():
# 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有


+ 7
- 22
fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py View File

@@ -22,12 +22,14 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver
3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver

:param driver: 该参数的值应为以下之一:["paddle", "fleet"];
:param driver: 使用的 ``driver`` 类型,在这个函数中仅支持 ``paddle``
:param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致;
:param model: 训练或者评测的具体的模型;

:return: 返回构造的 `Driver` 实例。
"""
if driver != "paddle":
raise ValueError("When initialize PaddleDriver, parameter `driver` must be 'paddle'.")
if is_in_paddle_launch_dist():
if device is not None:
logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull "
@@ -37,9 +39,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
# TODO 目前一个进程仅对应一个卡,所以暂时传入一个 int
return PaddleFleetDriver(model, device[0], True, **kwargs)

if driver not in {"paddle", "fleet"}:
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].")

user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES")
if user_visible_devices is None:
raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set "
@@ -64,22 +63,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
" the available gpu number.")
elif device is not None and not isinstance(device, str):
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.")
if isinstance(device, List):
return PaddleFleetDriver(model, device, **kwargs)
else:
return PaddleSingleDriver(model, device, **kwargs)

if driver == "paddle":
if not isinstance(device, List):
return PaddleSingleDriver(model, device, **kwargs)
else:
logger.info("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use"
"`PaddleFleetDriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter"
"`driver` as `fleet`.")
return PaddleFleetDriver(model, device, **kwargs)
elif driver == "fleet":
if not isinstance(device, List):
if device == "cpu":
raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.")
logger.info("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will"
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should "
"`driver` as `paddle`.")
return PaddleFleetDriver(model, [device], **kwargs)
else:
return PaddleFleetDriver(model, device, **kwargs)

+ 3
- 1
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -1,4 +1,5 @@
import os
import contextlib
from typing import Optional, Dict, Union, Callable, Tuple

from .paddle_driver import PaddleDriver
@@ -70,7 +71,8 @@ class PaddleSingleDriver(PaddleDriver):
"""
device = get_device_from_visible(self.model_device, output_type=str)
paddle.device.set_device(device)
self.model.to(device)
with contextlib.redirect_stdout(None):
self.model.to(device)

def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict:
if isinstance(batch, Dict) and not self.wo_auto_param_call:


+ 1
- 1
tests/core/controllers/_test_trainer_fleet.py View File

@@ -76,7 +76,7 @@ def test_trainer_fleet(
trainer.run()

if __name__ == "__main__":
driver = "fleet"
driver = "paddle"
device = [0,2,3]
# driver = "paddle"
# device = 2


+ 1
- 1
tests/core/controllers/_test_trainer_fleet_outside.py View File

@@ -83,7 +83,7 @@ def test_trainer_fleet(
trainer.run()

if __name__ == "__main__":
driver = "fleet"
driver = "paddle"
device = [0,2,3]
callbacks = [
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),


+ 2
- 3
tests/core/controllers/test_trainer_paddle.py View File

@@ -24,13 +24,12 @@ class TrainPaddleConfig:
shuffle: bool = True
evaluate_every = 2

@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])])
@pytest.mark.parametrize("device", ["cpu", 1, [0, 1]])
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
@pytest.mark.parametrize("callbacks", [[RichCallback(5)]])
@pytest.mark.paddledist
@magic_argv_env_context
def test_trainer_paddle(
driver,
device,
callbacks,
n_epochs=2,
@@ -56,7 +55,7 @@ def test_trainer_paddle(
metrics = {"acc": Accuracy(backend="paddle")}
trainer = Trainer(
model=model,
driver=driver,
driver="paddle",
device=device,
optimizers=optimizers,
train_dataloader=train_dataloader,


+ 6
- 52
tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py View File

@@ -21,87 +21,41 @@ def test_incorrect_driver():
"device",
["cpu", "gpu:0", 0]
)
@pytest.mark.parametrize(
"driver",
["paddle"]
)
def test_get_single_device(driver, device):
def test_get_single_device(device):
"""
测试正常情况下初始化 PaddleSingleDriver 的情况
"""

model = PaddleNormalModel_Classification_1(2, 100)
driver = initialize_paddle_driver(driver, device, model)
driver = initialize_paddle_driver("paddle", device, model)
assert isinstance(driver, PaddleSingleDriver)

@pytest.mark.paddle
@pytest.mark.parametrize(
"device",
[0, 1, [1]]
)
@pytest.mark.parametrize(
"driver",
["fleet"]
)
@magic_argv_env_context
def test_get_fleet_2(driver, device):
"""
测试 fleet 多卡的初始化情况,但传入了单个 gpu
"""

model = PaddleNormalModel_Classification_1(64, 10)
driver = initialize_paddle_driver(driver, device, model)

assert isinstance(driver, PaddleFleetDriver)

@pytest.mark.paddle
@pytest.mark.parametrize(
"device",
[[0, 2, 3], -1]
)
@pytest.mark.parametrize(
"driver",
["paddle", "fleet"]
)
@magic_argv_env_context
def test_get_fleet(driver, device):
def test_get_fleet(device):
"""
测试 fleet 多卡的初始化情况
"""

model = PaddleNormalModel_Classification_1(64, 10)
driver = initialize_paddle_driver(driver, device, model)
driver = initialize_paddle_driver("paddle", device, model)

assert isinstance(driver, PaddleFleetDriver)

@pytest.mark.paddle
@pytest.mark.parametrize(
("driver", "device"),
[("fleet", "cpu")]
)
@magic_argv_env_context
def test_get_fleet_cpu(driver, device):
"""
测试试图在 cpu 上初始化分布式训练的情况
"""
model = PaddleNormalModel_Classification_1(64, 10)
with pytest.raises(ValueError):
driver = initialize_paddle_driver(driver, device, model)

@pytest.mark.paddle
@pytest.mark.parametrize(
"device",
[-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1]
)
@pytest.mark.parametrize(
"driver",
["paddle", "fleet"]
)
@magic_argv_env_context
def test_device_out_of_range(driver, device):
def test_device_out_of_range(device):
"""
测试传入的device超过范围的情况
"""
model = PaddleNormalModel_Classification_1(2, 100)
with pytest.raises(ValueError):
driver = initialize_paddle_driver(driver, device, model)
driver = initialize_paddle_driver("paddle", device, model)

Loading…
Cancel
Save