diff --git a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py index db30517f..98655757 100644 --- a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py @@ -38,23 +38,19 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ if driver not in {"paddle", "fleet"}: raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") - cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") - # 优先级 user > cuda - # 判断单机情况 device 的合法性 - # 分布式情况下通过 world_device 判断 - if user_visible_devices != "": - _could_use_device_num = len(user_visible_devices.split(",")) - elif cuda_visible_devices is not None: - _could_use_device_num = len(cuda_visible_devices.split(",")) - else: - _could_use_device_num = paddle.device.cuda.device_count() + if user_visible_devices is None: + raise RuntimeError("This situation cannot happen, please report a bug to us.") + _could_use_device_num = len(user_visible_devices.split(",")) if isinstance(device, int): if device < 0 and device != -1: raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") - # if device >= _could_use_device_num: - # raise ValueError("The gpu device that parameter `device` specifies is not existed.") - device = f"gpu:{device}" + if device >= _could_use_device_num: + raise ValueError("The gpu device that parameter `device` specifies is not existed.") + if device != -1: + device = f"gpu:{device}" + else: + device = list(range(_could_use_device_num)) elif isinstance(device, Sequence) and not isinstance(device, str): device = list(set(device)) for each in device: @@ -62,6 +58,9 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") elif each < 0: raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") + elif each >= _could_use_device_num: + raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" + " the available gpu number.") if len(device) == 1: # 传入了 [1] 这样的,视为单卡。 device = device[0] diff --git a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py index 30d5ef3c..54ef22b6 100644 --- a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py +++ b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py @@ -1,83 +1,103 @@ +import os import pytest -from fastNLP.envs.set_backend import set_env -from fastNLP.envs.set_env_on_import import set_env_on_import_paddle - -set_env_on_import_paddle() -set_env("paddle") -import paddle +os.environ["FASTNLP_BACKEND"] = "paddle" +from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver -from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver -from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver -from tests.helpers.models.paddle_model import PaddleNormalModel_Classification +from fastNLP.envs import get_gpu_count +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from tests.helpers.utils import magic_argv_env_context +import paddle def test_incorrect_driver(): + model = PaddleNormalModel_Classification_1(2, 100) with pytest.raises(ValueError): - driver = initialize_paddle_driver("torch") + driver = initialize_paddle_driver("torch", 0, model) @pytest.mark.parametrize( "device", - ["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] + ["cpu", "gpu:0", 0, [1]] ) -def test_get_single_device(device): +@pytest.mark.parametrize( + "driver", + ["paddle"] +) +def test_get_single_device(driver, device): """ 测试正常情况下初始化PaddleSingleDriver的情况 """ - model = PaddleNormalModel_Classification(2, 100) - driver = initialize_paddle_driver("paddle", device, model) - + model = PaddleNormalModel_Classification_1(2, 100) + driver = initialize_paddle_driver(driver, device, model) assert isinstance(driver, PaddleSingleDriver) @pytest.mark.parametrize( "device", - ["cpu", "gpu:0", [1, 2, 3], 0, "gpu:1"] + [0, 1] ) -def test_get_single_device_with_visiblde_devices(device): +@pytest.mark.parametrize( + "driver", + ["fleet"] +) +@magic_argv_env_context +def test_get_fleet_2(driver, device): """ - 测试 CUDA_VISIBLE_DEVICES 启动时初始化PaddleSingleDriver的情况 + 测试 fleet 多卡的初始化情况 """ - # TODO - model = PaddleNormalModel_Classification(2, 100) - driver = initialize_paddle_driver("paddle", device, model) + model = PaddleNormalModel_Classification_1(64, 10) + driver = initialize_paddle_driver(driver, device, model) - assert isinstance(driver, PaddleSingleDriver) + assert isinstance(driver, PaddleFleetDriver) @pytest.mark.parametrize( "device", - [[1, 2, 3]] + [[0, 2, 3], -1] +) +@pytest.mark.parametrize( + "driver", + ["paddle", "fleet"] ) -def test_get_fleet(device): +@magic_argv_env_context +def test_get_fleet(driver, device): """ 测试 fleet 多卡的初始化情况 """ - model = PaddleNormalModel_Classification(2, 100) - driver = initialize_paddle_driver("paddle", device, model) + model = PaddleNormalModel_Classification_1(64, 10) + driver = initialize_paddle_driver(driver, device, model) assert isinstance(driver, PaddleFleetDriver) @pytest.mark.parametrize( - "device", - [[1,2,3]] + ("driver", "device"), + [("fleet", "cpu")] ) -def test_get_fleet(device): +@magic_argv_env_context +def test_get_fleet_cpu(driver, device): """ - 测试 launch 启动 fleet 多卡的初始化情况 + 测试试图在 cpu 上初始化分布式训练的情况 """ - # TODO - - model = PaddleNormalModel_Classification(2, 100) - driver = initialize_paddle_driver("paddle", device, model) - - assert isinstance(driver, PaddleFleetDriver) + model = PaddleNormalModel_Classification_1(64, 10) + with pytest.raises(ValueError): + driver = initialize_paddle_driver(driver, device, model) -def test_device_out_of_range(device): +@pytest.mark.parametrize( + "device", + [-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] +) +@pytest.mark.parametrize( + "driver", + ["paddle", "fleet"] +) +@magic_argv_env_context +def test_device_out_of_range(driver, device): """ 测试传入的device超过范围的情况 """ - pass \ No newline at end of file + model = PaddleNormalModel_Classification_1(2, 100) + with pytest.raises(ValueError): + driver = initialize_paddle_driver(driver, device, model) \ No newline at end of file