@@ -42,7 +42,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
# 优先级 user > cuda | # 优先级 user > cuda | ||||
# 判断单机情况 device 的合法性 | # 判断单机情况 device 的合法性 | ||||
# 分布式情况下通过 world_device 判断 | # 分布式情况下通过 world_device 判断 | ||||
if user_visible_devices is not None: | |||||
if user_visible_devices != "": | |||||
_could_use_device_num = len(user_visible_devices.split(",")) | _could_use_device_num = len(user_visible_devices.split(",")) | ||||
elif cuda_visible_devices is not None: | elif cuda_visible_devices is not None: | ||||
_could_use_device_num = len(cuda_visible_devices.split(",")) | _could_use_device_num = len(cuda_visible_devices.split(",")) | ||||
@@ -51,8 +51,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
if isinstance(device, int): | if isinstance(device, int): | ||||
if device < 0 and device != -1: | if device < 0 and device != -1: | ||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
if device >= _could_use_device_num: | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
# if device >= _could_use_device_num: | |||||
# raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||||
device = f"gpu:{device}" | device = f"gpu:{device}" | ||||
elif isinstance(device, Sequence) and not isinstance(device, str): | elif isinstance(device, Sequence) and not isinstance(device, str): | ||||
device = list(set(device)) | device = list(set(device)) | ||||
@@ -1,8 +1,14 @@ | |||||
import os | |||||
from typing import Optional, Dict, Union | from typing import Optional, Dict, Union | ||||
from .paddle_driver import PaddleDriver | from .paddle_driver import PaddleDriver | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
from fastNLP.core.utils import auto_param_call, get_paddle_gpu_str | |||||
from fastNLP.core.utils import ( | |||||
auto_param_call, | |||||
get_paddle_gpu_str, | |||||
get_paddle_device_id, | |||||
paddle_move_data_to_device, | |||||
) | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -86,8 +92,9 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self._test_signature_fn = model.forward | self._test_signature_fn = model.forward | ||||
def setup(self): | def setup(self): | ||||
paddle.device.set_device(self.model_device) | |||||
self.model.to(self.model_device) | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device)) | |||||
paddle.device.set_device("gpu:0") | |||||
self.model.to("gpu:0") | |||||
def train_step(self, batch) -> Dict: | def train_step(self, batch) -> Dict: | ||||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | ||||
@@ -116,6 +123,16 @@ class PaddleSingleDriver(PaddleDriver): | |||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def move_data_to_device(self, batch: 'paddle.Tensor'): | |||||
r""" | |||||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||||
在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0` | |||||
:return: 将移动到指定机器上的 batch 对象返回; | |||||
""" | |||||
return paddle_move_data_to_device(batch, "gpu:0") | |||||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | ||||
# 暂时不支持IteratorDataset | # 暂时不支持IteratorDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
@@ -272,7 +272,7 @@ def get_device_from_visible(device: Union[str, int]): | |||||
else: | else: | ||||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | ||||
user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | ||||
if user_visiblde_devices is None or user_visiblde_devices != "": | |||||
if user_visiblde_devices is not None and user_visiblde_devices != "": | |||||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | ||||
idx = user_visiblde_devices.split(",")[idx] | idx = user_visiblde_devices.split(",")[idx] | ||||
else: | else: | ||||
@@ -122,7 +122,6 @@ class PaddleBackend(Backend): | |||||
def move_tensor_to_device(self, tensor, device): | def move_tensor_to_device(self, tensor, device): | ||||
# TODO 如果在这里处理的话,会不会在别的地方引起bug? | # TODO 如果在这里处理的话,会不会在别的地方引起bug? | ||||
if is_in_paddle_dist(): | |||||
device = get_device_from_visible(device) | |||||
device = get_device_from_visible(device) | |||||
return paddle_to(tensor, device) | return paddle_to(tensor, device) | ||||
@@ -4,17 +4,18 @@ __all__ = [ | |||||
from typing import Any | from typing import Any | ||||
from functools import wraps | from functools import wraps | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||||
from fastNLP.envs.utils import _module_available | from fastNLP.envs.utils import _module_available | ||||
_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') | _IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') | ||||
if _IS_TORCHMETRICS_AVAILABLE: | |||||
from torchmetrics import Metric as torchmetrics_Metric | |||||
_IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | _IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | ||||
if _IS_ALLENNLP_AVAILABLE: | if _IS_ALLENNLP_AVAILABLE: | ||||
from allennlp.training.metrics import Metric as allennlp_Metric | from allennlp.training.metrics import Metric as allennlp_Metric | ||||
if _NEED_IMPORT_TORCH and _IS_TORCHMETRICS_AVAILABLE: | |||||
if _IS_TORCHMETRICS_AVAILABLE: | |||||
from torchmetrics import Metric as torchmetrics_Metric | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
from paddle.metric import Metric as paddle_Metric | from paddle.metric import Metric as paddle_Metric | ||||
@@ -9,6 +9,7 @@ __all__ = [ | |||||
] | ] | ||||
import os | import os | ||||
import re | |||||
from typing import Any, Optional, Union | from typing import Any, Optional, Union | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
@@ -42,10 +43,19 @@ def get_paddle_device_id(device: Union[str, int]): | |||||
if isinstance(device, int): | if isinstance(device, int): | ||||
return device | return device | ||||
device = device.lower() | |||||
if device == "cpu": | if device == "cpu": | ||||
raise ValueError("Cannot get device id from `cpu`.") | raise ValueError("Cannot get device id from `cpu`.") | ||||
return paddle.device._convert_to_place(device).get_device_id() | |||||
match_res = re.match(r"gpu:\d+", device) | |||||
if not match_res: | |||||
raise ValueError( | |||||
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x'" | |||||
) | |||||
device_id = device.split(':', 1)[1] | |||||
device_id = int(device_id) | |||||
return device_id | |||||
def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | ||||
data_device: Optional[str] = None) -> Any: | data_device: Optional[str] = None) -> Any: | ||||
@@ -52,21 +52,33 @@ def _set_backend(): | |||||
if backend == 'paddle': | if backend == 'paddle': | ||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | ||||
assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`." | assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`." | ||||
if 'CUDA_VISIBLE_DEVICES' not in os.environ and 'PADDLE_RANK_IN_NODE' not in os.environ \ | |||||
and 'FLAGS_selected_gpus' not in os.environ: | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = '' | |||||
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||||
if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | |||||
# 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备 | |||||
selected_gpus = os.environ['FLAGS_selected_gpus'].split(',') | |||||
if user_visible_devices is not None and user_visible_devices != "": | |||||
# 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练 | |||||
# 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中 | |||||
# 我们需要从中找到真正使用的设备编号 | |||||
user_visible_devices = user_visible_devices.split(",") | |||||
selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus]) | |||||
else: | |||||
# 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = "" | |||||
# TODO 这里的 [0] 可能在单个节点多卡的时候有问题 | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0] | |||||
os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) | |||||
os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) | |||||
elif 'CUDA_VISIBLE_DEVICES' in os.environ: | elif 'CUDA_VISIBLE_DEVICES' in os.environ: | ||||
# 主进程中,用户设置了 CUDA_VISIBLE_DEVICES | |||||
# 将用户设置的 CUDA_VISIBLE_DEVICES hack 掉 | |||||
CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'] | CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'] | ||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | ||||
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | ||||
elif 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: | |||||
# TODO 这里由于fastNLP需要hack CUDA_VISIBLE_DEVICES,因此需要相应滴修改FLAGS等paddle变量 @xsh | |||||
CUDA_VISIBLE_DEVICES = os.environ['FLAGS_selected_gpus'] | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] | |||||
os.environ['FLAGS_selected_gpus'] = "0" | |||||
os.environ['FLAGS_selected_accelerators'] = "0" | |||||
else: | |||||
# 没有设置的话限制在单卡上,防止多进程时占用别的卡 | |||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||||
os.environ[USER_CUDA_VISIBLE_DEVICES] = '' | |||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | assert _module_available(backend), f"You must have {backend} available to use {backend} backend." | ||||
@@ -36,8 +36,14 @@ def set_env_on_import_torch(): | |||||
# TODO paddle may need set this | # TODO paddle may need set this | ||||
def set_env_on_import_paddle(): | def set_env_on_import_paddle(): | ||||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH | |||||
pass | |||||
# todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS | |||||
if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ | |||||
and "PADDLE_RANK_IN_NODE" in os.environ: | |||||
# 检测到了分布式环境的环境变量 | |||||
os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"] | |||||
# 如果不是由 fastnlp 启动的 | |||||
if FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||||
os.environ[FASTNLP_BACKEND_LAUNCH] = "1" | |||||
# TODO jittor may need set this | # TODO jittor may need set this | ||||
def set_env_on_import_jittor(): | def set_env_on_import_jittor(): | ||||