|
|
@@ -13,10 +13,11 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 |
|
|
|
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset |
|
|
|
from tests.helpers.utils import magic_argv_env_context |
|
|
|
from fastNLP.core import rank_zero_rm |
|
|
|
|
|
|
|
import paddle |
|
|
|
import paddle.distributed as dist |
|
|
|
from paddle.io import DataLoader, BatchSampler |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE |
|
|
|
if _NEED_IMPORT_PADDLE: |
|
|
|
import paddle |
|
|
|
import paddle.distributed as dist |
|
|
|
from paddle.io import DataLoader, BatchSampler |
|
|
|
|
|
|
|
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): |
|
|
|
paddle_model = PaddleNormalModel_Classification_1(num_labels, feature_dimension) |
|
|
|