|
|
@@ -1,3 +1,5 @@ |
|
|
|
import os |
|
|
|
|
|
|
|
import pytest |
|
|
|
|
|
|
|
from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver |
|
|
@@ -40,9 +42,14 @@ def test_get_fleet(device): |
|
|
|
""" |
|
|
|
测试 fleet 多卡的初始化情况 |
|
|
|
""" |
|
|
|
|
|
|
|
flag = False |
|
|
|
if "USER_CUDA_VISIBLE_DEVICES" not in os.environ: |
|
|
|
os.environ["USER_CUDA_VISIBLE_DEVICES"] = "0,1,2,3" |
|
|
|
flag = True |
|
|
|
model = PaddleNormalModel_Classification_1(20, 10) |
|
|
|
driver = initialize_paddle_driver("paddle", device, model) |
|
|
|
if flag: |
|
|
|
del os.environ["USER_CUDA_VISIBLE_DEVICES"] |
|
|
|
|
|
|
|
assert isinstance(driver, PaddleFleetDriver) |
|
|
|
|
|
|
|