From f144bc31c3f417695a56e17aa5bb46df3f71309b Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 9 Apr 2022 13:42:30 +0800 Subject: [PATCH] fleet test --- .../paddle_driver/initialize_paddle_driver.py | 6 ++-- .../drivers/paddle_driver/single_device.py | 23 +++++++++++-- fastNLP/core/drivers/paddle_driver/utils.py | 2 +- .../metrics/backend/paddle_backend/backend.py | 3 +- fastNLP/core/metrics/utils.py | 9 ++--- fastNLP/core/utils/paddle_utils.py | 12 ++++++- fastNLP/envs/set_backend.py | 34 +++++++++++++------ fastNLP/envs/set_env_on_import.py | 10 ++++-- 8 files changed, 72 insertions(+), 27 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py index 0e76ceae..e362017e 100644 --- a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py @@ -42,7 +42,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ # 优先级 user > cuda # 判断单机情况 device 的合法性 # 分布式情况下通过 world_device 判断 - if user_visible_devices is not None: + if user_visible_devices != "": _could_use_device_num = len(user_visible_devices.split(",")) elif cuda_visible_devices is not None: _could_use_device_num = len(cuda_visible_devices.split(",")) @@ -51,8 +51,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ if isinstance(device, int): if device < 0 and device != -1: raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") - if device >= _could_use_device_num: - raise ValueError("The gpu device that parameter `device` specifies is not existed.") + # if device >= _could_use_device_num: + # raise ValueError("The gpu device that parameter `device` specifies is not existed.") device = f"gpu:{device}" elif isinstance(device, Sequence) and not isinstance(device, str): device = list(set(device)) diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 849bf4d1..1dad6d97 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -1,8 +1,14 @@ +import os from typing import Optional, Dict, Union from .paddle_driver import PaddleDriver from fastNLP.envs.imports import _NEED_IMPORT_PADDLE -from fastNLP.core.utils import auto_param_call, get_paddle_gpu_str +from fastNLP.core.utils import ( + auto_param_call, + get_paddle_gpu_str, + get_paddle_device_id, + paddle_move_data_to_device, +) from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator from fastNLP.core.log import logger @@ -86,8 +92,9 @@ class PaddleSingleDriver(PaddleDriver): self._test_signature_fn = model.forward def setup(self): - paddle.device.set_device(self.model_device) - self.model.to(self.model_device) + os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device)) + paddle.device.set_device("gpu:0") + self.model.to("gpu:0") def train_step(self, batch) -> Dict: # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; @@ -116,6 +123,16 @@ class PaddleSingleDriver(PaddleDriver): else: return self._test_step(batch) + def move_data_to_device(self, batch: 'paddle.Tensor'): + r""" + 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 + 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 + 在单卡时,由于 CUDA_VISIBLE_DEVICES 始终被限制在一个设备上,因此实际上只会迁移到 `gpu:0` + + :return: 将移动到指定机器上的 batch 对象返回; + """ + return paddle_move_data_to_device(batch, "gpu:0") + def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): # 暂时不支持IteratorDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 9b54a30a..b99ae581 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -272,7 +272,7 @@ def get_device_from_visible(device: Union[str, int]): else: # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) - if user_visiblde_devices is None or user_visiblde_devices != "": + if user_visiblde_devices is not None and user_visiblde_devices != "": # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES idx = user_visiblde_devices.split(",")[idx] else: diff --git a/fastNLP/core/metrics/backend/paddle_backend/backend.py b/fastNLP/core/metrics/backend/paddle_backend/backend.py index cf7feb79..12216d4b 100644 --- a/fastNLP/core/metrics/backend/paddle_backend/backend.py +++ b/fastNLP/core/metrics/backend/paddle_backend/backend.py @@ -122,7 +122,6 @@ class PaddleBackend(Backend): def move_tensor_to_device(self, tensor, device): # TODO 如果在这里处理的话,会不会在别的地方引起bug? - if is_in_paddle_dist(): - device = get_device_from_visible(device) + device = get_device_from_visible(device) return paddle_to(tensor, device) diff --git a/fastNLP/core/metrics/utils.py b/fastNLP/core/metrics/utils.py index 1363282a..beafd6f4 100644 --- a/fastNLP/core/metrics/utils.py +++ b/fastNLP/core/metrics/utils.py @@ -4,17 +4,18 @@ __all__ = [ from typing import Any from functools import wraps -from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH from fastNLP.envs.utils import _module_available _IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') -if _IS_TORCHMETRICS_AVAILABLE: - from torchmetrics import Metric as torchmetrics_Metric - _IS_ALLENNLP_AVAILABLE = _module_available('allennlp') if _IS_ALLENNLP_AVAILABLE: from allennlp.training.metrics import Metric as allennlp_Metric +if _NEED_IMPORT_TORCH and _IS_TORCHMETRICS_AVAILABLE: + if _IS_TORCHMETRICS_AVAILABLE: + from torchmetrics import Metric as torchmetrics_Metric + if _NEED_IMPORT_PADDLE: from paddle.metric import Metric as paddle_Metric diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index 8af6efc9..2e1bfeda 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -9,6 +9,7 @@ __all__ = [ ] import os +import re from typing import Any, Optional, Union from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -42,10 +43,19 @@ def get_paddle_device_id(device: Union[str, int]): if isinstance(device, int): return device + device = device.lower() if device == "cpu": raise ValueError("Cannot get device id from `cpu`.") - return paddle.device._convert_to_place(device).get_device_id() + match_res = re.match(r"gpu:\d+", device) + if not match_res: + raise ValueError( + "The device must be a string which is like 'cpu', 'gpu', 'gpu:x'" + ) + device_id = device.split(':', 1)[1] + device_id = int(device_id) + + return device_id def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, data_device: Optional[str] = None) -> Any: diff --git a/fastNLP/envs/set_backend.py b/fastNLP/envs/set_backend.py index a1ac5efb..68a28335 100644 --- a/fastNLP/envs/set_backend.py +++ b/fastNLP/envs/set_backend.py @@ -52,21 +52,33 @@ def _set_backend(): if backend == 'paddle': assert _module_available(backend), f"You must have {backend} available to use {backend} backend." assert 'paddle' not in sys.modules, "You have to use `set_backend()` before `import paddle`." - if 'CUDA_VISIBLE_DEVICES' not in os.environ and 'PADDLE_RANK_IN_NODE' not in os.environ \ - and 'FLAGS_selected_gpus' not in os.environ: - os.environ['CUDA_VISIBLE_DEVICES'] = '0' - os.environ[USER_CUDA_VISIBLE_DEVICES] = '' + user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) + if 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: + # 在分布式子进程下,根据 USER_VISIBLE_DEVICES 得到进程真正占有的设备 + selected_gpus = os.environ['FLAGS_selected_gpus'].split(',') + if user_visible_devices is not None and user_visible_devices != "": + # 用户通过 CUDA_VISIBLE_DEVICES 启动了分布式训练 + # 此时经过 set_backend,用户的设置会保存在 USER_CUDA_VISIBLE_DEVICES 中 + # 我们需要从中找到真正使用的设备编号 + user_visible_devices = user_visible_devices.split(",") + selected_gpus = ",".join([user_visible_devices[int(i)] for i in selected_gpus]) + else: + # 设置 USER_CUDA_VISIBLE_DEVICES 表明用户视角中所有设备可见 + os.environ[USER_CUDA_VISIBLE_DEVICES] = "" + # TODO 这里的 [0] 可能在单个节点多卡的时候有问题 + os.environ['CUDA_VISIBLE_DEVICES'] = selected_gpus[0] + os.environ['FLAGS_selected_gpus'] = ",".join([str(g) for g in range(len(selected_gpus))]) + os.environ['FLAGS_selected_accelerators'] = ",".join([str(g) for g in range(len(selected_gpus))]) elif 'CUDA_VISIBLE_DEVICES' in os.environ: + # 主进程中,用户设置了 CUDA_VISIBLE_DEVICES + # 将用户设置的 CUDA_VISIBLE_DEVICES hack 掉 CUDA_VISIBLE_DEVICES = os.environ['CUDA_VISIBLE_DEVICES'] os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] - elif 'PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ: - # TODO 这里由于fastNLP需要hack CUDA_VISIBLE_DEVICES,因此需要相应滴修改FLAGS等paddle变量 @xsh - CUDA_VISIBLE_DEVICES = os.environ['FLAGS_selected_gpus'] - os.environ[USER_CUDA_VISIBLE_DEVICES] = CUDA_VISIBLE_DEVICES - os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES.split(',')[0] - os.environ['FLAGS_selected_gpus'] = "0" - os.environ['FLAGS_selected_accelerators'] = "0" + else: + # 没有设置的话限制在单卡上,防止多进程时占用别的卡 + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + os.environ[USER_CUDA_VISIBLE_DEVICES] = '' elif backend == 'jittor': assert _module_available(backend), f"You must have {backend} available to use {backend} backend." diff --git a/fastNLP/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py index c6828d1c..db978bae 100644 --- a/fastNLP/envs/set_env_on_import.py +++ b/fastNLP/envs/set_env_on_import.py @@ -36,8 +36,14 @@ def set_env_on_import_torch(): # TODO paddle may need set this def set_env_on_import_paddle(): - # todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH - pass + # todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_LAUNCH_PROCESS + if "PADDLE_TRANERS_NUM" in os.environ and "PADDLE_TRAINER_ID" in os.environ \ + and "PADDLE_RANK_IN_NODE" in os.environ: + # 检测到了分布式环境的环境变量 + os.environ[FASTNLP_GLOBAL_RANK] = os.environ["PADDLE_TRAINER_ID"] + # 如果不是由 fastnlp 启动的 + if FASTNLP_DISTRIBUTED_CHECK not in os.environ: + os.environ[FASTNLP_BACKEND_LAUNCH] = "1" # TODO jittor may need set this def set_env_on_import_jittor():