From 5ea4f75ff873a7c845c649f5d9046aba4bcd81eb Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 10 Apr 2022 06:54:31 +0000 Subject: [PATCH] =?UTF-8?q?paddle=20=E7=8E=AF=E5=A2=83=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/envs/set_backend.py | 5 ++--- fastNLP/envs/set_env_on_import.py | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/fastNLP/envs/set_backend.py b/fastNLP/envs/set_backend.py index 68a28335..18cc970e 100644 --- a/fastNLP/envs/set_backend.py +++ b/fastNLP/envs/set_backend.py @@ -8,7 +8,7 @@ import sys from collections import defaultdict -from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED +from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED from fastNLP.envs.imports import SUPPORT_BACKENDS from fastNLP.envs.utils import _module_available @@ -65,8 +65,7 @@ def _set_backend(): else: # 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 os.environ[USER_CUDA_VISIBLE_DEVICES] = "" - # TODO 这里的 [0] 可能在单个节点多卡的时候有问题 - os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0] + os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(selected_gpus) os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) elif 'CUDA_VISIBLE_DEVICES' in os.environ: diff --git a/fastNLP/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py index db978bae..1ca49289 100644 --- a/fastNLP/envs/set_env_on_import.py +++ b/fastNLP/envs/set_env_on_import.py @@ -36,8 +36,7 @@ def set_env_on_import_torch(): # TODO paddle may need set this def set_env_on_import_paddle(): - # todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS - if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ + if "PADDLE_TRAINERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ and "PADDLE_RANK_IN_NODE" in os.environ: # 检测到了分布式环境的环境变量 os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"]