Browse Source

修复但å单卡的设备逻辑

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
49d18f3683
3 changed files with 10 additions and 5 deletions
  1. +0
    -1
      fastNLP/core/drivers/paddle_driver/fleet.py
  2. +7
    -1
      fastNLP/core/drivers/paddle_driver/single_device.py
  3. +3
    -3
      fastNLP/core/drivers/paddle_driver/utils.py

+ 0
- 1
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -241,7 +241,6 @@ class PaddleFleetDriver(PaddleDriver):
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc)
launcher.launch()
# 设置参数和初始化分布式环境
reset_seed()
fleet.init(self.role_maker, self.is_collective, self.strategy)
self.global_rank = int(os.getenv("PADDLE_TRAINER_ID"))
self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM"))


+ 7
- 1
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -3,6 +3,7 @@ from typing import Optional, Dict, Union

from .paddle_driver import PaddleDriver
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.utils import (
auto_param_call,
get_paddle_gpu_str,
@@ -92,7 +93,12 @@ class PaddleSingleDriver(PaddleDriver):
self._test_signature_fn = model.forward

def setup(self):
os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device))
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES]
device_id = get_paddle_device_id(self.model_device)
if user_visible_devices is not None and user_visible_devices != "":
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
device_id = user_visible_devices.split(",")[device_id]
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
paddle.device.set_device("gpu:0")
self.model.to("gpu:0")



+ 3
- 3
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -271,10 +271,10 @@ def get_device_from_visible(device: Union[str, int]):
return idx
else:
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if user_visiblde_devices is not None and user_visiblde_devices != "":
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if user_visible_devices is not None and user_visible_devices != "":
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
idx = user_visiblde_devices.split(",")[idx]
idx = user_visible_devices.split(",")[idx]
else:
idx = str(idx)



Loading…
Cancel
Save