|
|
@@ -2,8 +2,11 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE |
|
|
|
if _NEED_IMPORT_PADDLE: |
|
|
|
import paddle |
|
|
|
import paddle.nn as nn |
|
|
|
from paddle.nn import Layer |
|
|
|
else: |
|
|
|
from fastNLP.core.utils.dummy_class import DummyClass as Layer |
|
|
|
|
|
|
|
class PaddleNormalModel_Classification_1(paddle.nn.Layer): |
|
|
|
class PaddleNormalModel_Classification_1(Layer): |
|
|
|
""" |
|
|
|
基础的paddle分类模型 |
|
|
|
""" |
|
|
@@ -34,7 +37,7 @@ class PaddleNormalModel_Classification_1(paddle.nn.Layer): |
|
|
|
return {"pred": x, "target": y.reshape((-1,))} |
|
|
|
|
|
|
|
|
|
|
|
class PaddleNormalModel_Classification_2(paddle.nn.Layer): |
|
|
|
class PaddleNormalModel_Classification_2(Layer): |
|
|
|
""" |
|
|
|
基础的paddle分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景 |
|
|
|
""" |
|
|
|