You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paddle_model.py 2.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
  2. if _NEED_IMPORT_PADDLE:
  3. import paddle
  4. import paddle.nn as nn
  5. from paddle.nn import Layer
  6. else:
  7. from fastNLP.core.utils.dummy_class import DummyClass as Layer
  8. class PaddleNormalModel_Classification_1(Layer):
  9. """
  10. 基础的 paddle 分类模型
  11. """
  12. def __init__(self, num_labels, feature_dimension):
  13. super(PaddleNormalModel_Classification_1, self).__init__()
  14. self.num_labels = num_labels
  15. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
  16. self.ac1 = nn.ReLU()
  17. self.linear2 = nn.Linear(in_features=64, out_features=32)
  18. self.ac2 = nn.ReLU()
  19. self.output = nn.Linear(in_features=32, out_features=num_labels)
  20. self.loss_fn = nn.CrossEntropyLoss()
  21. def forward(self, x):
  22. x = self.ac1(self.linear1(x))
  23. x = self.ac2(self.linear2(x))
  24. x = self.output(x)
  25. return x
  26. def train_step(self, x, y):
  27. x = self(x)
  28. return {"loss": self.loss_fn(x, y)}
  29. def evaluate_step(self, x, y):
  30. x = self(x)
  31. return {"pred": x, "target": y.reshape((-1,))}
  32. class PaddleNormalModel_Classification_2(Layer):
  33. """
  34. 基础的 paddle 分类模型,只实现 forward 函数测试用户自己初始化了分布式的场景
  35. """
  36. def __init__(self, num_labels, feature_dimension):
  37. super(PaddleNormalModel_Classification_2, self).__init__()
  38. self.num_labels = num_labels
  39. self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64)
  40. self.ac1 = nn.ReLU()
  41. self.linear2 = nn.Linear(in_features=64, out_features=32)
  42. self.ac2 = nn.ReLU()
  43. self.output = nn.Linear(in_features=32, out_features=num_labels)
  44. self.loss_fn = nn.CrossEntropyLoss()
  45. def forward(self, x, y):
  46. x = self.ac1(self.linear1(x))
  47. x = self.ac2(self.linear2(x))
  48. x = self.output(x)
  49. return {"loss": self.loss_fn(x, y), "pred": x, "target": y.reshape((-1,))}