diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index ff80cb9e..77cd62c2 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -241,7 +241,6 @@ class PaddleFleetDriver(PaddleDriver): launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) launcher.launch() # 设置参数和初始化分布式环境 - reset_seed() fleet.init(self.role_maker, self.is_collective, self.strategy) self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 1dad6d97..85e17e07 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -3,6 +3,7 @@ from typing import Optional, Dict, Union from .paddle_driver import PaddleDriver from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES from fastNLP.core.utils import ( auto_param_call, get_paddle_gpu_str, @@ -92,7 +93,12 @@ class PaddleSingleDriver(PaddleDriver): self._test_signature_fn = model.forward def setup(self): - os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device)) + user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES] + device_id = get_paddle_device_id(self.model_device) + if user_visible_devices is not None and user_visible_devices != "": + # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES + device_id = user_visible_devices.split(",")[device_id] + os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) paddle.device.set_device("gpu:0") self.model.to("gpu:0") diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index b99ae581..ebe0f6c5 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -271,10 +271,10 @@ def get_device_from_visible(device: Union[str, int]): return idx else: # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 - user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) - if user_visiblde_devices is not None and user_visiblde_devices != "": + user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) + if user_visible_devices is not None and user_visible_devices != "": # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES - idx = user_visiblde_devices.split(",")[idx] + idx = user_visible_devices.split(",")[idx] else: idx = str(idx)