|
|
@@ -7,6 +7,7 @@ from .single_device import PaddleSingleDriver |
|
|
|
from .fleet import PaddleFleetDriver |
|
|
|
|
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE |
|
|
|
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES |
|
|
|
from fastNLP.core.utils import is_in_paddle_launch_dist, get_paddle_gpu_str |
|
|
|
from fastNLP.core.log import logger |
|
|
|
|
|
|
@@ -30,8 +31,10 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ |
|
|
|
""" |
|
|
|
if driver != "paddle": |
|
|
|
raise ValueError("When initialize PaddleDriver, parameter `driver` must be 'paddle'.") |
|
|
|
user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") |
|
|
|
user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) |
|
|
|
if is_in_paddle_launch_dist(): |
|
|
|
if user_visible_devices is None: |
|
|
|
raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") |
|
|
|
if device is not None: |
|
|
|
logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " |
|
|
|
"up your script. And we will directly get the local device via environment variables.") |
|
|
@@ -65,6 +68,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ |
|
|
|
device = [get_paddle_gpu_str(g) for g in device] |
|
|
|
elif device is not None and not isinstance(device, str): |
|
|
|
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") |
|
|
|
|
|
|
|
if isinstance(device, List): |
|
|
|
return PaddleFleetDriver(model, device, **kwargs) |
|
|
|
else: |
|
|
|