From 21699749033dc65fe2f6fd400c54cd14d7e3567b Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 07:20:54 +0000 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=B5=8B=E8=AF=95=E4=BE=8B?= =?UTF-8?q?=E7=9A=84import?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/dataloaders/paddle_dataloader/test_fdl.py | 2 ++ tests/helpers/datasets/paddle_data.py | 2 ++ tests/helpers/models/paddle_model.py | 7 +++++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index 484b0daa..abed1e83 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -7,6 +7,8 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: from paddle.io import Dataset, DataLoader import paddle +else: + from fastNLP.core.utils.dummy_class import DummyClass as Dataset class RandomDataset(Dataset): diff --git a/tests/helpers/datasets/paddle_data.py b/tests/helpers/datasets/paddle_data.py index 0fa8ee83..8a8d39b1 100644 --- a/tests/helpers/datasets/paddle_data.py +++ b/tests/helpers/datasets/paddle_data.py @@ -4,6 +4,8 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: import paddle from paddle.io import Dataset +else: + from fastNLP.core.utils.dummy_class import DummyClass as Dataset class PaddleNormalDataset(Dataset): diff --git a/tests/helpers/models/paddle_model.py b/tests/helpers/models/paddle_model.py index 7a897235..d2969b8e 100644 --- a/tests/helpers/models/paddle_model.py +++ b/tests/helpers/models/paddle_model.py @@ -2,8 +2,11 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: import paddle import paddle.nn as nn + from paddle.nn import Layer +else: + from fastNLP.core.utils.dummy_class import DummyClass as Layer -class PaddleNormalModel_Classification_1(paddle.nn.Layer): +class PaddleNormalModel_Classification_1(Layer): """ 基础的paddle分类模型 """ @@ -34,7 +37,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer): return {"pred": x, "target": y.reshape((-1,))} -class PaddleNormalModel_Classification_2(paddle.nn.Layer): +class PaddleNormalModel_Classification_2(Layer): """ 基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 """