|
|
@@ -8,7 +8,7 @@ import sys |
|
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED |
|
|
|
from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED |
|
|
|
from fastNLP.envs.imports import SUPPORT_BACKENDS |
|
|
|
from fastNLP.envs.utils import _module_available |
|
|
|
|
|
|
@@ -65,8 +65,7 @@ def _set_backend(): |
|
|
|
else: |
|
|
|
# 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 |
|
|
|
os.environ[USER_CUDA_VISIBLE_DEVICES] = "" |
|
|
|
# TODO 这里的 [0] 可能在单个节点多卡的时候有问题 |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0] |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(selected_gpus) |
|
|
|
os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) |
|
|
|
os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) |
|
|
|
elif 'CUDA_VISIBLE_DEVICES' in os.environ: |
|
|
|