@@ -238,14 +238,16 @@ class PaddleFleetDriver(PaddleDriver): | |||||
self.gloo_rendezvous_dir = None | self.gloo_rendezvous_dir = None | ||||
# 分布式环境的其它参数设置 | # 分布式环境的其它参数设置 | ||||
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | |||||
paddle_kwargs = kwargs.get("paddle_kwargs", {}) | |||||
self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {}) | |||||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | ||||
# fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 | # fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 | ||||
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) | ||||
self.is_collective = self._fleet_kwargs.get("is_collective", True) | |||||
self.is_collective = self._fleet_kwargs.pop("is_collective", True) | |||||
if not self.is_collective: | if not self.is_collective: | ||||
raise NotImplementedError("FastNLP only support `collective` for distributed training now.") | raise NotImplementedError("FastNLP only support `collective` for distributed training now.") | ||||
self.role_maker = self._fleet_kwargs.get("role_maker", None) | |||||
self.role_maker = self._fleet_kwargs.pop("role_maker", None) | |||||
if self.local_rank == 0 and not is_in_paddle_dist(): | if self.local_rank == 0 and not is_in_paddle_dist(): | ||||
# 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有 | # 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有 | ||||
@@ -22,12 +22,14 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver | 2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver | ||||
3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver | 3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver | ||||
:param driver: 该参数的值应为以下之一:["paddle", "fleet"]; | |||||
:param driver: 使用的 ``driver`` 类型,在这个函数中仅支持 ``paddle`` | |||||
:param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; | :param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; | ||||
:param model: 训练或者评测的具体的模型; | :param model: 训练或者评测的具体的模型; | ||||
:return: 返回构造的 `Driver` 实例。 | :return: 返回构造的 `Driver` 实例。 | ||||
""" | """ | ||||
if driver != "paddle": | |||||
raise ValueError("When initialize PaddleDriver, parameter `driver` must be 'paddle'.") | |||||
if is_in_paddle_launch_dist(): | if is_in_paddle_launch_dist(): | ||||
if device is not None: | if device is not None: | ||||
logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | ||||
@@ -37,9 +39,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
# TODO 目前一个进程仅对应一个卡,所以暂时传入一个 int | # TODO 目前一个进程仅对应一个卡,所以暂时传入一个 int | ||||
return PaddleFleetDriver(model, device[0], True, **kwargs) | return PaddleFleetDriver(model, device[0], True, **kwargs) | ||||
if driver not in {"paddle", "fleet"}: | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | |||||
user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") | user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") | ||||
if user_visible_devices is None: | if user_visible_devices is None: | ||||
raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set " | raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` cannot be None, please check if you have set " | ||||
@@ -64,22 +63,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
" the available gpu number.") | " the available gpu number.") | ||||
elif device is not None and not isinstance(device, str): | elif device is not None and not isinstance(device, str): | ||||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | ||||
if isinstance(device, List): | |||||
return PaddleFleetDriver(model, device, **kwargs) | |||||
else: | |||||
return PaddleSingleDriver(model, device, **kwargs) | |||||
if driver == "paddle": | |||||
if not isinstance(device, List): | |||||
return PaddleSingleDriver(model, device, **kwargs) | |||||
else: | |||||
logger.info("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" | |||||
"`PaddleFleetDriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" | |||||
"`driver` as `fleet`.") | |||||
return PaddleFleetDriver(model, device, **kwargs) | |||||
elif driver == "fleet": | |||||
if not isinstance(device, List): | |||||
if device == "cpu": | |||||
raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.") | |||||
logger.info("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | |||||
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | |||||
"`driver` as `paddle`.") | |||||
return PaddleFleetDriver(model, [device], **kwargs) | |||||
else: | |||||
return PaddleFleetDriver(model, device, **kwargs) |
@@ -1,4 +1,5 @@ | |||||
import os | import os | ||||
import contextlib | |||||
from typing import Optional, Dict, Union, Callable, Tuple | from typing import Optional, Dict, Union, Callable, Tuple | ||||
from .paddle_driver import PaddleDriver | from .paddle_driver import PaddleDriver | ||||
@@ -70,7 +71,8 @@ class PaddleSingleDriver(PaddleDriver): | |||||
""" | """ | ||||
device = get_device_from_visible(self.model_device, output_type=str) | device = get_device_from_visible(self.model_device, output_type=str) | ||||
paddle.device.set_device(device) | paddle.device.set_device(device) | ||||
self.model.to(device) | |||||
with contextlib.redirect_stdout(None): | |||||
self.model.to(device) | |||||
def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: | ||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | if isinstance(batch, Dict) and not self.wo_auto_param_call: | ||||
@@ -76,7 +76,7 @@ def test_trainer_fleet( | |||||
trainer.run() | trainer.run() | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
driver = "fleet" | |||||
driver = "paddle" | |||||
device = [0,2,3] | device = [0,2,3] | ||||
# driver = "paddle" | # driver = "paddle" | ||||
# device = 2 | # device = 2 | ||||
@@ -83,7 +83,7 @@ def test_trainer_fleet( | |||||
trainer.run() | trainer.run() | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
driver = "fleet" | |||||
driver = "paddle" | |||||
device = [0,2,3] | device = [0,2,3] | ||||
callbacks = [ | callbacks = [ | ||||
# RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | # RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | ||||
@@ -24,13 +24,12 @@ class TrainPaddleConfig: | |||||
shuffle: bool = True | shuffle: bool = True | ||||
evaluate_every = 2 | evaluate_every = 2 | ||||
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])]) | |||||
@pytest.mark.parametrize("device", ["cpu", 1, [0, 1]]) | |||||
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | ||||
@pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) | @pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) | ||||
@pytest.mark.paddledist | @pytest.mark.paddledist | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_paddle( | def test_trainer_paddle( | ||||
driver, | |||||
device, | device, | ||||
callbacks, | callbacks, | ||||
n_epochs=2, | n_epochs=2, | ||||
@@ -56,7 +55,7 @@ def test_trainer_paddle( | |||||
metrics = {"acc": Accuracy(backend="paddle")} | metrics = {"acc": Accuracy(backend="paddle")} | ||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model, | model=model, | ||||
driver=driver, | |||||
driver="paddle", | |||||
device=device, | device=device, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
@@ -21,87 +21,41 @@ def test_incorrect_driver(): | |||||
"device", | "device", | ||||
["cpu", "gpu:0", 0] | ["cpu", "gpu:0", 0] | ||||
) | ) | ||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["paddle"] | |||||
) | |||||
def test_get_single_device(driver, device): | |||||
def test_get_single_device(device): | |||||
""" | """ | ||||
测试正常情况下初始化 PaddleSingleDriver 的情况 | 测试正常情况下初始化 PaddleSingleDriver 的情况 | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(2, 100) | model = PaddleNormalModel_Classification_1(2, 100) | ||||
driver = initialize_paddle_driver(driver, device, model) | |||||
driver = initialize_paddle_driver("paddle", device, model) | |||||
assert isinstance(driver, PaddleSingleDriver) | assert isinstance(driver, PaddleSingleDriver) | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[0, 1, [1]] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["fleet"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_fleet_2(driver, device): | |||||
""" | |||||
测试 fleet 多卡的初始化情况,但传入了单个 gpu | |||||
""" | |||||
model = PaddleNormalModel_Classification_1(64, 10) | |||||
driver = initialize_paddle_driver(driver, device, model) | |||||
assert isinstance(driver, PaddleFleetDriver) | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[[0, 2, 3], -1] | [[0, 2, 3], -1] | ||||
) | ) | ||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["paddle", "fleet"] | |||||
) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_get_fleet(driver, device): | |||||
def test_get_fleet(device): | |||||
""" | """ | ||||
测试 fleet 多卡的初始化情况 | 测试 fleet 多卡的初始化情况 | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(64, 10) | model = PaddleNormalModel_Classification_1(64, 10) | ||||
driver = initialize_paddle_driver(driver, device, model) | |||||
driver = initialize_paddle_driver("paddle", device, model) | |||||
assert isinstance(driver, PaddleFleetDriver) | assert isinstance(driver, PaddleFleetDriver) | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize( | |||||
("driver", "device"), | |||||
[("fleet", "cpu")] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_fleet_cpu(driver, device): | |||||
""" | |||||
测试试图在 cpu 上初始化分布式训练的情况 | |||||
""" | |||||
model = PaddleNormalModel_Classification_1(64, 10) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_paddle_driver(driver, device, model) | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] | [-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] | ||||
) | ) | ||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["paddle", "fleet"] | |||||
) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_device_out_of_range(driver, device): | |||||
def test_device_out_of_range(device): | |||||
""" | """ | ||||
测试传入的device超过范围的情况 | 测试传入的device超过范围的情况 | ||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(2, 100) | model = PaddleNormalModel_Classification_1(2, 100) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
driver = initialize_paddle_driver(driver, device, model) | |||||
driver = initialize_paddle_driver("paddle", device, model) |