@@ -238,14 +238,16 @@ class PaddleFleetDriver(PaddleDriver): | |||
self.gloo_rendezvous_dir = None | |||
# 分布式环境的其它参数设置 | |||
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | |||
paddle_kwargs = kwargs.get("paddle_kwargs", {}) | |||
self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {}) | |||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | |||
# fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 | |||
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | |||
self.is_collective = self._fleet_kwargs.get("is_collective", True) | |||
self.is_collective = self._fleet_kwargs.pop("is_collective", True) | |||
if not self.is_collective: | |||
raise NotImplementedError("FastNLP only support `collective` for distributed training now.") | |||
self.role_maker = self._fleet_kwargs.get("role_maker", None) | |||
self.role_maker = self._fleet_kwargs.pop("role_maker", None) | |||
if self.local_rank == 0 and not is_in_paddle_dist(): | |||
# 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有 | |||
@@ -22,12 +22,14 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||
2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver | |||
3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver | |||
:param driver: 该参数的值应为以下之一:["paddle", "fleet"]; | |||
:param driver: 使用的 ``driver`` 类型,在这个函数中仅支持 ``paddle`` | |||
:param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; | |||
:param model: 训练或者评测的具体的模型; | |||
:return: 返回构造的 `Driver` 实例。 | |||
""" | |||
if driver != "paddle": | |||
raise ValueError("When initialize PaddleDriver, parameter `driver` must be 'paddle'.") | |||
if is_in_paddle_launch_dist(): | |||
if device is not None: | |||
logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | |||
@@ -37,9 +39,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||
# TODO 目前一个进程仅对应一个卡,所以暂时传入一个 int | |||
return PaddleFleetDriver(model, device[0], True, **kwargs) | |||
if driver not in {"paddle", "fleet"}: | |||
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | |||
user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") | |||
if user_visible_devices is None: | |||
raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set " | |||
@@ -64,22 +63,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||
" the available gpu number.") | |||
elif device is not None and not isinstance(device, str): | |||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||
if isinstance(device, List): | |||
return PaddleFleetDriver(model, device, **kwargs) | |||
else: | |||
return PaddleSingleDriver(model, device, **kwargs) | |||
if driver == "paddle": | |||
if not isinstance(device, List): | |||
return PaddleSingleDriver(model, device, **kwargs) | |||
else: | |||
logger.info("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" | |||
"`PaddleFleetDriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" | |||
"`driver` as `fleet`.") | |||
return PaddleFleetDriver(model, device, **kwargs) | |||
elif driver == "fleet": | |||
if not isinstance(device, List): | |||
if device == "cpu": | |||
raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.") | |||
logger.info("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | |||
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | |||
"`driver` as `paddle`.") | |||
return PaddleFleetDriver(model, [device], **kwargs) | |||
else: | |||
return PaddleFleetDriver(model, device, **kwargs) |
@@ -1,4 +1,5 @@ | |||
import os | |||
import contextlib | |||
from typing import Optional, Dict, Union, Callable, Tuple | |||
from .paddle_driver import PaddleDriver | |||
@@ -70,7 +71,8 @@ class PaddleSingleDriver(PaddleDriver): | |||
""" | |||
device = get_device_from_visible(self.model_device, output_type=str) | |||
paddle.device.set_device(device) | |||
self.model.to(device) | |||
with contextlib.redirect_stdout(None): | |||
self.model.to(device) | |||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
@@ -76,7 +76,7 @@ def test_trainer_fleet( | |||
trainer.run() | |||
if __name__ == "__main__": | |||
driver = "fleet" | |||
driver = "paddle" | |||
device = [0,2,3] | |||
# driver = "paddle" | |||
# device = 2 | |||
@@ -83,7 +83,7 @@ def test_trainer_fleet( | |||
trainer.run() | |||
if __name__ == "__main__": | |||
driver = "fleet" | |||
driver = "paddle" | |||
device = [0,2,3] | |||
callbacks = [ | |||
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||
@@ -24,13 +24,12 @@ class TrainPaddleConfig: | |||
shuffle: bool = True | |||
evaluate_every = 2 | |||
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])]) | |||
@pytest.mark.parametrize("device", ["cpu", 1, [0, 1]]) | |||
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | |||
@pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) | |||
@pytest.mark.paddledist | |||
@magic_argv_env_context | |||
def test_trainer_paddle( | |||
driver, | |||
device, | |||
callbacks, | |||
n_epochs=2, | |||
@@ -56,7 +55,7 @@ def test_trainer_paddle( | |||
metrics = {"acc": Accuracy(backend="paddle")} | |||
trainer = Trainer( | |||
model=model, | |||
driver=driver, | |||
driver="paddle", | |||
device=device, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
@@ -21,87 +21,41 @@ def test_incorrect_driver(): | |||
"device", | |||
["cpu", "gpu:0", 0] | |||
) | |||
@pytest.mark.parametrize( | |||
"driver", | |||
["paddle"] | |||
) | |||
def test_get_single_device(driver, device): | |||
def test_get_single_device(device): | |||
""" | |||
测试正常情况下初始化 PaddleSingleDriver 的情况 | |||
""" | |||
model = PaddleNormalModel_Classification_1(2, 100) | |||
driver = initialize_paddle_driver(driver, device, model) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
assert isinstance(driver, PaddleSingleDriver) | |||
@pytest.mark.paddle | |||
@pytest.mark.parametrize( | |||
"device", | |||
[0, 1, [1]] | |||
) | |||
@pytest.mark.parametrize( | |||
"driver", | |||
["fleet"] | |||
) | |||
@magic_argv_env_context | |||
def test_get_fleet_2(driver, device): | |||
""" | |||
测试 fleet 多卡的初始化情况,但传入了单个 gpu | |||
""" | |||
model = PaddleNormalModel_Classification_1(64, 10) | |||
driver = initialize_paddle_driver(driver, device, model) | |||
assert isinstance(driver, PaddleFleetDriver) | |||
@pytest.mark.paddle | |||
@pytest.mark.parametrize( | |||
"device", | |||
[[0, 2, 3], -1] | |||
) | |||
@pytest.mark.parametrize( | |||
"driver", | |||
["paddle", "fleet"] | |||
) | |||
@magic_argv_env_context | |||
def test_get_fleet(driver, device): | |||
def test_get_fleet(device): | |||
""" | |||
测试 fleet 多卡的初始化情况 | |||
""" | |||
model = PaddleNormalModel_Classification_1(64, 10) | |||
driver = initialize_paddle_driver(driver, device, model) | |||
driver = initialize_paddle_driver("paddle", device, model) | |||
assert isinstance(driver, PaddleFleetDriver) | |||
@pytest.mark.paddle | |||
@pytest.mark.parametrize( | |||
("driver", "device"), | |||
[("fleet", "cpu")] | |||
) | |||
@magic_argv_env_context | |||
def test_get_fleet_cpu(driver, device): | |||
""" | |||
测试试图在 cpu 上初始化分布式训练的情况 | |||
""" | |||
model = PaddleNormalModel_Classification_1(64, 10) | |||
with pytest.raises(ValueError): | |||
driver = initialize_paddle_driver(driver, device, model) | |||
@pytest.mark.paddle | |||
@pytest.mark.parametrize( | |||
"device", | |||
[-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] | |||
) | |||
@pytest.mark.parametrize( | |||
"driver", | |||
["paddle", "fleet"] | |||
) | |||
@magic_argv_env_context | |||
def test_device_out_of_range(driver, device): | |||
def test_device_out_of_range(device): | |||
""" | |||
测试传入的device超过范围的情况 | |||
""" | |||
model = PaddleNormalModel_Classification_1(2, 100) | |||
with pytest.raises(ValueError): | |||
driver = initialize_paddle_driver(driver, device, model) | |||
driver = initialize_paddle_driver("paddle", device, model) |