diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 73342748..acbeefec 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -1,12 +1,10 @@ import os -import shutil from typing import List, Union, Optional, Dict, Tuple, Callable from .paddle_driver import PaddleDriver from .fleet_launcher import FleetLauncher from .utils import ( _FleetWrappingModel, - get_device_from_visible, reset_seed, replace_sampler, replace_batch_sampler, @@ -17,7 +15,6 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.utils import ( auto_param_call, check_user_specific_params, - paddle_move_data_to_device, is_in_paddle_dist, rank_zero_rm ) @@ -609,12 +606,6 @@ class PaddleFleetDriver(PaddleDriver): def is_distributed(self): return True - def move_data_to_device(self, batch: 'paddle.Tensor'): - device = self.data_device - # 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误 - device = get_device_from_visible(device) - return paddle_move_data_to_device(batch, device) - @staticmethod def _check_optimizer_legality(optimizers): # paddle 存在设置分布式 optimizers 的函数,返回值为 fleet.meta_optimizers.HybridParallelOptimizer @@ -637,10 +628,8 @@ class PaddleFleetDriver(PaddleDriver): :return: 如果当前不是分布式 driver 直接返回输入的 obj 。如果当前 rank 是接收端(其 global rank 包含在了 dst 中),则返回 接收到的参数;如果是 source 端则返回发射的内容;既不是发送端、又不是接收端,则返回 None 。 """ - device = self.data_device # 因为设置了CUDA_VISIBLE_DEVICES,可能会引起错误 - device = get_device_from_visible(device) - return fastnlp_paddle_broadcast_object(obj, src, device=device, group=group) + return fastnlp_paddle_broadcast_object(obj, src, device=self.data_device, group=group) def all_gather(self, obj, group=None) -> List: """ diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index f65efd3d..48ff9de1 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -10,7 +10,7 @@ import numpy as np from .utils import _build_fp16_env, optimizer_state_to_device, DummyGradScaler from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.drivers.driver import Driver -from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device +from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device, get_device_from_visible from fastNLP.envs import ( FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, @@ -394,7 +394,8 @@ class PaddleDriver(Driver): :return: 将移动到指定机器上的 batch 对象返回; """ - return paddle_move_data_to_device(batch, self.data_device) + device = get_device_from_visible(self.data_device) + return paddle_move_data_to_device(batch, device) @staticmethod def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 52805a97..b2546788 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -2,14 +2,14 @@ import os from typing import Optional, Dict, Union, Callable, Tuple from .paddle_driver import PaddleDriver -from .utils import replace_batch_sampler, replace_sampler, get_device_from_visible +from .utils import replace_batch_sampler, replace_sampler from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES from fastNLP.core.utils import ( auto_param_call, + get_device_from_visible, get_paddle_gpu_str, get_paddle_device_id, - paddle_move_data_to_device, ) from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.core.samplers import ( @@ -65,8 +65,7 @@ class PaddleSingleDriver(PaddleDriver): r""" 该函数用来初始化训练环境,用于设置当前训练的设备,并将模型迁移到对应设备上。 """ - device = self.model_device - device = get_device_from_visible(device, output_type=str) + device = get_device_from_visible(self.model_device, output_type=str) paddle.device.set_device(device) self.model.to(device) @@ -121,16 +120,6 @@ class PaddleSingleDriver(PaddleDriver): else: raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") - def move_data_to_device(self, batch: 'paddle.Tensor'): - r""" - 将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 - 在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 - - :return: 将移动到指定机器上的 batch 对象返回; - """ - device = get_device_from_visible(self.data_device) - return paddle_move_data_to_device(batch, device) - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, reproducible: bool = False): r""" diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 6cd7b252..60d243e7 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -6,12 +6,11 @@ import inspect import numpy as np from copy import deepcopy from contextlib import ExitStack, closing -from enum import IntEnum -from typing import Dict, Optional, Union +from typing import Dict, Optional from fastNLP.envs.imports import _NEED_IMPORT_PADDLE -from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to -from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES +from fastNLP.core.utils import auto_param_call, paddle_to +from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS from fastNLP.core.log import logger @@ -173,40 +172,6 @@ def find_free_ports(num): return None -def get_device_from_visible(device: Union[str, int], output_type=int): - """ - 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 - 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 - - :param device: 未转化的设备名 - :param output_type: 返回值的类型 - :return: 转化后的设备id - """ - if output_type not in [int, str]: - raise ValueError("Parameter `output_type` should be one of these types: [int, str]") - if device == "cpu": - return device - cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") - idx = get_paddle_device_id(device) - if cuda_visible_devices is None or cuda_visible_devices == "": - # 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES - raise RuntimeError("This situation should not happen, please report us this bug.") - else: - # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 - user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) - if user_visible_devices is None: - raise RuntimeError("This situation cannot happen, please report a bug to us.") - idx = user_visible_devices.split(",")[idx] - - cuda_visible_devices_list = cuda_visible_devices.split(',') - if idx not in cuda_visible_devices_list: - raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}].") - res = cuda_visible_devices_list.index(idx) - if output_type == int: - return res - else: - return f"gpu:{res}" - def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): """ 利用 `batch_sampler` 重新构建一个 DataLoader,起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用。 diff --git a/fastNLP/core/metrics/backend/paddle_backend/backend.py b/fastNLP/core/metrics/backend/paddle_backend/backend.py index 243c5aac..aa57bbc2 100644 --- a/fastNLP/core/metrics/backend/paddle_backend/backend.py +++ b/fastNLP/core/metrics/backend/paddle_backend/backend.py @@ -1,11 +1,10 @@ -from typing import List, Optional, Any +from typing import List, Any import numpy as np from fastNLP.core.metrics.backend import Backend -from fastNLP.core.utils.paddle_utils import paddle_to +from fastNLP.core.utils.paddle_utils import paddle_to, get_device_from_visible from fastNLP.core.metrics.utils import AggregateMethodError -from fastNLP.core.drivers.paddle_driver.utils import get_device_from_visible from fastNLP.core.drivers.paddle_driver.dist_utils import fastnlp_paddle_all_gather from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -80,7 +79,6 @@ class PaddleBackend(Backend): raise ValueError(f"tensor: {tensor} can not convert to ndarray!") def move_tensor_to_device(self, tensor, device): - # TODO 如果在这里处理的话,会不会在别的地方引起bug? device = get_device_from_visible(device) return paddle_to(tensor, device) diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 9fb538a9..648f3cae 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -2,6 +2,7 @@ __all__ = [ 'cache_results', 'is_jittor_dataset', 'jittor_collate_wraps', + 'get_device_from_visible', 'paddle_to', 'paddle_move_data_to_device', 'get_paddle_device_id', @@ -29,7 +30,7 @@ __all__ = [ from .cache_results import cache_results from .jittor_utils import is_jittor_dataset, jittor_collate_wraps -from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ +from .paddle_utils import get_device_from_visible, paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ is_in_fnlp_paddle_dist, is_in_paddle_launch_dist from .rich_progress import f_rich_progress from .torch_paddle_utils import torch_paddle_move_data_to_device diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index e4c0a8a9..b9bc26a5 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -1,4 +1,5 @@ __all__ = [ + "get_device_from_visible", "paddle_to", "paddle_move_data_to_device", "get_paddle_gpu_str", @@ -13,13 +14,45 @@ import re from typing import Any, Optional, Union from fastNLP.envs.imports import _NEED_IMPORT_PADDLE -from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_BACKEND_LAUNCH +from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_BACKEND_LAUNCH, USER_CUDA_VISIBLE_DEVICES if _NEED_IMPORT_PADDLE: import paddle from .utils import apply_to_collection +def get_device_from_visible(device: Union[str, int], output_type=int): + """ + 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 + 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 + + :param device: 未转化的设备名 + :param output_type: 返回值的类型 + :return: 转化后的设备id + """ + if output_type not in [int, str]: + raise ValueError("Parameter `output_type` should be one of these types: [int, str]") + if device == "cpu": + return device + cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) + if user_visible_devices is None: + raise RuntimeError("`USER_CUDA_VISIBLE_DEVICES` is None, please check if you have set " + "`FASTNLP_BACKEND` to 'paddle' before 'import fastNLP'.") + idx = get_paddle_device_id(device) + # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 + if user_visible_devices is None: + raise RuntimeError("This situation cannot happen, please report a bug to us.") + idx = user_visible_devices.split(",")[idx] + + cuda_visible_devices_list = cuda_visible_devices.split(',') + if idx not in cuda_visible_devices_list: + raise ValueError(f"Can't find your devices {idx} in CUDA_VISIBLE_DEVICES[{cuda_visible_devices}]. ") + res = cuda_visible_devices_list.index(idx) + if output_type == int: + return res + else: + return f"gpu:{res}" def paddle_to(data, device: Union[str, int]): """ @@ -33,6 +66,7 @@ def paddle_to(data, device: Union[str, int]): if device == "cpu": return data.cpu() else: + # device = get_device_from_visible(device, output_type=int) return data.cuda(get_paddle_device_id(device)) diff --git a/tests/core/controllers/_test_trainer_fleet.py b/tests/core/controllers/_test_trainer_fleet.py index f438b6de..309e6eb4 100644 --- a/tests/core/controllers/_test_trainer_fleet.py +++ b/tests/core/controllers/_test_trainer_fleet.py @@ -1,7 +1,7 @@ """ 这个文件测试用户以python -m paddle.distributed.launch 启动的情况 看看有没有用pytest执行的机会 -python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py +FASTNLP_BACKEND=paddle python -m paddle.distributed.launch --gpus=0,2,3 _test_trainer_fleet.py """ import os import sys diff --git a/tests/core/controllers/_test_trainer_fleet_outside.py b/tests/core/controllers/_test_trainer_fleet_outside.py index e8c9a244..d2bcbc41 100644 --- a/tests/core/controllers/_test_trainer_fleet_outside.py +++ b/tests/core/controllers/_test_trainer_fleet_outside.py @@ -1,7 +1,7 @@ """ 这个文件测试用户以python -m paddle.distributed.launch 启动的情况 并且自己初始化了 fleet -python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet_outside.py +FASTNLP_BACKEND=paddle python -m paddle.distributed.launch --gpus=0,2,3 _test_trainer_fleet_outside.py """ import os import sys @@ -93,5 +93,5 @@ if __name__ == "__main__": driver=driver, device=device, callbacks=callbacks, - n_epochs=30, + n_epochs=5, ) \ No newline at end of file diff --git a/tests/core/utils/test_paddle_utils.py b/tests/core/utils/test_paddle_utils.py index ba9dcf79..d86d215f 100644 --- a/tests/core/utils/test_paddle_utils.py +++ b/tests/core/utils/test_paddle_utils.py @@ -1,10 +1,40 @@ +import os + import pytest -from fastNLP.core.utils.paddle_utils import paddle_to, paddle_move_data_to_device +from fastNLP.core.utils.paddle_utils import get_device_from_visible, paddle_to, paddle_move_data_to_device from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: import paddle - +@pytest.mark.parametrize( + ("user_visible_devices, cuda_visible_devices, device, output_type, correct"), + ( + ("0,1,2,3,4,5,6,7", "0", "cpu", str, "cpu"), + ("0,1,2,3,4,5,6,7", "0", "cpu", int, "cpu"), + ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:4", int, 1), + ("0,1,2,3,4,5,6,7", "3,4,5", "gpu:5", str, "gpu:2"), + ("3,4,5,6", "3,5", 0, int, 0), + ("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"), + ) +) +@pytest.mark.paddle +def test_get_device_from_visible(user_visible_devices, cuda_visible_devices, device, output_type, correct): + _cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + _user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices + res = get_device_from_visible(device, output_type) + assert res == correct + + # 还原环境变量 + if _cuda_visible_devices is None: + del os.environ["CUDA_VISIBLE_DEVICES"] + else: + os.environ["CUDA_VISIBLE_DEVICES"] = _cuda_visible_devices + if _user_visible_devices is None: + del os.environ["USER_CUDA_VISIBLE_DEVICES"] + else: + os.environ["USER_CUDA_VISIBLE_DEVICES"] = _user_visible_devices ############################################################################ # @@ -22,12 +52,6 @@ class TestPaddleToDevice: assert res.place.gpu_device_id() == 0 res = paddle_to(tensor, "cpu") assert res.place.is_cpu_place() - res = paddle_to(tensor, "gpu:2") - assert res.place.is_gpu_place() - assert res.place.gpu_device_id() == 2 - res = paddle_to(tensor, "gpu:1") - assert res.place.is_gpu_place() - assert res.place.gpu_device_id() == 1 ############################################################################ # @@ -64,28 +88,18 @@ class TestPaddleMoveDataToDevice: res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device=None) self.check_gpu(res, 0) - res = paddle_move_data_to_device(paddle_tensor, device="gpu:1", data_device=None) - self.check_gpu(res, 1) - res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device="cpu") self.check_gpu(res, 0) res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:0") self.check_gpu(res, 0) - res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:1") - self.check_gpu(res, 1) - def test_list_transfer(self): """ 测试张量列表的迁移 """ paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] - res = paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") - assert isinstance(res, list) - for r in res: - self.check_gpu(r, 1) res = paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") assert isinstance(res, list) @@ -97,11 +111,6 @@ class TestPaddleMoveDataToDevice: for r in res: self.check_gpu(r, 0) - res = paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") - assert isinstance(res, list) - for r in res: - self.check_gpu(r, 1) - def test_tensor_tuple_transfer(self): """ 测试张量元组的迁移 @@ -109,10 +118,6 @@ class TestPaddleMoveDataToDevice: paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] paddle_tuple = tuple(paddle_list) - res = paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") - assert isinstance(res, tuple) - for r in res: - self.check_gpu(r, 1) res = paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") assert isinstance(res, tuple) @@ -124,11 +129,6 @@ class TestPaddleMoveDataToDevice: for r in res: self.check_gpu(r, 0) - res = paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") - assert isinstance(res, tuple) - for r in res: - self.check_gpu(r, 1) - def test_dict_transfer(self): """ 测试字典结构的迁移 @@ -173,20 +173,6 @@ class TestPaddleMoveDataToDevice: self.check_gpu(t, 0) self.check_gpu(res["dict"]["tensor"], 0) - res = paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") - assert isinstance(res, dict) - self.check_gpu(res["tensor"], 1) - assert isinstance(res["list"], list) - for t in res["list"]: - self.check_gpu(t, 1) - assert isinstance(res["int"], int) - assert isinstance(res["string"], str) - assert isinstance(res["dict"], dict) - assert isinstance(res["dict"]["list"], list) - for t in res["dict"]["list"]: - self.check_gpu(t, 1) - self.check_gpu(res["dict"]["tensor"], 1) - res = paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") assert isinstance(res, dict) self.check_cpu(res["tensor"])