Browse Source

完善测试例的import

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
2169974903
3 changed files with 9 additions and 2 deletions
  1. +2
    -0
      tests/core/dataloaders/paddle_dataloader/test_fdl.py
  2. +2
    -0
      tests/helpers/datasets/paddle_data.py
  3. +5
    -2
      tests/helpers/models/paddle_model.py

+ 2
- 0
tests/core/dataloaders/paddle_dataloader/test_fdl.py View File

@@ -7,6 +7,8 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE:
from paddle.io import Dataset, DataLoader
import paddle
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset


class RandomDataset(Dataset):


+ 2
- 0
tests/helpers/datasets/paddle_data.py View File

@@ -4,6 +4,8 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE:
import paddle
from paddle.io import Dataset
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset


class PaddleNormalDataset(Dataset):


+ 5
- 2
tests/helpers/models/paddle_model.py View File

@@ -2,8 +2,11 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE:
import paddle
import paddle.nn as nn
from paddle.nn import Layer
else:
from fastNLP.core.utils.dummy_class import DummyClass as Layer

class PaddleNormalModel_Classification_1(paddle.nn.Layer):
class PaddleNormalModel_Classification_1(Layer):
"""
基础的paddle分类模型
"""
@@ -34,7 +37,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer):
return {"pred": x, "target": y.reshape((-1,))}


class PaddleNormalModel_Classification_2(paddle.nn.Layer):
class PaddleNormalModel_Classification_2(Layer):
"""
基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景
"""


Loading…
Cancel
Save