|
|
@@ -3,6 +3,7 @@ from typing import Optional, Dict, Union |
|
|
|
|
|
|
|
from .paddle_driver import PaddleDriver |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE |
|
|
|
from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES |
|
|
|
from fastNLP.core.utils import ( |
|
|
|
auto_param_call, |
|
|
|
get_paddle_gpu_str, |
|
|
@@ -92,7 +93,12 @@ class PaddleSingleDriver(PaddleDriver): |
|
|
|
self._test_signature_fn = model.forward |
|
|
|
|
|
|
|
def setup(self): |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(get_paddle_device_id(self.model_device)) |
|
|
|
user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES] |
|
|
|
device_id = get_paddle_device_id(self.model_device) |
|
|
|
if user_visible_devices is not None and user_visible_devices != "": |
|
|
|
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES |
|
|
|
device_id = user_visible_devices.split(",")[device_id] |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) |
|
|
|
paddle.device.set_device("gpu:0") |
|
|
|
self.model.to("gpu:0") |
|
|
|
|
|
|
|