Browse Source

fleet test

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
f144bc31c3
8 changed files with 72 additions and 27 deletions
  1. +3
    -3
      fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
  2. +20
    -3
      fastNLP/core/drivers/paddle_driver/single_device.py
  3. +1
    -1
      fastNLP/core/drivers/paddle_driver/utils.py
  4. +1
    -2
      fastNLP/core/metrics/backend/paddle_backend/backend.py
  5. +5
    -4
      fastNLP/core/metrics/utils.py
  6. +11
    -1
      fastNLP/core/utils/paddle_utils.py
  7. +23
    -11
      fastNLP/envs/set_backend.py
  8. +8
    -2
      fastNLP/envs/set_env_on_import.py

+ 3
- 3
fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py View File

@@ -42,7 +42,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
# 优先级 user > cuda
# 判断单机情况 device 的合法性
# 分布式情况下通过 world_device 判断
if user_visible_devices is not None:
if user_visible_devices != "":
_could_use_device_num = len(user_visible_devices.split(","))
elif cuda_visible_devices is not None:
_could_use_device_num = len(cuda_visible_devices.split(","))
@@ -51,8 +51,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
if isinstance(device, int):
if device < 0 and device != -1:
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
if device >= _could_use_device_num:
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
# if device >= _could_use_device_num:
# raise ValueError("The gpu device that parameter `device` specifies is not existed.")
device = f"gpu:{device}"
elif isinstance(device, Sequence) and not isinstance(device, str):
device = list(set(device))


+ 20
- 3
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -1,8 +1,14 @@
import os
from typing import Optional, Dict, Union

from .paddle_driver import PaddleDriver
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.core.utils import auto_param_call, get_paddle_gpu_str
from fastNLP.core.utils import (
auto_param_call,
get_paddle_gpu_str,
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
from fastNLP.core.log import logger

@@ -86,8 +92,9 @@ class PaddleSingleDriver(PaddleDriver):
self._test_signature_fn = model.forward

def setup(self):
paddle.device.set_device(self.model_device)
self.model.to(self.model_device)
os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device))
paddle.device.set_device("gpu:0")
self.model.to("gpu:0")

def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
@@ -116,6 +123,16 @@ class PaddleSingleDriver(PaddleDriver):
else:
return self._test_step(batch)

def move_data_to_device(self, batch: 'paddle.Tensor'):
r"""
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。
在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0`

:return: 将移动到指定机器上的 batch 对象返回;
"""
return paddle_move_data_to_device(batch, "gpu:0")

def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False):
# 暂时不支持IteratorDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \


+ 1
- 1
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -272,7 +272,7 @@ def get_device_from_visible(device: Union[str, int]):
else:
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备
user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if user_visiblde_devices is None or user_visiblde_devices != "":
if user_visiblde_devices is not None and user_visiblde_devices != "":
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES
idx = user_visiblde_devices.split(",")[idx]
else:


+ 1
- 2
fastNLP/core/metrics/backend/paddle_backend/backend.py View File

@@ -122,7 +122,6 @@ class PaddleBackend(Backend):

def move_tensor_to_device(self, tensor, device):
# TODO 如果在这里处理的话,会不会在别的地方引起bug?
if is_in_paddle_dist():
device = get_device_from_visible(device)
device = get_device_from_visible(device)
return paddle_to(tensor, device)


+ 5
- 4
fastNLP/core/metrics/utils.py View File

@@ -4,17 +4,18 @@ __all__ = [

from typing import Any
from functools import wraps
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
from fastNLP.envs.utils import _module_available

_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics')
if _IS_TORCHMETRICS_AVAILABLE:
from torchmetrics import Metric as torchmetrics_Metric

_IS_ALLENNLP_AVAILABLE = _module_available('allennlp')
if _IS_ALLENNLP_AVAILABLE:
from allennlp.training.metrics import Metric as allennlp_Metric

if _NEED_IMPORT_TORCH and _IS_TORCHMETRICS_AVAILABLE:
if _IS_TORCHMETRICS_AVAILABLE:
from torchmetrics import Metric as torchmetrics_Metric

if _NEED_IMPORT_PADDLE:
from paddle.metric import Metric as paddle_Metric



+ 11
- 1
fastNLP/core/utils/paddle_utils.py View File

@@ -9,6 +9,7 @@ __all__ = [
]

import os
import re
from typing import Any, Optional, Union

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
@@ -42,10 +43,19 @@ def get_paddle_device_id(device: Union[str, int]):
if isinstance(device, int):
return device

device = device.lower()
if device == "cpu":
raise ValueError("Cannot get device id from `cpu`.")

return paddle.device._convert_to_place(device).get_device_id()
match_res = re.match(r"gpu:\d+", device)
if not match_res:
raise ValueError(
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x'"
)
device_id = device.split(':', 1)[1]
device_id = int(device_id)

return device_id

def paddle_move_data_to_device(batch: Any, device: Optional[str] = None,
data_device: Optional[str] = None) -> Any:


+ 23
- 11
fastNLP/envs/set_backend.py View File

@@ -52,21 +52,33 @@ def _set_backend():
if backend == 'paddle':
assert _module_available(backend), f"You must have {backend} available to use {backend} backend."
assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`."
if 'CUDA_VISIBLE_DEVICES' not in os.environ and 'PADDLE_RANK_IN_NODE' not in os.environ \
and 'FLAGS_selected_gpus' not in os.environ:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ[USER_CUDA_VISIBLE_DEVICES] = ''
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES)
if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ:
# 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备
selected_gpus = os.environ['FLAGS_selected_gpus'].split(',')
if user_visible_devices is not None and user_visible_devices != "":
# 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练
# 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中
# 我们需要从中找到真正使用的设备编号
user_visible_devices = user_visible_devices.split(",")
selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus])
else:
# 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见
os.environ[USER_CUDA_VISIBLE_DEVICES] = ""
# TODO 这里的 [0] 可能在单个节点多卡的时候有问题
os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0]
os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))])
os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))])
elif 'CUDA_VISIBLE_DEVICES' in os.environ:
# 主进程中,用户设置了 CUDA_VISIBLE_DEVICES
# 将用户设置的 CUDA_VISIBLE_DEVICES hack 掉
CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES']
os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0]
elif 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ:
# TODO 这里由于fastNLP需要hack CUDA_VISIBLE_DEVICES,因此需要相应滴修改FLAGS等paddle变量 @xsh
CUDA_VISIBLE_DEVICES = os.environ['FLAGS_selected_gpus']
os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0]
os.environ['FLAGS_selected_gpus'] = "0"
os.environ['FLAGS_selected_accelerators'] = "0"
else:
# 没有设置的话限制在单卡上,防止多进程时占用别的卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ[USER_CUDA_VISIBLE_DEVICES] = ''

elif backend == 'jittor':
assert _module_available(backend), f"You must have {backend} available to use {backend} backend."


+ 8
- 2
fastNLP/envs/set_env_on_import.py View File

@@ -36,8 +36,14 @@ def set_env_on_import_torch():

# TODO paddle may need set this
def set_env_on_import_paddle():
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH
pass
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS
if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \
and "PADDLE_RANK_IN_NODE" in os.environ:
# 检测到了分布式环境的环境变量
os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"]
# 如果不是由 fastnlp 启动的
if FASTNLP_DISTRIBUTED_CHECK not in os.environ:
os.environ[FASTNLP_BACKEND_LAUNCH] = "1"

# TODO jittor may need set this
def set_env_on_import_jittor():


Loading…
Cancel
Save